1use crate::asset::AssetPool;
3use crate::graph::{
4 CommoditiesGraph, build_commodity_graphs_for_model, validate_commodity_graphs_for_model,
5};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fmt::{self, Write};
17use std::fs;
18use std::hash::Hash;
19use std::path::Path;
20
21mod agent;
22use agent::read_agents;
23mod asset;
24use asset::read_assets;
25mod commodity;
26use commodity::read_commodities;
27mod process;
28use process::read_processes;
29mod region;
30use region::read_regions;
31mod time_slice;
32use time_slice::read_time_slice_info;
33
34pub fn read_csv<'a, T: DeserializeOwned + 'a>(
42 file_path: &'a Path,
43) -> Result<impl Iterator<Item = T> + 'a> {
44 let vec = read_csv_internal(file_path)?;
45 if vec.is_empty() {
46 bail!("CSV file {} cannot be empty", file_path.display());
47 }
48 Ok(vec.into_iter())
49}
50
51pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
57 file_path: &'a Path,
58) -> Result<impl Iterator<Item = T> + 'a> {
59 let vec = read_csv_internal(file_path)?;
60 Ok(vec.into_iter())
61}
62
63fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
64 let vec = csv::Reader::from_path(file_path)
65 .with_context(|| input_err_msg(file_path))?
66 .into_deserialize()
67 .process_results(|iter| iter.collect_vec())
68 .with_context(|| input_err_msg(file_path))?;
69
70 Ok(vec)
71}
72
73pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
83 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
84 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
85 Ok(toml_data)
86}
87
88pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
90where
91 T: UnitType,
92 D: Deserializer<'de>,
93{
94 let value = f64::deserialize(deserialiser)?;
95 if !(value > 0.0 && value <= 1.0) {
96 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
97 }
98
99 Ok(T::new(value))
100}
101
102pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
104 format!("Error reading {}", file_path.as_ref().display())
105}
106
107fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
112where
113 T: HasID<ID> + DeserializeOwned,
114{
115 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
116 where
117 T: HasID<ID> + DeserializeOwned,
118 {
119 let mut map = IndexMap::new();
120 for record in read_csv::<T>(file_path)? {
121 let id = record.get_id().clone();
122 let existing = map.insert(id.clone(), record).is_some();
123 ensure!(!existing, "Duplicate ID found: {id}");
124 }
125 ensure!(!map.is_empty(), "CSV file is empty");
126
127 Ok(map)
128 }
129
130 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
131}
132
133fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
135where
136 T: UnitType,
137 I: Iterator<Item = T>,
138{
139 let sum = fractions.sum();
140 ensure!(
141 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
142 "Sum of fractions does not equal one (actual: {sum})"
143 );
144
145 Ok(())
146}
147
148pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
150where
151 T: PartialOrd + Clone,
152 I: IntoIterator<Item = T>,
153{
154 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
155}
156
157pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: &K, value: V) -> Result<()>
161where
162 K: Eq + Hash + Clone + fmt::Debug,
163{
164 let existing = map.insert(key.clone(), value).is_some();
165 ensure!(!existing, "Key {key:?} already exists in the map");
166 Ok(())
167}
168
169pub fn format_items_with_cap<I, J, T>(items: I) -> String
171where
172 I: IntoIterator<Item = T, IntoIter = J>,
173 J: ExactSizeIterator<Item = T>,
174 T: fmt::Debug,
175{
176 const MAX_DISPLAY: usize = 10;
177
178 let items = items.into_iter();
179 let total_count = items.len();
180
181 let formatted_items = items
183 .take(MAX_DISPLAY)
184 .format_with(", ", |items, f| f(&format_args!("{items:?}")));
185 let mut out = format!("[{formatted_items}]");
186
187 if total_count > MAX_DISPLAY {
189 write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
190 }
191
192 out
193}
194
195pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
205 let model_params = ModelParameters::from_path(&model_dir)?;
206
207 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
208 let regions = read_regions(model_dir.as_ref())?;
209 let region_ids = regions.keys().cloned().collect();
210 let years = &model_params.milestone_years;
211
212 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
213 let processes = read_processes(
214 model_dir.as_ref(),
215 &commodities,
216 ®ion_ids,
217 &time_slice_info,
218 years,
219 )?;
220 let agents = read_agents(
221 model_dir.as_ref(),
222 &commodities,
223 &processes,
224 ®ion_ids,
225 years,
226 )?;
227 let agent_ids = agents.keys().cloned().collect();
228 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
229
230 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years)?;
233 let commodity_order = validate_commodity_graphs_for_model(
234 &commodity_graphs,
235 &processes,
236 &commodities,
237 &time_slice_info,
238 )?;
239
240 let model_path = model_dir
241 .as_ref()
242 .canonicalize()
243 .context("Could not parse path to model")?;
244 let model = Model {
245 model_path,
246 parameters: model_params,
247 agents,
248 commodities,
249 processes,
250 time_slice_info,
251 regions,
252 commodity_order,
253 };
254 Ok((model, AssetPool::new(assets)))
255}
256
257pub fn load_commodity_graphs<P: AsRef<Path>>(
265 model_dir: P,
266) -> Result<HashMap<(RegionID, u32), CommoditiesGraph>> {
267 let model_params = ModelParameters::from_path(&model_dir)?;
268
269 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
270 let regions = read_regions(model_dir.as_ref())?;
271 let region_ids = regions.keys().cloned().collect();
272 let years = &model_params.milestone_years;
273
274 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
275 let processes = read_processes(
276 model_dir.as_ref(),
277 &commodities,
278 ®ion_ids,
279 &time_slice_info,
280 years,
281 )?;
282
283 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years)?;
284 Ok(commodity_graphs)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::id::GenericID;
291 use crate::units::Dimensionless;
292 use rstest::rstest;
293 use serde::Deserialize;
294 use serde::de::IntoDeserializer;
295 use serde::de::value::{Error as ValueError, F64Deserializer};
296 use std::fs::File;
297 use std::io::Write;
298 use std::path::PathBuf;
299 use tempfile::tempdir;
300
301 #[derive(Debug, PartialEq, Deserialize)]
302 struct Record {
303 id: GenericID,
304 value: u32,
305 }
306
307 impl HasID<GenericID> for Record {
308 fn get_id(&self) -> &GenericID {
309 &self.id
310 }
311 }
312
313 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
315 let file_path = dir_path.join("test.csv");
316 let mut file = File::create(&file_path).unwrap();
317 writeln!(file, "{contents}").unwrap();
318 file_path
319 }
320
321 #[test]
323 fn test_read_csv() {
324 let dir = tempdir().unwrap();
325 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
326 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
327 assert_eq!(
328 records,
329 &[
330 Record {
331 id: "hello".into(),
332 value: 1,
333 },
334 Record {
335 id: "world".into(),
336 value: 2,
337 }
338 ]
339 );
340
341 let file_path = create_csv_file(dir.path(), "id,value\n");
343 assert!(read_csv::<Record>(&file_path).is_err());
344 assert!(
345 read_csv_optional::<Record>(&file_path)
346 .unwrap()
347 .next()
348 .is_none()
349 );
350 }
351
352 #[test]
353 fn test_read_toml() {
354 let dir = tempdir().unwrap();
355 let file_path = dir.path().join("test.toml");
356 {
357 let mut file = File::create(&file_path).unwrap();
358 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
359 }
360
361 assert_eq!(
362 read_toml::<Record>(&file_path).unwrap(),
363 Record {
364 id: "hello".into(),
365 value: 1,
366 }
367 );
368
369 {
370 let mut file = File::create(&file_path).unwrap();
371 writeln!(file, "bad toml syntax").unwrap();
372 }
373
374 assert!(read_toml::<Record>(&file_path).is_err());
375 }
376
377 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
379 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
380 deserialise_proportion_nonzero(deserialiser)
381 }
382
383 #[test]
384 fn test_deserialise_proportion_nonzero() {
385 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
387 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
388 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
389
390 assert!(deserialise_f64(0.0).is_err());
392 assert!(deserialise_f64(-1.0).is_err());
393 assert!(deserialise_f64(2.0).is_err());
394 assert!(deserialise_f64(f64::NAN).is_err());
395 assert!(deserialise_f64(f64::INFINITY).is_err());
396 }
397
398 #[test]
399 fn test_check_values_sum_to_one_approx() {
400 assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
402
403 assert!(
405 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
406 .is_ok()
407 );
408
409 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
411
412 assert!(
414 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
415 .is_err()
416 );
417
418 assert!(
420 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
421 );
422 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
423 }
424
425 #[rstest]
426 #[case(&[], true)]
427 #[case(&[1], true)]
428 #[case(&[1,2], true)]
429 #[case(&[1,2,3,4], true)]
430 #[case(&[2,1],false)]
431 #[case(&[1,1],false)]
432 #[case(&[1,3,2,4], false)]
433 fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
434 assert_eq!(is_sorted_and_unique(values), expected)
435 }
436
437 #[test]
438 fn test_format_items_with_cap() {
439 let items = vec!["a", "b", "c"];
440 assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
441
442 let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
444 assert_eq!(
445 format_items_with_cap(&many_items),
446 r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
447 );
448 }
449}