muse2/
input.rs

1//! Common routines for handling input data.
2use crate::graph::investment::solve_investment_order_for_model;
3use crate::graph::validate::validate_commodity_graphs_for_model;
4use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
5use crate::id::{HasID, IDLike};
6use crate::model::{Model, ModelParameters};
7use crate::region::RegionID;
8use crate::units::UnitType;
9use anyhow::{Context, Result, bail, ensure};
10use float_cmp::approx_eq;
11use indexmap::IndexMap;
12use itertools::Itertools;
13use serde::de::{Deserialize, DeserializeOwned, Deserializer};
14use std::collections::HashMap;
15use std::fmt::{self, Write};
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_user_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32
33/// A trait which provides a method to insert a key and value into a map
34pub trait Insert<K, V> {
35    /// Insert a key and value into the map
36    fn insert(&mut self, key: K, value: V) -> Option<V>;
37}
38
39impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
40    fn insert(&mut self, key: K, value: V) -> Option<V> {
41        HashMap::insert(self, key, value)
42    }
43}
44
45impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
46    fn insert(&mut self, key: K, value: V) -> Option<V> {
47        IndexMap::insert(self, key, value)
48    }
49}
50
51/// Read a series of type `T`s from a CSV file.
52///
53/// Returns an error if the file is empty.
54///
55/// # Arguments
56///
57/// * `file_path` - Path to the CSV file
58pub fn read_csv<'a, T: DeserializeOwned + 'a>(
59    file_path: &'a Path,
60) -> Result<impl Iterator<Item = T> + 'a> {
61    let vec = read_csv_internal(file_path)?;
62    if vec.is_empty() {
63        bail!("CSV file {} cannot be empty", file_path.display());
64    }
65    Ok(vec.into_iter())
66}
67
68/// Read a series of type `T`s from a CSV file.
69///
70/// # Arguments
71///
72/// * `file_path` - Path to the CSV file
73pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
74    file_path: &'a Path,
75) -> Result<impl Iterator<Item = T> + 'a> {
76    if !file_path.exists() {
77        return Ok(Vec::new().into_iter());
78    }
79
80    let vec = read_csv_internal(file_path)?;
81    Ok(vec.into_iter())
82}
83
84fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
85    let vec = csv::ReaderBuilder::new()
86        .trim(csv::Trim::All)
87        .from_path(file_path)
88        .with_context(|| input_err_msg(file_path))?
89        .into_deserialize()
90        .process_results(|iter| iter.collect_vec())
91        .with_context(|| input_err_msg(file_path))?;
92
93    Ok(vec)
94}
95
96/// Parse a TOML file at the specified path.
97///
98/// # Arguments
99///
100/// * `file_path` - Path to the TOML file
101///
102/// # Returns
103///
104/// * The deserialised TOML data or an error if the file could not be read or parsed.
105pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
106    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
107    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
108    Ok(toml_data)
109}
110
111/// Read a Dimensionless float, checking that it is between 0 and 1
112pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
113where
114    T: UnitType,
115    D: Deserializer<'de>,
116{
117    let value = f64::deserialize(deserialiser)?;
118    if !(value > 0.0 && value <= 1.0) {
119        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
120    }
121
122    Ok(T::new(value))
123}
124
125/// Format an error message to include the file path. To be used with `anyhow::Context`.
126pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
127    format!("Error reading {}", file_path.as_ref().display())
128}
129
130/// Read a CSV file of items with IDs.
131///
132/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
133/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
134fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
135where
136    T: HasID<ID> + DeserializeOwned,
137{
138    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
139    where
140        T: HasID<ID> + DeserializeOwned,
141    {
142        let mut map = IndexMap::new();
143        for record in read_csv::<T>(file_path)? {
144            let id = record.get_id().clone();
145            let existing = map.insert(id.clone(), record).is_some();
146            ensure!(!existing, "Duplicate ID found: {id}");
147        }
148        ensure!(!map.is_empty(), "CSV file is empty");
149
150        Ok(map)
151    }
152
153    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
154}
155
156/// Check that fractions sum to (approximately) one
157fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
158where
159    T: UnitType,
160    I: Iterator<Item = T>,
161{
162    let sum = fractions.sum();
163    ensure!(
164        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
165        "Sum of fractions does not equal one (actual: {sum})"
166    );
167
168    Ok(())
169}
170
171/// Check whether an iterator contains values that are sorted and unique
172pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
173where
174    T: PartialOrd + Clone,
175    I: IntoIterator<Item = T>,
176{
177    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
178}
179
180/// Insert a key-value pair into a map implementing the `Insert` trait if the key does not
181/// already exist.
182///
183/// If the key already exists, this returns an error with a message indicating the key's
184/// existence.
185pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
186where
187    M: Insert<K, V>,
188    K: Eq + Hash + Clone + std::fmt::Debug,
189{
190    let existing = map.insert(key.clone(), value).is_some();
191    ensure!(!existing, "Key {key:?} already exists in the map");
192    Ok(())
193}
194
195/// Format a list of items with a cap on display count for error messages
196pub fn format_items_with_cap<I, J, T>(items: I) -> String
197where
198    I: IntoIterator<Item = T, IntoIter = J>,
199    J: ExactSizeIterator<Item = T>,
200    T: fmt::Debug,
201{
202    const MAX_DISPLAY: usize = 10;
203
204    let items = items.into_iter();
205    let total_count = items.len();
206
207    // Format items with fmt::Debug::fmt() and separate with commas
208    let formatted_items = items
209        .take(MAX_DISPLAY)
210        .format_with(", ", |items, f| f(&format_args!("{items:?}")));
211    let mut out = format!("[{formatted_items}]");
212
213    // If there are remaining items, include the count
214    if total_count > MAX_DISPLAY {
215        write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
216    }
217
218    out
219}
220
221/// Read a model from the specified directory.
222///
223/// # Arguments
224///
225/// * `model_dir` - Folder containing model configuration files
226///
227/// # Returns
228///
229/// The static model data ([`Model`]) or an error.
230pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<Model> {
231    let model_params = ModelParameters::from_path(&model_dir)?;
232
233    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
234    let regions = read_regions(model_dir.as_ref())?;
235    let region_ids = regions.keys().cloned().collect();
236    let years = &model_params.milestone_years;
237
238    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
239    let processes = read_processes(
240        model_dir.as_ref(),
241        &commodities,
242        &region_ids,
243        &time_slice_info,
244        years,
245    )?;
246    let agents = read_agents(
247        model_dir.as_ref(),
248        &commodities,
249        &processes,
250        &region_ids,
251        years,
252    )?;
253    let agent_ids = agents.keys().cloned().collect();
254    let user_assets = read_user_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
255
256    // Build and validate commodity graphs for all regions and years
257    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
258    validate_commodity_graphs_for_model(
259        &commodity_graphs,
260        &processes,
261        &commodities,
262        &time_slice_info,
263    )?;
264
265    // Solve investment order for each region/year
266    let investment_order =
267        solve_investment_order_for_model(&commodity_graphs, &commodities, years)?;
268
269    let model_path = model_dir
270        .as_ref()
271        .canonicalize()
272        .context("Could not parse path to model")?;
273    let model = Model {
274        model_path,
275        parameters: model_params,
276        agents,
277        commodities,
278        processes,
279        time_slice_info,
280        regions,
281        user_assets,
282        investment_order,
283    };
284    Ok(model)
285}
286
287/// Load commodity flow graphs for a model.
288///
289/// Loads necessary input data and creates a graph of commodity flows for each region and year,
290/// where nodes are commodities and edges are processes.
291///
292/// Graphs validation is NOT performed. This ensures that graphs can be generated even when
293/// validation would fail, which may be helpful for debugging.
294pub fn load_commodity_graphs<P: AsRef<Path>>(
295    model_dir: P,
296) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
297    let model_params = ModelParameters::from_path(&model_dir)?;
298
299    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
300    let regions = read_regions(model_dir.as_ref())?;
301    let region_ids = regions.keys().cloned().collect();
302    let years = &model_params.milestone_years;
303
304    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
305    let processes = read_processes(
306        model_dir.as_ref(),
307        &commodities,
308        &region_ids,
309        &time_slice_info,
310        years,
311    )?;
312
313    let commodity_graphs = build_commodity_graphs_for_model(&processes, &region_ids, years);
314    Ok(commodity_graphs)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::id::GenericID;
321    use crate::units::Dimensionless;
322    use rstest::rstest;
323    use serde::Deserialize;
324    use serde::de::IntoDeserializer;
325    use serde::de::value::{Error as ValueError, F64Deserializer};
326    use std::fs::File;
327    use std::io::Write;
328    use std::path::PathBuf;
329    use tempfile::tempdir;
330
331    #[derive(Debug, PartialEq, Deserialize)]
332    struct Record {
333        id: GenericID,
334        value: u32,
335    }
336
337    impl HasID<GenericID> for Record {
338        fn get_id(&self) -> &GenericID {
339            &self.id
340        }
341    }
342
343    /// Create an example CSV file in `dir_path`
344    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
345        let file_path = dir_path.join("test.csv");
346        let mut file = File::create(&file_path).unwrap();
347        writeln!(file, "{contents}").unwrap();
348        file_path
349    }
350
351    /// Test a normal read
352    #[test]
353    fn read_csv_works() {
354        let dir = tempdir().unwrap();
355        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
356        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
357        assert_eq!(
358            records,
359            &[
360                Record {
361                    id: "hello".into(),
362                    value: 1,
363                },
364                Record {
365                    id: "world".into(),
366                    value: 2,
367                }
368            ]
369        );
370
371        // File with leading/trailing whitespace
372        let dir = tempdir().unwrap();
373        let file_path = create_csv_file(dir.path(), "id  , value\t\n  hello\t ,1\n world ,2\n");
374        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
375        assert_eq!(
376            records,
377            &[
378                Record {
379                    id: "hello".into(),
380                    value: 1,
381                },
382                Record {
383                    id: "world".into(),
384                    value: 2,
385                }
386            ]
387        );
388
389        // File with no data (only column headers)
390        let file_path = create_csv_file(dir.path(), "id,value\n");
391        assert!(read_csv::<Record>(&file_path).is_err());
392        assert!(
393            read_csv_optional::<Record>(&file_path)
394                .unwrap()
395                .next()
396                .is_none()
397        );
398
399        // Missing file
400        let dir = tempdir().unwrap();
401        let file_path = dir.path().join("a_missing_file.csv");
402        assert!(!file_path.exists());
403        assert!(read_csv::<Record>(&file_path).is_err());
404        // optional csv's should return empty iterator
405        assert!(
406            read_csv_optional::<Record>(&file_path)
407                .unwrap()
408                .next()
409                .is_none()
410        );
411    }
412
413    #[test]
414    fn read_toml_works() {
415        let dir = tempdir().unwrap();
416        let file_path = dir.path().join("test.toml");
417        {
418            let mut file = File::create(&file_path).unwrap();
419            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
420        }
421
422        assert_eq!(
423            read_toml::<Record>(&file_path).unwrap(),
424            Record {
425                id: "hello".into(),
426                value: 1,
427            }
428        );
429
430        {
431            let mut file = File::create(&file_path).unwrap();
432            writeln!(file, "bad toml syntax").unwrap();
433        }
434
435        read_toml::<Record>(&file_path).unwrap_err();
436    }
437
438    /// Deserialise value with `deserialise_proportion_nonzero()`
439    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
440        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
441        deserialise_proportion_nonzero(deserialiser)
442    }
443
444    #[test]
445    fn deserialise_proportion_nonzero_works() {
446        // Valid inputs
447        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
448        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
449        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
450
451        // Invalid inputs
452        deserialise_f64(0.0).unwrap_err();
453        deserialise_f64(-1.0).unwrap_err();
454        deserialise_f64(2.0).unwrap_err();
455        deserialise_f64(f64::NAN).unwrap_err();
456        deserialise_f64(f64::INFINITY).unwrap_err();
457    }
458
459    #[test]
460    fn check_values_sum_to_one_approx_works() {
461        // Single input, valid
462        check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).unwrap();
463
464        // Multiple inputs, valid
465        check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
466            .unwrap();
467
468        // Single input, invalid
469        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
470
471        // Multiple inputs, invalid
472        assert!(
473            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
474                .is_err()
475        );
476
477        // Edge cases
478        assert!(
479            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
480        );
481        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
482    }
483
484    #[rstest]
485    #[case(&[], true)]
486    #[case(&[1], true)]
487    #[case(&[1,2], true)]
488    #[case(&[1,2,3,4], true)]
489    #[case(&[2,1],false)]
490    #[case(&[1,1],false)]
491    #[case(&[1,3,2,4], false)]
492    fn is_sorted_and_unique_works(#[case] values: &[u32], #[case] expected: bool) {
493        assert_eq!(is_sorted_and_unique(values), expected);
494    }
495
496    #[test]
497    fn format_items_with_cap_works() {
498        let items = vec!["a", "b", "c"];
499        assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
500
501        // Test with more than 10 items to trigger the cap
502        let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
503        assert_eq!(
504            format_items_with_cap(&many_items),
505            r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
506        );
507    }
508}