1
use std::fs::File;
2
use std::path::PathBuf;
3
use std::process::Command;
4
use std::str::FromStr;
5

            
6
use anyhow::{Context, Result};
7
use bonsol_prover::input_resolver::{ProgramInput, ResolvedInput};
8
use bonsol_sdk::instructions::CallbackConfig;
9
use bonsol_sdk::{InputT, InputType, ProgramInputType};
10
use clap::Args;
11
use rand::distributions::Alphanumeric;
12
use rand::Rng;
13
use serde::{Deserialize, Serialize};
14
use solana_rpc_client::nonblocking::rpc_client;
15
use solana_sdk::instruction::AccountMeta;
16
use solana_sdk::pubkey::Pubkey;
17

            
18
use crate::error::{BonsolCliError, ParseConfigError};
19

            
20
pub(crate) const MANIFEST_JSON: &str = "manifest.json";
21
pub(crate) const CARGO_COMMAND: &str = "cargo";
22
pub(crate) const CARGO_TOML: &str = "Cargo.toml";
23
pub(crate) const TARGET_DIR: &str = "target";
24
pub(crate) const CARGO_RISCZERO_VERSION: &str = "1.2.1";
25

            
26
pub fn cargo_has_plugin(plugin_name: &str) -> bool {
27
    Command::new("cargo")
28
        .args(["--list"])
29
        .output()
30
        .map(|output| {
31
            String::from_utf8_lossy(&output.stdout)
32
                .lines()
33
                .any(|line| line.trim().starts_with(plugin_name))
34
        })
35
        .unwrap_or(false)
36
}
37

            
38
pub fn has_executable(executable: &str) -> bool {
39
    Command::new("which")
40
        .arg(executable)
41
        .output()
42
        .map(|output| output.status.success())
43
        .unwrap_or(false)
44
}
45

            
46
#[derive(Debug, Deserialize, Serialize)]
47
#[serde(rename_all = "camelCase")]
48
pub struct ZkProgramManifest {
49
    pub name: String,
50
    pub binary_path: String,
51
    pub image_id: String,
52
    pub input_order: Vec<String>,
53
    pub signature: String,
54
    pub size: u64,
55
}
56

            
57
#[derive(Debug, Deserialize, Serialize, Clone, Args)]
58
#[serde(rename_all = "camelCase")]
59
pub struct CliInput {
60
    pub input_type: String,
61
    pub data: String, // hex encoded if binary with hex: prefix
62
}
63

            
64
#[derive(Debug, Clone)]
65
pub struct CliInputType(InputType);
66
impl ToString for CliInputType {
67
    fn to_string(&self) -> String {
68
        match self.0 {
69
            InputType::PublicData => "PublicData".to_string(),
70
            InputType::PublicAccountData => "PublicAccountData".to_string(),
71
            InputType::PublicUrl => "PublicUrl".to_string(),
72
            InputType::Private => "Private".to_string(),
73
            InputType::PublicProof => "PublicProof".to_string(),
74
            InputType::PrivateLocal => "PrivateUrl".to_string(),
75
            _ => "InvalidInputType".to_string(),
76
        }
77
    }
78
}
79

            
80
impl FromStr for CliInputType {
81
    type Err = anyhow::Error;
82

            
83
5
    fn from_str(s: &str) -> Result<Self, Self::Err> {
84
5
        match s {
85
5
            "PublicData" => Ok(CliInputType(InputType::PublicData)),
86
            "PublicAccountData" => Ok(CliInputType(InputType::PublicAccountData)),
87
            "PublicUrl" => Ok(CliInputType(InputType::PublicUrl)),
88
            "Private" => Ok(CliInputType(InputType::Private)),
89
            "PublicProof" => Ok(CliInputType(InputType::PublicProof)),
90
            "PrivateUrl" => Ok(CliInputType(InputType::PrivateLocal)),
91
            _ => Err(anyhow::anyhow!("Invalid input type")),
92
        }
93
5
    }
94
}
95

            
96
#[derive(Debug, Clone, Serialize, Deserialize)]
97
#[serde(rename_all = "camelCase")]
98
pub struct ExecutionRequestFile {
99
    pub image_id: Option<String>,
100
    pub execution_config: CliExecutionConfig,
101
    pub execution_id: Option<String>,
102
    pub tip: Option<u64>,
103
    pub expiry: Option<u64>,
104
    pub inputs: Option<Vec<CliInput>>,
105
    pub callback_config: Option<CliCallbackConfig>,
106
}
107

            
108
#[derive(Debug, Clone, Serialize, Deserialize)]
109
#[serde(rename_all = "camelCase")]
110
pub struct CliExecutionConfig {
111
    pub verify_input_hash: Option<bool>,
112
    pub input_hash: Option<String>,
113
    pub forward_output: Option<bool>,
114
}
115

            
116
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
117
#[serde(rename_all = "camelCase")]
118
pub struct CliCallbackConfig {
119
    #[serde(with = "bonsol_sdk::instructions::serde_helpers::optpubkey")]
120
    pub program_id: Option<Pubkey>,
121
    pub instruction_prefix: Option<Vec<u8>>,
122
    pub extra_accounts: Option<Vec<CliAccountMeta>>,
123
}
124

            
125
impl From<CliCallbackConfig> for CallbackConfig {
126
    fn from(val: CliCallbackConfig) -> Self {
127
        CallbackConfig {
128
            program_id: val.program_id.unwrap_or_default(),
129
            instruction_prefix: val.instruction_prefix.unwrap_or_default(),
130
            extra_accounts: val
131
                .extra_accounts
132
                .map(|v| v.into_iter().map(|a| a.into()).collect())
133
                .unwrap_or_default(),
134
        }
135
    }
136
}
137

            
138
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
139
#[serde(rename_all = "camelCase")]
140
pub struct CliAccountMeta {
141
    #[serde(default, with = "bonsol_sdk::instructions::serde_helpers::pubkey")]
142
    pub pubkey: Pubkey,
143
    pub is_signer: bool,
144
    pub is_writable: bool,
145
}
146

            
147
impl From<CliAccountMeta> for AccountMeta {
148
    fn from(val: CliAccountMeta) -> Self {
149
        AccountMeta {
150
            pubkey: val.pubkey,
151
            is_signer: val.is_signer,
152
            is_writable: val.is_writable,
153
        }
154
    }
155
}
156

            
157
#[derive(Debug, Clone, Serialize, Deserialize)]
158
#[serde(rename_all = "camelCase")]
159
pub struct InputFile {
160
    pub inputs: Vec<CliInput>,
161
}
162

            
163
/// Attempt to load the RPC URL and keypair file from a solana `config.yaml`.
164
pub(crate) fn try_load_from_config(config: Option<String>) -> anyhow::Result<(String, String)> {
165
    let whoami = String::from_utf8_lossy(&std::process::Command::new("whoami").output()?.stdout)
166
        .trim_end()
167
        .to_string();
168
    let default_config_path = solana_cli_config::CONFIG_FILE.as_ref();
169

            
170
    let config_file = config.as_ref().map_or_else(
171
        || -> anyhow::Result<&String> {
172
            let inner_err = ParseConfigError::DefaultConfigNotFound {
173
                whoami: whoami.clone(),
174
            };
175
            let context = inner_err.context(None);
176

            
177
            // If no config is given, try to find it at the default location.
178
            default_config_path
179
                .and_then(|s| PathBuf::from_str(s).is_ok_and(|p| p.exists()).then_some(s))
180
                .ok_or(BonsolCliError::ParseConfigError(inner_err))
181
                .context(context)
182
        },
183
        |config| -> anyhow::Result<&String> {
184
            // Here we throw an error if the user provided a path to a config that does not exist.
185
            // Instead of using the default location, it's better to show the user the path they
186
            // expected to use was not valid.
187
            if !PathBuf::from_str(config)?.exists() {
188
                let inner_err = ParseConfigError::ConfigNotFound {
189
                    path: config.into(),
190
                };
191
                let context = inner_err.context(None);
192
                let err: anyhow::Error = BonsolCliError::ParseConfigError(inner_err).into();
193
                return Err(err.context(context));
194
            }
195
            Ok(config)
196
        },
197
    )?;
198
    let config = {
199
        let mut inner_err = ParseConfigError::Uninitialized;
200

            
201
        let mut maybe_config = solana_cli_config::Config::load(config_file).map_err(|err| {
202
            let err = ParseConfigError::FailedToLoad {
203
                path: config.unwrap_or(default_config_path.cloned().unwrap()),
204
                err: format!("{err:?}"),
205
            };
206
            inner_err = err.clone();
207
            BonsolCliError::ParseConfigError(err).into()
208
        });
209
        if maybe_config.is_err() {
210
            maybe_config = maybe_config.context(inner_err.context(Some(whoami)));
211
        }
212
        maybe_config
213
    }?;
214
    Ok((config.json_rpc_url, config.keypair_path))
215
}
216

            
217
pub(crate) fn load_solana_config(
218
    config: Option<String>,
219
    rpc_url: Option<String>,
220
    keypair: Option<String>,
221
) -> anyhow::Result<(String, solana_sdk::signer::keypair::Keypair)> {
222
    let (rpc_url, keypair_file) = match rpc_url.zip(keypair) {
223
        Some(config) => config,
224
        None => try_load_from_config(config)?,
225
    };
226
    Ok((
227
        rpc_url,
228
        solana_sdk::signature::read_keypair_file(std::path::Path::new(&keypair_file)).map_err(
229
            |err| BonsolCliError::FailedToReadKeypair {
230
                file: keypair_file,
231
                err: format!("{err:?}"),
232
            },
233
        )?,
234
    ))
235
}
236

            
237
pub async fn sol_check(rpc_client: String, pubkey: Pubkey) -> bool {
238
    let rpc_client = rpc_client::RpcClient::new(rpc_client);
239
    if let Ok(account) = rpc_client.get_account(&pubkey).await {
240
        return account.lamports > 0;
241
    }
242
    false
243
}
244

            
245
pub fn execute_get_inputs(
246
    inputs_file: Option<String>,
247
    stdin: Option<String>,
248
) -> Result<Vec<CliInput>> {
249
    if let Some(std) = stdin {
250
        let parsed = serde_json::from_str::<InputFile>(&std)
251
            .map_err(|e| anyhow::anyhow!("Error parsing stdin: {:?}", e))?;
252
        return Ok(parsed.inputs);
253
    }
254

            
255
    if let Some(istr) = inputs_file {
256
        let ifile = File::open(istr)?;
257
        let parsed: InputFile = serde_json::from_reader(&ifile)
258
            .map_err(|e| anyhow::anyhow!("Error parsing inputs file: {:?}", e))?;
259
        return Ok(parsed.inputs);
260
    }
261

            
262
    Err(anyhow::anyhow!("No inputs provided"))
263
}
264

            
265
pub fn proof_get_inputs(
266
    inputs_file: Option<String>,
267
    stdin: Option<String>,
268
) -> Result<Vec<ProgramInput>> {
269
    if let Some(std) = stdin {
270
        return proof_parse_stdin(&std);
271
    }
272
    if let Some(istr) = inputs_file {
273
        return proof_parse_input_file(&istr);
274
    }
275
    Err(anyhow::anyhow!("No inputs provided"))
276
}
277

            
278
1
pub fn execute_transform_cli_inputs(inputs: Vec<CliInput>) -> Result<Vec<InputT>> {
279
1
    let mut res = vec![];
280
5
    for input in inputs.into_iter() {
281
5
        let input_type = CliInputType::from_str(&input.input_type)?.0;
282
5
        match input_type {
283
            InputType::PublicData => {
284
5
                let has_hex_prefix = input.data.starts_with("0x");
285
5
                if has_hex_prefix {
286
1
                    let (is_valid, data) = is_valid_hex(&input.data[2..]);
287
1
                    if is_valid {
288
1
                        res.push(InputT::public(data));
289
1
                    }
290
1
                    continue;
291
4
                }
292
4
                if let Some(n) = is_valid_number(&input.data) {
293
3
                    let data = n.into_bytes();
294
3
                    res.push(InputT::public(data));
295
3
                    continue;
296
1
                }
297
1
                res.push(InputT::public(input.data.into_bytes()));
298
            }
299
            _ => res.push(InputT::new(input_type, Some(input.data.into_bytes()))),
300
        }
301
    }
302
1
    Ok(res)
303
1
}
304

            
305
2
fn is_valid_hex(s: &str) -> (bool, Vec<u8>) {
306
2
    if s.len() % 4 != 0 {
307
        return (false, vec![]);
308
2
    }
309
72
    let is_hex_char = |c: char| c.is_ascii_hexdigit();
310
2
    if !s.chars().all(is_hex_char) {
311
        return (false, vec![]);
312
2
    }
313
2
    let out = hex::decode(s);
314
2
    (out.is_ok(), out.unwrap_or_default())
315
2
}
316

            
317
#[derive(Debug, PartialEq)]
318
pub enum NumberType {
319
    Float(f64),
320
    Unsigned(u64),
321
    Integer(i64),
322
    // TODO: add BigInt
323
}
324

            
325
impl NumberType {
326
3
    fn into_bytes(&self) -> Vec<u8> {
327
3
        match self {
328
1
            NumberType::Float(f) => f.to_le_bytes().to_vec(),
329
1
            NumberType::Unsigned(u) => u.to_le_bytes().to_vec(),
330
1
            NumberType::Integer(i) => i.to_le_bytes().to_vec(),
331
        }
332
3
    }
333
}
334

            
335
9
fn is_valid_number(s: &str) -> Option<NumberType> {
336
9
    if let Ok(num) = s.parse::<u64>() {
337
2
        return Some(NumberType::Unsigned(num));
338
7
    }
339
7
    if let Ok(num) = s.parse::<i64>() {
340
2
        return Some(NumberType::Integer(num));
341
5
    }
342
5
    if let Ok(num) = s.parse::<f64>() {
343
2
        return Some(NumberType::Float(num));
344
3
    }
345
3
    None
346
9
}
347

            
348
6
fn proof_parse_entry(index: u8, s: &str) -> Result<ProgramInput> {
349
6
    if let Ok(num) = s.parse::<i64>() {
350
2
        return Ok(ProgramInput::Resolved(ResolvedInput {
351
2
            index,
352
2
            data: num.to_le_bytes().to_vec(),
353
2
            input_type: ProgramInputType::Private,
354
2
        }));
355
4
    }
356
4
    if let Ok(num) = s.parse::<f64>() {
357
1
        return Ok(ProgramInput::Resolved(ResolvedInput {
358
1
            index,
359
1
            data: num.to_le_bytes().to_vec(),
360
1
            input_type: ProgramInputType::Private,
361
1
        }));
362
3
    }
363
3
    if let Ok(num) = s.parse::<u64>() {
364
        return Ok(ProgramInput::Resolved(ResolvedInput {
365
            index,
366
            data: num.to_le_bytes().to_vec(),
367
            input_type: ProgramInputType::Private,
368
        }));
369
3
    }
370
3
    let has_hex_prefix = s.starts_with("0x");
371
3
    if has_hex_prefix {
372
1
        let (is_valid, data) = is_valid_hex(&s[2..]);
373
1
        if is_valid {
374
1
            return Ok(ProgramInput::Resolved(ResolvedInput {
375
1
                index,
376
1
                data,
377
1
                input_type: ProgramInputType::Private,
378
1
            }));
379
        } else {
380
            return Err(anyhow::anyhow!("Invalid hex data"));
381
        }
382
2
    }
383
2
    return Ok(ProgramInput::Resolved(ResolvedInput {
384
2
        index,
385
2
        data: s.as_bytes().to_vec(),
386
2
        input_type: ProgramInputType::Private,
387
2
    }));
388
6
}
389

            
390
fn proof_parse_input_file(input_file: &str) -> Result<Vec<ProgramInput>> {
391
    if let Ok(ifile) = serde_json::from_str::<InputFile>(input_file) {
392
        let len = ifile.inputs.len();
393
        let parsed: Vec<ProgramInput> = ifile
394
            .inputs
395
            .into_iter()
396
            .enumerate()
397
            .flat_map(|(index, input)| proof_parse_entry(index as u8, &input.data).ok())
398
            .collect();
399
        if parsed.len() != len {
400
            return Err(anyhow::anyhow!("Invalid input file"));
401
        }
402
        return Ok(parsed);
403
    }
404
    Err(anyhow::anyhow!("Invalid input file"))
405
}
406

            
407
1
fn proof_parse_stdin(input: &str) -> Result<Vec<ProgramInput>> {
408
1
    let mut entries = Vec::new();
409
1
    let mut current_entry = String::new();
410
1
    let mut in_quotes = false;
411
1
    let mut in_brackets = 0;
412
93
    for c in input.chars() {
413
2
        match c {
414
2
            '"' if !in_quotes => in_quotes = true,
415
2
            '"' if in_quotes => in_quotes = false,
416
1
            '{' | '[' if !in_quotes => in_brackets += 1,
417
1
            '}' | ']' if !in_quotes => in_brackets -= 1,
418
5
            ' ' if !in_quotes && in_brackets == 0 && !current_entry.is_empty() => {
419
5
                let index = entries.len() as u8;
420
5
                entries.push(proof_parse_entry(index, &current_entry)?);
421
5
                current_entry.clear();
422
5
                continue;
423
            }
424
82
            _ => {}
425
        }
426
88
        current_entry.push(c);
427
    }
428
1
    if !current_entry.is_empty() {
429
1
        entries.push(proof_parse_entry(entries.len() as u8, &current_entry)?);
430
    }
431
1
    Ok(entries)
432
1
}
433

            
434
pub fn rand_id(chars: usize) -> String {
435
    let mut rng = rand::thread_rng();
436
    (&mut rng)
437
        .sample_iter(Alphanumeric)
438
        .take(chars)
439
        .map(char::from)
440
        .collect()
441
}
442

            
443
#[cfg(test)]
444
mod test {
445
    use super::*;
446

            
447
    #[test]
448
1
    fn test_proof_parse_stdin() {
449
1
        let inputs = r#"1234567890abcdef 0x313233343536373839313061626364656667 2.1 2000 -2000 {"attestation":"test"}"#;
450
1
        let inputs_parsed = proof_parse_stdin(inputs).unwrap();
451
1

            
452
1
        let expected_inputs = vec![
453
1
            ProgramInput::Resolved(ResolvedInput {
454
1
                index: 0,
455
1
                data: "1234567890abcdef".as_bytes().to_vec(),
456
1
                input_type: ProgramInputType::Private,
457
1
            }),
458
1
            ProgramInput::Resolved(ResolvedInput {
459
1
                index: 1,
460
1
                data: "12345678910abcdefg".as_bytes().to_vec(),
461
1
                input_type: ProgramInputType::Private,
462
1
            }),
463
1
            ProgramInput::Resolved(ResolvedInput {
464
1
                index: 2,
465
1
                data: 2.1f64.to_le_bytes().to_vec(),
466
1
                input_type: ProgramInputType::Private,
467
1
            }),
468
1
            ProgramInput::Resolved(ResolvedInput {
469
1
                index: 3,
470
1
                data: 2000u64.to_le_bytes().to_vec(),
471
1
                input_type: ProgramInputType::Private,
472
1
            }),
473
1
            ProgramInput::Resolved(ResolvedInput {
474
1
                index: 4,
475
1
                data: (-2000i64).to_le_bytes().to_vec(),
476
1
                input_type: ProgramInputType::Private,
477
1
            }),
478
1
            ProgramInput::Resolved(ResolvedInput {
479
1
                index: 5,
480
1
                data: "{\"attestation\":\"test\"}".as_bytes().to_vec(),
481
1
                input_type: ProgramInputType::Private,
482
1
            }),
483
1
        ];
484
1
        assert_eq!(inputs_parsed, expected_inputs);
485
1
    }
486

            
487
    #[test]
488
1
    fn test_is_valid_number() {
489
1
        let num = is_valid_number("1234567890abcdef");
490
1
        assert!(num.is_none());
491
1
        let num = is_valid_number("1234567890abcdefg");
492
1
        assert!(num.is_none());
493
1
        let num = is_valid_number("2.1");
494
1
        assert!(num.is_some());
495
1
        assert_eq!(num.unwrap(), NumberType::Float(2.1));
496
1
        let num = is_valid_number("2000");
497
1
        assert!(num.is_some());
498
1
        assert_eq!(num.unwrap(), NumberType::Unsigned(2000));
499
1
        let num = is_valid_number("-2000");
500
1
        assert!(num.is_some());
501
1
        assert_eq!(num.unwrap(), NumberType::Integer(-2000));
502
1
    }
503

            
504
    #[test]
505
1
    fn test_execute_transform_cli_inputs() {
506
1
        let input = CliInput {
507
1
            input_type: "PublicData".to_string(),
508
1
            data: "1234567890abcdef".to_string(),
509
1
        };
510
1
        let hex_input = CliInput {
511
1
            input_type: "PublicData".to_string(),
512
1
            data: "0x313233343536373839313061626364656667".to_string(),
513
1
        };
514
1
        let hex_input2 = CliInput {
515
1
            input_type: "PublicData".to_string(),
516
1
            data: "2.1".to_string(),
517
1
        };
518
1
        let hex_input3 = CliInput {
519
1
            input_type: "PublicData".to_string(),
520
1
            data: "2000".to_string(),
521
1
        };
522
1
        let hex_input4 = CliInput {
523
1
            input_type: "PublicData".to_string(),
524
1
            data: "-2000".to_string(),
525
1
        };
526
1
        let inputs = vec![input, hex_input, hex_input2, hex_input3, hex_input4];
527
1
        let parsed_inputs = execute_transform_cli_inputs(inputs).unwrap();
528
1
        assert_eq!(
529
1
            parsed_inputs,
530
1
            vec![
531
1
                InputT::public("1234567890abcdef".as_bytes().to_vec()),
532
1
                InputT::public("12345678910abcdefg".as_bytes().to_vec()),
533
1
                InputT::public(2.1f64.to_le_bytes().to_vec()),
534
1
                InputT::public(2000u64.to_le_bytes().to_vec()),
535
1
                InputT::public((-2000i64).to_le_bytes().to_vec()),
536
1
            ]
537
1
        );
538
1
    }
539
}