1
use std::time::Duration;
2

            
3
use anyhow::Result;
4

            
5
use bytes::Bytes;
6
use futures_util::TryFutureExt;
7
use num_traits::FromPrimitive;
8

            
9
use solana_rpc_client::nonblocking::rpc_client::RpcClient;
10
use solana_rpc_client_api::config::RpcSendTransactionConfig;
11
use solana_sdk::account::Account;
12
use solana_sdk::commitment_config::CommitmentConfig;
13
use solana_sdk::compute_budget::ComputeBudgetInstruction;
14
use solana_sdk::instruction::Instruction;
15
use solana_sdk::message::{v0, VersionedMessage};
16
use solana_sdk::pubkey::Pubkey;
17
use solana_sdk::signer::Signer;
18
use solana_sdk::transaction::VersionedTransaction;
19

            
20
use tokio::time::Instant;
21

            
22
use bonsol_interface::bonsol_schema::{root_as_deploy_v1, root_as_execution_request_v1};
23
pub use bonsol_interface::bonsol_schema::{
24
    ClaimV1T, DeployV1T, ExecutionRequestV1T, ExitCode, InputT, InputType, ProgramInputType,
25
    StatusTypes,
26
};
27
use bonsol_interface::claim_state::ClaimStateHolder;
28
use bonsol_interface::prover_version::ProverVersion;
29
pub use bonsol_interface::util::*;
30
pub use bonsol_interface::{instructions, ID};
31
use instructions::{CallbackConfig, ExecutionConfig, InputRef};
32

            
33
pub use flatbuffers;
34

            
35
pub struct BonsolClient {
36
    rpc_client: RpcClient,
37
}
38

            
39
pub enum ExecutionAccountStatus {
40
    Completed(ExitCode),
41
    Pending(ExecutionRequestV1T),
42
}
43

            
44
impl BonsolClient {
45
    pub fn new(rpc_url: String) -> Self {
46
        BonsolClient {
47
            rpc_client: RpcClient::new(rpc_url),
48
        }
49
    }
50

            
51
    pub async fn get_current_slot(&self) -> Result<u64> {
52
        self.rpc_client
53
            .get_slot()
54
            .map_err(|_| anyhow::anyhow!("Failed to get slot"))
55
            .await
56
    }
57

            
58
    pub fn with_rpc_client(rpc_client: RpcClient) -> Self {
59
        BonsolClient { rpc_client }
60
    }
61

            
62
    pub async fn get_deployment_v1(&self, image_id: &str) -> Result<DeployV1T> {
63
        let (deployment_account, _) = deployment_address(image_id);
64
        let account = self
65
            .rpc_client
66
            .get_account_with_commitment(&deployment_account, CommitmentConfig::confirmed())
67
            .await
68
            .map_err(|e| anyhow::anyhow!("Failed to get account: {:?}", e))?
69
            .value
70
            .ok_or(anyhow::anyhow!("Invalid deployment account"))?;
71
        let deployment = root_as_deploy_v1(&account.data)
72
            .map_err(|_| anyhow::anyhow!("Invalid deployment account"))?;
73
        Ok(deployment.unpack())
74
    }
75

            
76
    pub async fn get_execution_request_v1(
77
        &self,
78
        requester_pubkey: &Pubkey,
79
        execution_id: &str,
80
    ) -> Result<ExecutionAccountStatus> {
81
        let (er, _) = execution_address(requester_pubkey, execution_id.as_bytes());
82
        let account = self
83
            .rpc_client
84
            .get_account_with_commitment(&er, CommitmentConfig::confirmed())
85
            .await
86
            .map_err(|e| anyhow::anyhow!("Failed to get account: {:?}", e))?
87
            .value
88
            .ok_or(anyhow::anyhow!("Invalid execution request account"))?;
89
        if account.data.len() == 1 {
90
            let ec =
91
                ExitCode::from_u8(account.data[0]).ok_or(anyhow::anyhow!("Invalid exit code"))?;
92
            return Ok(ExecutionAccountStatus::Completed(ec));
93
        }
94
        let er = root_as_execution_request_v1(&account.data)
95
            .map_err(|_| anyhow::anyhow!("Invalid execution request account"))?;
96
        Ok(ExecutionAccountStatus::Pending(er.unpack()))
97
    }
98

            
99
    pub async fn get_claim_state_v1<'a>(
100
        &self,
101
        requester_pubkey: &Pubkey,
102
        execution_id: &str,
103
    ) -> Result<ClaimStateHolder> {
104
        let (exad, _) = execution_address(requester_pubkey, execution_id.as_bytes());
105
        let (eca, _) = execution_claim_address(exad.as_ref());
106
        let account = self
107
            .rpc_client
108
            .get_account_with_commitment(&eca, CommitmentConfig::confirmed())
109
            .await
110
            .map_err(|e| anyhow::anyhow!("Failed to get account: {:?}", e))?
111
            .value
112
            .ok_or(anyhow::anyhow!("Invalid claim account"))?;
113
        Ok(ClaimStateHolder::new(account.data))
114
    }
