muse2/
input.rs

1//! Common routines for handling input data.
2use crate::asset::AssetPool;
3use crate::id::{HasID, IDLike};
4use crate::model::{Model, ModelFile};
5use anyhow::{bail, ensure, Context, Result};
6use float_cmp::approx_eq;
7use indexmap::IndexMap;
8use itertools::Itertools;
9use serde::de::{Deserialize, DeserializeOwned, Deserializer};
10use std::collections::{HashMap, HashSet};
11use std::fs;
12use std::hash::Hash;
13use std::path::Path;
14
15mod agent;
16use agent::read_agents;
17mod asset;
18use asset::read_assets;
19mod commodity;
20use commodity::read_commodities;
21mod process;
22use process::read_processes;
23mod region;
24use region::read_regions;
25mod time_slice;
26use time_slice::read_time_slice_info;
27
28/// Read a series of type `T`s from a CSV file.
29///
30/// Will raise an error if the file is empty.
31///
32/// # Arguments
33///
34/// * `file_path` - Path to the CSV file
35pub fn read_csv<'a, T: DeserializeOwned + 'a>(
36    file_path: &'a Path,
37) -> Result<impl Iterator<Item = T> + 'a> {
38    let vec = _read_csv_internal(file_path)?;
39    if vec.is_empty() {
40        bail!("CSV file {} cannot be empty", file_path.display());
41    }
42    Ok(vec.into_iter())
43}
44
45/// Read a series of type `T`s from a CSV file.
46///
47/// # Arguments
48///
49/// * `file_path` - Path to the CSV file
50pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
51    file_path: &'a Path,
52) -> Result<impl Iterator<Item = T> + 'a> {
53    let vec = _read_csv_internal(file_path)?;
54    Ok(vec.into_iter())
55}
56
57fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
58    let vec = csv::Reader::from_path(file_path)
59        .with_context(|| input_err_msg(file_path))?
60        .into_deserialize()
61        .process_results(|iter| iter.collect_vec())
62        .with_context(|| input_err_msg(file_path))?;
63
64    Ok(vec)
65}
66
67/// Parse a TOML file at the specified path.
68///
69/// # Arguments
70///
71/// * `file_path` - Path to the TOML file
72///
73/// # Returns
74///
75/// * The deserialised TOML data or an error if the file could not be read or parsed.
76pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
77    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
78    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
79    Ok(toml_data)
80}
81
82/// Read an f64, checking that it is between 0 and 1
83fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<f64, D::Error>
84where
85    D: Deserializer<'de>,
86{
87    let value = Deserialize::deserialize(deserialiser)?;
88    if !(value > 0.0 && value <= 1.0) {
89        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
90    }
91
92    Ok(value)
93}
94
95/// Format an error message to include the file path. To be used with `anyhow::Context`.
96pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
97    format!("Error reading {}", file_path.as_ref().display())
98}
99
100/// Read a CSV file of items with IDs.
101///
102/// As this function is only ever used for top-level CSV files (i.e. the ones which actually define
103/// the IDs for a given type), we use an ordered map to maintain the order in the input files.
104fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
105where
106    T: HasID<ID> + DeserializeOwned,
107{
108    fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
109    where
110        T: HasID<ID> + DeserializeOwned,
111    {
112        let mut map = IndexMap::new();
113        for record in read_csv::<T>(file_path)? {
114            let id = record.get_id().clone();
115            let existing = map.insert(id.clone(), record).is_some();
116            ensure!(!existing, "Duplicate ID found: {id}");
117        }
118        ensure!(!map.is_empty(), "CSV file is empty");
119
120        Ok(map)
121    }
122
123    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
124}
125
126/// Check that fractions sum to (approximately) one
127fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
128where
129    I: Iterator<Item = f64>,
130{
131    let sum = fractions.sum();
132    ensure!(
133        approx_eq!(f64, sum, 1.0, epsilon = 1e-5),
134        "Sum of fractions does not equal one (actual: {})",
135        sum
136    );
137
138    Ok(())
139}
140
141/// Inserts a key-value pair into a HashMap if the key does not already exist.
142///
143/// If the key already exists, it returns an error with a message indicating the key's existence.
144pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
145where
146    K: Eq + Hash + Clone + std::fmt::Debug,
147{
148    let existing = map.insert(key.clone(), value);
149    match existing {
150        Some(_) => bail!("Key {:?} already exists in the map", key),
151        None => Ok(()),
152    }
153}
154
155/// Read a model from the specified directory.
156///
157/// # Arguments
158///
159/// * `model_dir` - Folder containing model configuration files
160///
161/// # Returns
162///
163/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
164pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
165    let model_file = ModelFile::from_path(&model_dir)?;
166
167    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
168    let regions = read_regions(model_dir.as_ref())?;
169    let region_ids = regions.keys().cloned().collect();
170    let years = &model_file.milestone_years.years;
171
172    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
173    let processes = read_processes(
174        model_dir.as_ref(),
175        &commodities,
176        &region_ids,
177        &time_slice_info,
178        years,
179    )?;
180    let agents = read_agents(
181        model_dir.as_ref(),
182        &commodities,
183        &processes,
184        &region_ids,
185        years,
186    )?;
187    let agent_ids = agents.keys().cloned().collect();
188    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;
189
190    let model = Model {
191        milestone_years: model_file.milestone_years.years,
192        agents,
193        commodities,
194        processes,
195        time_slice_info,
196        regions,
197    };
198    Ok((model, AssetPool::new(assets)))
199}
200
201#[cfg(test)]
202mod tests {
203    use crate::id::GenericID;
204
205    use super::*;
206    use serde::de::value::{Error as ValueError, F64Deserializer};
207    use serde::de::IntoDeserializer;
208    use serde::Deserialize;
209    use std::fs::File;
210    use std::io::Write;
211    use std::path::PathBuf;
212    use tempfile::tempdir;
213
214    #[derive(Debug, PartialEq, Deserialize)]
215    struct Record {
216        id: GenericID,
217        value: u32,
218    }
219
220    impl HasID<GenericID> for Record {
221        fn get_id(&self) -> &GenericID {
222            &self.id
223        }
224    }
225
226    /// Create an example CSV file in dir_path
227    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
228        let file_path = dir_path.join("test.csv");
229        let mut file = File::create(&file_path).unwrap();
230        writeln!(file, "{}", contents).unwrap();
231        file_path
232    }
233
234    /// Test a normal read
235    #[test]
236    fn test_read_csv() {
237        let dir = tempdir().unwrap();
238        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
239        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
240        assert_eq!(
241            records,
242            &[
243                Record {
244                    id: "hello".into(),
245                    value: 1,
246                },
247                Record {
248                    id: "world".into(),
249                    value: 2,
250                }
251            ]
252        );
253
254        // File with no data (only column headers)
255        let file_path = create_csv_file(dir.path(), "id,value\n");
256        assert!(read_csv::<Record>(&file_path).is_err());
257        assert!(read_csv_optional::<Record>(&file_path)
258            .unwrap()
259            .next()
260            .is_none());
261    }
262
263    #[test]
264    fn test_read_toml() {
265        let dir = tempdir().unwrap();
266        let file_path = dir.path().join("test.toml");
267        {
268            let mut file = File::create(&file_path).unwrap();
269            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
270        }
271
272        assert_eq!(
273            read_toml::<Record>(&file_path).unwrap(),
274            Record {
275                id: "hello".into(),
276                value: 1,
277            }
278        );
279
280        {
281            let mut file = File::create(&file_path).unwrap();
282            writeln!(file, "bad toml syntax").unwrap();
283        }
284
285        assert!(read_toml::<Record>(&file_path).is_err());
286    }
287
288    /// Deserialise value with deserialise_proportion_nonzero()
289    fn deserialise_f64(value: f64) -> Result<f64, ValueError> {
290        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
291        deserialise_proportion_nonzero(deserialiser)
292    }
293
294    #[test]
295    fn test_deserialise_proportion_nonzero() {
296        // Valid inputs
297        assert_eq!(deserialise_f64(0.01), Ok(0.01));
298        assert_eq!(deserialise_f64(0.5), Ok(0.5));
299        assert_eq!(deserialise_f64(1.0), Ok(1.0));
300
301        // Invalid inputs
302        assert!(deserialise_f64(0.0).is_err());
303        assert!(deserialise_f64(-1.0).is_err());
304        assert!(deserialise_f64(2.0).is_err());
305        assert!(deserialise_f64(f64::NAN).is_err());
306        assert!(deserialise_f64(f64::INFINITY).is_err());
307    }
308
309    #[test]
310    fn test_check_fractions_sum_to_one() {
311        // Single input, valid
312        assert!(check_fractions_sum_to_one([1.0].into_iter()).is_ok());
313
314        // Multiple inputs, valid
315        assert!(check_fractions_sum_to_one([0.4, 0.6].into_iter()).is_ok());
316
317        // Single input, invalid
318        assert!(check_fractions_sum_to_one([0.5].into_iter()).is_err());
319
320        // Multiple inputs, invalid
321        assert!(check_fractions_sum_to_one([0.4, 0.3].into_iter()).is_err());
322
323        // Edge cases
324        assert!(check_fractions_sum_to_one([f64::INFINITY].into_iter()).is_err());
325        assert!(check_fractions_sum_to_one([f64::NAN].into_iter()).is_err());
326    }
327}