1use crate::asset::AssetPool;
3use crate::id::{HasID, IDLike};
4use crate::model::{Model, ModelFile};
5use crate::units::Dimensionless;
6use anyhow::{bail, ensure, Context, Result};
7use float_cmp::approx_eq;
8use indexmap::IndexMap;
9use itertools::Itertools;
10use serde::de::{Deserialize, DeserializeOwned, Deserializer};
11use std::collections::{HashMap, HashSet};
12use std::fs;
13use std::hash::Hash;
14use std::path::Path;
15
16mod agent;
17use agent::read_agents;
18mod asset;
19use asset::read_assets;
20mod commodity;
21use commodity::read_commodities;
22mod process;
23use process::read_processes;
24mod region;
25use region::read_regions;
26mod time_slice;
27use time_slice::read_time_slice_info;
28
29pub fn read_csv<'a, T: DeserializeOwned + 'a>(
37 file_path: &'a Path,
38) -> Result<impl Iterator<Item = T> + 'a> {
39 let vec = _read_csv_internal(file_path)?;
40 if vec.is_empty() {
41 bail!("CSV file {} cannot be empty", file_path.display());
42 }
43 Ok(vec.into_iter())
44}
45
46pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
52 file_path: &'a Path,
53) -> Result<impl Iterator<Item = T> + 'a> {
54 let vec = _read_csv_internal(file_path)?;
55 Ok(vec.into_iter())
56}
57
58fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
59 let vec = csv::Reader::from_path(file_path)
60 .with_context(|| input_err_msg(file_path))?
61 .into_deserialize()
62 .process_results(|iter| iter.collect_vec())
63 .with_context(|| input_err_msg(file_path))?;
64
65 Ok(vec)
66}
67
68pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
78 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
79 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
80 Ok(toml_data)
81}
82
83fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<Dimensionless, D::Error>
85where
86 D: Deserializer<'de>,
87{
88 let value = f64::deserialize(deserialiser)?;
89 if !(value > 0.0 && value <= 1.0) {
90 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
91 }
92
93 Ok(Dimensionless(value))
94}
95
96pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
98 format!("Error reading {}", file_path.as_ref().display())
99}
100
101fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
106where
107 T: HasID<ID> + DeserializeOwned,
108{
109 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
110 where
111 T: HasID<ID> + DeserializeOwned,
112 {
113 let mut map = IndexMap::new();
114 for record in read_csv::<T>(file_path)? {
115 let id = record.get_id().clone();
116 let existing = map.insert(id.clone(), record).is_some();
117 ensure!(!existing, "Duplicate ID found: {id}");
118 }
119 ensure!(!map.is_empty(), "CSV file is empty");
120
121 Ok(map)
122 }
123
124 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
125}
126
127fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
129where
130 I: Iterator<Item = Dimensionless>,
131{
132 let sum = fractions.sum();
133 ensure!(
134 approx_eq!(Dimensionless, sum, Dimensionless(1.0), epsilon = 1e-5),
135 "Sum of fractions does not equal one (actual: {})",
136 sum
137 );
138
139 Ok(())
140}
141
142pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
144where
145 T: PartialOrd + Clone,
146 I: IntoIterator<Item = T>,
147{
148 iter.into_iter().tuple_windows().all(|(a, b)| a < b)
149}
150
151pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
155where
156 K: Eq + Hash + Clone + std::fmt::Debug,
157{
158 let existing = map.insert(key.clone(), value);
159 match existing {
160 Some(_) => bail!("Key {:?} already exists in the map", key),
161 None => Ok(()),
162 }
163}
164
165pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
175 let model_file = ModelFile::from_path(&model_dir)?;
176
177 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
178 let regions = read_regions(model_dir.as_ref())?;
179 let region_ids = regions.keys().cloned().collect();
180 let years = &model_file.milestone_years.years;
181
182 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
183 let processes = read_processes(
184 model_dir.as_ref(),
185 &commodities,
186 ®ion_ids,
187 &time_slice_info,
188 years,
189 )?;
190 let agents = read_agents(
191 model_dir.as_ref(),
192 &commodities,
193 &processes,
194 ®ion_ids,
195 years,
196 )?;
197 let agent_ids = agents.keys().cloned().collect();
198 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
199
200 let model_path = model_dir
201 .as_ref()
202 .canonicalize()
203 .context("Could not parse path to model")?;
204 let model = Model {
205 model_path,
206 milestone_years: model_file.milestone_years.years,
207 agents,
208 commodities,
209 processes,
210 time_slice_info,
211 regions,
212 };
213 Ok((model, AssetPool::new(assets)))
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::id::GenericID;
220 use rstest::rstest;
221 use serde::de::value::{Error as ValueError, F64Deserializer};
222 use serde::de::IntoDeserializer;
223 use serde::Deserialize;
224 use std::fs::File;
225 use std::io::Write;
226 use std::path::PathBuf;
227 use tempfile::tempdir;
228
229 #[derive(Debug, PartialEq, Deserialize)]
230 struct Record {
231 id: GenericID,
232 value: u32,
233 }
234
235 impl HasID<GenericID> for Record {
236 fn get_id(&self) -> &GenericID {
237 &self.id
238 }
239 }
240
241 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
243 let file_path = dir_path.join("test.csv");
244 let mut file = File::create(&file_path).unwrap();
245 writeln!(file, "{}", contents).unwrap();
246 file_path
247 }
248
249 #[test]
251 fn test_read_csv() {
252 let dir = tempdir().unwrap();
253 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
254 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
255 assert_eq!(
256 records,
257 &[
258 Record {
259 id: "hello".into(),
260 value: 1,
261 },
262 Record {
263 id: "world".into(),
264 value: 2,
265 }
266 ]
267 );
268
269 let file_path = create_csv_file(dir.path(), "id,value\n");
271 assert!(read_csv::<Record>(&file_path).is_err());
272 assert!(read_csv_optional::<Record>(&file_path)
273 .unwrap()
274 .next()
275 .is_none());
276 }
277
278 #[test]
279 fn test_read_toml() {
280 let dir = tempdir().unwrap();
281 let file_path = dir.path().join("test.toml");
282 {
283 let mut file = File::create(&file_path).unwrap();
284 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
285 }
286
287 assert_eq!(
288 read_toml::<Record>(&file_path).unwrap(),
289 Record {
290 id: "hello".into(),
291 value: 1,
292 }
293 );
294
295 {
296 let mut file = File::create(&file_path).unwrap();
297 writeln!(file, "bad toml syntax").unwrap();
298 }
299
300 assert!(read_toml::<Record>(&file_path).is_err());
301 }
302
303 fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
305 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
306 deserialise_proportion_nonzero(deserialiser)
307 }
308
309 #[test]
310 fn test_deserialise_proportion_nonzero() {
311 assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
313 assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
314 assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
315
316 assert!(deserialise_f64(0.0).is_err());
318 assert!(deserialise_f64(-1.0).is_err());
319 assert!(deserialise_f64(2.0).is_err());
320 assert!(deserialise_f64(f64::NAN).is_err());
321 assert!(deserialise_f64(f64::INFINITY).is_err());
322 }
323
324 #[test]
325 fn test_check_fractions_sum_to_one() {
326 assert!(check_fractions_sum_to_one([Dimensionless(1.0)].into_iter()).is_ok());
328
329 assert!(
331 check_fractions_sum_to_one([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
332 .is_ok()
333 );
334
335 assert!(check_fractions_sum_to_one([Dimensionless(0.5)].into_iter()).is_err());
337
338 assert!(
340 check_fractions_sum_to_one([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
341 .is_err()
342 );
343
344 assert!(check_fractions_sum_to_one([Dimensionless(f64::INFINITY)].into_iter()).is_err());
346 assert!(check_fractions_sum_to_one([Dimensionless(f64::NAN)].into_iter()).is_err());
347 }
348
349 #[rstest]
350 #[case(&[], true)]
351 #[case(&[1], true)]
352 #[case(&[1,2], true)]
353 #[case(&[1,2,3,4], true)]
354 #[case(&[2,1],false)]
355 #[case(&[1,1],false)]
356 #[case(&[1,3,2,4], false)]
357 fn test_is_sorted_and_unique(#[case] values: &[u32], #[case] expected: bool) {
358 assert_eq!(is_sorted_and_unique(values), expected)
359 }
360}