1use crate::graph::investment::solve_investment_order_for_model;
3use crate::graph::validate::validate_commodity_graphs_for_model;
4use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
5use crate::id::{HasID, ID};
6use crate::model::{Model, ModelParameters};
7use crate::region::RegionID;
8use crate::units::UnitType;
9use anyhow::{Context, Result, bail, ensure};
10use float_cmp::approx_eq;
11use indexmap::IndexMap;
12use itertools::Itertools;
13use serde::de::{Deserialize, DeserializeOwned, Deserializer};
14use std::collections::HashMap;
15use std::fmt::{self, Write};
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_user_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32mod range;
33use range::{parse_range, parse_range_parts, partition};
34mod year;
35use year::parse_year_str;
36
37pub trait Insert<K, V> {
39 fn insert(&mut self, key: K, value: V) -> Option<V>;
41}
42
43impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
44 fn insert(&mut self, key: K, value: V) -> Option<V> {
45 HashMap::insert(self, key, value)
46 }
47}
48
49impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
50 fn insert(&mut self, key: K, value: V) -> Option<V> {
51 IndexMap::insert(self, key, value)
52 }
53}
54
55pub fn read_csv<'a, T: DeserializeOwned + 'a>(
63 file_path: &'a Path,
64) -> Result<impl Iterator<Item = T> + 'a> {
65 let vec = read_csv_internal(file_path)?;
66 if vec.is_empty() {
67 bail!("CSV file {} cannot be empty", file_path.display());
68 }
69 Ok(vec.into_iter())
70}
71
72pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
78 file_path: &'a Path,
79) -> Result<impl Iterator<Item = T> + 'a> {
80 if !file_path.exists() {
81 return Ok(Vec::new().into_iter());
82 }
83
84 let vec = read_csv_internal(file_path)?;
85 Ok(vec.into_iter())
86}
87
88fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
89 let vec = csv::ReaderBuilder::new()
90 .trim(csv::Trim::All)
91 .from_path(file_path)
92 .with_context(|| input_err_msg(file_path))?
93 .into_deserialize()
94 .process_results(|iter| iter.collect_vec())
95 .with_context(|| input_err_msg(file_path))?;
96
97 Ok(vec)
98}
99
100pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
110 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
111 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
112 Ok(toml_data)
113}
114
115pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
117where
118 T: UnitType,
119 D: Deserializer<'de>,
120{
121 let value = f64::deserialize(deserialiser)?;
122 if !(value > 0.0 && value <= 1.0) {
123 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
124 }
125
126 Ok(T::new(value))
127}
128
129pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
131 format!("Error reading {}", file_path.as_ref().display())
132}
133
134fn read_csv_id_file<T, I: ID>(file_path: &Path) -> Result<IndexMap<I, T>>
139where
140 T: HasID<I> + DeserializeOwned,
141{
142 fn fill_and_validate_map<T, I: ID>(file_path: &Path) -> Result<IndexMap<I, T>>
143 where
144 T: HasID<I> + DeserializeOwned,
145 {
146 let mut map = IndexMap::new();
147 for record in read_csv::<T>(file_path)? {
148 let id = record.get_id().clone();
149 let existing = map.insert(id.clone(), record).is_some();
150 ensure!(!existing, "Duplicate ID found: {id}");
151 }
152 ensure!(!map.is_empty(), "CSV file is empty");
153
154 Ok(map)
155 }
156
157 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
158}
159
160fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
162where
163 T: UnitType,
164 I: Iterator<Item = T>,
165{
166 let sum = fractions.sum();
167 ensure!(
168 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
169 "Sum of fractions does not equal one (actual: {sum})"
170 );
171
172 Ok(())
173}
174
175pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
177where
178 T: PartialOrd + Clone,
179 I: IntoIterator<Item = T>,
180{
181 is_sorted_and_unique_with(iter, |a, b| a < b)
182}
183
184pub fn is_sorted_and_unique_with<T, I, F>(iter: I, mut less_than: F) -> bool
187where
188 T: Clone,
189 I: IntoIterator<Item = T>,
190 F: FnMut(&T, &T) -> bool,
191{
192 iter.into_iter()
193 .tuple_windows()
194 .all(|(a, b)| less_than(&a, &b))
195}
196
197pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
203where
204 M: Insert<K, V>,
205 K: Eq + Hash + Clone + std::fmt::Debug,
206{
207 let existing = map.insert(key.clone(), value).is_some();
208 ensure!(!existing, "Key {key:?} already exists in the map");
209 Ok(())
210}
211
212pub fn format_items_with_cap<I, J, T>(items: I) -> String
214where
215 I: IntoIterator<Item = T, IntoIter = J>,
216 J: ExactSizeIterator<Item = T>,
217 T: fmt::Debug,
218{
219 const MAX_DISPLAY: usize = 10;
220
221 let items = items.into_iter();
222 let total_count = items.len();
223
224 let formatted_items = items
226 .take(MAX_DISPLAY)
227 .format_with(", ", |items, f| f(&format_args!("{items:?}")));
228 let mut out = format!("[{formatted_items}]");
229
230 if total_count > MAX_DISPLAY {
232 write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
233 }
234
235 out
236}
237
238pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<Model> {
248 let model_params = ModelParameters::from_path(&model_dir)?;
249
250 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
251 let regions = read_regions(model_dir.as_ref())?;
252 let region_ids = regions.keys().cloned().collect();
253 let years = &model_params.milestone_years;
254
255 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
256 let processes = read_processes(
257 model_dir.as_ref(),
258 &commodities,
259 ®ion_ids,
260 &time_slice_info,
261 years,
262 )?;
263 let agents = read_agents(
264 model_dir.as_ref(),
265 &commodities,
266 &processes,
267 ®ion_ids,
268 years,
269 )?;
270 let agent_ids = agents.keys().cloned().collect();
271 let user_assets = read_user_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
272
273 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
275 validate_commodity_graphs_for_model(
276 &commodity_graphs,
277 &processes,
278 &commodities,
279 &time_slice_info,
280 )?;
281
282 let investment_order = solve_investment_order_for_model(&commodity_graphs, &commodities, years);
284
285 let model_path = model_dir
286 .as_ref()
287 .canonicalize()
288 .context("Could not parse path to model")?;
289 let model = Model {
290 model_path,
291 parameters: model_params,
292 agents,
293 commodities,
294 processes,
295 time_slice_info,
296 regions,
297 user_assets,
298 investment_order,
299 };
300 Ok(model)
301}
302
303pub fn load_commodity_graphs<P: AsRef<Path>>(
311 model_dir: P,
312) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
313 let model_params = ModelParameters::from_path(&model_dir)?;
314
315 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
316 let regions = read_regions(model_dir.as_ref())?;
317 let region_ids = regions.keys().cloned().collect();
318 let years = &model_params.milestone_years;
319
320 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
321 let processes = read_processes(
322 model_dir.as_ref(),
323 &commodities,
324 ®ion_ids,
325 &time_slice_info,
326 years,
327 )?;
328
329 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
330 Ok(commodity_graphs)
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::id::GenericID;
337 use crate::units::Dimensionless;
338 use rstest::rstest;
339 use serde::Deserialize;
340 use serde::de::IntoDeserializer;
341 use serde::de::value::{Error as ValueError, F64Deserializer};
342 use std::fs::File;
343 use std::io::Write;
344 use std::path::PathBuf;
345 use tempfile::tempdir;
346
347 #[derive(Debug, PartialEq, Deserialize)]
348 struct Record {
349 id: GenericID,
350 value: u32,
351 }
352
353 impl HasID<GenericID> for Record {
354 fn get_id(&self) -> &GenericID {
355 &self.id
356 }
357 }
358
359 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
361 let file_path = dir_path.join("test.csv");
362 let mut file = File::create(&file_path).unwrap();
363 writeln!(file, "{contents}").unwrap();
364 file_path
365 }
366
367 #[test]
369 fn read_csv_works() {
370 let dir = tempdir().unwrap();
371 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
372 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
373 assert_eq!(
374 records,
375 &[
376 Record {
377 id: "hello".into(),
378 value: 1,
379 },
380 Record {
381 id: "world".into(),
382 value: 2,
383 }
384 ]
385 );
386
387 let dir = tempdir().unwrap();
389 let file_path = create_csv_file(dir.path(), "id , value\t\n hello\t ,1\n world ,2\n");
390 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
391 assert_eq!(
392 records,
393 &[
394 Record {
395 id: "hello".into(),
396 value: 1,
397 },
398 Record {
399 id: "world".into(),
400 value: 2,
401 }
402 ]
403 );
404
405 let file_path = create_csv_file(dir.path(), "id,value\n");
407 assert!(read_csv::<Record>(&file_path).is_err());
408 assert!(
409 read_csv_optional::<Record>(&file_path)
410 .unwrap()
411 .next()
412 .is_none()
413 );
414
415 let dir = tempdir().unwrap();
417 let file_path = dir.path().join("a_missing_file.csv");
418 assert!(!file_path.exists());
419 assert!(read_csv::<Record>(&file_path).is_err());
420 assert!(
422 read_csv_optional::<Record>(&file_path)
423 .unwrap()
424 .next()
425 .is_none()
426 );
427 }
428
429 #[test]
430 fn read_toml_works() {
431 let dir = tempdir().unwrap();
432 let file_path = dir.path().join("test.toml");
433 {
434 let mut file = File::create(&file_path).unwrap();
435 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
436 }
437
438 assert_eq!(
439 read_toml::<Record>(&file_path).unwrap(),
440 Record {
441 id: "hello".into(),
442 value: 1,
443 }
444 );
445
446 {
447 let mut file = File::create(&file_path).unwrap();
448 writeln!(file, "bad toml syntax").unwrap();
449 }
450
451 read_toml::<Record>(&file_path).unwrap_err();
452 }
453
454 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
456 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
457 deserialise_proportion_nonzero(deserialiser)
458 }
459
460 #[test]
461 fn deserialise_proportion_nonzero_works() {
462 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
464 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
465 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
466
467 deserialise_f64(0.0).unwrap_err();
469 deserialise_f64(-1.0).unwrap_err();
470 deserialise_f64(2.0).unwrap_err();
471 deserialise_f64(f64::NAN).unwrap_err();
472 deserialise_f64(f64::INFINITY).unwrap_err();
473 }
474
475 #[test]
476 fn check_values_sum_to_one_approx_works() {
477 check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).unwrap();
479
480 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
482 .unwrap();
483
484 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
486
487 assert!(
489 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
490 .is_err()
491 );
492
493 assert!(
495 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
496 );
497 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
498 }
499
500 #[rstest]
501 #[case(&[], true)]
502 #[case(&[1], true)]
503 #[case(&[1,2], true)]
504 #[case(&[1,2,3,4], true)]
505 #[case(&[2,1],false)]
506 #[case(&[1,1],false)]
507 #[case(&[1,3,2,4], false)]
508 fn is_sorted_and_unique_works(#[case] values: &[u32], #[case] expected: bool) {
509 assert_eq!(is_sorted_and_unique(values), expected);
510 }
511
512 #[test]
513 fn format_items_with_cap_works() {
514 let items = vec!["a", "b", "c"];
515 assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
516
517 let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
519 assert_eq!(
520 format_items_with_cap(&many_items),
521 r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
522 );
523 }
524}