muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::graph::build_and_validate_commodity_graphs_for_model;
4use crate::id::{HasID, IDLike};
5use crate::model::{Model, ModelFile};
6use crate::units::UnitType;
7use anyhow::{Context, Result, bail, ensure};
8use float_cmp::approx_eq;
9use indexmap::IndexMap;
10use itertools::Itertools;
11use serde::de::{Deserialize, DeserializeOwned, Deserializer};
12use std::collections::{HashMap, HashSet};
13use std::fs;
14use std::hash::Hash;
15use std::path::Path;
16
17mod agent;
18use agent::read_agents;
19mod asset;
20use asset::read_assets;
21mod commodity;
22use commodity::read_commodities;
23mod process;
24use process::read_processes;
25mod region;
26use region::read_regions;
27mod time_slice;
28use time_slice::read_time_slice_info;
29
30/// Read a series of type `T`s from a CSV file.
31///
32/// Will raise an error if the file is empty.
33///
34/// # Arguments
35///
36/// * `file_path` - Path to the CSV file
37pub fn read_csv<'a, T: DeserializeOwned + 'a>(
38    file_path: &'a Path,
39) -> Result<impl Iterator<Item = T> + 'a> {
40    let vec = _read_csv_internal(file_path)?;
41    if vec.is_empty() {
42        bail!("CSV file {} cannot be empty", file_path.display());
43    }
44    Ok(vec.into_iter())
45}
46
47/// Read a series of type `T`s from a CSV file.
48///
49/// # Arguments
50///
51/// * `file_path` - Path to the CSV file
52pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
53    file_path: &'a Path,
54) -> Result<impl Iterator<Item = T> + 'a> {
55    let vec = _read_csv_internal(file_path)?;
56    Ok(vec.into_iter())
57}
58
59fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
60    let vec = csv::Reader::from_path(file_path)
61        .with_context(|| input_err_msg(file_path))?
62        .into_deserialize()
63        .process_results(|iter| iter.collect_vec())
64        .with_context(|| input_err_msg(file_path))?;
65
66    Ok(vec)
67}
68
69/// Parse a TOML file at the specified path.
70///
71/// # Arguments
72///
73/// * `file_path` - Path to the TOML file
74///
75/// # Returns
76///
77/// * The deserialised TOML data or an error if the file could not be read or parsed.
78pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
79    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
80    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
81    Ok(toml_data)
82}
83
84/// Read a Dimensionless float, checking that it is between 0 and 1
85pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
86where
87    T: UnitType,
88    D: Deserializer<'de>,
89{
90    let value = f64::deserialize(deserialiser)?;
91    if !(value > 0.0 && value <= 1.0) {
92        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
93    }
94
95    Ok(T::new(value))
96}
97
98/// Format an error message to include the file path. To be used with `anyhow::Context`.
99pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
100    format!("Error reading {}", file_path.as_ref().display())
101}
102
103/// Read a CSV file of items with IDs.
104///
105/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
106/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
107fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
108where
109    T: HasID<ID> + DeserializeOwned,
110{
111    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
112    where
113        T: HasID<ID> + DeserializeOwned,
114    {
115        let mut map = IndexMap::new();
116        for record in read_csv::<T>(file_path)? {
117            let id = record.get_id().clone();
118            let existing = map.insert(id.clone(), record).is_some();
119            ensure!(!existing, "Duplicate ID found: {id}");
120        }
121        ensure!(!map.is_empty(), "CSV file is empty");
122
123        Ok(map)
124    }
125
126    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
127}
128
129/// Check that fractions sum to (approximately) one
130fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
131where
132    T: UnitType,
133    I: Iterator<Item = T>,
134{
135    let sum = fractions.sum();
136    ensure!(
137        approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
138        "Sum of fractions does not equal one (actual: {})",
139        sum
140    );
141
142    Ok(())
143}
144
145/// Check whether an iterator contains values that are sorted and unique
146pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
147where
148    T: PartialOrd + Clone,
149    I: IntoIterator<Item = T>,
150{
151    iter.into_iter().tuple_windows().all(|(a, b)| a < b)
152}
153
154/// Inserts a key-value pair into a HashMap if the key does not already exist.
155///
156/// If the key already exists, it returns an error with a message indicating the key's existence.
157pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
158where
159    K: Eq + Hash + Clone + std::fmt::Debug,
160{
161    let existing = map.insert(key.clone(), value);
162    match existing {
163        Some(_) => bail!("Key {:?} already exists in the map", key),
164        None => Ok(()),
165    }
166}
167
168/// Read a model from the specified directory.
169///
170/// # Arguments
171///
172/// * `model_dir` - Folder containing model configuration files
173///
174/// # Returns
175///
176/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
177pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
178    let model_file = ModelFile::from_path(&model_dir)?;
179
180    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
181    let regions = read_regions(model_dir.as_ref())?;
182    let region_ids = regions.keys().cloned().collect();
183    let years = &model_file.milestone_years;
184
185    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
186    let processes = read_processes(
187        model_dir.as_ref(),
188        &commodities,
189        &region_ids,
190        &time_slice_info,
191        years,
192    )?;
193    let agents = read_agents(
194        model_dir.as_ref(),
195        &commodities,
196        &processes,
197        &region_ids,
198        years,
199    )?;
200    let agent_ids = agents.keys().cloned().collect();
201    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
202
203    // Build and validate commodity graphs for all regions and years
204    // This gives us the commodity order for each region/year which is passed to the model
205    let commodity_order = build_and_validate_commodity_graphs_for_model(
206        &processes,
207        &commodities,
208        &region_ids,
209        years,
210        &time_slice_info,
211    )?;
212
213    let model_path = model_dir
214        .as_ref()
215        .canonicalize()
216        .context("Could not parse path to model")?;
217    let model = Model {
218        model_path,
219        parameters: model_file,
220        agents,
221        commodities,
222        processes,
223        time_slice_info,
224        regions,
225        commodity_order,
226    };
227    Ok((model, AssetPool::new(assets)))
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::id::GenericID;
234    use crate::units::Dimensionless;
235    use rstest::rstest;
236    use serde::Deserialize;
237    use serde::de::IntoDeserializer;
238    use serde::de::value::{Error as ValueError, F64Deserializer};
239    use std::fs::File;
240    use std::io::Write;
241    use std::path::PathBuf;
242    use tempfile::tempdir;
243
244    #[derive(Debug, PartialEq, Deserialize)]
245    struct Record {
246        id: GenericID,
247        value: u32,
248    }
249
250    impl HasID<GenericID> for Record {
251        fn get_id(&self) -> &GenericID {
252            &self.id
253        }
254    }
255
256    /// Create an example CSV file in dir_path
257    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
258        let file_path = dir_path.join("test.csv");
259        let mut file = File::create(&file_path).unwrap();
260        writeln!(file, "{contents}").unwrap();
261        file_path
262    }
263
264    /// Test a normal read
265    #[test]
266    fn test_read_csv() {
267        let dir = tempdir().unwrap();
268        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
269        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
270        assert_eq!(
271            records,
272            &[
273                Record {
274                    id: "hello".into(),
275                    value: 1,
276                },
277                Record {
278                    id: "world".into(),
279                    value: 2,
280                }
281            ]
282        );
283
284        // File with no data (only column headers)
285        let file_path = create_csv_file(dir.path(), "id,value\n");
286        assert!(read_csv::<Record>(&file_path).is_err());
287        assert!(
288            read_csv_optional::<Record>(&file_path)
289                .unwrap()
290                .next()
291                .is_none()
292        );
293    }
294
295    #[test]
296    fn test_read_toml() {
297        let dir = tempdir().unwrap();
298        let file_path = dir.path().join("test.toml");
299        {
300            let mut file = File::create(&file_path).unwrap();
301            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
302        }
303
304        assert_eq!(
305            read_toml::<Record>(&file_path).unwrap(),
306            Record {
307                id: "hello".into(),
308                value: 1,
309            }
310        );
311
312        {
313            let mut file = File::create(&file_path).unwrap();
314            writeln!(file, "bad toml syntax").unwrap();
315        }
316
317        assert!(read_toml::<Record>(&file_path).is_err());
318    }
319
320    /// Deserialise value with deserialise_proportion_nonzero()
321    fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
322        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
323        deserialise_proportion_nonzero(deserialiser)
324    }
325
326    #[test]
327    fn test_deserialise_proportion_nonzero() {
328        // Valid inputs
329        assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
330        assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
331        assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
332
333        // Invalid inputs
334        assert!(deserialise_f64(0.0).is_err());
335        assert!(deserialise_f64(-1.0).is_err());
336        assert!(deserialise_f64(2.0).is_err());
337        assert!(deserialise_f64(f64::NAN).is_err());
338        assert!(deserialise_f64(f64::INFINITY).is_err());
339    }
340
341    #[test]
342    fn test_check_values_sum_to_one_approx() {
343        // Single input, valid
344        assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
345
346        // Multiple inputs, valid
347        assert!(
348            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
349                .is_ok()
350        );
351
352        // Single input, invalid
353        assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
354
355        // Multiple inputs, invalid
356        assert!(
357            check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
358                .is_err()
359        );
360
361        // Edge cases
362        assert!(
363            check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
364        );
365        assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
366    }
367
368    #[rstest]
369    #[case(&[], true)]
370    #[case(&[1], true)]
371    #[case(&[1,2], true)]
372    #[case(&[1,2,3,4], true)]
373    #[case(&[2,1],false)]
374    #[case(&[1,1],false)]
375    #[case(&[1,3,2,4], false)]
376    fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
377        assert_eq!(is_sorted_and_unique(values), expected)
378    }
379}