muse2/
patch.rs

1//! Code for applying patches to model input files.
2use anyhow::{Context, Result, ensure};
3use csv::{ReaderBuilder, Trim, Writer};
4use indexmap::IndexSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8/// Struct to hold a set of patches to apply to a base model.
9pub struct ModelPatch {
10    // The base model directory path
11    base_model_dir: PathBuf,
12    // The list of file patches to apply
13    file_patches: Vec<FilePatch>,
14    // Optional patch for model.toml (TOML table)
15    toml_patch: Option<toml::value::Table>,
16}
17
18impl ModelPatch {
19    /// Create a new empty `ModelPatch` for a base model at the given directory.
20    pub fn new<P: Into<PathBuf>>(base_model_dir: P) -> Self {
21        ModelPatch {
22            base_model_dir: base_model_dir.into(),
23            file_patches: Vec::new(),
24            toml_patch: None,
25        }
26    }
27
28    /// Create a new empty `ModelPatch` for an example model
29    pub fn from_example(name: &str) -> Self {
30        let base_model_dir = PathBuf::from("examples").join(name);
31        ModelPatch::new(base_model_dir)
32    }
33
34    /// Add a single `FilePatch` to this `ModelPatch`.
35    pub fn with_file_patch(mut self, patch: FilePatch) -> Self {
36        self.file_patches.push(patch);
37        self
38    }
39
40    /// Add multiple `FilePatch` entries to this `ModelPatch`.
41    pub fn with_file_patches<I>(mut self, patches: I) -> Self
42    where
43        I: IntoIterator<Item = FilePatch>,
44    {
45        self.file_patches.extend(patches);
46        self
47    }
48
49    /// Add a TOML patch (provided as a string) to this `ModelPatch`.
50    /// The string will be parsed into a `toml::value::Table`.
51    pub fn with_toml_patch(mut self, patch_str: impl AsRef<str>) -> Self {
52        assert!(
53            self.toml_patch.is_none(),
54            "TOML patch already set for this ModelPatch"
55        );
56        let s = patch_str.as_ref();
57        let patch: toml::value::Table =
58            toml::from_str(s).expect("Failed to parse string passed to with_toml_patch");
59        self.toml_patch = Some(patch);
60        self
61    }
62
63    /// Build this `ModelPatch` into `out_dir` (creating/overwriting files there).
64    pub fn build<O: AsRef<Path>>(&self, out_dir: O) -> Result<()> {
65        let base_dir = self.base_model_dir.as_path();
66        let out_path = out_dir.as_ref();
67
68        // Apply toml patch (if any), or copy model.toml unchanged from the base model
69        let base_toml_path = base_dir.join("model.toml");
70        let out_toml_path = out_path.join("model.toml");
71        if let Some(toml_patch) = &self.toml_patch {
72            let toml_content = fs::read_to_string(&base_toml_path)?;
73            let merged_toml = merge_model_toml(&toml_content, toml_patch)?;
74            fs::write(&out_toml_path, merged_toml)?;
75        } else {
76            fs::copy(&base_toml_path, &out_toml_path)?;
77        }
78
79        // Copy all CSV files from the base model into the output directory
80        // Any files with associated patches will be overwritten later
81        for entry in fs::read_dir(base_dir)? {
82            let entry = entry?;
83            let src_path = entry.path();
84            if src_path.is_file()
85                && src_path
86                    .extension()
87                    .and_then(|e| e.to_str())
88                    .is_some_and(|ext| ext.eq_ignore_ascii_case("csv"))
89            {
90                let dst_path = out_path.join(entry.file_name());
91                fs::copy(&src_path, &dst_path)?;
92            }
93        }
94
95        // Apply file patches
96        for patch in &self.file_patches {
97            patch.apply_and_save(base_dir, out_path)?;
98        }
99
100        Ok(())
101    }
102
103    /// Build the patched model into a temporary directory and return the `TempDir`.
104    pub fn build_to_tempdir(&self) -> Result<tempfile::TempDir> {
105        let temp_dir = tempfile::tempdir()?;
106        self.build(temp_dir.path())?;
107        Ok(temp_dir)
108    }
109}
110
111/// Represents all rows and columns of a CSV file.
112///
113/// Assumes that each row is unique (as it should be for all MUSE2 input files).
114type CSVTable = IndexSet<Vec<String>>;
115
116/// Structure to hold patches for a model csv file.
117#[derive(Clone)]
118pub struct FilePatch {
119    /// The file that this patch applies to (e.g. "agents.csv")
120    filename: String,
121    /// The header row (optional). If `None`, the header is not checked against base files.
122    header_row: Option<Vec<String>>,
123    /// Full replacement content for this file (optional)
124    replacement_content: Option<String>,
125    /// Rows to delete (each row is a vector of fields)
126    to_delete: CSVTable,
127    /// Rows to add (each row is a vector of fields)
128    to_add: CSVTable,
129}
130
131impl FilePatch {
132    /// Create a new empty `Patch` for the given file.
133    pub fn new(filename: impl Into<String>) -> Self {
134        FilePatch {
135            filename: filename.into(),
136            header_row: None,
137            replacement_content: None,
138            to_delete: IndexSet::new(),
139            to_add: IndexSet::new(),
140        }
141    }
142
143    /// Set the header row for this patch (header should be a comma-joined string, e.g. "a,b,c").
144    pub fn with_header(mut self, header: impl Into<String>) -> Self {
145        assert!(
146            self.replacement_content.is_none(),
147            "Cannot set header when replacement content is set for this FilePatch",
148        );
149        assert!(
150            self.header_row.is_none(),
151            "Header already set for this FilePatch",
152        );
153        let s = header.into();
154        let v = s.split(',').map(|s| s.trim().to_string()).collect();
155        self.header_row = Some(v);
156        self
157    }
158
159    /// Set full replacement content for this file from a slice of lines.
160    ///
161    /// Each line is joined with newlines, and a trailing newline is added.
162    /// All lines must have the same number of columns (commas).
163    /// Example: `with_replacement(&["header1,header2", "value1,value2"])`
164    pub fn with_replacement(mut self, lines: &[&str]) -> Self {
165        assert!(
166            self.header_row.is_none(),
167            "Cannot set replacement content when header is set for this FilePatch",
168        );
169        assert!(
170            self.to_delete.is_empty() && self.to_add.is_empty(),
171            "Cannot set replacement content when additions/deletions are set for this FilePatch",
172        );
173        assert!(
174            self.replacement_content.is_none(),
175            "Replacement content already set for this FilePatch",
176        );
177
178        // Validate that all lines have the same number of columns
179        if !lines.is_empty() {
180            let first_col_count = lines[0].matches(',').count() + 1;
181            for (idx, line) in lines.iter().enumerate() {
182                let col_count = line.matches(',').count() + 1;
183                assert_eq!(
184                    col_count, first_col_count,
185                    "Line {idx} has {col_count} columns but line 0 has {first_col_count}: {line:?}"
186                );
187            }
188        }
189
190        let content = lines.join("\n") + "\n";
191        self.replacement_content = Some(content);
192        self
193    }
194
195    /// Add a row to the patch (row should be a comma-joined string, e.g. "a,b,c").
196    pub fn with_addition(mut self, row: impl Into<String>) -> Self {
197        assert!(
198            self.replacement_content.is_none(),
199            "Cannot add rows when replacement content is set for this FilePatch",
200        );
201        let s = row.into();
202        let v = s.split(',').map(|s| s.trim().to_string()).collect();
203        self.to_add.insert(v);
204        self
205    }
206
207    /// Mark a row for deletion from the base (row should be a comma-joined string, e.g. "a,b,c").
208    pub fn with_deletion(mut self, row: impl Into<String>) -> Self {
209        assert!(
210            self.replacement_content.is_none(),
211            "Cannot delete rows when replacement content is set for this FilePatch",
212        );
213        let s = row.into();
214        let v = s.split(',').map(|s| s.trim().to_string()).collect();
215        self.to_delete.insert(v);
216        self
217    }
218
219    /// Apply this patch to a base model and return the modified CSV as a string.
220    fn apply(&self, base_model_dir: &Path) -> Result<String> {
221        // Read and validate the base file path
222        let base_path = base_model_dir.join(&self.filename);
223        ensure!(
224            base_path.exists() && base_path.is_file(),
225            "Base file for patching does not exist: {}",
226            base_path.display()
227        );
228
229        // If this patch is a full replacement, validate the base file exists
230        // (checked above) and return the replacement content
231        if let Some(content) = &self.replacement_content {
232            return Ok(content.clone());
233        }
234
235        // Read the base file to string
236        let base = fs::read_to_string(&base_path)?;
237
238        // Apply the patch
239        let modified = modify_base_with_patch(&base, self)
240            .with_context(|| format!("Error applying patch to file: {}", self.filename))?;
241        Ok(modified)
242    }
243
244    /// Apply this patch to a base model and save the modified CSV to another directory.
245    pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> {
246        let modified = self.apply(base_model_dir)?;
247        let new_path = out_model_dir.join(&self.filename);
248        fs::write(&new_path, modified)?;
249        Ok(())
250    }
251}
252
253/// Merge a TOML patch into a base TOML string and return the merged TOML.
254fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result<String> {
255    // Parse base TOML into a table
256    let mut base_val: toml::Value = toml::from_str(base_toml)?;
257    let base_tbl = base_val
258        .as_table_mut()
259        .context("Base model TOML must be a table")?;
260
261    // Apply patch entries
262    for (k, v) in patch {
263        base_tbl.insert(k.clone(), v.clone());
264    }
265
266    // Serialize merged TOML back to string
267    let out = toml::to_string_pretty(&base_val)?;
268    Ok(out)
269}
270
271/// Modify a string representation of a base CSV file by applying a `FilePatch`.
272/// Preserves the order of rows from the base file, with new rows appended at the end.
273fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result<String> {
274    // Read base string, trimming whitespace
275    let mut reader = ReaderBuilder::new()
276        .trim(Trim::All)
277        .from_reader(base.as_bytes());
278
279    // Extract header from the base string
280    let base_header = reader
281        .headers()
282        .context("Failed to read base file header")?;
283    let base_header_vec: Vec<String> = base_header.iter().map(ToString::to_string).collect();
284
285    // If the patch contains a header, compare it with the base header.
286    if let Some(ref header_row_vec) = patch.header_row {
287        ensure!(
288            base_header_vec == *header_row_vec,
289            "Header mismatch: base file has [{}], patch has [{}]",
290            base_header_vec.join(", "),
291            header_row_vec.join(", ")
292        );
293    }
294    // Read all rows from the base, preserving order and checking for duplicates
295    let mut base_rows: CSVTable = CSVTable::new();
296    for result in reader.records() {
297        let record = result?;
298
299        // Create normalized row vector by trimming fields
300        let row_vec = record
301            .iter()
302            .map(|s| s.trim().to_string())
303            .collect::<Vec<_>>();
304
305        // Check for duplicates
306        ensure!(
307            base_rows.insert(row_vec.clone()),
308            "Duplicate row in base file: {row_vec:?}",
309        );
310    }
311
312    // Check that there's no overlap between additions and deletions
313    for del_row in &patch.to_delete {
314        ensure!(
315            !patch.to_add.contains(del_row),
316            "Row appears in both deletions and additions: {del_row:?}",
317        );
318    }
319
320    // Ensure every row requested for deletion actually exists in the base file.
321    for del_row in &patch.to_delete {
322        ensure!(
323            base_rows.contains(del_row),
324            "Row to delete not present in base file: {del_row:?}"
325        );
326    }
327
328    // Apply deletions
329    base_rows.retain(|row| !patch.to_delete.contains(row));
330
331    // Apply additions (append to end, checking for duplicates)
332    for add_row in &patch.to_add {
333        ensure!(
334            base_rows.insert(add_row.clone()),
335            "Addition already present in base file: {add_row:?}"
336        );
337    }
338
339    // Check all rows match base header length
340    let expected_len = base_header_vec.len();
341    for row in &base_rows {
342        ensure!(
343            row.len() == expected_len,
344            "Row has {} columns but header has {expected_len}: {row:?}",
345            row.len(),
346        );
347    }
348
349    // Serialize CSV output using csv::Writer
350    let mut wtr = Writer::from_writer(vec![]);
351    wtr.write_record(base_header_vec.iter())?;
352    for row in &base_rows {
353        let row_iter = row.iter().map(String::as_str);
354        wtr.write_record(row_iter)?;
355    }
356    wtr.flush()?;
357    let inner = wtr.into_inner()?;
358    let output = String::from_utf8(inner)?;
359    Ok(output)
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::fixture::assert_error;
366    use crate::input::read_toml;
367    use crate::model::ModelParameters;
368    use crate::patch::{FilePatch, ModelPatch};
369
370    #[test]
371    fn modify_base_with_patch_works() {
372        let base = "col1,col2\nvalue1,value2\nvalue3,value4\nvalue5,value6\n";
373
374        // Create a patch to delete row3,row4 and add row7,row8
375        let patch = FilePatch::new("test.csv")
376            .with_header("col1,col2")
377            .with_deletion("value3,value4")
378            .with_addition("value7,value8");
379
380        let modified = modify_base_with_patch(base, &patch).unwrap();
381
382        let lines: Vec<&str> = modified.lines().collect();
383        assert_eq!(lines[0], "col1,col2"); // header is present
384        assert_eq!(lines[1], "value1,value2"); // unchanged row
385        assert_eq!(lines[2], "value5,value6"); // unchanged row
386        assert_eq!(lines[3], "value7,value8"); // added row
387        assert!(!modified.contains("value3,value4")); // deleted row
388    }
389
390    #[test]
391    fn modify_base_with_patch_mismatched_header() {
392        let base = "col1,col2\nvalue1,value2\n";
393
394        // Create a patch with a mismatched header
395        let patch = FilePatch::new("test.csv").with_header("col1,col3");
396
397        assert_error!(
398            modify_base_with_patch(base, &patch),
399            "Header mismatch: base file has [col1, col2], patch has [col1, col3]"
400        );
401    }
402
403    #[test]
404    fn merge_model_toml_basic() {
405        let base = r#"
406            field = "data"
407            [section]
408            a = 1
409        "#;
410
411        // Create a TOML patch
412        let mut patch = toml::value::Table::new();
413        patch.insert(
414            "field".to_string(),
415            toml::Value::String("patched".to_string()),
416        );
417        patch.insert(
418            "new_field".to_string(),
419            toml::Value::String("added".to_string()),
420        );
421
422        // Apply patch with `merge_model_toml`
423        // Should overwrite field and add new_field, but keep section.a
424        let merged = merge_model_toml(base, &patch).unwrap();
425        assert!(merged.contains("field = \"patched\""));
426        assert!(merged.contains("[section]"));
427        assert!(merged.contains("new_field = \"added\""));
428    }
429
430    #[test]
431    fn file_patch() {
432        // Patch with a small change to an asset capacity
433        let assets_patch = FilePatch::new("assets.csv")
434            .with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020")
435            .with_addition("GASDRV,GBR,A0_GEX,4003.26,2020");
436
437        // Build patched model into a temporary directory
438        let model_dir = ModelPatch::from_example("simple")
439            .with_file_patch(assets_patch)
440            .build_to_tempdir()
441            .unwrap();
442
443        // Check that the appropriate change has been made
444        let assets_path = model_dir.path().join("assets.csv");
445        let assets_content = std::fs::read_to_string(assets_path).unwrap();
446        assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020"));
447        assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020"));
448    }
449
450    #[test]
451    fn file_patch_with_replacement() {
452        let expected = "col1,col2\nnew1,new2\n";
453
454        let model_dir = ModelPatch::from_example("simple")
455            .with_file_patch(
456                FilePatch::new("assets.csv").with_replacement(&["col1,col2", "new1,new2"]),
457            )
458            .build_to_tempdir()
459            .unwrap();
460
461        let assets_path = model_dir.path().join("assets.csv");
462        let assets_content = std::fs::read_to_string(assets_path).unwrap();
463        assert_eq!(assets_content, expected);
464    }
465
466    #[test]
467    #[should_panic(
468        expected = "Cannot set replacement content when header is set for this FilePatch"
469    )]
470    fn file_patch_replacement_after_header_panics() {
471        let _ = FilePatch::new("assets.csv")
472            .with_header("col1,col2")
473            .with_replacement(&["col1,col2", "a,b"]);
474    }
475
476    #[test]
477    #[should_panic(
478        expected = "Cannot set replacement content when additions/deletions are set for this FilePatch"
479    )]
480    fn file_patch_replacement_after_addition_panics() {
481        let _ = FilePatch::new("assets.csv")
482            .with_addition("a,b")
483            .with_replacement(&["col1,col2", "a,b"]);
484    }
485
486    #[test]
487    #[should_panic(expected = "Cannot add rows when replacement content is set for this FilePatch")]
488    fn file_patch_addition_after_replacement_panics() {
489        let _ = FilePatch::new("assets.csv")
490            .with_replacement(&["col1,col2", "a,b"])
491            .with_addition("c,d");
492    }
493
494    #[test]
495    fn file_patch_with_replacement_missing_base_file_fails() {
496        let model_patch = ModelPatch::from_example("simple").with_file_patch(
497            FilePatch::new("not_a_real_file.csv").with_replacement(&["x,y", "1,2"]),
498        );
499
500        let expected = format!(
501            "Base file for patching does not exist: {}",
502            std::path::PathBuf::from("examples")
503                .join("simple")
504                .join("not_a_real_file.csv")
505                .display()
506        );
507
508        assert_error!(model_patch.build_to_tempdir(), expected);
509    }
510
511    #[test]
512    #[should_panic(expected = "Line 1 has 2 columns but line 0 has 3")]
513    fn file_patch_replacement_column_count_mismatch_panics() {
514        let _ = FilePatch::new("test.csv").with_replacement(&["col1,col2,col3", "a,b"]);
515    }
516
517    #[test]
518    fn toml_patch() {
519        // Patch to add an extra milestone year (2050)
520        let toml_patch = "milestone_years = [2020, 2030, 2040, 2050]\n";
521
522        // Build patched model into a temporary directory
523        let model_dir = ModelPatch::from_example("simple")
524            .with_toml_patch(toml_patch)
525            .build_to_tempdir()
526            .unwrap();
527
528        // Check that the appropriate change has been made
529        let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap();
530        assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]);
531    }
532}