muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::graph::{
4    CommoditiesGraph, build_commodity_graphs_for_model, validate_commodity_graphs_for_model,
5};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32
33/// Read a series of type `T`s from a CSV file.
34///
35/// Will raise an error if the file is empty.
36///
37/// # Arguments
38///
39/// * `file_path` - Path to the CSV file
40pub fn read_csv<'a, T: DeserializeOwned + 'a>(
41    file_path: &'a Path,
42) -> Result<impl Iterator<Item = T> + 'a> {
43    let vec = read_csv_internal(file_path)?;
44    if vec.is_empty() {
45        bail!("CSV file {} cannot be empty", file_path.display());
46    }
47    Ok(vec.into_iter())
48}
49
50/// Read a series of type `T`s from a CSV file.
51///
52/// # Arguments
53///
54/// * `file_path` - Path to the CSV file
55pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
56    file_path: &'a Path,
57) -> Result<impl Iterator<Item = T> + 'a> {
58    let vec = read_csv_internal(file_path)?;
59    Ok(vec.into_iter())
60}
61
62fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
63    let vec = csv::Reader::from_path(file_path)
64        .with_context(|| input_err_msg(file_path))?
65        .into_deserialize()
66        .process_results(|iter| iter.collect_vec())
67        .with_context(|| input_err_msg(file_path))?;
68
69    Ok(vec)
70}
71
72/// Parse a TOML file at the specified path.
73///
74/// # Arguments
75///
76/// * `file_path` - Path to the TOML file
77///
78/// # Returns
79///
80/// * The deserialised TOML data or an error if the file could not be read or parsed.
81pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
82    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
83    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
84    Ok(toml_data)
85}
86
87/// Read a Dimensionless float, checking that it is between 0 and 1
88pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
89where
90    T: UnitType,
91    D: Deserializer<'de>,
92{
93    let value = f64::deserialize(deserialiser)?;
94    if !(value > 0.0 && value <= 1.0) {
95        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
96    }
97
98    Ok(T::new(value))
99}
100
101/// Format an error message to include the file path. To be used with `anyhow::Context`.
102pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
103    format!("Error reading {}", file_path.as_ref().display())
104}
105
106/// Read a CSV file of items with IDs.
107///
108/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
109/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
110fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
111where
112    T: HasID<ID> + DeserializeOwned,
113{
114    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
115    where
116        T: HasID<ID> + DeserializeOwned,
117    {
118        let mut map = IndexMap::new();
119        for record in read_csv::<T>(file_path)? {
120            let id = record.get_id().clone();
121            let existing = map.insert(id.clone(), record).is_some();
122            ensure!(!existing, "Duplicate ID found: {id}");
123        }
124        ensure!(!map.is_empty(), "CSV file is empty");
125
126        Ok(map)
127    }
128
129    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
130}
131
132/// Check that fractions sum to (approximately) one
133fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
134where
135    T: UnitType,
136    I: Iterator<Item = T>,
137{
138    let sum = fractions.sum();
139    ensure!(
140        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
141        "Sum of fractions does not equal one (actual: {sum})"
142    );
143
144    Ok(())
145}
146
147/// Check whether an iterator contains values that are sorted and unique
148pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
149where
150    T: PartialOrd + Clone,
151    I: IntoIterator<Item = T>,
152{
153    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
154}
155
156/// Inserts a key-value pair into a `HashMap` if the key does not already exist.
157///
158/// If the key already exists, it returns an error with a message indicating the key's existence.
159pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: &K, value: V) -> Result<()>
160where
161    K: Eq + Hash + Clone + std::fmt::Debug,
162{
163    let existing = map.insert(key.clone(), value).is_some();
164    ensure!(!existing, "Key {key:?} already exists in the map");
165    Ok(())
166}
167
168/// Format a list of items with a cap on display count for error messages
169pub fn format_items_with_cap<T: std::fmt::Debug>(items: &[T]) -> String {
170    const MAX_DISPLAY: usize = 10;
171    if items.len() <= MAX_DISPLAY {
172        format!("{items:?}")
173    } else {
174        format!(
175            "{:?} and {} more",
176            &items[..MAX_DISPLAY],
177            items.len() - MAX_DISPLAY
178        )
179    }
180}
181
182/// Read a model from the specified directory.
183///
184/// # Arguments
185///
186/// * `model_dir` - Folder containing model configuration files
187///
188/// # Returns
189///
190/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
191pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
192    let model_params = ModelParameters::from_path(&model_dir)?;
193
194    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
195    let regions = read_regions(model_dir.as_ref())?;
196    let region_ids = regions.keys().cloned().collect();
197    let years = &model_params.milestone_years;
198
199    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
200    let processes = read_processes(
201        model_dir.as_ref(),
202        &commodities,
203        &region_ids,
204        &time_slice_info,
205        years,
206    )?;
207    let agents = read_agents(
208        model_dir.as_ref(),
209        &commodities,
210        &processes,
211        &region_ids,
212        years,
213    )?;
214    let agent_ids = agents.keys().cloned().collect();
215    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
216
217    // Build and validate commodity graphs for all regions and years
218    // This gives us the commodity order for each region/year which is passed to the model
219    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
220    let commodity_order = validate_commodity_graphs_for_model(
221        &commodity_graphs,
222        &processes,
223        &commodities,
224        &time_slice_info,
225    )?;
226
227    let model_path = model_dir
228        .as_ref()
229        .canonicalize()
230        .context("Could not parse path to model")?;
231    let model = Model {
232        model_path,
233        parameters: model_params,
234        agents,
235        commodities,
236        processes,
237        time_slice_info,
238        regions,
239        commodity_order,
240    };
241    Ok((model, AssetPool::new(assets)))
242}
243
244/// Load commodity flow graphs for a model.
245///
246/// Loads necessary input data and creates a graph of commodity flows for each region and year,
247/// where nodes are commodities and edges are processes.
248///
249/// Graphs validation is NOT performed. This ensures that graphs can be generated even when
250/// validation would fail, which may be helpful for debugging.
251pub fn load_commodity_graphs<P: AsRef<Path>>(
252    model_dir: P,
253) -> Result<HashMap<(RegionID, u32), CommoditiesGraph>> {
254    let model_params = ModelParameters::from_path(&model_dir)?;
255
256    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
257    let regions = read_regions(model_dir.as_ref())?;
258    let region_ids = regions.keys().cloned().collect();
259    let years = &model_params.milestone_years;
260
261    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
262    let processes = read_processes(
263        model_dir.as_ref(),
264        &commodities,
265        &region_ids,
266        &time_slice_info,
267        years,
268    )?;
269
270    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years)?;
271    Ok(commodity_graphs)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::id::GenericID;
278    use crate::units::Dimensionless;
279    use rstest::rstest;
280    use serde::Deserialize;
281    use serde::de::IntoDeserializer;
282    use serde::de::value::{Error as ValueError, F64Deserializer};
283    use std::fs::File;
284    use std::io::Write;
285    use std::path::PathBuf;
286    use tempfile::tempdir;
287
288    #[derive(Debug, PartialEq, Deserialize)]
289    struct Record {
290        id: GenericID,
291        value: u32,
292    }
293
294    impl HasID<GenericID> for Record {
295        fn get_id(&self) -> &GenericID {
296            &self.id
297        }
298    }
299
300    /// Create an example CSV file in dir_path
301    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
302        let file_path = dir_path.join("test.csv");
303        let mut file = File::create(&file_path).unwrap();
304        writeln!(file, "{contents}").unwrap();
305        file_path
306    }
307
308    /// Test a normal read
309    #[test]
310    fn test_read_csv() {
311        let dir = tempdir().unwrap();
312        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
313        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
314        assert_eq!(
315            records,
316            &[
317                Record {
318                    id: "hello".into(),
319                    value: 1,
320                },
321                Record {
322                    id: "world".into(),
323                    value: 2,
324                }
325            ]
326        );
327
328        // File with no data (only column headers)
329        let file_path = create_csv_file(dir.path(), "id,value\n");
330        assert!(read_csv::<Record>(&file_path).is_err());
331        assert!(
332            read_csv_optional::<Record>(&file_path)
333                .unwrap()
334                .next()
335                .is_none()
336        );
337    }
338
339    #[test]
340    fn test_read_toml() {
341        let dir = tempdir().unwrap();
342        let file_path = dir.path().join("test.toml");
343        {
344            let mut file = File::create(&file_path).unwrap();
345            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
346        }
347
348        assert_eq!(
349            read_toml::<Record>(&file_path).unwrap(),
350            Record {
351                id: "hello".into(),
352                value: 1,
353            }
354        );
355
356        {
357            let mut file = File::create(&file_path).unwrap();
358            writeln!(file, "bad toml syntax").unwrap();
359        }
360
361        assert!(read_toml::<Record>(&file_path).is_err());
362    }
363
364    /// Deserialise value with deserialise_proportion_nonzero()
365    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
366        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
367        deserialise_proportion_nonzero(deserialiser)
368    }
369
370    #[test]
371    fn test_deserialise_proportion_nonzero() {
372        // Valid inputs
373        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
374        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
375        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
376
377        // Invalid inputs
378        assert!(deserialise_f64(0.0).is_err());
379        assert!(deserialise_f64(-1.0).is_err());
380        assert!(deserialise_f64(2.0).is_err());
381        assert!(deserialise_f64(f64::NAN).is_err());
382        assert!(deserialise_f64(f64::INFINITY).is_err());
383    }
384
385    #[test]
386    fn test_check_values_sum_to_one_approx() {
387        // Single input, valid
388        assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
389
390        // Multiple inputs, valid
391        assert!(
392            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
393                .is_ok()
394        );
395
396        // Single input, invalid
397        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
398
399        // Multiple inputs, invalid
400        assert!(
401            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
402                .is_err()
403        );
404
405        // Edge cases
406        assert!(
407            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
408        );
409        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
410    }
411
412    #[rstest]
413    #[case(&[], true)]
414    #[case(&[1], true)]
415    #[case(&[1,2], true)]
416    #[case(&[1,2,3,4], true)]
417    #[case(&[2,1],false)]
418    #[case(&[1,1],false)]
419    #[case(&[1,3,2,4], false)]
420    fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
421        assert_eq!(is_sorted_and_unique(values), expected)
422    }
423
424    #[test]
425    fn test_format_items_with_cap() {
426        let items = vec!["a", "b", "c"];
427        assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
428
429        // Test with more than 10 items to trigger the cap
430        let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
431        assert_eq!(
432            format_items_with_cap(&many_items),
433            r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
434        );
435    }
436}