1use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetPool};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::CommodityPrices;
8use crate::time_slice::TimeSliceID;
9use anyhow::{Context, Result};
10use csv;
11use serde::{Deserialize, Serialize};
12use std::fs;
13use std::fs::File;
14use std::path::{Path, PathBuf};
15
16const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
18
19const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
21
22const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
24
25const ASSETS_FILE_NAME: &str = "assets.csv";
27
28pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
30 let model_dir = model_dir
33 .canonicalize() .context("Could not resolve path to model")?;
35
36 let model_name = model_dir
37 .file_name()
38 .context("Model cannot be in root folder")?
39 .to_str()
40 .context("Invalid chars in model dir name")?;
41
42 Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
44}
45
46pub fn create_output_directory(output_dir: &Path) -> Result<()> {
48 if output_dir.is_dir() {
49 return Ok(());
51 }
52
53 fs::create_dir_all(output_dir)?;
55
56 Ok(())
57}
58
59#[derive(Serialize, Deserialize, Debug, PartialEq)]
61struct AssetRow {
62 milestone_year: u32,
63 process_id: ProcessID,
64 region_id: RegionID,
65 agent_id: AgentID,
66 commission_year: u32,
67}
68
69impl AssetRow {
70 fn new(milestone_year: u32, asset: &Asset) -> Self {
71 Self {
72 milestone_year,
73 process_id: asset.process.id.clone(),
74 region_id: asset.region_id.clone(),
75 agent_id: asset.agent_id.clone(),
76 commission_year: asset.commission_year,
77 }
78 }
79}
80
81#[derive(Serialize, Deserialize, Debug, PartialEq)]
85struct CommodityFlowRow {
86 commodity_id: CommodityID,
87 time_slice: String,
88 flow: f64,
89}
90
91#[derive(Serialize, Deserialize, Debug, PartialEq)]
93struct CommodityPriceRow {
94 milestone_year: u32,
95 commodity_id: CommodityID,
96 region_id: RegionID,
97 time_slice: String,
98 price: f64,
99}
100
101pub struct DataWriter {
103 assets_writer: csv::Writer<File>,
104 flows_writer: csv::Writer<File>,
105 prices_writer: csv::Writer<File>,
106}
107
108impl DataWriter {
109 pub fn create(output_path: &Path) -> Result<Self> {
111 let new_writer = |file_name| {
112 let file_path = output_path.join(file_name);
113 csv::Writer::from_path(file_path)
114 };
115
116 Ok(Self {
117 assets_writer: new_writer(ASSETS_FILE_NAME)?,
118 flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
119 prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
120 })
121 }
122
123 pub fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
125 where
126 I: Iterator<Item = &'a Asset>,
127 {
128 for asset in assets {
129 let row = AssetRow::new(milestone_year, asset);
130 self.assets_writer.serialize(row)?;
131 }
132
133 Ok(())
134 }
135
136 pub fn write_flows<'a, I>(
138 &mut self,
139 milestone_year: u32,
140 assets: &AssetPool,
141 flows: I,
142 ) -> Result<()>
143 where
144 I: Iterator<Item = (AssetID, &'a CommodityID, &'a TimeSliceID, f64)>,
145 {
146 for (asset_id, commodity_id, time_slice, flow) in flows {
147 let asset = assets.get(asset_id).unwrap();
148 let asset_row = AssetRow::new(milestone_year, asset);
149 let flow_row = CommodityFlowRow {
150 commodity_id: commodity_id.clone(),
151 time_slice: time_slice.to_string(),
152 flow,
153 };
154 self.flows_writer.serialize((asset_row, flow_row))?;
155 }
156
157 Ok(())
158 }
159
160 pub fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
162 for (commodity_id, region_id, time_slice, price) in prices.iter() {
163 let row = CommodityPriceRow {
164 milestone_year,
165 commodity_id: commodity_id.clone(),
166 region_id: region_id.clone(),
167 time_slice: time_slice.to_string(),
168 price,
169 };
170 self.prices_writer.serialize(row)?;
171 }
172
173 Ok(())
174 }
175
176 pub fn flush(&mut self) -> Result<()> {
178 self.assets_writer.flush()?;
179 self.flows_writer.flush()?;
180 self.prices_writer.flush()?;
181
182 Ok(())
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::fixture::process;
190 use crate::process::Process;
191 use crate::time_slice::TimeSliceID;
192 use itertools::{assert_equal, Itertools};
193 use rstest::{fixture, rstest};
194 use std::iter;
195 use tempfile::tempdir;
196
197 #[fixture]
198 pub fn asset(process: Process) -> Asset {
199 let region_id: RegionID = "GBR".into();
200 let agent_id = "agent1".into();
201 let commission_year = 2015;
202 Asset::new(agent_id, process.into(), region_id, 2.0, commission_year).unwrap()
203 }
204
205 #[rstest]
206 fn test_write_assets(asset: Asset) {
207 let milestone_year = 2020;
208 let dir = tempdir().unwrap();
209
210 {
212 let mut writer = DataWriter::create(dir.path()).unwrap();
213 writer
214 .write_assets(milestone_year, iter::once(&asset))
215 .unwrap();
216 writer.flush().unwrap();
217 }
218
219 let expected = AssetRow::new(milestone_year, &asset);
221 let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
222 .unwrap()
223 .into_deserialize()
224 .try_collect()
225 .unwrap();
226 assert_equal(records, iter::once(expected));
227 }
228
229 #[rstest]
230 fn test_write_flows(asset: Asset) {
231 let milestone_year = 2020;
232 let commodity_id = "commodity1".into();
233 let time_slice = TimeSliceID {
234 season: "winter".into(),
235 time_of_day: "day".into(),
236 };
237 let mut assets = AssetPool::new(vec![asset]);
238 assets.commission_new(2020);
239 let flow_item = (
240 assets.iter().next().unwrap().id,
241 &commodity_id,
242 &time_slice,
243 42.0,
244 );
245
246 let dir = tempdir().unwrap();
248 {
249 let mut writer = DataWriter::create(dir.path()).unwrap();
250 writer
251 .write_flows(milestone_year, &assets, iter::once(flow_item))
252 .unwrap();
253 writer.flush().unwrap();
254 }
255
256 let expected = CommodityFlowRow {
258 commodity_id,
259 time_slice: time_slice.to_string(),
260 flow: 42.0,
261 };
262 let records: Vec<CommodityFlowRow> =
263 csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
264 .unwrap()
265 .into_deserialize()
266 .try_collect()
267 .unwrap();
268 assert_equal(records, iter::once(expected));
269 }
270
271 #[test]
272 fn test_write_prices() {
273 let commodity_id = "commodity1".into();
274 let region_id = "GBR".into();
275 let time_slice = TimeSliceID {
276 season: "winter".into(),
277 time_of_day: "day".into(),
278 };
279 let milestone_year = 2020;
280 let price = 42.0;
281 let mut prices = CommodityPrices::default();
282 prices.insert(&commodity_id, ®ion_id, &time_slice, price);
283
284 let dir = tempdir().unwrap();
285
286 {
288 let mut writer = DataWriter::create(dir.path()).unwrap();
289 writer.write_prices(milestone_year, &prices).unwrap();
290 writer.flush().unwrap();
291 }
292
293 let expected = CommodityPriceRow {
295 milestone_year,
296 commodity_id,
297 region_id,
298 time_slice: time_slice.to_string(),
299 price,
300 };
301 let records: Vec<CommodityPriceRow> =
302 csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
303 .unwrap()
304 .into_deserialize()
305 .try_collect()
306 .unwrap();
307 assert_equal(records, iter::once(expected));
308 }
309}