muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::graph::{
4    CommoditiesGraph, build_commodity_graphs_for_model, validate_commodity_graphs_for_model,
5};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fmt::{self, Write};
17use std::fs;
18use std::hash::Hash;
19use std::path::Path;
20
21mod agent;
22use agent::read_agents;
23mod asset;
24use asset::read_assets;
25mod commodity;
26use commodity::read_commodities;
27mod process;
28use process::read_processes;
29mod region;
30use region::read_regions;
31mod time_slice;
32use time_slice::read_time_slice_info;
33
34/// Read a series of type `T`s from a CSV file.
35///
36/// Will raise an error if the file is empty.
37///
38/// # Arguments
39///
40/// * `file_path` - Path to the CSV file
41pub fn read_csv<'a, T: DeserializeOwned + 'a>(
42    file_path: &'a Path,
43) -> Result<impl Iterator<Item = T> + 'a> {
44    let vec = read_csv_internal(file_path)?;
45    if vec.is_empty() {
46        bail!("CSV file {} cannot be empty", file_path.display());
47    }
48    Ok(vec.into_iter())
49}
50
51/// Read a series of type `T`s from a CSV file.
52///
53/// # Arguments
54///
55/// * `file_path` - Path to the CSV file
56pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
57    file_path: &'a Path,
58) -> Result<impl Iterator<Item = T> + 'a> {
59    let vec = read_csv_internal(file_path)?;
60    Ok(vec.into_iter())
61}
62
63fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
64    let vec = csv::Reader::from_path(file_path)
65        .with_context(|| input_err_msg(file_path))?
66        .into_deserialize()
67        .process_results(|iter| iter.collect_vec())
68        .with_context(|| input_err_msg(file_path))?;
69
70    Ok(vec)
71}
72
73/// Parse a TOML file at the specified path.
74///
75/// # Arguments
76///
77/// * `file_path` - Path to the TOML file
78///
79/// # Returns
80///
81/// * The deserialised TOML data or an error if the file could not be read or parsed.
82pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
83    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
84    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
85    Ok(toml_data)
86}
87
88/// Read a Dimensionless float, checking that it is between 0 and 1
89pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
90where
91    T: UnitType,
92    D: Deserializer<'de>,
93{
94    let value = f64::deserialize(deserialiser)?;
95    if !(value > 0.0 && value <= 1.0) {
96        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
97    }
98
99    Ok(T::new(value))
100}
101
102/// Format an error message to include the file path. To be used with `anyhow::Context`.
103pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
104    format!("Error reading {}", file_path.as_ref().display())
105}
106
107/// Read a CSV file of items with IDs.
108///
109/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
110/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
111fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
112where
113    T: HasID<ID> + DeserializeOwned,
114{
115    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
116    where
117        T: HasID<ID> + DeserializeOwned,
118    {
119        let mut map = IndexMap::new();
120        for record in read_csv::<T>(file_path)? {
121            let id = record.get_id().clone();
122            let existing = map.insert(id.clone(), record).is_some();
123            ensure!(!existing, "Duplicate ID found: {id}");
124        }
125        ensure!(!map.is_empty(), "CSV file is empty");
126
127        Ok(map)
128    }
129
130    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
131}
132
133/// Check that fractions sum to (approximately) one
134fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
135where
136    T: UnitType,
137    I: Iterator<Item = T>,
138{
139    let sum = fractions.sum();
140    ensure!(
141        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
142        "Sum of fractions does not equal one (actual: {sum})"
143    );
144
145    Ok(())
146}
147
148/// Check whether an iterator contains values that are sorted and unique
149pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
150where
151    T: PartialOrd + Clone,
152    I: IntoIterator<Item = T>,
153{
154    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
155}
156
157/// Inserts a key-value pair into a `HashMap` if the key does not already exist.
158///
159/// If the key already exists, it returns an error with a message indicating the key's existence.
160pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: &K, value: V) -> Result<()>
161where
162    K: Eq + Hash + Clone + fmt::Debug,
163{
164    let existing = map.insert(key.clone(), value).is_some();
165    ensure!(!existing, "Key {key:?} already exists in the map");
166    Ok(())
167}
168
169/// Format a list of items with a cap on display count for error messages
170pub fn format_items_with_cap<I, J, T>(items: I) -> String
171where
172    I: IntoIterator<Item = T, IntoIter = J>,
173    J: ExactSizeIterator<Item = T>,
174    T: fmt::Debug,
175{
176    const MAX_DISPLAY: usize = 10;
177
178    let items = items.into_iter();
179    let total_count = items.len();
180
181    // Format items with fmt::Debug::fmt() and separate with commas
182    let formatted_items = items
183        .take(MAX_DISPLAY)
184        .format_with(", ", |items, f| f(&format_args!("{items:?}")));
185    let mut out = format!("[{formatted_items}]");
186
187    // If there are remaining items, include the count
188    if total_count > MAX_DISPLAY {
189        write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
190    }
191
192    out
193}
194
195/// Read a model from the specified directory.
196///
197/// # Arguments
198///
199/// * `model_dir` - Folder containing model configuration files
200///
201/// # Returns
202///
203/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
204pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
205    let model_params = ModelParameters::from_path(&model_dir)?;
206
207    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
208    let regions = read_regions(model_dir.as_ref())?;
209    let region_ids = regions.keys().cloned().collect();
210    let years = &model_params.milestone_years;
211
212    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
213    let processes = read_processes(
214        model_dir.as_ref(),
215        &commodities,
216        &region_ids,
217        &time_slice_info,
218        years,
219    )?;
220    let agents = read_agents(
221        model_dir.as_ref(),
222        &commodities,
223        &processes,
224        &region_ids,
225        years,
226    )?;
227    let agent_ids = agents.keys().cloned().collect();
228    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
229
230    // Build and validate commodity graphs for all regions and years
231    // This gives us the commodity order for each region/year which is passed to the model
232    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
233    let commodity_order = validate_commodity_graphs_for_model(
234        &commodity_graphs,
235        &processes,
236        &commodities,
237        &time_slice_info,
238    )?;
239
240    let model_path = model_dir
241        .as_ref()
242        .canonicalize()
243        .context("Could not parse path to model")?;
244    let model = Model {
245        model_path,
246        parameters: model_params,
247        agents,
248        commodities,
249        processes,
250        time_slice_info,
251        regions,
252        commodity_order,
253    };
254    Ok((model, AssetPool::new(assets)))
255}
256
257/// Load commodity flow graphs for a model.
258///
259/// Loads necessary input data and creates a graph of commodity flows for each region and year,
260/// where nodes are commodities and edges are processes.
261///
262/// Graphs validation is NOT performed. This ensures that graphs can be generated even when
263/// validation would fail, which may be helpful for debugging.
264pub fn load_commodity_graphs<P: AsRef<Path>>(
265    model_dir: P,
266) -> Result<HashMap<(RegionID, u32), CommoditiesGraph>> {
267    let model_params = ModelParameters::from_path(&model_dir)?;
268
269    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
270    let regions = read_regions(model_dir.as_ref())?;
271    let region_ids = regions.keys().cloned().collect();
272    let years = &model_params.milestone_years;
273
274    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
275    let processes = read_processes(
276        model_dir.as_ref(),
277        &commodities,
278        &region_ids,
279        &time_slice_info,
280        years,
281    )?;
282
283    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
284    Ok(commodity_graphs)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::id::GenericID;
291    use crate::units::Dimensionless;
292    use rstest::rstest;
293    use serde::Deserialize;
294    use serde::de::IntoDeserializer;
295    use serde::de::value::{Error as ValueError, F64Deserializer};
296    use std::fs::File;
297    use std::io::Write;
298    use std::path::PathBuf;
299    use tempfile::tempdir;
300
301    #[derive(Debug, PartialEq, Deserialize)]
302    struct Record {
303        id: GenericID,
304        value: u32,
305    }
306
307    impl HasID<GenericID> for Record {
308        fn get_id(&self) -> &GenericID {
309            &self.id
310        }
311    }
312
313    /// Create an example CSV file in dir_path
314    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
315        let file_path = dir_path.join("test.csv");
316        let mut file = File::create(&file_path).unwrap();
317        writeln!(file, "{contents}").unwrap();
318        file_path
319    }
320
321    /// Test a normal read
322    #[test]
323    fn test_read_csv() {
324        let dir = tempdir().unwrap();
325        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
326        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
327        assert_eq!(
328            records,
329            &[
330                Record {
331                    id: "hello".into(),
332                    value: 1,
333                },
334                Record {
335                    id: "world".into(),
336                    value: 2,
337                }
338            ]
339        );
340
341        // File with no data (only column headers)
342        let file_path = create_csv_file(dir.path(), "id,value\n");
343        assert!(read_csv::<Record>(&file_path).is_err());
344        assert!(
345            read_csv_optional::<Record>(&file_path)
346                .unwrap()
347                .next()
348                .is_none()
349        );
350    }
351
352    #[test]
353    fn test_read_toml() {
354        let dir = tempdir().unwrap();
355        let file_path = dir.path().join("test.toml");
356        {
357            let mut file = File::create(&file_path).unwrap();
358            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
359        }
360
361        assert_eq!(
362            read_toml::<Record>(&file_path).unwrap(),
363            Record {
364                id: "hello".into(),
365                value: 1,
366            }
367        );
368
369        {
370            let mut file = File::create(&file_path).unwrap();
371            writeln!(file, "bad toml syntax").unwrap();
372        }
373
374        assert!(read_toml::<Record>(&file_path).is_err());
375    }
376
377    /// Deserialise value with deserialise_proportion_nonzero()
378    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
379        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
380        deserialise_proportion_nonzero(deserialiser)
381    }
382
383    #[test]
384    fn test_deserialise_proportion_nonzero() {
385        // Valid inputs
386        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
387        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
388        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
389
390        // Invalid inputs
391        assert!(deserialise_f64(0.0).is_err());
392        assert!(deserialise_f64(-1.0).is_err());
393        assert!(deserialise_f64(2.0).is_err());
394        assert!(deserialise_f64(f64::NAN).is_err());
395        assert!(deserialise_f64(f64::INFINITY).is_err());
396    }
397
398    #[test]
399    fn test_check_values_sum_to_one_approx() {
400        // Single input, valid
401        assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
402
403        // Multiple inputs, valid
404        assert!(
405            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
406                .is_ok()
407        );
408
409        // Single input, invalid
410        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
411
412        // Multiple inputs, invalid
413        assert!(
414            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
415                .is_err()
416        );
417
418        // Edge cases
419        assert!(
420            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
421        );
422        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
423    }
424
425    #[rstest]
426    #[case(&[], true)]
427    #[case(&[1], true)]
428    #[case(&[1,2], true)]
429    #[case(&[1,2,3,4], true)]
430    #[case(&[2,1],false)]
431    #[case(&[1,1],false)]
432    #[case(&[1,3,2,4], false)]
433    fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
434        assert_eq!(is_sorted_and_unique(values), expected)
435    }
436
437    #[test]
438    fn test_format_items_with_cap() {
439        let items = vec!["a", "b", "c"];
440        assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
441
442        // Test with more than 10 items to trigger the cap
443        let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
444        assert_eq!(
445            format_items_with_cap(&many_items),
446            r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
447        );
448    }
449}