1use anyhow::{Context, Result, ensure};
3use csv::{ReaderBuilder, Trim, Writer};
4use indexmap::IndexSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8pub struct ModelPatch {
10 base_model_dir: PathBuf,
12 file_patches: Vec<FilePatch>,
14 toml_patch: Option<toml::value::Table>,
16}
17
18impl ModelPatch {
19 pub fn new<P: Into<PathBuf>>(base_model_dir: P) -> Self {
21 ModelPatch {
22 base_model_dir: base_model_dir.into(),
23 file_patches: Vec::new(),
24 toml_patch: None,
25 }
26 }
27
28 pub fn from_example(name: &str) -> Self {
30 let base_model_dir = PathBuf::from("examples").join(name);
31 ModelPatch::new(base_model_dir)
32 }
33
34 pub fn with_file_patch(mut self, patch: FilePatch) -> Self {
36 self.file_patches.push(patch);
37 self
38 }
39
40 pub fn with_file_patches<I>(mut self, patches: I) -> Self
42 where
43 I: IntoIterator<Item = FilePatch>,
44 {
45 self.file_patches.extend(patches);
46 self
47 }
48
49 pub fn with_toml_patch(mut self, patch_str: impl AsRef<str>) -> Self {
52 assert!(
53 self.toml_patch.is_none(),
54 "TOML patch already set for this ModelPatch"
55 );
56 let s = patch_str.as_ref();
57 let patch: toml::value::Table =
58 toml::from_str(s).expect("Failed to parse string passed to with_toml_patch");
59 self.toml_patch = Some(patch);
60 self
61 }
62
63 pub fn build<O: AsRef<Path>>(&self, out_dir: O) -> Result<()> {
65 let base_dir = self.base_model_dir.as_path();
66 let out_path = out_dir.as_ref();
67
68 let base_toml_path = base_dir.join("model.toml");
70 let out_toml_path = out_path.join("model.toml");
71 if let Some(toml_patch) = &self.toml_patch {
72 let toml_content = fs::read_to_string(&base_toml_path)?;
73 let merged_toml = merge_model_toml(&toml_content, toml_patch)?;
74 fs::write(&out_toml_path, merged_toml)?;
75 } else {
76 fs::copy(&base_toml_path, &out_toml_path)?;
77 }
78
79 for entry in fs::read_dir(base_dir)? {
82 let entry = entry?;
83 let src_path = entry.path();
84 if src_path.is_file()
85 && src_path
86 .extension()
87 .and_then(|e| e.to_str())
88 .is_some_and(|ext| ext.eq_ignore_ascii_case("csv"))
89 {
90 let dst_path = out_path.join(entry.file_name());
91 fs::copy(&src_path, &dst_path)?;
92 }
93 }
94
95 for patch in &self.file_patches {
97 patch.apply_and_save(base_dir, out_path)?;
98 }
99
100 Ok(())
101 }
102
103 pub fn build_to_tempdir(&self) -> Result<tempfile::TempDir> {
105 let temp_dir = tempfile::tempdir()?;
106 self.build(temp_dir.path())?;
107 Ok(temp_dir)
108 }
109}
110
111type CSVTable = IndexSet<Vec<String>>;
115
116#[derive(Clone)]
118pub struct FilePatch {
119 filename: String,
121 header_row: Option<Vec<String>>,
123 replacement_content: Option<String>,
125 to_delete: CSVTable,
127 to_add: CSVTable,
129}
130
131impl FilePatch {
132 pub fn new(filename: impl Into<String>) -> Self {
134 FilePatch {
135 filename: filename.into(),
136 header_row: None,
137 replacement_content: None,
138 to_delete: IndexSet::new(),
139 to_add: IndexSet::new(),
140 }
141 }
142
143 pub fn with_header(mut self, header: impl Into<String>) -> Self {
145 assert!(
146 self.replacement_content.is_none(),
147 "Cannot set header when replacement content is set for this FilePatch",
148 );
149 assert!(
150 self.header_row.is_none(),
151 "Header already set for this FilePatch",
152 );
153 let s = header.into();
154 let v = s.split(',').map(|s| s.trim().to_string()).collect();
155 self.header_row = Some(v);
156 self
157 }
158
159 pub fn with_replacement(mut self, lines: &[&str]) -> Self {
165 assert!(
166 self.header_row.is_none(),
167 "Cannot set replacement content when header is set for this FilePatch",
168 );
169 assert!(
170 self.to_delete.is_empty() && self.to_add.is_empty(),
171 "Cannot set replacement content when additions/deletions are set for this FilePatch",
172 );
173 assert!(
174 self.replacement_content.is_none(),
175 "Replacement content already set for this FilePatch",
176 );
177
178 if !lines.is_empty() {
180 let first_col_count = lines[0].matches(',').count() + 1;
181 for (idx, line) in lines.iter().enumerate() {
182 let col_count = line.matches(',').count() + 1;
183 assert_eq!(
184 col_count, first_col_count,
185 "Line {idx} has {col_count} columns but line 0 has {first_col_count}: {line:?}"
186 );
187 }
188 }
189
190 let content = lines.join("\n") + "\n";
191 self.replacement_content = Some(content);
192 self
193 }
194
195 pub fn with_addition(mut self, row: impl Into<String>) -> Self {
197 assert!(
198 self.replacement_content.is_none(),
199 "Cannot add rows when replacement content is set for this FilePatch",
200 );
201 let s = row.into();
202 let v = s.split(',').map(|s| s.trim().to_string()).collect();
203 self.to_add.insert(v);
204 self
205 }
206
207 pub fn with_deletion(mut self, row: impl Into<String>) -> Self {
209 assert!(
210 self.replacement_content.is_none(),
211 "Cannot delete rows when replacement content is set for this FilePatch",
212 );
213 let s = row.into();
214 let v = s.split(',').map(|s| s.trim().to_string()).collect();
215 self.to_delete.insert(v);
216 self
217 }
218
219 fn apply(&self, base_model_dir: &Path) -> Result<String> {
221 let base_path = base_model_dir.join(&self.filename);
223 ensure!(
224 base_path.exists() && base_path.is_file(),
225 "Base file for patching does not exist: {}",
226 base_path.display()
227 );
228
229 if let Some(content) = &self.replacement_content {
232 return Ok(content.clone());
233 }
234
235 let base = fs::read_to_string(&base_path)?;
237
238 let modified = modify_base_with_patch(&base, self)
240 .with_context(|| format!("Error applying patch to file: {}", self.filename))?;
241 Ok(modified)
242 }
243
244 pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> {
246 let modified = self.apply(base_model_dir)?;
247 let new_path = out_model_dir.join(&self.filename);
248 fs::write(&new_path, modified)?;
249 Ok(())
250 }
251}
252
253fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result<String> {
255 let mut base_val: toml::Value = toml::from_str(base_toml)?;
257 let base_tbl = base_val
258 .as_table_mut()
259 .context("Base model TOML must be a table")?;
260
261 for (k, v) in patch {
263 base_tbl.insert(k.clone(), v.clone());
264 }
265
266 let out = toml::to_string_pretty(&base_val)?;
268 Ok(out)
269}
270
271fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result<String> {
274 let mut reader = ReaderBuilder::new()
276 .trim(Trim::All)
277 .from_reader(base.as_bytes());
278
279 let base_header = reader
281 .headers()
282 .context("Failed to read base file header")?;
283 let base_header_vec: Vec<String> = base_header.iter().map(ToString::to_string).collect();
284
285 if let Some(ref header_row_vec) = patch.header_row {
287 ensure!(
288 base_header_vec == *header_row_vec,
289 "Header mismatch: base file has [{}], patch has [{}]",
290 base_header_vec.join(", "),
291 header_row_vec.join(", ")
292 );
293 }
294 let mut base_rows: CSVTable = CSVTable::new();
296 for result in reader.records() {
297 let record = result?;
298
299 let row_vec = record
301 .iter()
302 .map(|s| s.trim().to_string())
303 .collect::<Vec<_>>();
304
305 ensure!(
307 base_rows.insert(row_vec.clone()),
308 "Duplicate row in base file: {row_vec:?}",
309 );
310 }
311
312 for del_row in &patch.to_delete {
314 ensure!(
315 !patch.to_add.contains(del_row),
316 "Row appears in both deletions and additions: {del_row:?}",
317 );
318 }
319
320 for del_row in &patch.to_delete {
322 ensure!(
323 base_rows.contains(del_row),
324 "Row to delete not present in base file: {del_row:?}"
325 );
326 }
327
328 base_rows.retain(|row| !patch.to_delete.contains(row));
330
331 for add_row in &patch.to_add {
333 ensure!(
334 base_rows.insert(add_row.clone()),
335 "Addition already present in base file: {add_row:?}"
336 );
337 }
338
339 let expected_len = base_header_vec.len();
341 for row in &base_rows {
342 ensure!(
343 row.len() == expected_len,
344 "Row has {} columns but header has {expected_len}: {row:?}",
345 row.len(),
346 );
347 }
348
349 let mut wtr = Writer::from_writer(vec![]);
351 wtr.write_record(base_header_vec.iter())?;
352 for row in &base_rows {
353 let row_iter = row.iter().map(String::as_str);
354 wtr.write_record(row_iter)?;
355 }
356 wtr.flush()?;
357 let inner = wtr.into_inner()?;
358 let output = String::from_utf8(inner)?;
359 Ok(output)
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::fixture::assert_error;
366 use crate::input::read_toml;
367 use crate::model::ModelParameters;
368 use crate::patch::{FilePatch, ModelPatch};
369
370 #[test]
371 fn modify_base_with_patch_works() {
372 let base = "col1,col2\nvalue1,value2\nvalue3,value4\nvalue5,value6\n";
373
374 let patch = FilePatch::new("test.csv")
376 .with_header("col1,col2")
377 .with_deletion("value3,value4")
378 .with_addition("value7,value8");
379
380 let modified = modify_base_with_patch(base, &patch).unwrap();
381
382 let lines: Vec<&str> = modified.lines().collect();
383 assert_eq!(lines[0], "col1,col2"); assert_eq!(lines[1], "value1,value2"); assert_eq!(lines[2], "value5,value6"); assert_eq!(lines[3], "value7,value8"); assert!(!modified.contains("value3,value4")); }
389
390 #[test]
391 fn modify_base_with_patch_mismatched_header() {
392 let base = "col1,col2\nvalue1,value2\n";
393
394 let patch = FilePatch::new("test.csv").with_header("col1,col3");
396
397 assert_error!(
398 modify_base_with_patch(base, &patch),
399 "Header mismatch: base file has [col1, col2], patch has [col1, col3]"
400 );
401 }
402
403 #[test]
404 fn merge_model_toml_basic() {
405 let base = r#"
406 field = "data"
407 [section]
408 a = 1
409 "#;
410
411 let mut patch = toml::value::Table::new();
413 patch.insert(
414 "field".to_string(),
415 toml::Value::String("patched".to_string()),
416 );
417 patch.insert(
418 "new_field".to_string(),
419 toml::Value::String("added".to_string()),
420 );
421
422 let merged = merge_model_toml(base, &patch).unwrap();
425 assert!(merged.contains("field = \"patched\""));
426 assert!(merged.contains("[section]"));
427 assert!(merged.contains("new_field = \"added\""));
428 }
429
430 #[test]
431 fn file_patch() {
432 let assets_patch = FilePatch::new("assets.csv")
434 .with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020")
435 .with_addition("GASDRV,GBR,A0_GEX,4003.26,2020");
436
437 let model_dir = ModelPatch::from_example("simple")
439 .with_file_patch(assets_patch)
440 .build_to_tempdir()
441 .unwrap();
442
443 let assets_path = model_dir.path().join("assets.csv");
445 let assets_content = std::fs::read_to_string(assets_path).unwrap();
446 assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020"));
447 assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020"));
448 }
449
450 #[test]
451 fn file_patch_with_replacement() {
452 let expected = "col1,col2\nnew1,new2\n";
453
454 let model_dir = ModelPatch::from_example("simple")
455 .with_file_patch(
456 FilePatch::new("assets.csv").with_replacement(&["col1,col2", "new1,new2"]),
457 )
458 .build_to_tempdir()
459 .unwrap();
460
461 let assets_path = model_dir.path().join("assets.csv");
462 let assets_content = std::fs::read_to_string(assets_path).unwrap();
463 assert_eq!(assets_content, expected);
464 }
465
466 #[test]
467 #[should_panic(
468 expected = "Cannot set replacement content when header is set for this FilePatch"
469 )]
470 fn file_patch_replacement_after_header_panics() {
471 let _ = FilePatch::new("assets.csv")
472 .with_header("col1,col2")
473 .with_replacement(&["col1,col2", "a,b"]);
474 }
475
476 #[test]
477 #[should_panic(
478 expected = "Cannot set replacement content when additions/deletions are set for this FilePatch"
479 )]
480 fn file_patch_replacement_after_addition_panics() {
481 let _ = FilePatch::new("assets.csv")
482 .with_addition("a,b")
483 .with_replacement(&["col1,col2", "a,b"]);
484 }
485
486 #[test]
487 #[should_panic(expected = "Cannot add rows when replacement content is set for this FilePatch")]
488 fn file_patch_addition_after_replacement_panics() {
489 let _ = FilePatch::new("assets.csv")
490 .with_replacement(&["col1,col2", "a,b"])
491 .with_addition("c,d");
492 }
493
494 #[test]
495 fn file_patch_with_replacement_missing_base_file_fails() {
496 let model_patch = ModelPatch::from_example("simple").with_file_patch(
497 FilePatch::new("not_a_real_file.csv").with_replacement(&["x,y", "1,2"]),
498 );
499
500 let expected = format!(
501 "Base file for patching does not exist: {}",
502 std::path::PathBuf::from("examples")
503 .join("simple")
504 .join("not_a_real_file.csv")
505 .display()
506 );
507
508 assert_error!(model_patch.build_to_tempdir(), expected);
509 }
510
511 #[test]
512 #[should_panic(expected = "Line 1 has 2 columns but line 0 has 3")]
513 fn file_patch_replacement_column_count_mismatch_panics() {
514 let _ = FilePatch::new("test.csv").with_replacement(&["col1,col2,col3", "a,b"]);
515 }
516
517 #[test]
518 fn toml_patch() {
519 let toml_patch = "milestone_years = [2020, 2030, 2040, 2050]\n";
521
522 let model_dir = ModelPatch::from_example("simple")
524 .with_toml_patch(toml_patch)
525 .build_to_tempdir()
526 .unwrap();
527
528 let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap();
530 assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]);
531 }
532}