1use anyhow::{Context, Result, ensure};
3use csv::{ReaderBuilder, Trim, Writer};
4use indexmap::IndexSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8pub struct ModelPatch {
10 base_model_dir: PathBuf,
12 file_patches: Vec<FilePatch>,
14 toml_patch: Option<toml::value::Table>,
16}
17
18impl ModelPatch {
19 pub fn new<P: Into<PathBuf>>(base_model_dir: P) -> Self {
21 ModelPatch {
22 base_model_dir: base_model_dir.into(),
23 file_patches: Vec::new(),
24 toml_patch: None,
25 }
26 }
27
28 pub fn from_example(name: &str) -> Self {
30 let base_model_dir = PathBuf::from("examples").join(name);
31 ModelPatch::new(base_model_dir)
32 }
33
34 pub fn with_file_patch(mut self, patch: FilePatch) -> Self {
36 self.file_patches.push(patch);
37 self
38 }
39
40 pub fn with_file_patches<I>(mut self, patches: I) -> Self
42 where
43 I: IntoIterator<Item = FilePatch>,
44 {
45 self.file_patches.extend(patches);
46 self
47 }
48
49 pub fn with_toml_patch(mut self, patch_str: impl AsRef<str>) -> Self {
52 assert!(
53 self.toml_patch.is_none(),
54 "TOML patch already set for this ModelPatch"
55 );
56 let s = patch_str.as_ref();
57 let patch: toml::value::Table =
58 toml::from_str(s).expect("Failed to parse string passed to with_toml_patch");
59 self.toml_patch = Some(patch);
60 self
61 }
62
63 pub fn build<O: AsRef<Path>>(&self, out_dir: O) -> Result<()> {
65 let base_dir = self.base_model_dir.as_path();
66 let out_path = out_dir.as_ref();
67
68 let base_toml_path = base_dir.join("model.toml");
70 let out_toml_path = out_path.join("model.toml");
71 if let Some(toml_patch) = &self.toml_patch {
72 let toml_content = fs::read_to_string(&base_toml_path)?;
73 let merged_toml = merge_model_toml(&toml_content, toml_patch)?;
74 fs::write(&out_toml_path, merged_toml)?;
75 } else {
76 fs::copy(&base_toml_path, &out_toml_path)?;
77 }
78
79 for entry in fs::read_dir(base_dir)? {
82 let entry = entry?;
83 let src_path = entry.path();
84 if src_path.is_file()
85 && src_path
86 .extension()
87 .and_then(|e| e.to_str())
88 .is_some_and(|ext| ext.eq_ignore_ascii_case("csv"))
89 {
90 let dst_path = out_path.join(entry.file_name());
91 fs::copy(&src_path, &dst_path)?;
92 }
93 }
94
95 for patch in &self.file_patches {
97 patch.apply_and_save(base_dir, out_path)?;
98 }
99
100 Ok(())
101 }
102
103 pub fn build_to_tempdir(&self) -> Result<tempfile::TempDir> {
105 let temp_dir = tempfile::tempdir()?;
106 self.build(temp_dir.path())?;
107 Ok(temp_dir)
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct FilePatch {
114 filename: String,
116 header_row: Option<Vec<String>>,
118 to_delete: IndexSet<Vec<String>>,
120 to_add: IndexSet<Vec<String>>,
122}
123
124impl FilePatch {
125 pub fn new(filename: impl Into<String>) -> Self {
127 FilePatch {
128 filename: filename.into(),
129 header_row: None,
130 to_delete: IndexSet::new(),
131 to_add: IndexSet::new(),
132 }
133 }
134
135 pub fn with_header(mut self, header: impl Into<String>) -> Self {
137 assert!(
138 self.header_row.is_none(),
139 "Header already set for this FilePatch",
140 );
141 let s = header.into();
142 let v = s.split(',').map(|s| s.trim().to_string()).collect();
143 self.header_row = Some(v);
144 self
145 }
146
147 pub fn with_addition(mut self, row: impl Into<String>) -> Self {
149 let s = row.into();
150 let v = s.split(',').map(|s| s.trim().to_string()).collect();
151 self.to_add.insert(v);
152 self
153 }
154
155 pub fn with_deletion(mut self, row: impl Into<String>) -> Self {
157 let s = row.into();
158 let v = s.split(',').map(|s| s.trim().to_string()).collect();
159 self.to_delete.insert(v);
160 self
161 }
162
163 fn apply(&self, base_model_dir: &Path) -> Result<String> {
165 let base_path = base_model_dir.join(&self.filename);
167 ensure!(
168 base_path.exists() && base_path.is_file(),
169 "Base file for patching does not exist: {}",
170 base_path.display()
171 );
172 let base = fs::read_to_string(&base_path)?;
173
174 let modified = modify_base_with_patch(&base, self)
176 .with_context(|| format!("Error applying patch to file: {}", self.filename))?;
177 Ok(modified)
178 }
179
180 pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> {
182 let modified = self.apply(base_model_dir)?;
183 let new_path = out_model_dir.join(&self.filename);
184 fs::write(&new_path, modified)?;
185 Ok(())
186 }
187}
188
189fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result<String> {
191 let mut base_val: toml::Value = toml::from_str(base_toml)?;
193 let base_tbl = base_val
194 .as_table_mut()
195 .context("Base model TOML must be a table")?;
196
197 for (k, v) in patch {
199 base_tbl.insert(k.clone(), v.clone());
200 }
201
202 let out = toml::to_string_pretty(&base_val)?;
204 Ok(out)
205}
206
207fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result<String> {
210 let mut reader = ReaderBuilder::new()
212 .trim(Trim::All)
213 .from_reader(base.as_bytes());
214
215 let base_header = reader
217 .headers()
218 .context("Failed to read base file header")?;
219 let base_header_vec: Vec<String> = base_header.iter().map(ToString::to_string).collect();
220
221 if let Some(ref header_row_vec) = patch.header_row {
223 ensure!(
224 base_header_vec == *header_row_vec,
225 "Header mismatch: base file has [{}], patch has [{}]",
226 base_header_vec.join(", "),
227 header_row_vec.join(", ")
228 );
229 }
230
231 let mut base_rows: IndexSet<Vec<String>> = IndexSet::new();
233 for result in reader.records() {
234 let record = result?;
235
236 let row_vec = record
238 .iter()
239 .map(|s| s.trim().to_string())
240 .collect::<Vec<_>>();
241
242 ensure!(
244 base_rows.insert(row_vec.clone()),
245 "Duplicate row in base file: {row_vec:?}",
246 );
247 }
248
249 for del_row in &patch.to_delete {
251 ensure!(
252 !patch.to_add.contains(del_row),
253 "Row appears in both deletions and additions: {del_row:?}",
254 );
255 }
256
257 for del_row in &patch.to_delete {
259 ensure!(
260 base_rows.contains(del_row),
261 "Row to delete not present in base file: {del_row:?}"
262 );
263 }
264
265 base_rows.retain(|row| !patch.to_delete.contains(row));
267
268 for add_row in &patch.to_add {
270 ensure!(
271 base_rows.insert(add_row.clone()),
272 "Addition already present in base file: {add_row:?}"
273 );
274 }
275
276 let mut wtr = Writer::from_writer(vec![]);
278 wtr.write_record(base_header_vec.iter())?;
279 for row in &base_rows {
280 let row_iter = row.iter().map(String::as_str);
281 wtr.write_record(row_iter)?;
282 }
283 wtr.flush()?;
284 let inner = wtr.into_inner()?;
285 let output = String::from_utf8(inner)?;
286 Ok(output)
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::fixture::assert_error;
293 use crate::input::read_toml;
294 use crate::model::ModelParameters;
295 use crate::patch::{FilePatch, ModelPatch};
296
297 #[test]
298 fn modify_base_with_patch_works() {
299 let base = "col1,col2\nvalue1,value2\nvalue3,value4\nvalue5,value6\n";
300
301 let patch = FilePatch::new("test.csv")
303 .with_header("col1,col2")
304 .with_deletion("value3,value4")
305 .with_addition("value7,value8");
306
307 let modified = modify_base_with_patch(base, &patch).unwrap();
308
309 let lines: Vec<&str> = modified.lines().collect();
310 assert_eq!(lines[0], "col1,col2"); assert_eq!(lines[1], "value1,value2"); assert_eq!(lines[2], "value5,value6"); assert_eq!(lines[3], "value7,value8"); assert!(!modified.contains("value3,value4")); }
316
317 #[test]
318 fn modify_base_with_patch_mismatched_header() {
319 let base = "col1,col2\nvalue1,value2\n";
320
321 let patch = FilePatch::new("test.csv").with_header("col1,col3");
323
324 assert_error!(
325 modify_base_with_patch(base, &patch),
326 "Header mismatch: base file has [col1, col2], patch has [col1, col3]"
327 );
328 }
329
330 #[test]
331 fn merge_model_toml_basic() {
332 let base = r#"
333 field = "data"
334 [section]
335 a = 1
336 "#;
337
338 let mut patch = toml::value::Table::new();
340 patch.insert(
341 "field".to_string(),
342 toml::Value::String("patched".to_string()),
343 );
344 patch.insert(
345 "new_field".to_string(),
346 toml::Value::String("added".to_string()),
347 );
348
349 let merged = merge_model_toml(base, &patch).unwrap();
352 assert!(merged.contains("field = \"patched\""));
353 assert!(merged.contains("[section]"));
354 assert!(merged.contains("new_field = \"added\""));
355 }
356
357 #[test]
358 fn file_patch() {
359 let assets_patch = FilePatch::new("assets.csv")
361 .with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020")
362 .with_addition("GASDRV,GBR,A0_GEX,4003.26,2020");
363
364 let model_dir = ModelPatch::from_example("simple")
366 .with_file_patch(assets_patch)
367 .build_to_tempdir()
368 .unwrap();
369
370 let assets_path = model_dir.path().join("assets.csv");
372 let assets_content = std::fs::read_to_string(assets_path).unwrap();
373 assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020"));
374 assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020"));
375 }
376
377 #[test]
378 fn toml_patch() {
379 let toml_patch = "milestone_years = [2020, 2030, 2040, 2050]\n";
381
382 let model_dir = ModelPatch::from_example("simple")
384 .with_toml_patch(toml_patch)
385 .build_to_tempdir()
386 .unwrap();
387
388 let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap();
390 assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]);
391 }
392}