muse2/
input.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
//! Common routines for handling input data.
use crate::agent::AssetPool;
use crate::model::{Model, ModelFile};
use anyhow::{ensure, Context, Result};
use float_cmp::approx_eq;
use itertools::Itertools;
use serde::de::{Deserialize, DeserializeOwned, Deserializer};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::Path;
use std::rc::Rc;

pub mod agent;
pub use agent::read_agents;
pub mod asset;
use asset::read_assets;
pub mod commodity;
pub use commodity::read_commodities;
pub mod process;
pub use process::read_processes;
pub mod region;
pub use region::read_regions;
mod time_slice;
pub use time_slice::read_time_slice_info;

/// Read a series of type `T`s from a CSV file.
///
/// # Arguments
///
/// * `file_path` - Path to the CSV file
pub fn read_csv<'a, T: DeserializeOwned + 'a>(
    file_path: &'a Path,
) -> Result<impl Iterator<Item = T> + 'a> {
    let vec = csv::Reader::from_path(file_path)
        .with_context(|| input_err_msg(file_path))?
        .into_deserialize()
        .process_results(|iter| iter.collect_vec())
        .with_context(|| input_err_msg(file_path))?;

    Ok(vec.into_iter())
}

/// Parse a TOML file at the specified path.
///
/// # Arguments
///
/// * `file_path` - Path to the TOML file
///
/// # Returns
///
/// * The deserialised TOML data or an error if the file could not be read or parsed.
pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
    let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
    let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
    Ok(toml_data)
}

/// Read an f64, checking that it is between 0 and 1
pub fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<f64, D::Error>
where
    D: Deserializer<'de>,
{
    let value = Deserialize::deserialize(deserialiser)?;
    if !(value > 0.0 && value <= 1.0) {
        Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
    }

    Ok(value)
}

/// Format an error message to include the file path. To be used with `anyhow::Context`.
pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
    format!("Error reading {}", file_path.as_ref().to_string_lossy())
}

/// Indicates that the struct has an ID field
pub trait HasID {
    /// Get a string representation of the struct's ID
    fn get_id(&self) -> &str;
}

/// Implement the `HasID` trait for the given type, assuming it has a field called `id`
macro_rules! define_id_getter {
    ($t:ty) => {
        impl HasID for $t {
            fn get_id(&self) -> &str {
                &self.id
            }
        }
    };
}

pub(crate) use define_id_getter;

/// A data structure containing a set of IDs
pub trait IDCollection {
    /// Get the ID after checking that it exists this collection.
    ///
    /// # Arguments
    ///
    /// * `id` - The ID to look up
    ///
    /// # Returns
    ///
    /// A copy of the `Rc<str>` in `self` or an error if not found.
    fn get_id(&self, id: &str) -> Result<Rc<str>>;
}

impl IDCollection for HashSet<Rc<str>> {
    fn get_id(&self, id: &str) -> Result<Rc<str>> {
        let id = self
            .get(id)
            .with_context(|| format!("Unknown ID {id} found"))?;
        Ok(Rc::clone(id))
    }
}

/// Read a CSV file of items with IDs
pub fn read_csv_id_file<T>(file_path: &Path) -> Result<HashMap<Rc<str>, T>>
where
    T: HasID + DeserializeOwned,
{
    fn fill_and_validate_map<T>(file_path: &Path) -> Result<HashMap<Rc<str>, T>>
    where
        T: HasID + DeserializeOwned,
    {
        let mut map = HashMap::new();
        for record in read_csv::<T>(file_path)? {
            let id = record.get_id();

            ensure!(!map.contains_key(id), "Duplicate ID found: {id}");

            map.insert(id.into(), record);
        }
        ensure!(!map.is_empty(), "CSV file is empty");

        Ok(map)
    }

    fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
}

/// Trait for converting an iterator into a [`HashMap`] grouped by IDs.
pub trait IntoIDMap<T> {
    /// Convert into a [`HashMap`] grouped by IDs.
    fn into_id_map(self, ids: &HashSet<Rc<str>>) -> Result<HashMap<Rc<str>, Vec<T>>>;
}

impl<T, I> IntoIDMap<T> for I
where
    T: HasID,
    I: Iterator<Item = T>,
{
    /// Convert the specified iterator into a `HashMap` of the items grouped by ID.
    ///
    /// # Arguments
    ///
    /// `ids` - The set of valid IDs to check against.
    fn into_id_map(self, ids: &HashSet<Rc<str>>) -> Result<HashMap<Rc<str>, Vec<T>>> {
        let map = self
            .map(|item| -> Result<_> {
                let id = ids.get_id(item.get_id())?;
                Ok((id, item))
            })
            .process_results(|iter| iter.into_group_map())?;

        ensure!(!map.is_empty(), "CSV file is empty");

        Ok(map)
    }
}

