1use crate::asset::AssetPool;
3use crate::graph::build_and_validate_commodity_graphs_for_model;
4use crate::id::{HasID, IDLike};
5use crate::model::{Model, ModelFile};
6use crate::units::UnitType;
7use anyhow::{Context, Result, bail, ensure};
8use float_cmp::approx_eq;
9use indexmap::IndexMap;
10use itertools::Itertools;
11use serde::de::{Deserialize, DeserializeOwned, Deserializer};
12use std::collections::{HashMap, HashSet};
13use std::fs;
14use std::hash::Hash;
15use std::path::Path;
16
17mod agent;
18use agent::read_agents;
19mod asset;
20use asset::read_assets;
21mod commodity;
22use commodity::read_commodities;
23mod process;
24use process::read_processes;
25mod region;
26use region::read_regions;
27mod time_slice;
28use time_slice::read_time_slice_info;
29
30pub fn read_csv<'a, T: DeserializeOwned + 'a>(
38 file_path: &'a Path,
39) -> Result<impl Iterator<Item = T> + 'a> {
40 let vec = _read_csv_internal(file_path)?;
41 if vec.is_empty() {
42 bail!("CSV file {} cannot be empty", file_path.display());
43 }
44 Ok(vec.into_iter())
45}
46
47pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
53 file_path: &'a Path,
54) -> Result<impl Iterator<Item = T> + 'a> {
55 let vec = _read_csv_internal(file_path)?;
56 Ok(vec.into_iter())
57}
58
59fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
60 let vec = csv::Reader::from_path(file_path)
61 .with_context(|| input_err_msg(file_path))?
62 .into_deserialize()
63 .process_results(|iter| iter.collect_vec())
64 .with_context(|| input_err_msg(file_path))?;
65
66 Ok(vec)
67}
68
69pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
79 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
80 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
81 Ok(toml_data)
82}
83
84pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
86where
87 T: UnitType,
88 D: Deserializer<'de>,
89{
90 let value = f64::deserialize(deserialiser)?;
91 if !(value > 0.0 && value <= 1.0) {
92 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
93 }
94
95 Ok(T::new(value))
96}
97
98pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
100 format!("Error reading {}", file_path.as_ref().display())
101}
102
103fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
108where
109 T: HasID<ID> + DeserializeOwned,
110{
111 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
112 where
113 T: HasID<ID> + DeserializeOwned,
114 {
115 let mut map = IndexMap::new();
116 for record in read_csv::<T>(file_path)? {
117 let id = record.get_id().clone();
118 let existing = map.insert(id.clone(), record).is_some();
119 ensure!(!existing, "Duplicate ID found: {id}");
120 }
121 ensure!(!map.is_empty(), "CSV file is empty");
122
123 Ok(map)
124 }
125
126 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
127}
128
129fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
131where
132 T: UnitType,
133 I: Iterator<Item = T>,
134{
135 let sum = fractions.sum();
136 ensure!(
137 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
138 "Sum of fractions does not equal one (actual: {})",
139 sum
140 );
141
142 Ok(())
143}
144
145pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
147where
148 T: PartialOrd + Clone,
149 I: IntoIterator<Item = T>,
150{
151 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
152}
153
154pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
158where
159 K: Eq + Hash + Clone + std::fmt::Debug,
160{
161 let existing = map.insert(key.clone(), value);
162 match existing {
163 Some(_) => bail!("Key {:?} already exists in the map", key),
164 None => Ok(()),
165 }
166}
167
168pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
178 let model_file = ModelFile::from_path(&model_dir)?;
179
180 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
181 let regions = read_regions(model_dir.as_ref())?;
182 let region_ids = regions.keys().cloned().collect();
183 let years = &model_file.milestone_years;
184
185 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
186 let processes = read_processes(
187 model_dir.as_ref(),
188 &commodities,
189 ®ion_ids,
190 &time_slice_info,
191 years,
192 )?;
193 let agents = read_agents(
194 model_dir.as_ref(),
195 &commodities,
196 &processes,
197 ®ion_ids,
198 years,
199 )?;
200 let agent_ids = agents.keys().cloned().collect();
201 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
202
203 let commodity_order = build_and_validate_commodity_graphs_for_model(
206 &processes,
207 &commodities,
208 ®ion_ids,
209 years,
210 &time_slice_info,
211 )?;
212
213 let model_path = model_dir
214 .as_ref()
215 .canonicalize()
216 .context("Could not parse path to model")?;
217 let model = Model {
218 model_path,
219 parameters: model_file,
220 agents,
221 commodities,
222 processes,
223 time_slice_info,
224 regions,
225 commodity_order,
226 };
227 Ok((model, AssetPool::new(assets)))
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::id::GenericID;
234 use crate::units::Dimensionless;
235 use rstest::rstest;
236 use serde::Deserialize;
237 use serde::de::IntoDeserializer;
238 use serde::de::value::{Error as ValueError, F64Deserializer};
239 use std::fs::File;
240 use std::io::Write;
241 use std::path::PathBuf;
242 use tempfile::tempdir;
243
244 #[derive(Debug, PartialEq, Deserialize)]
245 struct Record {
246 id: GenericID,
247 value: u32,
248 }
249
250 impl HasID<GenericID> for Record {
251 fn get_id(&self) -> &GenericID {
252 &self.id
253 }
254 }
255
256 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
258 let file_path = dir_path.join("test.csv");
259 let mut file = File::create(&file_path).unwrap();
260 writeln!(file, "{contents}").unwrap();
261 file_path
262 }
263
264 #[test]
266 fn test_read_csv() {
267 let dir = tempdir().unwrap();
268 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
269 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
270 assert_eq!(
271 records,
272 &[
273 Record {
274 id: "hello".into(),
275 value: 1,
276 },
277 Record {
278 id: "world".into(),
279 value: 2,
280 }
281 ]
282 );
283
284 let file_path = create_csv_file(dir.path(), "id,value\n");
286 assert!(read_csv::<Record>(&file_path).is_err());
287 assert!(
288 read_csv_optional::<Record>(&file_path)
289 .unwrap()
290 .next()
291 .is_none()
292 );
293 }
294
295 #[test]
296 fn test_read_toml() {
297 let dir = tempdir().unwrap();
298 let file_path = dir.path().join("test.toml");
299 {
300 let mut file = File::create(&file_path).unwrap();
301 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
302 }
303
304 assert_eq!(
305 read_toml::<Record>(&file_path).unwrap(),
306 Record {
307 id: "hello".into(),
308 value: 1,
309 }
310 );
311
312 {
313 let mut file = File::create(&file_path).unwrap();
314 writeln!(file, "bad toml syntax").unwrap();
315 }
316
317 assert!(read_toml::<Record>(&file_path).is_err());
318 }
319
320 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
322 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
323 deserialise_proportion_nonzero(deserialiser)
324 }
325
326 #[test]
327 fn test_deserialise_proportion_nonzero() {
328 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
330 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
331 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
332
333 assert!(deserialise_f64(0.0).is_err());
335 assert!(deserialise_f64(-1.0).is_err());
336 assert!(deserialise_f64(2.0).is_err());
337 assert!(deserialise_f64(f64::NAN).is_err());
338 assert!(deserialise_f64(f64::INFINITY).is_err());
339 }
340
341 #[test]
342 fn test_check_values_sum_to_one_approx() {
343 assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
345
346 assert!(
348 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
349 .is_ok()
350 );
351
352 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
354
355 assert!(
357 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
358 .is_err()
359 );
360
361 assert!(
363 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
364 );
365 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
366 }
367
368 #[rstest]
369 #[case(&[], true)]
370 #[case(&[1], true)]
371 #[case(&[1,2], true)]
372 #[case(&[1,2,3,4], true)]
373 #[case(&[2,1],false)]
374 #[case(&[1,1],false)]
375 #[case(&[1,3,2,4], false)]
376 fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
377 assert_eq!(is_sorted_and_unique(values), expected)
378 }
379}