1use crate::asset::AssetPool;
3use crate::graph::investment::solve_investment_order_for_model;
4use crate::graph::validate::validate_commodity_graphs_for_model;
5use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fmt::{self, Write};
17use std::fs;
18use std::hash::Hash;
19use std::path::Path;
20
21mod agent;
22use agent::read_agents;
23mod asset;
24use asset::read_assets;
25mod commodity;
26use commodity::read_commodities;
27mod process;
28use process::read_processes;
29mod region;
30use region::read_regions;
31mod time_slice;
32use time_slice::read_time_slice_info;
33
34pub trait Insert<K, V> {
36 fn insert(&mut self, key: K, value: V) -> Option<V>;
38}
39
40impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
41 fn insert(&mut self, key: K, value: V) -> Option<V> {
42 HashMap::insert(self, key, value)
43 }
44}
45
46impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
47 fn insert(&mut self, key: K, value: V) -> Option<V> {
48 IndexMap::insert(self, key, value)
49 }
50}
51
52pub fn read_csv<'a, T: DeserializeOwned + 'a>(
60 file_path: &'a Path,
61) -> Result<impl Iterator<Item = T> + 'a> {
62 let vec = read_csv_internal(file_path)?;
63 if vec.is_empty() {
64 bail!("CSV file {} cannot be empty", file_path.display());
65 }
66 Ok(vec.into_iter())
67}
68
69pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
75 file_path: &'a Path,
76) -> Result<impl Iterator<Item = T> + 'a> {
77 if !file_path.exists() {
78 return Ok(Vec::new().into_iter());
79 }
80
81 let vec = read_csv_internal(file_path)?;
82 Ok(vec.into_iter())
83}
84
85fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
86 let vec = csv::ReaderBuilder::new()
87 .trim(csv::Trim::All)
88 .from_path(file_path)
89 .with_context(|| input_err_msg(file_path))?
90 .into_deserialize()
91 .process_results(|iter| iter.collect_vec())
92 .with_context(|| input_err_msg(file_path))?;
93
94 Ok(vec)
95}
96
97pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
107 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
108 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
109 Ok(toml_data)
110}
111
112pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
114where
115 T: UnitType,
116 D: Deserializer<'de>,
117{
118 let value = f64::deserialize(deserialiser)?;
119 if !(value > 0.0 && value <= 1.0) {
120 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
121 }
122
123 Ok(T::new(value))
124}
125
126pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
128 format!("Error reading {}", file_path.as_ref().display())
129}
130
131fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
136where
137 T: HasID<ID> + DeserializeOwned,
138{
139 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
140 where
141 T: HasID<ID> + DeserializeOwned,
142 {
143 let mut map = IndexMap::new();
144 for record in read_csv::<T>(file_path)? {
145 let id = record.get_id().clone();
146 let existing = map.insert(id.clone(), record).is_some();
147 ensure!(!existing, "Duplicate ID found: {id}");
148 }
149 ensure!(!map.is_empty(), "CSV file is empty");
150
151 Ok(map)
152 }
153
154 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
155}
156
157fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
159where
160 T: UnitType,
161 I: Iterator<Item = T>,
162{
163 let sum = fractions.sum();
164 ensure!(
165 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
166 "Sum of fractions does not equal one (actual: {sum})"
167 );
168
169 Ok(())
170}
171
172pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
174where
175 T: PartialOrd + Clone,
176 I: IntoIterator<Item = T>,
177{
178 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
179}
180
181pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
185where
186 M: Insert<K, V>,
187 K: Eq + Hash + Clone + std::fmt::Debug,
188{
189 let existing = map.insert(key.clone(), value).is_some();
190 ensure!(!existing, "Key {key:?} already exists in the map");
191 Ok(())
192}
193
194pub fn format_items_with_cap<I, J, T>(items: I) -> String
196where
197 I: IntoIterator<Item = T, IntoIter = J>,
198 J: ExactSizeIterator<Item = T>,
199 T: fmt::Debug,
200{
201 const MAX_DISPLAY: usize = 10;
202
203 let items = items.into_iter();
204 let total_count = items.len();
205
206 let formatted_items = items
208 .take(MAX_DISPLAY)
209 .format_with(", ", |items, f| f(&format_args!("{items:?}")));
210 let mut out = format!("[{formatted_items}]");
211
212 if total_count > MAX_DISPLAY {
214 write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
215 }
216
217 out
218}
219
220pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
230 let model_params = ModelParameters::from_path(&model_dir)?;
231
232 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
233 let regions = read_regions(model_dir.as_ref())?;
234 let region_ids = regions.keys().cloned().collect();
235 let years = &model_params.milestone_years;
236
237 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
238 let processes = read_processes(
239 model_dir.as_ref(),
240 &commodities,
241 ®ion_ids,
242 &time_slice_info,
243 years,
244 )?;
245 let agents = read_agents(
246 model_dir.as_ref(),
247 &commodities,
248 &processes,
249 ®ion_ids,
250 years,
251 )?;
252 let agent_ids = agents.keys().cloned().collect();
253 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
254
255 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
257 validate_commodity_graphs_for_model(
258 &commodity_graphs,
259 &processes,
260 &commodities,
261 &time_slice_info,
262 )?;
263
264 let investment_order = solve_investment_order_for_model(&commodity_graphs, &commodities, years);
266
267 let model_path = model_dir
268 .as_ref()
269 .canonicalize()
270 .context("Could not parse path to model")?;
271 let model = Model {
272 model_path,
273 parameters: model_params,
274 agents,
275 commodities,
276 processes,
277 time_slice_info,
278 regions,
279 investment_order,
280 };
281 Ok((model, AssetPool::new(assets)))
282}
283
284pub fn load_commodity_graphs<P: AsRef<Path>>(
292 model_dir: P,
293) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
294 let model_params = ModelParameters::from_path(&model_dir)?;
295
296 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
297 let regions = read_regions(model_dir.as_ref())?;
298 let region_ids = regions.keys().cloned().collect();
299 let years = &model_params.milestone_years;
300
301 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
302 let processes = read_processes(
303 model_dir.as_ref(),
304 &commodities,
305 ®ion_ids,
306 &time_slice_info,
307 years,
308 )?;
309
310 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
311 Ok(commodity_graphs)
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::id::GenericID;
318 use crate::units::Dimensionless;
319 use rstest::rstest;
320 use serde::Deserialize;
321 use serde::de::IntoDeserializer;
322 use serde::de::value::{Error as ValueError, F64Deserializer};
323 use std::fs::File;
324 use std::io::Write;
325 use std::path::PathBuf;
326 use tempfile::tempdir;
327
328 #[derive(Debug, PartialEq, Deserialize)]
329 struct Record {
330 id: GenericID,
331 value: u32,
332 }
333
334 impl HasID<GenericID> for Record {
335 fn get_id(&self) -> &GenericID {
336 &self.id
337 }
338 }
339
340 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
342 let file_path = dir_path.join("test.csv");
343 let mut file = File::create(&file_path).unwrap();
344 writeln!(file, "{contents}").unwrap();
345 file_path
346 }
347
348 #[test]
350 fn test_read_csv() {
351 let dir = tempdir().unwrap();
352 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
353 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
354 assert_eq!(
355 records,
356 &[
357 Record {
358 id: "hello".into(),
359 value: 1,
360 },
361 Record {
362 id: "world".into(),
363 value: 2,
364 }
365 ]
366 );
367
368 let dir = tempdir().unwrap();
370 let file_path = create_csv_file(dir.path(), "id , value\t\n hello\t ,1\n world ,2\n");
371 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
372 assert_eq!(
373 records,
374 &[
375 Record {
376 id: "hello".into(),
377 value: 1,
378 },
379 Record {
380 id: "world".into(),
381 value: 2,
382 }
383 ]
384 );
385
386 let file_path = create_csv_file(dir.path(), "id,value\n");
388 assert!(read_csv::<Record>(&file_path).is_err());
389 assert!(
390 read_csv_optional::<Record>(&file_path)
391 .unwrap()
392 .next()
393 .is_none()
394 );
395
396 let dir = tempdir().unwrap();
398 let file_path = dir.path().join("a_missing_file.csv");
399 assert!(!file_path.exists());
400 assert!(read_csv::<Record>(&file_path).is_err());
401 assert!(
403 read_csv_optional::<Record>(&file_path)
404 .unwrap()
405 .next()
406 .is_none()
407 );
408 }
409
410 #[test]
411 fn test_read_toml() {
412 let dir = tempdir().unwrap();
413 let file_path = dir.path().join("test.toml");
414 {
415 let mut file = File::create(&file_path).unwrap();
416 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
417 }
418
419 assert_eq!(
420 read_toml::<Record>(&file_path).unwrap(),
421 Record {
422 id: "hello".into(),
423 value: 1,
424 }
425 );
426
427 {
428 let mut file = File::create(&file_path).unwrap();
429 writeln!(file, "bad toml syntax").unwrap();
430 }
431
432 assert!(read_toml::<Record>(&file_path).is_err());
433 }
434
435 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
437 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
438 deserialise_proportion_nonzero(deserialiser)
439 }
440
441 #[test]
442 fn test_deserialise_proportion_nonzero() {
443 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
445 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
446 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
447
448 assert!(deserialise_f64(0.0).is_err());
450 assert!(deserialise_f64(-1.0).is_err());
451 assert!(deserialise_f64(2.0).is_err());
452 assert!(deserialise_f64(f64::NAN).is_err());
453 assert!(deserialise_f64(f64::INFINITY).is_err());
454 }
455
456 #[test]
457 fn test_check_values_sum_to_one_approx() {
458 assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
460
461 assert!(
463 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
464 .is_ok()
465 );
466
467 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
469
470 assert!(
472 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
473 .is_err()
474 );
475
476 assert!(
478 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
479 );
480 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
481 }
482
483 #[rstest]
484 #[case(&[], true)]
485 #[case(&[1], true)]
486 #[case(&[1,2], true)]
487 #[case(&[1,2,3,4], true)]
488 #[case(&[2,1],false)]
489 #[case(&[1,1],false)]
490 #[case(&[1,3,2,4], false)]
491 fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
492 assert_eq!(is_sorted_and_unique(values), expected)
493 }
494
495 #[test]
496 fn test_format_items_with_cap() {
497 let items = vec!["a", "b", "c"];
498 assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
499
500 let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
502 assert_eq!(
503 format_items_with_cap(&many_items),
504 r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
505 );
506 }
507}