muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::id::{HasID, IDLike};
4use crate::model::{Model, ModelFile};
5use crate::units::Dimensionless;
6use anyhow::{bail, ensure, Context, Result};
7use float_cmp::approx_eq;
8use indexmap::IndexMap;
9use itertools::Itertools;
10use serde::de::{Deserialize, DeserializeOwned, Deserializer};
11use std::collections::{HashMap, HashSet};
12use std::fs;
13use std::hash::Hash;
14use std::path::Path;
15
16mod agent;
17use agent::read_agents;
18mod asset;
19use asset::read_assets;
20mod commodity;
21use commodity::read_commodities;
22mod process;
23use process::read_processes;
24mod region;
25use region::read_regions;
26mod time_slice;
27use time_slice::read_time_slice_info;
28
29/// Read a series of type `T`s from a CSV file.
30///
31/// Will raise an error if the file is empty.
32///
33/// # Arguments
34///
35/// * `file_path` - Path to the CSV file
36pub fn read_csv<'a, T: DeserializeOwned + 'a>(
37    file_path: &'a Path,
38) -> Result<impl Iterator<Item = T> + 'a> {
39    let vec = _read_csv_internal(file_path)?;
40    if vec.is_empty() {
41        bail!("CSV file {} cannot be empty", file_path.display());
42    }
43    Ok(vec.into_iter())
44}
45
46/// Read a series of type `T`s from a CSV file.
47///
48/// # Arguments
49///
50/// * `file_path` - Path to the CSV file
51pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
52    file_path: &'a Path,
53) -> Result<impl Iterator<Item = T> + 'a> {
54    let vec = _read_csv_internal(file_path)?;
55    Ok(vec.into_iter())
56}
57
58fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
59    let vec = csv::Reader::from_path(file_path)
60        .with_context(|| input_err_msg(file_path))?
61        .into_deserialize()
62        .process_results(|iter| iter.collect_vec())
63        .with_context(|| input_err_msg(file_path))?;
64
65    Ok(vec)
66}
67
68/// Parse a TOML file at the specified path.
69///
70/// # Arguments
71///
72/// * `file_path` - Path to the TOML file
73///
74/// # Returns
75///
76/// * The deserialised TOML data or an error if the file could not be read or parsed.
77pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
78    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
79    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
80    Ok(toml_data)
81}
82
83/// Read a Dimensionless float, checking that it is between 0 and 1
84fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<Dimensionless, D::Error>
85where
86    D: Deserializer<'de>,
87{
88    let value = f64::deserialize(deserialiser)?;
89    if !(value > 0.0 && value <= 1.0) {
90        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
91    }
92
93    Ok(Dimensionless(value))
94}
95
96/// Format an error message to include the file path. To be used with `anyhow::Context`.
97pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
98    format!("Error reading {}", file_path.as_ref().display())
99}
100
101/// Read a CSV file of items with IDs.
102///
103/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
104/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
105fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
106where
107    T: HasID<ID> + DeserializeOwned,
108{
109    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
110    where
111        T: HasID<ID> + DeserializeOwned,
112    {
113        let mut map = IndexMap::new();
114        for record in read_csv::<T>(file_path)? {
115            let id = record.get_id().clone();
116            let existing = map.insert(id.clone(), record).is_some();
117            ensure!(!existing, "Duplicate ID found: {id}");
118        }
119        ensure!(!map.is_empty(), "CSV file is empty");
120
121        Ok(map)
122    }
123
124    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
125}
126
127/// Check that fractions sum to (approximately) one
128fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
129where
130    I: Iterator<Item = Dimensionless>,
131{
132    let sum = fractions.sum();
133    ensure!(
134        approx_eq!(Dimensionless, sum, Dimensionless(1.0), epsilon = 1e-5),
135        "Sum of fractions does not equal one (actual: {})",
136        sum
137    );
138
139    Ok(())
140}
141
142/// Check whether an iterator contains values that are sorted and unique
143pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
144where
145    T: PartialOrd + Clone,
146    I: IntoIterator<Item = T>,
147{
148    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
149}
150
151/// Inserts a key-value pair into a HashMap if the key does not already exist.
152///
153/// If the key already exists, it returns an error with a message indicating the key's existence.
154pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
155where
156    K: Eq + Hash + Clone + std::fmt::Debug,
157{
158    let existing = map.insert(key.clone(), value);
159    match existing {
160        Some(_) => bail!("Key {:?} already exists in the map", key),
161        None => Ok(()),
162    }
163}
164
165/// Read a model from the specified directory.
166///
167/// # Arguments
168///
169/// * `model_dir` - Folder containing model configuration files
170///
171/// # Returns
172///
173/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
174pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
175    let model_file = ModelFile::from_path(&model_dir)?;
176
177    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
178    let regions = read_regions(model_dir.as_ref())?;
179    let region_ids = regions.keys().cloned().collect();
180    let years = &model_file.milestone_years.years;
181
182    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
183    let processes = read_processes(
184        model_dir.as_ref(),
185        &commodities,
186        &region_ids,
187        &time_slice_info,
188        years,
189    )?;
190    let agents = read_agents(
191        model_dir.as_ref(),
192        &commodities,
193        &processes,
194        &region_ids,
195        years,
196    )?;
197    let agent_ids = agents.keys().cloned().collect();
198    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
199
200    let model_path = model_dir
201        .as_ref()
202        .canonicalize()
203        .context("Could not parse path to model")?;
204    let model = Model {
205        model_path,
206        milestone_years: model_file.milestone_years.years,
207        agents,
208        commodities,
209        processes,
210        time_slice_info,
211        regions,
212    };
213    Ok((model, AssetPool::new(assets)))
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::id::GenericID;
220    use rstest::rstest;
221    use serde::de::value::{Error as ValueError, F64Deserializer};
222    use serde::de::IntoDeserializer;
223    use serde::Deserialize;
224    use std::fs::File;
225    use std::io::Write;
226    use std::path::PathBuf;
227    use tempfile::tempdir;
228
229    #[derive(Debug, PartialEq, Deserialize)]
230    struct Record {
231        id: GenericID,
232        value: u32,
233    }
234
235    impl HasID<GenericID> for Record {
236        fn get_id(&self) -> &GenericID {
237            &self.id
238        }
239    }
240
241    /// Create an example CSV file in dir_path
242    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
243        let file_path = dir_path.join("test.csv");
244        let mut file = File::create(&file_path).unwrap();
245        writeln!(file, "{}", contents).unwrap();
246        file_path
247    }
248
249    /// Test a normal read
250    #[test]
251    fn test_read_csv() {
252        let dir = tempdir().unwrap();
253        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
254        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
255        assert_eq!(
256            records,
257            &[
258                Record {
259                    id: "hello".into(),
260                    value: 1,
261                },
262                Record {
263                    id: "world".into(),
264                    value: 2,
265                }
266            ]
267        );
268
269        // File with no data (only column headers)
270        let file_path = create_csv_file(dir.path(), "id,value\n");
271        assert!(read_csv::<Record>(&file_path).is_err());
272        assert!(read_csv_optional::<Record>(&file_path)
273            .unwrap()
274            .next()
275            .is_none());
276    }
277
278    #[test]
279    fn test_read_toml() {
280        let dir = tempdir().unwrap();
281        let file_path = dir.path().join("test.toml");
282        {
283            let mut file = File::create(&file_path).unwrap();
284            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
285        }
286
287        assert_eq!(
288            read_toml::<Record>(&file_path).unwrap(),
289            Record {
290                id: "hello".into(),
291                value: 1,
292            }
293        );
294
295        {
296            let mut file = File::create(&file_path).unwrap();
297            writeln!(file, "bad toml syntax").unwrap();
298        }
299
300        assert!(read_toml::<Record>(&file_path).is_err());
301    }
302
303    /// Deserialise value with deserialise_proportion_nonzero()
304    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
305        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
306        deserialise_proportion_nonzero(deserialiser)
307    }
308
309    #[test]
310    fn test_deserialise_proportion_nonzero() {
311        // Valid inputs
312        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
313        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
314        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
315
316        // Invalid inputs
317        assert!(deserialise_f64(0.0).is_err());
318        assert!(deserialise_f64(-1.0).is_err());
319        assert!(deserialise_f64(2.0).is_err());
320        assert!(deserialise_f64(f64::NAN).is_err());
321        assert!(deserialise_f64(f64::INFINITY).is_err());
322    }
323
324    #[test]
325    fn test_check_fractions_sum_to_one() {
326        // Single input, valid
327        assert!(check_fractions_sum_to_one([Dimensionless(1.0)].into_iter()).is_ok());
328
329        // Multiple inputs, valid
330        assert!(
331            check_fractions_sum_to_one([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
332                .is_ok()
333        );
334
335        // Single input, invalid
336        assert!(check_fractions_sum_to_one([Dimensionless(0.5)].into_iter()).is_err());
337
338        // Multiple inputs, invalid
339        assert!(
340            check_fractions_sum_to_one([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
341                .is_err()
342        );
343
344        // Edge cases
345        assert!(check_fractions_sum_to_one([Dimensionless(f64::INFINITY)].into_iter()).is_err());
346        assert!(check_fractions_sum_to_one([Dimensionless(f64::NAN)].into_iter()).is_err());
347    }
348
349    #[rstest]
350    #[case(&[], true)]
351    #[case(&[1], true)]
352    #[case(&[1,2], true)]
353    #[case(&[1,2,3,4], true)]
354    #[case(&[2,1],false)]
355    #[case(&[1,1],false)]
356    #[case(&[1,3,2,4], false)]
357    fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
358        assert_eq!(is_sorted_and_unique(values), expected)
359    }
360}