115

            
116
    pub async fn download_program(&self, image_id: &str) -> Result<Bytes> {
117
        let deployment = self.get_deployment_v1(image_id).await?;
118
        let url = deployment
119
            .url
120
            .ok_or(anyhow::anyhow!("Invalid deployment"))?;
121
        let resp = reqwest::get(url)
122
            .await
123
            .map_err(|e| anyhow::anyhow!("Failed to download program: {:?}", e))?;
124
        resp.bytes()
125
            .await
126
            .map_err(|e| anyhow::anyhow!("Failed to download program: {:?}", e))
127
    }
128

            
129
    pub async fn get_deployment(&self, image_id: &str) -> Result<Option<Account>> {
130
        let (deployment_account, _) = deployment_address(image_id);
131
        let account = self
132
            .rpc_client
133
            .get_account_with_commitment(&deployment_account, CommitmentConfig::confirmed())
134
            .await
135
            .map_err(|e| anyhow::anyhow!("Failed to get account: {:?}", e))?;
136
        Ok(account.value)
137
    }
138

            
139
    pub async fn get_fees(&self, signer: &Pubkey) -> Result<u64> {
140
        let fee_accounts = vec![signer.to_owned(), bonsol_interface::ID];
141
        let compute_fees = self
142
            .rpc_client
143
            .get_recent_prioritization_fees(&fee_accounts)
144
            .await?;
145
        Ok(if compute_fees.is_empty() {
146
            5
147
        } else {
148
            compute_fees[0].prioritization_fee
149
        })
150
    }
151

            
152
    pub async fn deploy_v1(
153
        &self,
154
        signer: &Pubkey,
155
        image_id: &str,
156
        image_size: u64,
157
        program_name: &str,
158
        url: &str,
159
        inputs: Vec<ProgramInputType>,
160
    ) -> Result<Vec<Instruction>> {
161
        let compute_price_val = self.get_fees(signer).await?;
162
        let instruction =
163
            instructions::deploy_v1(signer, image_id, image_size, program_name, url, inputs)?;
164
        let compute = ComputeBudgetInstruction::set_compute_unit_limit(20_000);
165
        let compute_price = ComputeBudgetInstruction::set_compute_unit_price(compute_price_val);
166
        Ok(vec![compute, compute_price, instruction])
167
    }
168

            
169
    pub async fn execute_v1<'a>(
170
        &self,
171
        signer: &Pubkey,
172
        image_id: &str,
173
        execution_id: &str,
174
        inputs: Vec<InputRef<'a>>,
175
        tip: u64,
176
        expiration: u64,
177
        config: ExecutionConfig<'a>,
178
        callback: Option<CallbackConfig>,
179
        prover_version: Option<ProverVersion>,
180
    ) -> Result<Vec<Instruction>> {
181
        let compute_price_val = self.get_fees(signer).await?;
182

            
183
        let fbs_version_or_none = match prover_version {
184
            Some(version) => {
185
                let fbs_version = version.try_into().expect("Unknown prover version");
186
                Some(fbs_version)
187
            }
188
            None => None,
189
        };
190

            
191
        let instruction = instructions::execute_v1(
192
            signer,
193
            signer,
194
            image_id,
195
            execution_id,
196
            inputs,
197
            tip,
198
            expiration,
199
            config,
200
            callback,
201
            fbs_version_or_none,
202
        )?;
203
        let compute = ComputeBudgetInstruction::set_compute_unit_limit(20_000);
204
        let compute_price = ComputeBudgetInstruction::set_compute_unit_price(compute_price_val);
205
        Ok(vec![compute, compute_price, instruction])
206
    }
