Skip to main content

muse2/
input.rs

1//! Common routines for handling input data.
2use crate::graph::investment::solve_investment_order_for_model;
3use crate::graph::validate::validate_commodity_graphs_for_model;
4use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
5use crate::id::{HasID, ID};
6use crate::model::{Model, ModelParameters};
7use crate::region::RegionID;
8use crate::units::UnitType;
9use anyhow::{Context, Result, bail, ensure};
10use float_cmp::approx_eq;
11use indexmap::IndexMap;
12use itertools::Itertools;
13use serde::de::{Deserialize, DeserializeOwned, Deserializer};
14use std::collections::HashMap;
15use std::fmt::{self, Write};
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_user_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32mod range;
33use range::{parse_range, parse_range_parts, partition};
34mod year;
35use year::parse_year_str;
36
37/// A trait which provides a method to insert a key and value into a map
38pub trait Insert<K, V> {
39    /// Insert a key and value into the map
40    fn insert(&mut self, key: K, value: V) -> Option<V>;
41}
42
43impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
44    fn insert(&mut self, key: K, value: V) -> Option<V> {
45        HashMap::insert(self, key, value)
46    }
47}
48
49impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
50    fn insert(&mut self, key: K, value: V) -> Option<V> {
51        IndexMap::insert(self, key, value)
52    }
53}
54
55/// Read a series of type `T`s from a CSV file.
56///
57/// Returns an error if the file is empty.
58///
59/// # Arguments
60///
61/// * `file_path` - Path to the CSV file
62pub fn read_csv<'a, T: DeserializeOwned + 'a>(
63    file_path: &'a Path,
64) -> Result<impl Iterator<Item = T> + 'a> {
65    let vec = read_csv_internal(file_path)?;
66    if vec.is_empty() {
67        bail!("CSV file {} cannot be empty", file_path.display());
68    }
69    Ok(vec.into_iter())
70}
71
72/// Read a series of type `T`s from a CSV file.
73///
74/// # Arguments
75///
76/// * `file_path` - Path to the CSV file
77pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
78    file_path: &'a Path,
79) -> Result<impl Iterator<Item = T> + 'a> {
80    if !file_path.exists() {
81        return Ok(Vec::new().into_iter());
82    }
83
84    let vec = read_csv_internal(file_path)?;
85    Ok(vec.into_iter())
86}
87
88fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
89    let vec = csv::ReaderBuilder::new()
90        .trim(csv::Trim::All)
91        .from_path(file_path)
92        .with_context(|| input_err_msg(file_path))?
93        .into_deserialize()
94        .process_results(|iter| iter.collect_vec())
95        .with_context(|| input_err_msg(file_path))?;
96
97    Ok(vec)
98}
99
100/// Parse a TOML file at the specified path.
101///
102/// # Arguments
103///
104/// * `file_path` - Path to the TOML file
105///
106/// # Returns
107///
108/// * The deserialised TOML data or an error if the file could not be read or parsed.
109pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
110    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
111    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
112    Ok(toml_data)
113}
114
115/// Read a Dimensionless float, checking that it is between 0 and 1
116pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
117where
118    T: UnitType,
119    D: Deserializer<'de>,
120{
121    let value = f64::deserialize(deserialiser)?;
122    if !(value > 0.0 && value <= 1.0) {
123        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
124    }
125
126    Ok(T::new(value))
127}
128
129/// Format an error message to include the file path. To be used with `anyhow::Context`.
130pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
131    format!("Error reading {}", file_path.as_ref().display())
132}
133
134/// Read a CSV file of items with IDs.
135///
136/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
137/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
138fn read_csv_id_file<T, I: ID>(file_path: &Path) -> Result<IndexMap<I, T>>
139where
140    T: HasID<I> + DeserializeOwned,
141{
142    fn fill_and_validate_map<T, I: ID>(file_path: &Path) -> Result<IndexMap<I, T>>
143    where
144        T: HasID<I> + DeserializeOwned,
145    {
146        let mut map = IndexMap::new();
147        for record in read_csv::<T>(file_path)? {
148            let id = record.get_id().clone();
149            let existing = map.insert(id.clone(), record).is_some();
150            ensure!(!existing, "Duplicate ID found: {id}");
151        }
152        ensure!(!map.is_empty(), "CSV file is empty");
153
154        Ok(map)
155    }
156
157    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
158}
159
160/// Check that fractions sum to (approximately) one
161fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
162where
163    T: UnitType,
164    I: Iterator<Item = T>,
165{
166    let sum = fractions.sum();
167    ensure!(
168        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
169        "Sum of fractions does not equal one (actual: {sum})"
170    );
171
172    Ok(())
173}
174
175/// Check whether an iterator contains values that are sorted and unique
176pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
177where
178    T: PartialOrd + Clone,
179    I: IntoIterator<Item = T>,
180{
181    is_sorted_and_unique_with(iter, |a, b| a < b)
182}
183
184/// Check whether an iterator contains values that are sorted and unique, comparing with a custom
185/// function
186pub fn is_sorted_and_unique_with<T, I, F>(iter: I, mut less_than: F) -> bool
187where
188    T: Clone,
189    I: IntoIterator<Item = T>,
190    F: FnMut(&T, &T) -> bool,
191{
192    iter.into_iter()
193        .tuple_windows()
194        .all(|(a, b)| less_than(&a, &b))
195}
196
197/// Insert a key-value pair into a map implementing the `Insert` trait if the key does not
198/// already exist.
199///
200/// If the key already exists, this returns an error with a message indicating the key's
201/// existence.
202pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
203where
204    M: Insert<K, V>,
205    K: Eq + Hash + Clone + std::fmt::Debug,
206{
207    let existing = map.insert(key.clone(), value).is_some();
208    ensure!(!existing, "Key {key:?} already exists in the map");
209    Ok(())
210}
211
212/// Format a list of items with a cap on display count for error messages
213pub fn format_items_with_cap<I, J, T>(items: I) -> String
214where
215    I: IntoIterator<Item = T, IntoIter = J>,
216    J: ExactSizeIterator<Item = T>,
217    T: fmt::Debug,
218{
219    const MAX_DISPLAY: usize = 10;
220
221    let items = items.into_iter();
222    let total_count = items.len();
223
224    // Format items with fmt::Debug::fmt() and separate with commas
225    let formatted_items = items
226        .take(MAX_DISPLAY)
227        .format_with(", ", |items, f| f(&format_args!("{items:?}")));
228    let mut out = format!("[{formatted_items}]");
229
230    // If there are remaining items, include the count
231    if total_count > MAX_DISPLAY {
232        write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
233    }
234
235    out
236}
237
238/// Read a model from the specified directory.
239///
240/// # Arguments
241///
242/// * `model_dir` - Folder containing model configuration files
243///
244/// # Returns
245///
246/// The static model data ([`Model`]) or an error.
247pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<Model> {
248    let model_params = ModelParameters::from_path(&model_dir)?;
249
250    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
251    let regions = read_regions(model_dir.as_ref())?;
252    let region_ids = regions.keys().cloned().collect();
253    let years = &model_params.milestone_years;
254
255    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
256    let processes = read_processes(
257        model_dir.as_ref(),
258        &commodities,
259        &region_ids,
260        &time_slice_info,
261        years,
262    )?;
263    let agents = read_agents(
264        model_dir.as_ref(),
265        &commodities,
266        &processes,
267        &region_ids,
268        years,
269    )?;
270    let agent_ids = agents.keys().cloned().collect();
271    let user_assets = read_user_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
272
273    // Build and validate commodity graphs for all regions and years
274    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
275    validate_commodity_graphs_for_model(
276        &commodity_graphs,
277        &processes,
278        &commodities,
279        &time_slice_info,
280    )?;
281
282    // Solve investment order for each region/year
283    let investment_order = solve_investment_order_for_model(&commodity_graphs, &commodities, years);
284
285    let model_path = model_dir
286        .as_ref()
287        .canonicalize()
288        .context("Could not parse path to model")?;
289    let model = Model {
290        model_path,
291        parameters: model_params,
292        agents,
293        commodities,
294        processes,
295        time_slice_info,
296        regions,
297        user_assets,
298        investment_order,
299    };
300    Ok(model)
301}
302
303/// Load commodity flow graphs for a model.
304///
305/// Loads necessary input data and creates a graph of commodity flows for each region and year,
306/// where nodes are commodities and edges are processes.
307///
308/// Graphs validation is NOT performed. This ensures that graphs can be generated even when
309/// validation would fail, which may be helpful for debugging.
310pub fn load_commodity_graphs<P: AsRef<Path>>(
311    model_dir: P,
312) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
313    let model_params = ModelParameters::from_path(&model_dir)?;
314
315    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
316    let regions = read_regions(model_dir.as_ref())?;
317    let region_ids = regions.keys().cloned().collect();
318    let years = &model_params.milestone_years;
319
320    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
321    let processes = read_processes(
322        model_dir.as_ref(),
323        &commodities,
324        &region_ids,
325        &time_slice_info,
326        years,
327    )?;
328
329    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
330    Ok(commodity_graphs)
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::id::GenericID;
337    use crate::units::Dimensionless;
338    use rstest::rstest;
339    use serde::Deserialize;
340    use serde::de::IntoDeserializer;
341    use serde::de::value::{Error as ValueError, F64Deserializer};
342    use std::fs::File;
343    use std::io::Write;
344    use std::path::PathBuf;
345    use tempfile::tempdir;
346
347    #[derive(Debug, PartialEq, Deserialize)]
348    struct Record {
349        id: GenericID,
350        value: u32,
351    }
352
353    impl HasID<GenericID> for Record {
354        fn get_id(&self) -> &GenericID {
355            &self.id
356        }
357    }
358
359    /// Create an example CSV file in `dir_path`
360    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
361        let file_path = dir_path.join("test.csv");
362        let mut file = File::create(&file_path).unwrap();
363        writeln!(file, "{contents}").unwrap();
364        file_path
365    }
366
367    /// Test a normal read
368    #[test]
369    fn read_csv_works() {
370        let dir = tempdir().unwrap();
371        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
372        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
373        assert_eq!(
374            records,
375            &[
376                Record {
377                    id: "hello".into(),
378                    value: 1,
379                },
380                Record {
381                    id: "world".into(),
382                    value: 2,
383                }
384            ]
385        );
386
387        // File with leading/trailing whitespace
388        let dir = tempdir().unwrap();
389        let file_path = create_csv_file(dir.path(), "id  , value\t\n  hello\t ,1\n world ,2\n");
390        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
391        assert_eq!(
392            records,
393            &[
394                Record {
395                    id: "hello".into(),
396                    value: 1,
397                },
398                Record {
399                    id: "world".into(),
400                    value: 2,
401                }
402            ]
403        );
404
405        // File with no data (only column headers)
406        let file_path = create_csv_file(dir.path(), "id,value\n");
407        assert!(read_csv::<Record>(&file_path).is_err());
408        assert!(
409            read_csv_optional::<Record>(&file_path)
410                .unwrap()
411                .next()
412                .is_none()
413        );
414
415        // Missing file
416        let dir = tempdir().unwrap();
417        let file_path = dir.path().join("a_missing_file.csv");
418        assert!(!file_path.exists());
419        assert!(read_csv::<Record>(&file_path).is_err());
420        // optional csv's should return empty iterator
421        assert!(
422            read_csv_optional::<Record>(&file_path)
423                .unwrap()
424                .next()
425                .is_none()
426        );
427    }
428
429    #[test]
430    fn read_toml_works() {
431        let dir = tempdir().unwrap();
432        let file_path = dir.path().join("test.toml");
433        {
434            let mut file = File::create(&file_path).unwrap();
435            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
436        }
437
438        assert_eq!(
439            read_toml::<Record>(&file_path).unwrap(),
440            Record {
441                id: "hello".into(),
442                value: 1,
443            }
444        );
445
446        {
447            let mut file = File::create(&file_path).unwrap();
448            writeln!(file, "bad toml syntax").unwrap();
449        }
450
451        read_toml::<Record>(&file_path).unwrap_err();
452    }
453
454    /// Deserialise value with `deserialise_proportion_nonzero()`
455    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
456        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
457        deserialise_proportion_nonzero(deserialiser)
458    }
459
460    #[test]
461    fn deserialise_proportion_nonzero_works() {
462        // Valid inputs
463        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
464        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
465        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
466
467        // Invalid inputs
468        deserialise_f64(0.0).unwrap_err();
469        deserialise_f64(-1.0).unwrap_err();
470        deserialise_f64(2.0).unwrap_err();
471        deserialise_f64(f64::NAN).unwrap_err();
472        deserialise_f64(f64::INFINITY).unwrap_err();
473    }
474
475    #[test]
476    fn check_values_sum_to_one_approx_works() {
477        // Single input, valid
478        check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).unwrap();
479
480        // Multiple inputs, valid
481        check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
482            .unwrap();
483
484        // Single input, invalid
485        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
486
487        // Multiple inputs, invalid
488        assert!(
489            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
490                .is_err()
491        );
492
493        // Edge cases
494        assert!(
495            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
496        );
497        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
498    }
499
500    #[rstest]
501    #[case(&[], true)]
502    #[case(&[1], true)]
503    #[case(&[1,2], true)]
504    #[case(&[1,2,3,4], true)]
505    #[case(&[2,1],false)]
506    #[case(&[1,1],false)]
507    #[case(&[1,3,2,4], false)]
508    fn is_sorted_and_unique_works(#[case] values: &[u32], #[case] expected: bool) {
509        assert_eq!(is_sorted_and_unique(values), expected);
510    }
511
512    #[test]
513    fn format_items_with_cap_works() {
514        let items = vec!["a", "b", "c"];
515        assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
516
517        // Test with more than 10 items to trigger the cap
518        let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
519        assert_eq!(
520            format_items_with_cap(&many_items),
521            r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
522        );
523    }
524}