muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::graph::investment::solve_investment_order_for_model;
4use crate::graph::validate::validate_commodity_graphs_for_model;
5use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fmt::{self, Write};
17use std::fs;
18use std::hash::Hash;
19use std::path::Path;
20
21mod agent;
22use agent::read_agents;
23mod asset;
24use asset::read_assets;
25mod commodity;
26use commodity::read_commodities;
27mod process;
28use process::read_processes;
29mod region;
30use region::read_regions;
31mod time_slice;
32use time_slice::read_time_slice_info;
33
34/// A trait which provides a method to insert a key and value into a map
35pub trait Insert<K, V> {
36    /// Insert a key and value into the map
37    fn insert(&mut self, key: K, value: V) -> Option<V>;
38}
39
40impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
41    fn insert(&mut self, key: K, value: V) -> Option<V> {
42        HashMap::insert(self, key, value)
43    }
44}
45
46impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
47    fn insert(&mut self, key: K, value: V) -> Option<V> {
48        IndexMap::insert(self, key, value)
49    }
50}
51
52/// Read a series of type `T`s from a CSV file.
53///
54/// Will raise an error if the file is empty.
55///
56/// # Arguments
57///
58/// * `file_path` - Path to the CSV file
59pub fn read_csv<'a, T: DeserializeOwned + 'a>(
60    file_path: &'a Path,
61) -> Result<impl Iterator<Item = T> + 'a> {
62    let vec = read_csv_internal(file_path)?;
63    if vec.is_empty() {
64        bail!("CSV file {} cannot be empty", file_path.display());
65    }
66    Ok(vec.into_iter())
67}
68
69/// Read a series of type `T`s from a CSV file.
70///
71/// # Arguments
72///
73/// * `file_path` - Path to the CSV file
74pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
75    file_path: &'a Path,
76) -> Result<impl Iterator<Item = T> + 'a> {
77    if !file_path.exists() {
78        return Ok(Vec::new().into_iter());
79    }
80
81    let vec = read_csv_internal(file_path)?;
82    Ok(vec.into_iter())
83}
84
85fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
86    let vec = csv::ReaderBuilder::new()
87        .trim(csv::Trim::All)
88        .from_path(file_path)
89        .with_context(|| input_err_msg(file_path))?
90        .into_deserialize()
91        .process_results(|iter| iter.collect_vec())
92        .with_context(|| input_err_msg(file_path))?;
93
94    Ok(vec)
95}
96
97/// Parse a TOML file at the specified path.
98///
99/// # Arguments
100///
101/// * `file_path` - Path to the TOML file
102///
103/// # Returns
104///
105/// * The deserialised TOML data or an error if the file could not be read or parsed.
106pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
107    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
108    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
109    Ok(toml_data)
110}
111
112/// Read a Dimensionless float, checking that it is between 0 and 1
113pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
114where
115    T: UnitType,
116    D: Deserializer<'de>,
117{
118    let value = f64::deserialize(deserialiser)?;
119    if !(value > 0.0 && value <= 1.0) {
120        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
121    }
122
123    Ok(T::new(value))
124}
125
126/// Format an error message to include the file path. To be used with `anyhow::Context`.
127pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
128    format!("Error reading {}", file_path.as_ref().display())
129}
130
131/// Read a CSV file of items with IDs.
132///
133/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
134/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
135fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
136where
137    T: HasID<ID> + DeserializeOwned,
138{
139    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
140    where
141        T: HasID<ID> + DeserializeOwned,
142    {
143        let mut map = IndexMap::new();
144        for record in read_csv::<T>(file_path)? {
145            let id = record.get_id().clone();
146            let existing = map.insert(id.clone(), record).is_some();
147            ensure!(!existing, "Duplicate ID found: {id}");
148        }
149        ensure!(!map.is_empty(), "CSV file is empty");
150
151        Ok(map)
152    }
153
154    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
155}
156
157/// Check that fractions sum to (approximately) one
158fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
159where
160    T: UnitType,
161    I: Iterator<Item = T>,
162{
163    let sum = fractions.sum();
164    ensure!(
165        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
166        "Sum of fractions does not equal one (actual: {sum})"
167    );
168
169    Ok(())
170}
171
172/// Check whether an iterator contains values that are sorted and unique
173pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
174where
175    T: PartialOrd + Clone,
176    I: IntoIterator<Item = T>,
177{
178    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
179}
180
181/// Inserts a key-value pair into a `HashMap` if the key does not already exist.
182///
183/// If the key already exists, it returns an error with a message indicating the key's existence.
184pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
185where
186    M: Insert<K, V>,
187    K: Eq + Hash + Clone + std::fmt::Debug,
188{
189    let existing = map.insert(key.clone(), value).is_some();
190    ensure!(!existing, "Key {key:?} already exists in the map");
191    Ok(())
192}
193
194/// Format a list of items with a cap on display count for error messages
195pub fn format_items_with_cap<I, J, T>(items: I) -> String
196where
197    I: IntoIterator<Item = T, IntoIter = J>,
198    J: ExactSizeIterator<Item = T>,
199    T: fmt::Debug,
200{
201    const MAX_DISPLAY: usize = 10;
202
203    let items = items.into_iter();
204    let total_count = items.len();
205
206    // Format items with fmt::Debug::fmt() and separate with commas
207    let formatted_items = items
208        .take(MAX_DISPLAY)
209        .format_with(", ", |items, f| f(&format_args!("{items:?}")));
210    let mut out = format!("[{formatted_items}]");
211
212    // If there are remaining items, include the count
213    if total_count > MAX_DISPLAY {
214        write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
215    }
216
217    out
218}
219
220/// Read a model from the specified directory.
221///
222/// # Arguments
223///
224/// * `model_dir` - Folder containing model configuration files
225///
226/// # Returns
227///
228/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
229pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
230    let model_params = ModelParameters::from_path(&model_dir)?;
231
232    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
233    let regions = read_regions(model_dir.as_ref())?;
234    let region_ids = regions.keys().cloned().collect();
235    let years = &model_params.milestone_years;
236
237    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
238    let processes = read_processes(
239        model_dir.as_ref(),
240        &commodities,
241        &region_ids,
242        &time_slice_info,
243        years,
244    )?;
245    let agents = read_agents(
246        model_dir.as_ref(),
247        &commodities,
248        &processes,
249        &region_ids,
250        years,
251    )?;
252    let agent_ids = agents.keys().cloned().collect();
253    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
254
255    // Build and validate commodity graphs for all regions and years
256    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
257    validate_commodity_graphs_for_model(
258        &commodity_graphs,
259        &processes,
260        &commodities,
261        &time_slice_info,
262    )?;
263
264    // Solve investment order for each region/year
265    let investment_order = solve_investment_order_for_model(&commodity_graphs, &commodities, years);
266
267    let model_path = model_dir
268        .as_ref()
269        .canonicalize()
270        .context("Could not parse path to model")?;
271    let model = Model {
272        model_path,
273        parameters: model_params,
274        agents,
275        commodities,
276        processes,
277        time_slice_info,
278        regions,
279        investment_order,
280    };
281    Ok((model, AssetPool::new(assets)))
282}
283
284/// Load commodity flow graphs for a model.
285///
286/// Loads necessary input data and creates a graph of commodity flows for each region and year,
287/// where nodes are commodities and edges are processes.
288///
289/// Graphs validation is NOT performed. This ensures that graphs can be generated even when
290/// validation would fail, which may be helpful for debugging.
291pub fn load_commodity_graphs<P: AsRef<Path>>(
292    model_dir: P,
293) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
294    let model_params = ModelParameters::from_path(&model_dir)?;
295
296    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
297    let regions = read_regions(model_dir.as_ref())?;
298    let region_ids = regions.keys().cloned().collect();
299    let years = &model_params.milestone_years;
300
301    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
302    let processes = read_processes(
303        model_dir.as_ref(),
304        &commodities,
305        &region_ids,
306        &time_slice_info,
307        years,
308    )?;
309
310    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
311    Ok(commodity_graphs)
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::id::GenericID;
318    use crate::units::Dimensionless;
319    use rstest::rstest;
320    use serde::Deserialize;
321    use serde::de::IntoDeserializer;
322    use serde::de::value::{Error as ValueError, F64Deserializer};
323    use std::fs::File;
324    use std::io::Write;
325    use std::path::PathBuf;
326    use tempfile::tempdir;
327
328    #[derive(Debug, PartialEq, Deserialize)]
329    struct Record {
330        id: GenericID,
331        value: u32,
332    }
333
334    impl HasID<GenericID> for Record {
335        fn get_id(&self) -> &GenericID {
336            &self.id
337        }
338    }
339
340    /// Create an example CSV file in dir_path
341    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
342        let file_path = dir_path.join("test.csv");
343        let mut file = File::create(&file_path).unwrap();
344        writeln!(file, "{contents}").unwrap();
345        file_path
346    }
347
348    /// Test a normal read
349    #[test]
350    fn test_read_csv() {
351        let dir = tempdir().unwrap();
352        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
353        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
354        assert_eq!(
355            records,
356            &[
357                Record {
358                    id: "hello".into(),
359                    value: 1,
360                },
361                Record {
362                    id: "world".into(),
363                    value: 2,
364                }
365            ]
366        );
367
368        // File with leading/trailing whitespace
369        let dir = tempdir().unwrap();
370        let file_path = create_csv_file(dir.path(), "id  , value\t\n  hello\t ,1\n world ,2\n");
371        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
372        assert_eq!(
373            records,
374            &[
375                Record {
376                    id: "hello".into(),
377                    value: 1,
378                },
379                Record {
380                    id: "world".into(),
381                    value: 2,
382                }
383            ]
384        );
385
386        // File with no data (only column headers)
387        let file_path = create_csv_file(dir.path(), "id,value\n");
388        assert!(read_csv::<Record>(&file_path).is_err());
389        assert!(
390            read_csv_optional::<Record>(&file_path)
391                .unwrap()
392                .next()
393                .is_none()
394        );
395
396        // Missing file
397        let dir = tempdir().unwrap();
398        let file_path = dir.path().join("a_missing_file.csv");
399        assert!(!file_path.exists());
400        assert!(read_csv::<Record>(&file_path).is_err());
401        // optional csv's should return empty iterator
402        assert!(
403            read_csv_optional::<Record>(&file_path)
404                .unwrap()
405                .next()
406                .is_none()
407        );
408    }
409
410    #[test]
411    fn test_read_toml() {
412        let dir = tempdir().unwrap();
413        let file_path = dir.path().join("test.toml");
414        {
415            let mut file = File::create(&file_path).unwrap();
416            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
417        }
418
419        assert_eq!(
420            read_toml::<Record>(&file_path).unwrap(),
421            Record {
422                id: "hello".into(),
423                value: 1,
424            }
425        );
426
427        {
428            let mut file = File::create(&file_path).unwrap();
429            writeln!(file, "bad toml syntax").unwrap();
430        }
431
432        assert!(read_toml::<Record>(&file_path).is_err());
433    }
434
435    /// Deserialise value with deserialise_proportion_nonzero()
436    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
437        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
438        deserialise_proportion_nonzero(deserialiser)
439    }
440
441    #[test]
442    fn test_deserialise_proportion_nonzero() {
443        // Valid inputs
444        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
445        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
446        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
447
448        // Invalid inputs
449        assert!(deserialise_f64(0.0).is_err());
450        assert!(deserialise_f64(-1.0).is_err());
451        assert!(deserialise_f64(2.0).is_err());
452        assert!(deserialise_f64(f64::NAN).is_err());
453        assert!(deserialise_f64(f64::INFINITY).is_err());
454    }
455
456    #[test]
457    fn test_check_values_sum_to_one_approx() {
458        // Single input, valid
459        assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
460
461        // Multiple inputs, valid
462        assert!(
463            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
464                .is_ok()
465        );
466
467        // Single input, invalid
468        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
469
470        // Multiple inputs, invalid
471        assert!(
472            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
473                .is_err()
474        );
475
476        // Edge cases
477        assert!(
478            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
479        );
480        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
481    }
482
483    #[rstest]
484    #[case(&[], true)]
485    #[case(&[1], true)]
486    #[case(&[1,2], true)]
487    #[case(&[1,2,3,4], true)]
488    #[case(&[2,1],false)]
489    #[case(&[1,1],false)]
490    #[case(&[1,3,2,4], false)]
491    fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
492        assert_eq!(is_sorted_and_unique(values), expected)
493    }
494
495    #[test]
496    fn test_format_items_with_cap() {
497        let items = vec!["a", "b", "c"];
498        assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
499
500        // Test with more than 10 items to trigger the cap
501        let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
502        assert_eq!(
503            format_items_with_cap(&many_items),
504            r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
505        );
506    }
507}