207

            
208
    pub async fn send_txn_standard(
209
        &self,
210
        signer: impl Signer,
211
        instructions: Vec<Instruction>,
212
    ) -> Result<()> {
213
        self.send_txn(signer, instructions, false, 1, 5).await
214
    }
215

            
216
    pub async fn send_txn(
217
        &self,
218
        signer: impl Signer,
219
        instructions: Vec<Instruction>,
220
        skip_preflight: bool,
221
        retry_timeout: u64,
222
        retry_count: usize,
223
    ) -> Result<()> {
224
        let mut rt = retry_count;
225
        loop {
226
            let blockhash = self.rpc_client.get_latest_blockhash().await?;
227
            let message =
228
                v0::Message::try_compile(&signer.pubkey(), &instructions, &[], blockhash)?;
229
            let tx = VersionedTransaction::try_new(VersionedMessage::V0(message), &[&signer])?;
230
            let sig = self
231
                .rpc_client
232
                .send_transaction_with_config(
233
                    &tx,
234
                    RpcSendTransactionConfig {
235
                        skip_preflight,
236
                        max_retries: Some(0),
237
                        preflight_commitment: Some(self.rpc_client.commitment().commitment),
238
                        ..Default::default()
239
                    },
240
                )
241
                .await?;
242

            
243
            let now = Instant::now();
244
            let confirm_transaction_initial_timeout = Duration::from_secs(retry_timeout);
245
            let (_, status) = loop {
246
                let status = self.rpc_client.get_signature_status(&sig).await?;
247
                if status.is_none() {
248
                    let blockhash_not_found = !self
249
                        .rpc_client
250
                        .is_blockhash_valid(&blockhash, self.rpc_client.commitment())
251
                        .await?;
252
                    if blockhash_not_found && now.elapsed() >= confirm_transaction_initial_timeout {
253
                        break (sig, status);
254
                    }
255
                } else {
256
                    break (sig, status);
257
                }
258
                tokio::time::sleep(Duration::from_millis(500)).await;
259
            };
260

            
261
            match status {
262
                Some(Ok(())) => {
263
                    return Ok(());
264
                }
265
                Some(Err(e)) => {
266
                    return Err(anyhow::anyhow!(
267
                        "Transaction Failure Cannot Recover {:?}",
268
                        e
269
                    ));
270
                }
271
                None => {
272
                    rt -= 1;
273
                    if rt == 0 {
274
                        return Err(anyhow::anyhow!("Timeout: Failed to confirm transaction"));
275
                    }
276
                }
277
            }
278
        }
279
    }
280

            
281
    pub async fn wait_for_claim(
282
        &self,
283
        requester: Pubkey,
284
        execution_id: &str,
285
        timeout: Option<u64>,
286
    ) -> Result<ClaimStateHolder> {
287
        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
288
        let now = Instant::now();
289
        let mut end = false;
290
        loop {
291
            interval.tick().await;
292
            if now.elapsed().as_secs() > timeout.unwrap_or(0) {
293
                end = true;
294
            }
295
            if let Ok(claim_state) = self.get_claim_state_v1(&requester, execution_id).await {
296
                return Ok(claim_state);
297
            }
298
            if end {
299
                return Err(anyhow::anyhow!("Timeout"));
300
            }
301
        }
302
    }
303

            
304
    pub async fn wait_for_proof(
305
        &self,
306
        requester: Pubkey,
307
        execution_id: &str,
308
        timeout: Option<u64>,
309
    ) -> Result<ExitCode> {
310
        let current_block = self.get_current_slot().await?;
311
        let expiry = current_block + 100;
312
        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
313
        let now = Instant::now();
314
        loop {
315
            interval.tick().await;
316
            if now.elapsed().as_secs() > timeout.unwrap_or(0) {
317
                return Err(anyhow::anyhow!("Timeout"));
318
            }
319
            let status = self
320
                .get_execution_request_v1(&requester, execution_id)
321
                .await;
322
            match status {
323
                Ok(ExecutionAccountStatus::Pending(req)) => {
324
                    if req.max_block_height < expiry {
325
                        return Err(anyhow::anyhow!("Expired"));
326
                    }
327
                }
328
                Ok(ExecutionAccountStatus::Completed(s)) => {
329
                    return Ok(s);
330
                }
331
                Err(e) => {
332
                    return Err(e);
333
                }
334
            }
335
        }
336
    }
337
}