1use crate::asset::AssetPool;
3use crate::graph::{
4 CommoditiesGraph, build_commodity_graphs_for_model, validate_commodity_graphs_for_model,
5};
6use crate::id::{HasID, IDLike};
7use crate::model::{Model, ModelParameters};
8use crate::region::RegionID;
9use crate::units::UnitType;
10use anyhow::{Context, Result, bail, ensure};
11use float_cmp::approx_eq;
12use indexmap::IndexMap;
13use itertools::Itertools;
14use serde::de::{Deserialize, DeserializeOwned, Deserializer};
15use std::collections::HashMap;
16use std::fs;
17use std::hash::Hash;
18use std::path::Path;
19
20mod agent;
21use agent::read_agents;
22mod asset;
23use asset::read_assets;
24mod commodity;
25use commodity::read_commodities;
26mod process;
27use process::read_processes;
28mod region;
29use region::read_regions;
30mod time_slice;
31use time_slice::read_time_slice_info;
32
33pub fn read_csv<'a, T: DeserializeOwned + 'a>(
41 file_path: &'a Path,
42) -> Result<impl Iterator<Item = T> + 'a> {
43 let vec = read_csv_internal(file_path)?;
44 if vec.is_empty() {
45 bail!("CSV file {} cannot be empty", file_path.display());
46 }
47 Ok(vec.into_iter())
48}
49
50pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
56 file_path: &'a Path,
57) -> Result<impl Iterator<Item = T> + 'a> {
58 let vec = read_csv_internal(file_path)?;
59 Ok(vec.into_iter())
60}
61
62fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
63 let vec = csv::Reader::from_path(file_path)
64 .with_context(|| input_err_msg(file_path))?
65 .into_deserialize()
66 .process_results(|iter| iter.collect_vec())
67 .with_context(|| input_err_msg(file_path))?;
68
69 Ok(vec)
70}
71
72pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
82 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
83 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
84 Ok(toml_data)
85}
86
87pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
89where
90 T: UnitType,
91 D: Deserializer<'de>,
92{
93 let value = f64::deserialize(deserialiser)?;
94 if !(value > 0.0 && value <= 1.0) {
95 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
96 }
97
98 Ok(T::new(value))
99}
100
101pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
103 format!("Error reading {}", file_path.as_ref().display())
104}
105
106fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
111where
112 T: HasID<ID> + DeserializeOwned,
113{
114 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
115 where
116 T: HasID<ID> + DeserializeOwned,
117 {
118 let mut map = IndexMap::new();
119 for record in read_csv::<T>(file_path)? {
120 let id = record.get_id().clone();
121 let existing = map.insert(id.clone(), record).is_some();
122 ensure!(!existing, "Duplicate ID found: {id}");
123 }
124 ensure!(!map.is_empty(), "CSV file is empty");
125
126 Ok(map)
127 }
128
129 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
130}
131
132fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
134where
135 T: UnitType,
136 I: Iterator<Item = T>,
137{
138 let sum = fractions.sum();
139 ensure!(
140 approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
141 "Sum of fractions does not equal one (actual: {sum})"
142 );
143
144 Ok(())
145}
146
147pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
149where
150 T: PartialOrd + Clone,
151 I: IntoIterator<Item = T>,
152{
153 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
154}
155
156pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: &K, value: V) -> Result<()>
160where
161 K: Eq + Hash + Clone + std::fmt::Debug,
162{
163 let existing = map.insert(key.clone(), value).is_some();
164 ensure!(!existing, "Key {key:?} already exists in the map");
165 Ok(())
166}
167
168pub fn format_items_with_cap<T: std::fmt::Debug>(items: &[T]) -> String {
170 const MAX_DISPLAY: usize = 10;
171 if items.len() <= MAX_DISPLAY {
172 format!("{items:?}")
173 } else {
174 format!(
175 "{:?} and {} more",
176 &items[..MAX_DISPLAY],
177 items.len() - MAX_DISPLAY
178 )
179 }
180}
181
182pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
192 let model_params = ModelParameters::from_path(&model_dir)?;
193
194 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
195 let regions = read_regions(model_dir.as_ref())?;
196 let region_ids = regions.keys().cloned().collect();
197 let years = &model_params.milestone_years;
198
199 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
200 let processes = read_processes(
201 model_dir.as_ref(),
202 &commodities,
203 ®ion_ids,
204 &time_slice_info,
205 years,
206 )?;
207 let agents = read_agents(
208 model_dir.as_ref(),
209 &commodities,
210 &processes,
211 ®ion_ids,
212 years,
213 )?;
214 let agent_ids = agents.keys().cloned().collect();
215 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
216
217 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years)?;
220 let commodity_order = validate_commodity_graphs_for_model(
221 &commodity_graphs,
222 &processes,
223 &commodities,
224 &time_slice_info,
225 )?;
226
227 let model_path = model_dir
228 .as_ref()
229 .canonicalize()
230 .context("Could not parse path to model")?;
231 let model = Model {
232 model_path,
233 parameters: model_params,
234 agents,
235 commodities,
236 processes,
237 time_slice_info,
238 regions,
239 commodity_order,
240 };
241 Ok((model, AssetPool::new(assets)))
242}
243
244pub fn load_commodity_graphs<P: AsRef<Path>>(
252 model_dir: P,
253) -> Result<HashMap<(RegionID, u32), CommoditiesGraph>> {
254 let model_params = ModelParameters::from_path(&model_dir)?;
255
256 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
257 let regions = read_regions(model_dir.as_ref())?;
258 let region_ids = regions.keys().cloned().collect();
259 let years = &model_params.milestone_years;
260
261 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
262 let processes = read_processes(
263 model_dir.as_ref(),
264 &commodities,
265 ®ion_ids,
266 &time_slice_info,
267 years,
268 )?;
269
270 let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years)?;
271 Ok(commodity_graphs)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::id::GenericID;
278 use crate::units::Dimensionless;
279 use rstest::rstest;
280 use serde::Deserialize;
281 use serde::de::IntoDeserializer;
282 use serde::de::value::{Error as ValueError, F64Deserializer};
283 use std::fs::File;
284 use std::io::Write;
285 use std::path::PathBuf;
286 use tempfile::tempdir;
287
288 #[derive(Debug, PartialEq, Deserialize)]
289 struct Record {
290 id: GenericID,
291 value: u32,
292 }
293
294 impl HasID<GenericID> for Record {
295 fn get_id(&self) -> &GenericID {
296 &self.id
297 }
298 }
299
300 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
302 let file_path = dir_path.join("test.csv");
303 let mut file = File::create(&file_path).unwrap();
304 writeln!(file, "{contents}").unwrap();
305 file_path
306 }
307
308 #[test]
310 fn test_read_csv() {
311 let dir = tempdir().unwrap();
312 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
313 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
314 assert_eq!(
315 records,
316 &[
317 Record {
318 id: "hello".into(),
319 value: 1,
320 },
321 Record {
322 id: "world".into(),
323 value: 2,
324 }
325 ]
326 );
327
328 let file_path = create_csv_file(dir.path(), "id,value\n");
330 assert!(read_csv::<Record>(&file_path).is_err());
331 assert!(
332 read_csv_optional::<Record>(&file_path)
333 .unwrap()
334 .next()
335 .is_none()
336 );
337 }
338
339 #[test]
340 fn test_read_toml() {
341 let dir = tempdir().unwrap();
342 let file_path = dir.path().join("test.toml");
343 {
344 let mut file = File::create(&file_path).unwrap();
345 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
346 }
347
348 assert_eq!(
349 read_toml::<Record>(&file_path).unwrap(),
350 Record {
351 id: "hello".into(),
352 value: 1,
353 }
354 );
355
356 {
357 let mut file = File::create(&file_path).unwrap();
358 writeln!(file, "bad toml syntax").unwrap();
359 }
360
361 assert!(read_toml::<Record>(&file_path).is_err());
362 }
363
364 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
366 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
367 deserialise_proportion_nonzero(deserialiser)
368 }
369
370 #[test]
371 fn test_deserialise_proportion_nonzero() {
372 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
374 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
375 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
376
377 assert!(deserialise_f64(0.0).is_err());
379 assert!(deserialise_f64(-1.0).is_err());
380 assert!(deserialise_f64(2.0).is_err());
381 assert!(deserialise_f64(f64::NAN).is_err());
382 assert!(deserialise_f64(f64::INFINITY).is_err());
383 }
384
385 #[test]
386 fn test_check_values_sum_to_one_approx() {
387 assert!(check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).is_ok());
389
390 assert!(
392 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
393 .is_ok()
394 );
395
396 assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
398
399 assert!(
401 check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
402 .is_err()
403 );
404
405 assert!(
407 check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
408 );
409 assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
410 }
411
412 #[rstest]
413 #[case(&[], true)]
414 #[case(&[1], true)]
415 #[case(&[1,2], true)]
416 #[case(&[1,2,3,4], true)]
417 #[case(&[2,1],false)]
418 #[case(&[1,1],false)]
419 #[case(&[1,3,2,4], false)]
420 fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
421 assert_eq!(is_sorted_and_unique(values), expected)
422 }
423
424 #[test]
425 fn test_format_items_with_cap() {
426 let items = vec!["a", "b", "c"];
427 assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
428
429 let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
431 assert_eq!(
432 format_items_with_cap(&many_items),
433 r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
434 );
435 }
436}