muse2/
output.rs

1//! The module responsible for writing output data to disk.
2use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetPool};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::CommodityPrices;
8use crate::time_slice::TimeSliceID;
9use anyhow::{Context, Result};
10use csv;
11use serde::{Deserialize, Serialize};
12use std::fs;
13use std::fs::File;
14use std::path::{Path, PathBuf};
15
16/// The root folder in which model-specific output folders will be created
17const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
18
19/// The output file name for commodity flows
20const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
21
22/// The output file name for commodity prices
23const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
24
25/// The output file name for assets
26const ASSETS_FILE_NAME: &str = "assets.csv";
27
28/// Get the model name from the specified directory path
29pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
30    // Get the model name from the dir path. This ends up being convoluted because we need to check
31    // for all possible errors. Ugh.
32    let model_dir = model_dir
33        .canonicalize() // canonicalise in case the user has specified "."
34        .context("Could not resolve path to model")?;
35
36    let model_name = model_dir
37        .file_name()
38        .context("Model cannot be in root folder")?
39        .to_str()
40        .context("Invalid chars in model dir name")?;
41
42    // Construct path
43    Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
44}
45
46/// Create a new output directory for the model specified at `model_dir`.
47pub fn create_output_directory(output_dir: &Path) -> Result<()> {
48    if output_dir.is_dir() {
49        // already exists
50        return Ok(());
51    }
52
53    // Try to create the directory, with parents
54    fs::create_dir_all(output_dir)?;
55
56    Ok(())
57}
58
59/// Represents a row in the assets output CSV file
60#[derive(Serialize, Deserialize, Debug, PartialEq)]
61struct AssetRow {
62    milestone_year: u32,
63    process_id: ProcessID,
64    region_id: RegionID,
65    agent_id: AgentID,
66    commission_year: u32,
67}
68
69impl AssetRow {
70    fn new(milestone_year: u32, asset: &Asset) -> Self {
71        Self {
72            milestone_year,
73            process_id: asset.process.id.clone(),
74            region_id: asset.region_id.clone(),
75            agent_id: asset.agent_id.clone(),
76            commission_year: asset.commission_year,
77        }
78    }
79}
80
81/// Represents the flow-related data in a row of the commodity flows CSV file.
82///
83/// This will be written along with an [`AssetRow`] containing asset-related info.
84#[derive(Serialize, Deserialize, Debug, PartialEq)]
85struct CommodityFlowRow {
86    commodity_id: CommodityID,
87    time_slice: String,
88    flow: f64,
89}
90
91/// Represents a row in the commodity prices CSV file
92#[derive(Serialize, Deserialize, Debug, PartialEq)]
93struct CommodityPriceRow {
94    milestone_year: u32,
95    commodity_id: CommodityID,
96    region_id: RegionID,
97    time_slice: String,
98    price: f64,
99}
100
101/// An object for writing commodity prices to file
102pub struct DataWriter {
103    assets_writer: csv::Writer<File>,
104    flows_writer: csv::Writer<File>,
105    prices_writer: csv::Writer<File>,
106}
107
108impl DataWriter {
109    /// Create a new CSV files to write output data to
110    pub fn create(output_path: &Path) -> Result<Self> {
111        let new_writer = |file_name| {
112            let file_path = output_path.join(file_name);
113            csv::Writer::from_path(file_path)
114        };
115
116        Ok(Self {
117            assets_writer: new_writer(ASSETS_FILE_NAME)?,
118            flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
119            prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
120        })
121    }
122
123    /// Write assets to a CSV file
124    pub fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
125    where
126        I: Iterator<Item = &'a Asset>,
127    {
128        for asset in assets {
129            let row = AssetRow::new(milestone_year, asset);
130            self.assets_writer.serialize(row)?;
131        }
132
133        Ok(())
134    }
135
136    /// Write commodity flows to a CSV file
137    pub fn write_flows<'a, I>(
138        &mut self,
139        milestone_year: u32,
140        assets: &AssetPool,
141        flows: I,
142    ) -> Result<()>
143    where
144        I: Iterator<Item = (AssetID, &'a CommodityID, &'a TimeSliceID, f64)>,
145    {
146        for (asset_id, commodity_id, time_slice, flow) in flows {
147            let asset = assets.get(asset_id).unwrap();
148            let asset_row = AssetRow::new(milestone_year, asset);
149            let flow_row = CommodityFlowRow {
150                commodity_id: commodity_id.clone(),
151                time_slice: time_slice.to_string(),
152                flow,
153            };
154            self.flows_writer.serialize((asset_row, flow_row))?;
155        }
156
157        Ok(())
158    }
159
160    /// Write commodity prices to a CSV file
161    pub fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
162        for (commodity_id, region_id, time_slice, price) in prices.iter() {
163            let row = CommodityPriceRow {
164                milestone_year,
165                commodity_id: commodity_id.clone(),
166                region_id: region_id.clone(),
167                time_slice: time_slice.to_string(),
168                price,
169            };
170            self.prices_writer.serialize(row)?;
171        }
172
173        Ok(())
174    }
175
176    /// Flush the underlying streams
177    pub fn flush(&mut self) -> Result<()> {
178        self.assets_writer.flush()?;
179        self.flows_writer.flush()?;
180        self.prices_writer.flush()?;
181
182        Ok(())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::fixture::process;
190    use crate::process::Process;
191    use crate::time_slice::TimeSliceID;
192    use itertools::{assert_equal, Itertools};
193    use rstest::{fixture, rstest};
194    use std::iter;
195    use tempfile::tempdir;
196
197    #[fixture]
198    pub fn asset(process: Process) -> Asset {
199        let region_id: RegionID = "GBR".into();
200        let agent_id = "agent1".into();
201        let commission_year = 2015;
202        Asset::new(agent_id, process.into(), region_id, 2.0, commission_year).unwrap()
203    }
204
205    #[rstest]
206    fn test_write_assets(asset: Asset) {
207        let milestone_year = 2020;
208        let dir = tempdir().unwrap();
209
210        // Write an asset
211        {
212            let mut writer = DataWriter::create(dir.path()).unwrap();
213            writer
214                .write_assets(milestone_year, iter::once(&asset))
215                .unwrap();
216            writer.flush().unwrap();
217        }
218
219        // Read back and compare
220        let expected = AssetRow::new(milestone_year, &asset);
221        let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
222            .unwrap()
223            .into_deserialize()
224            .try_collect()
225            .unwrap();
226        assert_equal(records, iter::once(expected));
227    }
228
229    #[rstest]
230    fn test_write_flows(asset: Asset) {
231        let milestone_year = 2020;
232        let commodity_id = "commodity1".into();
233        let time_slice = TimeSliceID {
234            season: "winter".into(),
235            time_of_day: "day".into(),
236        };
237        let mut assets = AssetPool::new(vec![asset]);
238        assets.commission_new(2020);
239        let flow_item = (
240            assets.iter().next().unwrap().id,
241            &commodity_id,
242            &time_slice,
243            42.0,
244        );
245
246        // Write a flow
247        let dir = tempdir().unwrap();
248        {
249            let mut writer = DataWriter::create(dir.path()).unwrap();
250            writer
251                .write_flows(milestone_year, &assets, iter::once(flow_item))
252                .unwrap();
253            writer.flush().unwrap();
254        }
255
256        // Read back and compare
257        let expected = CommodityFlowRow {
258            commodity_id,
259            time_slice: time_slice.to_string(),
260            flow: 42.0,
261        };
262        let records: Vec<CommodityFlowRow> =
263            csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
264                .unwrap()
265                .into_deserialize()
266                .try_collect()
267                .unwrap();
268        assert_equal(records, iter::once(expected));
269    }
270
271    #[test]
272    fn test_write_prices() {
273        let commodity_id = "commodity1".into();
274        let region_id = "GBR".into();
275        let time_slice = TimeSliceID {
276            season: "winter".into(),
277            time_of_day: "day".into(),
278        };
279        let milestone_year = 2020;
280        let price = 42.0;
281        let mut prices = CommodityPrices::default();
282        prices.insert(&commodity_id, &region_id, &time_slice, price);
283
284        let dir = tempdir().unwrap();
285
286        // Write a price
287        {
288            let mut writer = DataWriter::create(dir.path()).unwrap();
289            writer.write_prices(milestone_year, &prices).unwrap();
290            writer.flush().unwrap();
291        }
292
293        // Read back and compare
294        let expected = CommodityPriceRow {
295            milestone_year,
296            commodity_id,
297            region_id,
298            time_slice: time_slice.to_string(),
299            price,
300        };
301        let records: Vec<CommodityPriceRow> =
302            csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
303                .unwrap()
304                .into_deserialize()
305                .try_collect()
306                .unwrap();
307        assert_equal(records, iter::once(expected));
308    }
309}