1
use std::str::from_utf8;
2
use std::sync::Arc;
3
use std::time::{Duration, SystemTime, UNIX_EPOCH};
4

            
5
use anyhow::Result;
6
use arrayref::array_ref;
7
use async_trait::async_trait;
8
use bonsol_schema::{InputT, InputType, ProgramInputType};
9
use reqwest::Url;
10
use serde::{Deserialize, Serialize};
11
use solana_sdk::pubkey::Pubkey;
12
use solana_sdk::signer::Signer;
13
use tokio::task::{JoinHandle, JoinSet};
14

            
15
use crate::util::get_body_max_size;
16

            
17
#[derive(Debug, Clone, PartialEq)]
18
pub enum ProgramInput {
19
    Empty,
20
    Resolved(ResolvedInput),
21
    Unresolved(UnresolvedInput),
22
}
23

            
24
#[derive(Debug, Clone, PartialEq)]
25
pub struct UnresolvedInput {
26
    pub index: u8,
27
    pub url: Url,
28
    pub input_type: ProgramInputType,
29
}
30

            
31
#[derive(Debug, Clone, PartialEq)]
32
pub struct ResolvedInput {
33
    pub index: u8,
34
    pub data: Vec<u8>,
35
    pub input_type: ProgramInputType,
36
}
37

            
38
impl ProgramInput {
39
    pub fn index(&self) -> u8 {
40
        match self {
41
            ProgramInput::Resolved(ri) => ri.index,
42
            ProgramInput::Unresolved(ui) => ui.index,
43
            _ => 0,
44
        }
45
    }
46
}
47

            
48
/// Input resolvers are responsible for downloading and resolving inputs
49
/// Private inputs must be resoloved post claim and therefore are separated from public inputs
50
/// Public inputs are resolved in parallel and are resolved as soon as possible, Private inputs are currently always remote.
51
/// The output of resolve_public_inputs is a vec of ProgramInputs and that must be passed to the private input resolver if any private inputs are present in the execution request
52
#[async_trait]
53
pub trait InputResolver: Send + Sync {
54
    /// Returns true if the input resolver supports the input type
55
    fn supports(&self, input_type: InputType) -> bool;
56
    /// Resolves public inputs by parsing them or if remote downloading them
57
    async fn resolve_public_inputs(
58
        &self,
59
        inputs: Vec<InputT>,
60
    ) -> Result<Vec<ProgramInput>, anyhow::Error>;
61

            
62
    /// Resolves private inputs by sigining the request and attempting to download the inputs
63
    async fn resolve_private_inputs(
64
        &self,
65
        execution_id: &str,
66
        inputs: &mut Vec<ProgramInput>,
67
        signer: Arc<&(dyn Signer + Send + Sync)>,
68
    ) -> Result<(), anyhow::Error>;
69
}
70

            
71
// naive resolver that downloads inputs just in time
72
pub struct DefaultInputResolver {
73
    http_client: Arc<reqwest::Client>,
74
    solana_rpc_client: Arc<solana_rpc_client::nonblocking::rpc_client::RpcClient>,
75
    max_input_size_mb: u32,
76
    timeout: Duration,
77
}
78

            
79
impl DefaultInputResolver {
80
    pub fn new(
81
        http_client: Arc<reqwest::Client>,
82
        solana_rpc_client: Arc<solana_rpc_client::nonblocking::rpc_client::RpcClient>,
83
    ) -> Self {
84
        DefaultInputResolver {
85
            http_client,
86
            solana_rpc_client,
87
            max_input_size_mb: 10,
88
            timeout: Duration::from_secs(30),
89
        }
90
    }
91

            
92
    pub fn new_with_opts(
93
        http_client: Arc<reqwest::Client>,
94
        solana_rpc_client: Arc<solana_rpc_client::nonblocking::rpc_client::RpcClient>,
95
        max_input_size_mb: Option<u32>,
96
        timeout: Option<Duration>,
97
    ) -> Self {
98
        DefaultInputResolver {
99
            http_client,
100
            solana_rpc_client,
101
            max_input_size_mb: max_input_size_mb.unwrap_or(10),
102
            timeout: timeout.unwrap_or(Duration::from_secs(30)),
103
        }
104
    }
105

            
106
    fn par_resolve_input(
107
        &self,
108
        client: Arc<reqwest::Client>,
109
        index: u8,
110
        input: InputT,
111
        task_set: &mut JoinSet<Result<ResolvedInput>>,
112
    ) -> Result<ProgramInput> {
113
        match input.input_type {
114
            InputType::PublicUrl => {
115
                let url = input.data.ok_or(anyhow::anyhow!("Invalid data"))?;
116
                let url = from_utf8(&url)?;
117
                let url = Url::parse(url)?;
118
                task_set.spawn(download_public_input(
119
                    client,
120
                    index,
121
                    url.clone(),
122
                    self.max_input_size_mb as usize,
123
                    ProgramInputType::Public,
124
                    self.timeout,
125
                ));
126
                Ok(ProgramInput::Unresolved(UnresolvedInput {
127
                    index,
128
                    url,
129
                    input_type: ProgramInputType::Public,
130
                }))
131
            }
132
            InputType::Private => {
133
                let url = input.data.ok_or(anyhow::anyhow!("Invalid data"))?;
134
                let url = from_utf8(&url)?;
135
                let url = Url::parse(url)?;
136
                Ok(ProgramInput::Unresolved(UnresolvedInput {
137
                    index,
138
                    url,
139
                    input_type: ProgramInputType::Private,
140
                }))
141
            }
142
            InputType::PublicData => {
143
                let data = input.data.ok_or(anyhow::anyhow!("Invalid data"))?;
144
                let data = data.to_vec();
145
                Ok(ProgramInput::Resolved(ResolvedInput {
146
                    index,
147
                    data,
148
                    input_type: ProgramInputType::Public,
149
                }))
150
            }
151
            InputType::PublicProof => {
152
                let url = input.data.ok_or(anyhow::anyhow!("Invalid data"))?;
153
                let url = from_utf8(&url)?;
154
                let url = Url::parse(url)?;
155
                task_set.spawn(download_public_input(
156
                    client,
157
                    index,
158
                    url.clone(),
159
                    self.max_input_size_mb as usize,
160
                    ProgramInputType::PublicProof,
161
                    self.timeout,
162
                ));
163
                Ok(ProgramInput::Unresolved(UnresolvedInput {
164
                    index,
165
                    url,
166
                    input_type: ProgramInputType::PublicProof,
167
                }))
168
            }
169
            InputType::PublicAccountData => {
170
                let pubkey = input.data.ok_or(anyhow::anyhow!("Invalid data"))?;
171
                if pubkey.len() != 32 {
172
                    return Err(anyhow::anyhow!("Invalid pubkey"));
173
                }
174
                let pubkey = Pubkey::new_from_array(*array_ref!(pubkey, 0, 32));
175
                let rpc_client_clone = self.solana_rpc_client.clone();
176
                task_set.spawn(download_public_account(
177
                    rpc_client_clone,
178
                    index,
179
                    pubkey,
180
                    self.max_input_size_mb as usize,
181
                ));
182
                Ok(ProgramInput::Unresolved(UnresolvedInput {
183
                    index,
184
                    url: format!("solana://{}", pubkey).parse()?,
185
                    input_type: ProgramInputType::Public,
186
                }))
187
            }
188
            _ => {
189
                // not implemented yet / or unknown
190
                Err(anyhow::anyhow!("Invalid input type"))
191
            }
192
        }
193
    }
194
}
195

            
196
#[async_trait]
197
impl InputResolver for DefaultInputResolver {
198
    fn supports(&self, input_type: InputType) -> bool {
199
        match input_type {
200
            InputType::PublicUrl => true,
201
            InputType::PublicData => true,
202
            InputType::PublicAccountData => true,
203
            InputType::Private => true,
204
            InputType::PublicProof => true,
205
            _ => false,
206
        }
207
    }
208

            
209
    async fn resolve_public_inputs(
210
        &self,
211
        inputs: Vec<InputT>,
212
    ) -> Result<Vec<ProgramInput>, anyhow::Error> {
213
        let mut url_set = JoinSet::new();
214
        let mut res = vec![ProgramInput::Empty; inputs.len()];
215
        for (index, input) in inputs.into_iter().enumerate() {
216
            let client = self.http_client.clone();
217
            res[index] = self.par_resolve_input(client, index as u8, input, &mut url_set)?;
218
        }
219
        while let Some(url) = url_set.join_next().await {
220
            match url {
221
                Ok(Ok(ri)) => {
222
                    let index = ri.index as usize;
223
                    res[index] = ProgramInput::Resolved(ri);
224
                }
225
                e => {
226
                    return Err(anyhow::anyhow!("Error downloading input: {:?}", e));
227
                }
228
            }
229
        }
230
        Ok(res)
231
    }
232

            
233
    async fn resolve_private_inputs(
234
        &self,
235
        execution_id: &str,
236
        inputs: &mut Vec<ProgramInput>,
237
        signer: Arc<&(dyn Signer + Send + Sync)>,
238
    ) -> Result<(), anyhow::Error> {
239
        let mut url_set = JoinSet::new();
240
        for (index, input) in inputs.iter().enumerate() {
241
            let client = self.http_client.clone();
242
            if let ProgramInput::Unresolved(ui) = input {
243
                let pir = PrivateInputRequest {
244
                    identity: signer.pubkey(),
245
                    claim_id: execution_id.to_string(),
246
                    input_index: ui.index,
247
                    now_utc: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
248
                };
249
                let pir_str = serde_json::to_string(&pir)?;
250
                let claim_authorization = signer.sign_message(pir_str.as_bytes());
251
                url_set.spawn(download_private_input(
252
                    client,
253
                    index as u8,
254
                    ui.url.clone(),
255
                    self.max_input_size_mb as usize,
256
                    pir_str,
257
                    claim_authorization.to_string(), // base58 encoded string
258
                    self.timeout,
259
                ));
260
            }
261
        }
262
        while let Some(url) = url_set.join_next().await {
263
            match url {
264
                Ok(Ok(ri)) => {
265
                    let index = ri.index as usize;
266
                    inputs[index] = ProgramInput::Resolved(ri);
267
                }
268
                e => {
269
                    return Err(anyhow::anyhow!("Error downloading input: {:?}", e));
270
                }
271
            }
272
        }
273
        Ok(())
274
    }
275
}
276

            
277
pub fn resolve_public_data(index: usize, data: &[u8]) -> Result<ProgramInput> {
278
    let data = data.to_vec();
279
    Ok(ProgramInput::Resolved(ResolvedInput {
280
        index: index as u8,
281
        data,
282
        input_type: ProgramInputType::Public,
283
    }))
284
}
285

            
286
pub fn resolve_remote_public_data(
287
    client: Arc<reqwest::Client>,
288
    max_input_size_mb: u64,
289
    index: usize,
290
    data: &[u8],
291
    timeout: Duration,
292
) -> Result<JoinHandle<Result<ResolvedInput>>> {
293
    let url = from_utf8(data)?;
294
    let url = Url::parse(url)?;
295
    Ok(tokio::task::spawn(download_public_input(
296
        client,
297
        index as u8,
298
        url,
299
        max_input_size_mb as usize,
300
        ProgramInputType::Public,
301
        timeout,
302
    )))
303
}
304

            
305
#[derive(Debug, Serialize, Deserialize)]
306
pub struct PrivateInputRequest {
307
    identity: Pubkey,
308
    claim_id: String,
309
    input_index: u8,
310
    now_utc: u64,
311
}
312

            
313
2
async fn download_public_input(
314
2
    client: Arc<reqwest::Client>,
315
2
    index: u8,
316
2
    url: Url,
317
2
    max_size_mb: usize,
318
2
    input_type: ProgramInputType,
319
2
    timeout: Duration,
320
2
) -> Result<ResolvedInput> {
321
2
    let resp = client
322
2
        .get(url)
323
2
        .timeout(timeout)
324
2
        .send()
325
6
        .await?
326
2
        .error_for_status()?;
327
8
    let byte = get_body_max_size(resp.bytes_stream(), max_size_mb * 1024 * 1024).await?;
328
1
    Ok(ResolvedInput {
329
1
        index,
330
1
        data: byte.to_vec(),
331
1
        input_type,
332
1
    })
333
2
}
334

            
335
async fn download_public_account(
336
    solana_client: Arc<solana_rpc_client::nonblocking::rpc_client::RpcClient>,
337
    index: u8,
338
    pubkey: Pubkey,
339
    max_size_mb: usize,
340
) -> Result<ResolvedInput> {
341
    let resp = solana_client.get_account_data(&pubkey).await?;
342
    if resp.len() > max_size_mb * 1024 * 1024 {
343
        return Err(anyhow::anyhow!("Max size exceeded"));
344
    }
345
    Ok(ResolvedInput {
346
        index,
347
        data: resp,
348
        input_type: ProgramInputType::Public,
349
    })
350
}
351

            
352
async fn download_private_input(
353
    client: Arc<reqwest::Client>,
354
    index: u8,
355
    url: Url,
356
    max_size_mb: usize,
357
    body: String,
358
    claim_authorization: String,
359
    timeout: Duration,
360
) -> Result<ResolvedInput> {
361
    let resp = client
362
        .post(url)
363
        .body(body)
364
        .timeout(timeout)
365
        // Signature of the json payload
366
        .header("Authorization", format!("Bearer {}", claim_authorization))
367
        .header("Content-Type", "application/json")
368
        .send()
369
        .await?
370
        .error_for_status()?;
371
    let byte = get_body_max_size(resp.bytes_stream(), max_size_mb * 1024 * 1024).await?;
372
    Ok(ResolvedInput {
373
        index,
374
        data: byte.to_vec(),
375
        input_type: ProgramInputType::Private,
376
    })
377
}
378

            
379
#[cfg(test)]
380
mod test {
381
    use super::*;
382
    use mockito::Mock;
383
    use reqwest::{Client, Url};
384

            
385
    use std::sync::Arc;
386

            
387
    // Modified to return the server along with the mock and URL
388
2
    pub async fn get_server(url_path: &str, response: &[u8]) -> (Mock, Url, mockito::ServerGuard) {
389
2
        let mut server = mockito::Server::new_async().await;
390
2
        let url = Url::parse(&format!("{}{}", server.url(), url_path)).unwrap();
391

            
392
2
        let mock = server
393
2
            .mock("GET", url_path) // Changed to POST to match your function
394
2
            .with_status(200)
395
2
            .with_header("content-type", "application/octet-stream")
396
2
            .with_body(response)
397
2
            .create_async()
398
            .await;
399

            
400
2
        (mock, url, server)
401
2
    }
402

            
403
    #[tokio::test]
404
1
    async fn test_download_public_input_success() {
405
1
        // 1 MB max size
406
1
        let max_size_mb = 1;
407
1
        // 10 KB response
408
1
        let input_data = vec![1u8; 1024 * 10];
409
1

            
410
1
        let (mock, url, _server) = get_server("/download", &input_data).await;
411
1
        let client = Arc::new(Client::new());
412
1

            
413
1
        let valid_result = download_public_input(
414
1
            client.clone(),
415
1
            1u8,
416
1
            url,
417
1
            max_size_mb,
418
1
            ProgramInputType::Public,
419
1
            Duration::from_secs(30),
420
1
        )
421
4
        .await;
422
1

            
423
1
        assert!(valid_result.is_ok());
424
1
        let resolved_input = valid_result.unwrap();
425
1
        assert_eq!(resolved_input.index, 1);
426
1
        assert_eq!(resolved_input.data, input_data);
427
1
        assert!(matches!(
428
1
            resolved_input.input_type,
429
1
            ProgramInputType::Public
430
1
        ));
431
1

            
432
1
        mock.assert();
433
1
    }
434

            
435
    #[tokio::test]
436
1
    async fn test_download_public_input_oversized() {
437
1
        // 1 MB max size
438
1
        let max_size_mb = 1;
439
1
        // 2 MB response
440
1
        let input_data = vec![1u8; 1024 * 1024 * 2];
441
1

            
442
1
        let (mock, url, _server) = get_server("/download", &input_data).await;
443
1
        let client = Arc::new(Client::new());
444
1

            
445
1
        let valid_result = download_public_input(
446
1
            client.clone(),
447
1
            1u8,
448
1
            url,
449
1
            max_size_mb,
450
1
            ProgramInputType::Public,
451
1
            Duration::from_secs(30),
452
1
        )
453
10
        .await;
454
1

            
455
1
        assert!(valid_result.is_err());
456
1
        assert_eq!(valid_result.unwrap_err().to_string(), "Max size exceeded");
457
1

            
458
1
        mock.assert();
459
1
    }
460
}