1use crate::graph::investment::solve_investment_order_for_model;
3use crate::graph::validate::validate_commodity_graphs_for_model;
4use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
5use crate::id::{HasID, IDLike};
6use crate::model::{Model, ModelParameters};
7use crate::region::RegionID;
8use crate::units::UnitType;
9use anyhow::{Context, Result, bail, ensure};
10use float_cmp::approx_eq;
11use indexmap::IndexMap;
12use itertools::Itertools;
13use serde::de::{Deserialize, DeserializeOwned, Deserializer};
14use std::collections::HashMap;
15use std::fmt::{self, Write};
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_user_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32
33pub trait Insert<K, V> {
35 fn insert(&mut self, key: K, value: V) -> Option<V>;
37}
38
39impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
40 fn insert(&mut self, key: K, value: V) -> Option<V> {
41 HashMap::insert(self, key, value)
42 }
43}
44
45impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
46 fn insert(&mut self, key: K, value: V) -> Option<V> {
47 IndexMap::insert(self, key, value)
48 }
49}
50
51pub fn read_csv<'a, T: DeserializeOwned + 'a>(
59 file_path: &'a Path,
60) -> Result<impl Iterator<Item = T> + 'a> {
61 let vec = read_csv_internal(file_path)?;
62 if vec.is_empty() {
63 bail!("CSV file {} cannot be empty", file_path.display());
64 }
65 Ok(vec.into_iter())
66}
67
68pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
74 file_path: &'a Path,
75) -> Result<impl Iterator<Item = T> + 'a> {
76 if !file_path.exists() {
77 return Ok(Vec::new().into_iter());
78 }
79
80 let vec = read_csv_internal(file_path)?;
81 Ok(vec.into_iter())
82}
83
84fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
85 let vec = csv::ReaderBuilder::new()
86 .trim(csv::Trim::All)
87 .from_path(file_path)
88 .with_context(|| input_err_msg(file_path))?
89 .into_deserialize()
90 .process_results(|iter| iter.collect_vec())
91 .with_context(|| input_err_msg(file_path))?;
92
93 Ok(vec)
94}
95
96pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
106 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
107 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
108 Ok(toml_data)
109}
110
111pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
113where
114 T: UnitType,
115 D: Deserializer<'de>,
116{
117 let value = f64::deserialize(deserialiser)?;
118 if !(value > 0.0 && value <= 1.0) {
119 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
120 }
121
122 Ok(T::new(value))
123}
124
125pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
127 format!("Error reading {}", file_path.as_ref().display())
128}
129
130fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
135where
136 T: HasID<ID> + DeserializeOwned,
137{
138 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
139 where
140 T: HasID<ID> + DeserializeOwned,
141 {
142 let mut map = IndexMap::new();
143 for record in read_csv::<T>(file_path)? {
144 let id = record.get_id().clone();
145 let existing = map.insert(id.clone(), record).is_some();
146 ensure!(!existing, "Duplicate ID found: {id}");
147 }
148 ensure!(!map.is_empty(), "CSV file is empty");
149
150 Ok(map)
151 }
152
153 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
154}
155
156fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
158where
159 T: UnitType,
160 I: Iterator<Item = T>,
161{
162 let sum = fractions.sum();
163 ensure!(
164 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
165 "Sum of fractions does not equal one (actual: {sum})"
166 );
167
168 Ok(())
169}
170
171pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
173where
174 T: PartialOrd + Clone,
175 I: IntoIterator<Item = T>,
176{
177 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
178}
179
180pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
186where
187 M: Insert<K, V>,
188 K: Eq + Hash + Clone + std::fmt::Debug,
189{
190 let existing = map.insert(key.clone(), value).is_some();
191 ensure!(!existing, "Key {key:?} already exists in the map");
192 Ok(())
193}
194
195pub fn format_items_with_cap<I, J, T>(items: I) -> String
197where
198 I: IntoIterator<Item = T, IntoIter = J>,
199 J: ExactSizeIterator<Item = T>,
200 T: fmt::Debug,
201{
202 const MAX_DISPLAY: usize = 10;
203
204 let items = items.into_iter();
205 let total_count = items.len();
206
207 let formatted_items = items
209 .take(MAX_DISPLAY)
210 .format_with(", ", |items, f| f(&format_args!("{items:?}")));
211 let mut out = format!("[{formatted_items}]");
212
213 if total_count > MAX_DISPLAY {
215 write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
216 }
217
218 out
219}
220
221pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<Model> {
231 let model_params = ModelParameters::from_path(&model_dir)?;
232
233 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
234 let regions = read_regions(model_dir.as_ref())?;
235 let region_ids = regions.keys().cloned().collect();
236 let years = &model_params.milestone_years;
237
238 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
239 let processes = read_processes(
240 model_dir.as_ref(),
241 &commodities,
242 ®ion_ids,
243 &time_slice_info,
244 years,
245 )?;
246 let agents = read_agents(
247 model_dir.as_ref(),
248 &commodities,
249 &processes,
250 ®ion_ids,
251 years,
252 )?;
253 let agent_ids = agents.keys().cloned().collect();
254 let user_assets = read_user_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
255
256 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
258 validate_commodity_graphs_for_model(
259 &commodity_graphs,
260 &processes,
261 &commodities,
262 &time_slice_info,
263 )?;
264
265 let investment_order =
267 solve_investment_order_for_model(&commodity_graphs, &commodities, years)?;
268
269 let model_path = model_dir
270 .as_ref()
271 .canonicalize()
272 .context("Could not parse path to model")?;
273 let model = Model {
274 model_path,
275 parameters: model_params,
276 agents,
277 commodities,
278 processes,
279 time_slice_info,
280 regions,
281 user_assets,
282 investment_order,
283 };
284 Ok(model)
285}
286
287pub fn load_commodity_graphs<P: AsRef<Path>>(
295 model_dir: P,
296) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
297 let model_params = ModelParameters::from_path(&model_dir)?;
298
299 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
300 let regions = read_regions(model_dir.as_ref())?;
301 let region_ids = regions.keys().cloned().collect();
302 let years = &model_params.milestone_years;
303
304 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
305 let processes = read_processes(
306 model_dir.as_ref(),
307 &commodities,
308 ®ion_ids,
309 &time_slice_info,
310 years,
311 )?;
312
313 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
314 Ok(commodity_graphs)
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::id::GenericID;
321 use crate::units::Dimensionless;
322 use rstest::rstest;
323 use serde::Deserialize;
324 use serde::de::IntoDeserializer;
325 use serde::de::value::{Error as ValueError, F64Deserializer};
326 use std::fs::File;
327 use std::io::Write;
328 use std::path::PathBuf;
329 use tempfile::tempdir;
330
331 #[derive(Debug, PartialEq, Deserialize)]
332 struct Record {
333 id: GenericID,
334 value: u32,
335 }
336
337 impl HasID<GenericID> for Record {
338 fn get_id(&self) -> &GenericID {
339 &self.id
340 }
341 }
342
343 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
345 let file_path = dir_path.join("test.csv");
346 let mut file = File::create(&file_path).unwrap();
347 writeln!(file, "{contents}").unwrap();
348 file_path
349 }
350
351 #[test]
353 fn read_csv_works() {
354 let dir = tempdir().unwrap();
355 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
356 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
357 assert_eq!(
358 records,
359 &[
360 Record {
361 id: "hello".into(),
362 value: 1,
363 },
364 Record {
365 id: "world".into(),
366 value: 2,
367 }
368 ]
369 );
370
371 let dir = tempdir().unwrap();
373 let file_path = create_csv_file(dir.path(), "id , value\t\n hello\t ,1\n world ,2\n");
374 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
375 assert_eq!(
376 records,
377 &[
378 Record {
379 id: "hello".into(),
380 value: 1,
381 },
382 Record {
383 id: "world".into(),
384 value: 2,
385 }
386 ]
387 );
388
389 let file_path = create_csv_file(dir.path(), "id,value\n");
391 assert!(read_csv::<Record>(&file_path).is_err());
392 assert!(
393 read_csv_optional::<Record>(&file_path)
394 .unwrap()
395 .next()
396 .is_none()
397 );
398
399 let dir = tempdir().unwrap();
401 let file_path = dir.path().join("a_missing_file.csv");
402 assert!(!file_path.exists());
403 assert!(read_csv::<Record>(&file_path).is_err());
404 assert!(
406 read_csv_optional::<Record>(&file_path)
407 .unwrap()
408 .next()
409 .is_none()
410 );
411 }
412
413 #[test]
414 fn read_toml_works() {
415 let dir = tempdir().unwrap();
416 let file_path = dir.path().join("test.toml");
417 {
418 let mut file = File::create(&file_path).unwrap();
419 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
420 }
421
422 assert_eq!(
423 read_toml::<Record>(&file_path).unwrap(),
424 Record {
425 id: "hello".into(),
426 value: 1,
427 }
428 );
429
430 {
431 let mut file = File::create(&file_path).unwrap();
432 writeln!(file, "bad toml syntax").unwrap();
433 }
434
435 read_toml::<Record>(&file_path).unwrap_err();
436 }
437
438 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
440 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
441 deserialise_proportion_nonzero(deserialiser)
442 }
443
444 #[test]
445 fn deserialise_proportion_nonzero_works() {
446 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
448 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
449 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
450
451 deserialise_f64(0.0).unwrap_err();
453 deserialise_f64(-1.0).unwrap_err();
454 deserialise_f64(2.0).unwrap_err();
455 deserialise_f64(f64::NAN).unwrap_err();
456 deserialise_f64(f64::INFINITY).unwrap_err();
457 }
458
459 #[test]
460 fn check_values_sum_to_one_approx_works() {
461 check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).unwrap();
463
464 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
466 .unwrap();
467
468 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
470
471 assert!(
473 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
474 .is_err()
475 );
476
477 assert!(
479 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
480 );
481 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
482 }
483
484 #[rstest]
485 #[case(&[], true)]
486 #[case(&[1], true)]
487 #[case(&[1,2], true)]
488 #[case(&[1,2,3,4], true)]
489 #[case(&[2,1],false)]
490 #[case(&[1,1],false)]
491 #[case(&[1,3,2,4], false)]
492 fn is_sorted_and_unique_works(#[case] values: &[u32], #[case] expected: bool) {
493 assert_eq!(is_sorted_and_unique(values), expected);
494 }
495
496 #[test]
497 fn format_items_with_cap_works() {
498 let items = vec!["a", "b", "c"];
499 assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
500
501 let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
503 assert_eq!(
504 format_items_with_cap(&many_items),
505 r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
506 );
507 }
508}