use crate::agent::AssetPool;
use crate::model::{Model, ModelFile};
use anyhow::{ensure, Context, Result};
use float_cmp::approx_eq;
use itertools::Itertools;
use serde::de::{Deserialize, DeserializeOwned, Deserializer};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::Path;
use std::rc::Rc;
pub mod agent;
pub use agent::read_agents;
pub mod asset;
use asset::read_assets;
pub mod commodity;
pub use commodity::read_commodities;
pub mod process;
pub use process::read_processes;
pub mod region;
pub use region::read_regions;
mod time_slice;
pub use time_slice::read_time_slice_info;
pub fn read_csv<'a, T: DeserializeOwned + 'a>(
file_path: &'a Path,
) -> Result<impl Iterator<Item = T> + 'a> {
let vec = csv::Reader::from_path(file_path)
.with_context(|| input_err_msg(file_path))?
.into_deserialize()
.process_results(|iter| iter.collect_vec())
.with_context(|| input_err_msg(file_path))?;
Ok(vec.into_iter())
}
pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
Ok(toml_data)
}
pub fn deserialise_proportion_nonzero<'de, D>(deserialiser: D) -> Result<f64, D::Error>
where
D: Deserializer<'de>,
{
let value = Deserialize::deserialize(deserialiser)?;
if !(value > 0.0 && value <= 1.0) {
Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?
}
Ok(value)
}
pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
format!("Error reading {}", file_path.as_ref().to_string_lossy())
}
pub trait HasID {
fn get_id(&self) -> &str;
}
macro_rules! define_id_getter {
($t:ty) => {
impl HasID for $t {
fn get_id(&self) -> &str {
&self.id
}
}
};
}
pub(crate) use define_id_getter;
pub trait IDCollection {
fn get_id(&self, id: &str) -> Result<Rc<str>>;
}
impl IDCollection for HashSet<Rc<str>> {
fn get_id(&self, id: &str) -> Result<Rc<str>> {
let id = self
.get(id)
.with_context(|| format!("Unknown ID {id} found"))?;
Ok(Rc::clone(id))
}
}
pub fn read_csv_id_file<T>(file_path: &Path) -> Result<HashMap<Rc<str>, T>>
where
T: HasID + DeserializeOwned,
{
fn fill_and_validate_map<T>(file_path: &Path) -> Result<HashMap<Rc<str>, T>>
where
T: HasID + DeserializeOwned,
{
let mut map = HashMap::new();
for record in read_csv::<T>(file_path)? {
let id = record.get_id();
ensure!(!map.contains_key(id), "Duplicate ID found: {id}");
map.insert(id.into(), record);
}
ensure!(!map.is_empty(), "CSV file is empty");
Ok(map)
}
fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
}
pub trait IntoIDMap<T> {
fn into_id_map(self, ids: &HashSet<Rc<str>>) -> Result<HashMap<Rc<str>, Vec<T>>>;
}
impl<T, I> IntoIDMap<T> for I
where
T: HasID,
I: Iterator<Item = T>,
{
fn into_id_map(self, ids: &HashSet<Rc<str>>) -> Result<HashMap<Rc<str>, Vec<T>>> {
let map = self
.map(|item| -> Result<_> {
let id = ids.get_id(item.get_id())?;
Ok((id, item))
})
.process_results(|iter| iter.into_group_map())?;
ensure!(!map.is_empty(), "CSV file is empty");
Ok(map)
}
}
pub fn check_fractions_sum_to_one<I>(fractions: I) -> Result<()>
where
I: Iterator<Item = f64>,
{
let sum = fractions.sum();
ensure!(
approx_eq!(f64, sum, 1.0, epsilon = 1e-5),
"Sum of fractions does not equal one (actual: {})",
sum
);
Ok(())
}
pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<(Model, AssetPool)> {
let model_file = ModelFile::from_path(&model_dir)?;
let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
let regions = read_regions(model_dir.as_ref())?;
let region_ids = regions.keys().cloned().collect();
let years = &model_file.milestone_years.years;
let year_range = *years.first().unwrap()..=*years.last().unwrap();
let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
let processes = read_processes(
model_dir.as_ref(),
&commodities,
®ion_ids,
&time_slice_info,
&year_range,
)?;
let agents = read_agents(model_dir.as_ref(), &commodities, &processes, ®ion_ids)?;
let agent_ids = agents.keys().cloned().collect();
let assets = read_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
let model = Model {
milestone_years: model_file.milestone_years.years,
agents,
commodities,
processes,
time_slice_info,
regions,
};
Ok((model, assets))
}
#[cfg(test)]
mod tests {
use super::*;
use serde::de::value::{Error as ValueError, F64Deserializer};
use serde::de::IntoDeserializer;
use serde::Deserialize;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tempfile::tempdir;
#[derive(Debug, PartialEq, Deserialize)]
struct Record {
id: String,
value: u32,
}
impl HasID for Record {
fn get_id(&self) -> &str {
&self.id
}
}
fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
let file_path = dir_path.join("test.csv");
let mut file = File::create(&file_path).unwrap();
writeln!(file, "{}", contents).unwrap();
file_path
}
#[test]
fn test_read_csv() {
let dir = tempdir().unwrap();
let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
assert_eq!(
records,
&[
Record {
id: "hello".to_string(),
value: 1,
},
Record {
id: "world".to_string(),
value: 2,
}
]
);
}
#[test]
fn test_read_toml() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.toml");
{
let mut file = File::create(&file_path).unwrap();
writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
}
assert_eq!(
read_toml::<Record>(&file_path).unwrap(),
Record {
id: "hello".to_string(),
value: 1,
}
);
{
let mut file = File::create(&file_path).unwrap();
writeln!(file, "bad toml syntax").unwrap();
}
assert!(read_toml::<Record>(&file_path).is_err());
}
fn deserialise_f64(value: f64) -> Result<f64, ValueError> {
let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
deserialise_proportion_nonzero(deserialiser)
}
#[test]
fn test_deserialise_proportion_nonzero() {
assert_eq!(deserialise_f64(0.01), Ok(0.01));
assert_eq!(deserialise_f64(0.5), Ok(0.5));
assert_eq!(deserialise_f64(1.0), Ok(1.0));
assert!(deserialise_f64(0.0).is_err());
assert!(deserialise_f64(-1.0).is_err());
assert!(deserialise_f64(2.0).is_err());
assert!(deserialise_f64(f64::NAN).is_err());
assert!(deserialise_f64(f64::INFINITY).is_err());
}
#[test]
fn test_check_fractions_sum_to_one() {
assert!(check_fractions_sum_to_one([1.0].into_iter()).is_ok());
assert!(check_fractions_sum_to_one([0.4, 0.6].into_iter()).is_ok());
assert!(check_fractions_sum_to_one([0.5].into_iter()).is_err());
assert!(check_fractions_sum_to_one([0.4, 0.3].into_iter()).is_err());
assert!(check_fractions_sum_to_one([f64::INFINITY].into_iter()).is_err());
assert!(check_fractions_sum_to_one([f64::NAN].into_iter()).is_err());
}
}