1
use {
2
    figment::{
3
        providers::{Format, Toml},
4
        Figment,
5
    },
6
    serde::{Deserialize, Serialize},
7
    std::path::Path,
8
};
9

            
10
#[derive(Debug, Deserialize, Serialize, Clone)]
11
pub enum IngesterConfig {
12
    RpcBlockSubscription {
13
        wss_rpc_url: String,
14
    },
15
    GrpcSubscription {
16
        grpc_url: String,
17
        connection_timeout_secs: u32,
18
        timeout_secs: u32,
19
        token: String,
20
    },
21
    WebsocketSub, //not implemented
22
}
23

            
24
#[derive(Debug, Deserialize, Serialize, Clone)]
25
pub enum TransactionSenderConfig {
26
    Rpc { rpc_url: String },
27
    //--- below not implemented yet
28
    Tpu,
29
}
30

            
31
#[derive(Debug, Deserialize, Serialize, Clone)]
32
pub enum SignerConfig {
33
    KeypairFile { path: String }, //--- below not implemented yet maybe hsm, signer server or some weird sig agg shiz
34
}
35

            
36
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
37
pub enum MissingImageStrategy {
38
    #[default]
39
    DownloadAndClaim,
40
    DownloadAndMiss,
41
    Fail,
42
}
43

            
44
#[derive(Debug, Deserialize, Serialize, Clone)]
45
pub struct ProverNodeConfig {
46
    pub env: Option<String>,
47
    #[serde(default = "default_bonsol_program")]
48
    pub bonsol_program: String,
49
    #[serde(default = "default_risc0_image_folder")]
50
    pub risc0_image_folder: String,
51
    #[serde(default = "default_max_image_size_mb")]
52
    pub max_image_size_mb: u32,
53
    #[serde(default = "default_image_compression_ttl_hours")]
54
    pub image_compression_ttl_hours: u32,
55
    #[serde(default = "default_max_input_size_mb")]
56
    pub max_input_size_mb: u32,
57
    #[serde(default = "default_image_download_timeout_secs")]
58
    pub image_download_timeout_secs: u32,
59
    #[serde(default = "default_input_download_timeout_secs")]
60
    pub input_download_timeout_secs: u32,
61
    #[serde(default = "default_maximum_concurrent_proofs")]
62
    pub maximum_concurrent_proofs: u32,
63
    #[serde(default = "default_ingester_config")]
64
    pub ingester_config: IngesterConfig,
65
    #[serde(default = "default_transaction_sender_config")]
66
    pub transaction_sender_config: TransactionSenderConfig,
67
    #[serde(default = "default_signer_config")]
68
    pub signer_config: SignerConfig,
69
    #[serde(default = "default_stark_compression_tools_path")]
70
    pub stark_compression_tools_path: String,
71
    #[serde(default = "default_metrics_config")]
72
    pub metrics_config: MetricsConfig,
73
    #[serde(default)]
74
    pub missing_image_strategy: MissingImageStrategy,
75
}
76

            
77
#[derive(Debug, Deserialize, Serialize, Clone)]
78
pub enum MetricsConfig {
79
    Prometheus {},
80
    None,
81
}
82

            
83
// ... keeping all the default functions unchanged ...
84

            
85
const fn default_metrics_config() -> MetricsConfig {
86
    MetricsConfig::None
87
}
88

            
89
fn default_stark_compression_tools_path() -> String {
90
    std::env::current_dir()
91
        .unwrap_or(Path::new("./").into())
92
        .join("stark")
93
        .to_string_lossy()
94
        .to_string()
95
}
96

            
97
fn default_bonsol_program() -> String {
98
    "BoNsHRcyLLNdtnoDf8hiCNZpyehMC4FDMxs6NTxFi3ew".to_string()
99
}
100

            
101
fn default_risc0_image_folder() -> String {
102
    "./elf".to_string()
103
}
104

            
105
const fn default_max_image_size_mb() -> u32 {
106
    10
107
}
108

            
109
const fn default_image_compression_ttl_hours() -> u32 {
110
    5
111
}
112

            
113
const fn default_max_input_size_mb() -> u32 {
114
    1
115
}
116

            
117
const fn default_image_download_timeout_secs() -> u32 {
118
    120
119
}
120

            
121
const fn default_input_download_timeout_secs() -> u32 {
122
    30
123
}
124

            
125
const fn default_maximum_concurrent_proofs() -> u32 {
126
    100
127
}
128

            
129
fn default_ingester_config() -> IngesterConfig {
130
    IngesterConfig::RpcBlockSubscription {
131
        wss_rpc_url: "ws://localhost:8900".to_string(),
132
    }
133
}
134

            
135
fn default_transaction_sender_config() -> TransactionSenderConfig {
136
    TransactionSenderConfig::Rpc {
137
        rpc_url: "http://localhost:8899".to_string(),
138
    }
139
}
140

            
141
fn default_signer_config() -> SignerConfig {
142
    SignerConfig::KeypairFile {
143
        path: "./node-keypair.json".to_string(),
144
    }
145
}
146

            
147
impl Default for ProverNodeConfig {
148
    fn default() -> Self {
149
        ProverNodeConfig {
150
            env: Some("dev".to_string()),
151
            bonsol_program: default_bonsol_program(),
152
            risc0_image_folder: default_risc0_image_folder(),
153
            max_image_size_mb: default_max_image_size_mb(),
154
            image_compression_ttl_hours: default_image_compression_ttl_hours(),
155
            max_input_size_mb: default_max_input_size_mb(),
156
            image_download_timeout_secs: default_image_download_timeout_secs(),
157
            input_download_timeout_secs: default_input_download_timeout_secs(),
158
            maximum_concurrent_proofs: default_maximum_concurrent_proofs(),
159
            ingester_config: default_ingester_config(),
160
            transaction_sender_config: default_transaction_sender_config(),
161
            signer_config: default_signer_config(),
162
            stark_compression_tools_path: default_stark_compression_tools_path(),
163
            metrics_config: default_metrics_config(),
164
            missing_image_strategy: MissingImageStrategy::default(),
165
        }
166
    }
167
}
168

            
169
pub fn load_config(config_path: &str) -> ProverNodeConfig {
170
    let figment = Figment::new().merge(Toml::file(config_path));
171
    figment.extract().unwrap()
172
}
173

            
174
// #[cfg(test)]
175
// mod tests {
176
//     use super::*;
177

            
178
//     #[test]
179
//     fn test_config_serialization() -> anyhow::Result<()> {
180
//         let config_content = r#"
181
//  risc0_image_folder = "/elf"
182
// max_input_size_mb = 10
183
// image_download_timeout_secs = 60
184
// input_download_timeout_secs = 60
185
// maximum_concurrent_proofs = 10
186
// max_image_size_mb = 4
187
// image_compression_ttl_hours = 24
188
// stark_compression_tools_path = "./stark/"
189
// env = "dev"
190

            
191
// [transaction_sender_config]
192
// Rpc = { rpc_url = "http://localhost:8899" }
193

            
194
// [signer_config]
195
// KeypairFile = { path = "node_keypair.json" }"#;
196
//         let config: ProverNodeConfig = toml::from_str(config_content).unwrap();
197
//         assert_eq!(config.risc0_image_folder, "/elf");
198
//         assert_eq!(config.max_input_size_mb, 10);
199
//         assert_eq!(config.image_download_timeout_secs, 60);
200
//         assert_eq!(config.input_download_timeout_secs, 60);
201
//         assert_eq!(config.maximum_concurrent_proofs, 10);
202
//         assert_eq!(config.max_image_size_mb, 4);
203
//         assert_eq!(config.image_compression_ttl_hours, 24);
204
//         assert_eq!(config.stark_compression_tools_path, "./stark/");
205
//         assert_eq!(config.env, Some("dev".to_string()));
206
//         let serialized = toml::to_string(&config)?;
207
//         let deserialized: ProverNodeConfig = toml::from_str(&serialized)?;
208
//         assert_eq!(deserialized.risc0_image_folder, config.risc0_image_folder);
209
//         assert_eq!(deserialized.max_input_size_mb, config.max_input_size_mb);
210
//         assert_eq!(
211
//             deserialized.image_download_timeout_secs,
212
//             config.image_download_timeout_secs
213
//         );
214
//         assert_eq!(
215
//             deserialized.input_download_timeout_secs,
216
//             config.input_download_timeout_secs
217
//         );
218
//         assert_eq!(
219
//             deserialized.maximum_concurrent_proofs,
220
//             config.maximum_concurrent_proofs
221
//         );
222
//         assert_eq!(deserialized.max_image_size_mb, config.max_image_size_mb);
223
//         assert_eq!(
224
//             deserialized.image_compression_ttl_hours,
225
//             config.image_compression_ttl_hours
226
//         );
227
//         assert_eq!(
228
//             deserialized.stark_compression_tools_path,
229
//             config.stark_compression_tools_path
230
//         );
231
//         assert_eq!(deserialized.env, config.env);
232
//         match &config.transaction_sender_config {
233
//             TransactionSenderConfig::Rpc { rpc_url } => {
234
//                 assert_eq!(rpc_url, "http://localhost:8899");
235
//             }
236
//             _ => panic!("Expected Rpc transaction sender config"),
237
//         }
238
//         match &config.signer_config {
239
//             SignerConfig::KeypairFile { path } => {
240
//                 assert_eq!(path, "node_keypair.json");
241
//             }
242
//         }
243
//         assert_eq!(config.risc0_image_folder, "/elf");
244
//         assert_eq!(config.max_input_size_mb, 10);
245
//         assert_eq!(config.image_download_timeout_secs, 60);
246
//         assert_eq!(config.input_download_timeout_secs, 60);
247
//         assert_eq!(config.maximum_concurrent_proofs, 10);
248
//         assert_eq!(config.max_image_size_mb, 4);
249
//         assert_eq!(config.image_compression_ttl_hours, 24);
250
//         assert_eq!(config.stark_compression_tools_path, "./stark/");
251
//         assert_eq!(config.env, Some("dev".to_string()));
252
//         match config.transaction_sender_config {
253
//             TransactionSenderConfig::Rpc { rpc_url } => {
254
//                 assert_eq!(rpc_url, "http://localhost:8899");
255
//             }
256
//             _ => panic!("Expected Rpc transaction sender config"),
257
//         }
258
//         match config.signer_config {
259
//             SignerConfig::KeypairFile { path } => {
260
//                 assert_eq!(path, "node_keypair.json");
261
//             }
262
//         }
263
//         Ok(())
264
//     }
265
// }