1use crate::asset::AssetPool;
3use crate::id::{HasID, IDLike};
4use crate::model::{Model, ModelFile};
5use anyhow::{bail, ensure, Context, Result};
6use float_cmp::approx_eq;
7use indexmap::IndexMap;
8use itertools::Itertools;
9use serde::de::{Deserialize, DeserializeOwned, Deserializer};
10use std::collections::{HashMap, HashSet};
11use std::fs;
12use std::hash::Hash;
13use std::path::Path;
14
15mod agent;
16use agent::read_agents;
17mod asset;
18use asset::read_assets;
19mod commodity;
20use commodity::read_commodities;
21mod process;
22use process::read_processes;
23mod region;
24use region::read_regions;
25mod time_slice;
26use time_slice::read_time_slice_info;
27
28pub fn read_csv<'a, T: DeserializeOwned + 'a>(
36 file_path: &'a Path,
37) -> Result<impl Iterator<Item = T> + 'a> {
38 let vec = _read_csv_internal(file_path)?;
39 if vec.is_empty() {
40 bail!("CSV file {} cannot be empty", file_path.display());
41 }
42 Ok(vec.into_iter())
43}
44
45pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
51 file_path: &'a Path,
52) -> Result<impl Iterator<Item = T> + 'a> {
53 let vec = _read_csv_internal(file_path)?;
54 Ok(vec.into_iter())
55}
56
57fn _read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
58 let vec = csv::Reader::from_path(file_path)
59 .with_context(|| input_err_msg(file_path))?
60 .into_deserialize()
61 .process_results(|iter| iter.collect_vec())
62 .with_context(|| input_err_msg(file_path))?;
63
64 Ok(vec)
65}
66
67pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
77 let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
78 let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
79 Ok(toml_data)
80}
81
82fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<f64, D::Error>
84where
85 D: Deserializer<'de>,
86{
87 let value = Deserialize::deserialize(deserialiser)?;
88 if !(value > 0.0 && value <= 1.0) {
89 Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
90 }
91
92 Ok(value)
93}
94
95pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
97 format!("Error reading {}", file_path.as_ref().display())
98}
99
100fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
105where
106 T: HasID<ID> + DeserializeOwned,
107{
108 fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
109 where
110 T: HasID<ID> + DeserializeOwned,
111 {
112 let mut map = IndexMap::new();
113 for record in read_csv::<T>(file_path)? {
114 let id = record.get_id().clone();
115 let existing = map.insert(id.clone(), record).is_some();
116 ensure!(!existing, "Duplicate ID found: {id}");
117 }
118 ensure!(!map.is_empty(), "CSV file is empty");
119
120 Ok(map)
121 }
122
123 fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
124}
125
126fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
128where
129 I: Iterator<Item = f64>,
130{
131 let sum = fractions.sum();
132 ensure!(
133 approx_eq!(f64, sum, 1.0, epsilon = 1e-5),
134 "Sum of fractions does not equal one (actual: {})",
135 sum
136 );
137
138 Ok(())
139}
140
141pub fn try_insert<K, V>(map: &mut HashMap<K, V>, key: K, value: V) -> Result<()>
145where
146 K: Eq + Hash + Clone + std::fmt::Debug,
147{
148 let existing = map.insert(key.clone(), value);
149 match existing {
150 Some(_) => bail!("Key {:?} already exists in the map", key),
151 None => Ok(()),
152 }
153}
154
155pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
165 let model_file = ModelFile::from_path(&model_dir)?;
166
167 let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
168 let regions = read_regions(model_dir.as_ref())?;
169 let region_ids = regions.keys().cloned().collect();
170 let years = &model_file.milestone_years.years;
171
172 let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
173 let processes = read_processes(
174 model_dir.as_ref(),
175 &commodities,
176 ®ion_ids,
177 &time_slice_info,
178 years,
179 )?;
180 let agents = read_agents(
181 model_dir.as_ref(),
182 &commodities,
183 &processes,
184 ®ion_ids,
185 years,
186 )?;
187 let agent_ids = agents.keys().cloned().collect();
188 let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
189
190 let model = Model {
191 milestone_years: model_file.milestone_years.years,
192 agents,
193 commodities,
194 processes,
195 time_slice_info,
196 regions,
197 };
198 Ok((model, AssetPool::new(assets)))
199}
200
201#[cfg(test)]
202mod tests {
203 use crate::id::GenericID;
204
205 use super::*;
206 use serde::de::value::{Error as ValueError, F64Deserializer};
207 use serde::de::IntoDeserializer;
208 use serde::Deserialize;
209 use std::fs::File;
210 use std::io::Write;
211 use std::path::PathBuf;
212 use tempfile::tempdir;
213
214 #[derive(Debug, PartialEq, Deserialize)]
215 struct Record {
216 id: GenericID,
217 value: u32,
218 }
219
220 impl HasID<GenericID> for Record {
221 fn get_id(&self) -> &GenericID {
222 &self.id
223 }
224 }
225
226 fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
228 let file_path = dir_path.join("test.csv");
229 let mut file = File::create(&file_path).unwrap();
230 writeln!(file, "{}", contents).unwrap();
231 file_path
232 }
233
234 #[test]
236 fn test_read_csv() {
237 let dir = tempdir().unwrap();
238 let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
239 let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
240 assert_eq!(
241 records,
242 &[
243 Record {
244 id: "hello".into(),
245 value: 1,
246 },
247 Record {
248 id: "world".into(),
249 value: 2,
250 }
251 ]
252 );
253
254 let file_path = create_csv_file(dir.path(), "id,value\n");
256 assert!(read_csv::<Record>(&file_path).is_err());
257 assert!(read_csv_optional::<Record>(&file_path)
258 .unwrap()
259 .next()
260 .is_none());
261 }
262
263 #[test]
264 fn test_read_toml() {
265 let dir = tempdir().unwrap();
266 let file_path = dir.path().join("test.toml");
267 {
268 let mut file = File::create(&file_path).unwrap();
269 writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
270 }
271
272 assert_eq!(
273 read_toml::<Record>(&file_path).unwrap(),
274 Record {
275 id: "hello".into(),
276 value: 1,
277 }
278 );
279
280 {
281 let mut file = File::create(&file_path).unwrap();
282 writeln!(file, "bad toml syntax").unwrap();
283 }
284
285 assert!(read_toml::<Record>(&file_path).is_err());
286 }
287
288 fn deserialise_f64(value: f64) -> Result<f64, ValueError> {
290 let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
291 deserialise_proportion_nonzero(deserialiser)
292 }
293
294 #[test]
295 fn test_deserialise_proportion_nonzero() {
296 assert_eq!(deserialise_f64(0.01), Ok(0.01));
298 assert_eq!(deserialise_f64(0.5), Ok(0.5));
299 assert_eq!(deserialise_f64(1.0), Ok(1.0));
300
301 assert!(deserialise_f64(0.0).is_err());
303 assert!(deserialise_f64(-1.0).is_err());
304 assert!(deserialise_f64(2.0).is_err());
305 assert!(deserialise_f64(f64::NAN).is_err());
306 assert!(deserialise_f64(f64::INFINITY).is_err());
307 }
308
309 #[test]
310 fn test_check_fractions_sum_to_one() {
311 assert!(check_fractions_sum_to_one([1.0].into_iter()).is_ok());
313
314 assert!(check_fractions_sum_to_one([0.4, 0.6].into_iter()).is_ok());
316
317 assert!(check_fractions_sum_to_one([0.5].into_iter()).is_err());
319
320 assert!(check_fractions_sum_to_one([0.4, 0.3].into_iter()).is_err());
322
323 assert!(check_fractions_sum_to_one([f64::INFINITY].into_iter()).is_err());
325 assert!(check_fractions_sum_to_one([f64::NAN].into_iter()).is_err());
326 }
327}