/// Check that fractions sum to (approximately) one
pub fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
where
    I: Iterator<Item = f64>,
{
    let sum = fractions.sum();
    ensure!(
        approx_eq!(f64, sum, 1.0, epsilon = 1e-5),
        "Sum of fractions does not equal one (actual: {})",
        sum
    );

    Ok(())
}

/// Read a model from the specified directory.
///
/// # Arguments
///
/// * `model_dir` - Folder containing model configuration files
///
/// # Returns
///
/// The static model data ([`Model`]) and an [`AssetPool`] struct or an error.
pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
    let model_file = ModelFile::from_path(&model_dir)?;

    let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
    let regions = read_regions(model_dir.as_ref())?;
    let region_ids = regions.keys().cloned().collect();
    let years = &model_file.milestone_years.years;
    let year_range = *years.first().unwrap()..=*years.last().unwrap();

    let commodities = read_commodities(model_dir.as_ref(), &region_ids, &time_slice_info, years)?;
    let processes = read_processes(
        model_dir.as_ref(),
        &commodities,
        &region_ids,
        &time_slice_info,
        &year_range,
    )?;
    let agents = read_agents(model_dir.as_ref(), &commodities, &processes, &region_ids)?;
    let agent_ids = agents.keys().cloned().collect();
    let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, &region_ids)?;

    let model = Model {
        milestone_years: model_file.milestone_years.years,
        agents,
        commodities,
        processes,
        time_slice_info,
        regions,
    };
    Ok((model, assets))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde::de::value::{Error as ValueError, F64Deserializer};
    use serde::de::IntoDeserializer;
    use serde::Deserialize;
    use std::fs::File;
    use std::io::Write;
    use std::path::PathBuf;
    use tempfile::tempdir;

    #[derive(Debug, PartialEq, Deserialize)]
    struct Record {
        id: String,
        value: u32,
    }

    impl HasID for Record {
        fn get_id(&self) -> &str {
            &self.id
        }
    }

    /// Create an example CSV file in dir_path
    fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
        let file_path = dir_path.join("test.csv");
        let mut file = File::create(&file_path).unwrap();
        writeln!(file, "{}", contents).unwrap();
        file_path
    }

    /// Test a normal read
    #[test]
    fn test_read_csv() {
        let dir = tempdir().unwrap();
        let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
        let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
        assert_eq!(
            records,
            &[
                Record {
                    id: "hello".to_string(),
                    value: 1,
                },
                Record {
                    id: "world".to_string(),
                    value: 2,
                }
            ]
        );
    }

    #[test]
    fn test_read_toml() {
        let dir = tempdir().unwrap();
        let file_path = dir.path().join("test.toml");
        {
            let mut file = File::create(&file_path).unwrap();
            writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
        }

        assert_eq!(
            read_toml::<Record>(&file_path).unwrap(),
            Record {
                id: "hello".to_string(),
                value: 1,
            }
        );

        {
            let mut file = File::create(&file_path).unwrap();
            writeln!(file, "bad toml syntax").unwrap();
        }

        assert!(read_toml::<Record>(&file_path).is_err());
    }

    /// Deserialise value with deserialise_proportion_nonzero()
    fn deserialise_f64(value: f64) -> Result<f64, ValueError> {
        let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
        deserialise_proportion_nonzero(deserialiser)
    }

    #[test]
    fn test_deserialise_proportion_nonzero() {
        // Valid inputs
        assert_eq!(deserialise_f64(0.01), Ok(0.01));
        assert_eq!(deserialise_f64(0.5), Ok(0.5));
        assert_eq!(deserialise_f64(1.0), Ok(1.0));

        // Invalid inputs
        assert!(deserialise_f64(0.0).is_err());
        assert!(deserialise_f64(-1.0).is_err());
        assert!(deserialise_f64(2.0).is_err());
        assert!(deserialise_f64(f64::NAN).is_err());
        assert!(deserialise_f64(f64::INFINITY).is_err());
    }

    #[test]
    fn test_check_fractions_sum_to_one() {
        // Single input, valid
        assert!(check_fractions_sum_to_one([1.0].into_iter()).is_ok());

        // Multiple inputs, valid
        assert!(check_fractions_sum_to_one([0.4, 0.6].into_iter()).is_ok());

        // Single input, invalid
        assert!(check_fractions_sum_to_one([0.5].into_iter()).is_err());

        // Multiple inputs, invalid
        assert!(check_fractions_sum_to_one([0.4, 0.3].into_iter()).is_err());

        // Edge cases
        assert!(check_fractions_sum_to_one([f64::INFINITY].into_iter()).is_err());
        assert!(check_fractions_sum_to_one([f64::NAN].into_iter()).is_err());
    }
}