1use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetRef};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::optimisation::{FlowMap, Solution};
8use crate::simulation::CommodityPrices;
9use crate::time_slice::TimeSliceID;
10use crate::units::{Flow, MoneyPerFlow};
11use anyhow::{Context, Result};
12use csv;
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::fs::File;
16use std::path::{Path, PathBuf};
17
18pub mod metadata;
19use metadata::write_metadata;
20
21const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
23
24const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
26
27const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
29
30const ASSETS_FILE_NAME: &str = "assets.csv";
32
33const COMMODITY_BALANCE_DUALS_FILE_NAME: &str = "debug_commodity_balance_duals.csv";
35
36const ACTIVITY_DUALS_FILE_NAME: &str = "debug_activity_duals.csv";
38
39pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
41 let model_dir = model_dir
44 .canonicalize() .context("Could not resolve path to model")?;
46
47 let model_name = model_dir
48 .file_name()
49 .context("Model cannot be in root folder")?
50 .to_str()
51 .context("Invalid chars in model dir name")?;
52
53 Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
55}
56
57pub fn create_output_directory(output_dir: &Path) -> Result<()> {
59 if output_dir.is_dir() {
60 return Ok(());
62 }
63
64 fs::create_dir_all(output_dir)?;
66
67 Ok(())
68}
69
70#[derive(Serialize, Deserialize, Debug, PartialEq)]
72struct AssetRow {
73 milestone_year: u32,
74 asset_id: AssetID,
75 process_id: ProcessID,
76 region_id: RegionID,
77 agent_id: AgentID,
78 commission_year: u32,
79}
80
81impl AssetRow {
82 fn new(milestone_year: u32, asset: &Asset) -> Self {
84 Self {
85 milestone_year,
86 asset_id: asset.id.unwrap(),
87 process_id: asset.process.id.clone(),
88 region_id: asset.region_id.clone(),
89 agent_id: asset.agent_id.clone(),
90 commission_year: asset.commission_year,
91 }
92 }
93}
94
95#[derive(Serialize, Deserialize, Debug, PartialEq)]
97struct CommodityFlowRow {
98 milestone_year: u32,
99 asset_id: AssetID,
100 commodity_id: CommodityID,
101 time_slice: TimeSliceID,
102 flow: Flow,
103}
104
105#[derive(Serialize, Deserialize, Debug, PartialEq)]
107struct CommodityPriceRow {
108 milestone_year: u32,
109 commodity_id: CommodityID,
110 region_id: RegionID,
111 time_slice: TimeSliceID,
112 price: MoneyPerFlow,
113}
114
115#[derive(Serialize, Deserialize, Debug, PartialEq)]
117struct ActivityDualsRow {
118 milestone_year: u32,
119 asset_id: Option<AssetID>,
120 time_slice: TimeSliceID,
121 value: f64,
122}
123
124#[derive(Serialize, Deserialize, Debug, PartialEq)]
126struct CommodityBalanceDualsRow {
127 milestone_year: u32,
128 commodity_id: CommodityID,
129 region_id: RegionID,
130 time_slice: TimeSliceID,
131 value: f64,
132}
133
134struct DebugDataWriter {
136 commodity_balance_duals_writer: csv::Writer<File>,
137 activity_duals_writer: csv::Writer<File>,
138}
139
140impl DebugDataWriter {
141 fn create(output_path: &Path) -> Result<Self> {
147 let new_writer = |file_name| {
148 let file_path = output_path.join(file_name);
149 csv::Writer::from_path(file_path)
150 };
151
152 Ok(Self {
153 commodity_balance_duals_writer: new_writer(COMMODITY_BALANCE_DUALS_FILE_NAME)?,
154 activity_duals_writer: new_writer(ACTIVITY_DUALS_FILE_NAME)?,
155 })
156 }
157
158 fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
160 self.write_activity_duals(milestone_year, solution.iter_activity_duals())?;
161 self.write_commodity_balance_duals(
162 milestone_year,
163 solution.iter_commodity_balance_duals(),
164 )?;
165 Ok(())
166 }
167
168 fn write_activity_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
170 where
171 I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, f64)>,
172 {
173 for (asset, time_slice, value) in iter {
174 let row = ActivityDualsRow {
175 milestone_year,
176 asset_id: asset.id,
177 time_slice: time_slice.clone(),
178 value,
179 };
180 self.activity_duals_writer.serialize(row)?;
181 }
182
183 Ok(())
184 }
185
186 fn write_commodity_balance_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
188 where
189 I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, f64)>,
190 {
191 for (commodity_id, region_id, time_slice, value) in iter {
192 let row = CommodityBalanceDualsRow {
193 milestone_year,
194 commodity_id: commodity_id.clone(),
195 region_id: region_id.clone(),
196 time_slice: time_slice.clone(),
197 value,
198 };
199 self.commodity_balance_duals_writer.serialize(row)?;
200 }
201
202 Ok(())
203 }
204
205 fn flush(&mut self) -> Result<()> {
207 self.commodity_balance_duals_writer.flush()?;
208 self.activity_duals_writer.flush()?;
209
210 Ok(())
211 }
212}
213
214pub struct DataWriter {
216 assets_writer: csv::Writer<File>,
217 flows_writer: csv::Writer<File>,
218 prices_writer: csv::Writer<File>,
219 debug_writer: Option<DebugDataWriter>,
220}
221
222impl DataWriter {
223 pub fn create(output_path: &Path, model_path: &Path, save_debug_info: bool) -> Result<Self> {
231 write_metadata(output_path, model_path).context("Failed to save metadata")?;
232
233 let new_writer = |file_name| {
234 let file_path = output_path.join(file_name);
235 csv::Writer::from_path(file_path)
236 };
237
238 let debug_writer = if save_debug_info {
239 Some(DebugDataWriter::create(output_path)?)
241 } else {
242 None
243 };
244
245 Ok(Self {
246 assets_writer: new_writer(ASSETS_FILE_NAME)?,
247 flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
248 prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
249 debug_writer,
250 })
251 }
252
253 pub fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
255 where
256 I: Iterator<Item = &'a AssetRef>,
257 {
258 for asset in assets {
259 let row = AssetRow::new(milestone_year, asset);
260 self.assets_writer.serialize(row)?;
261 }
262
263 Ok(())
264 }
265
266 pub fn write_flows(&mut self, milestone_year: u32, flow_map: &FlowMap) -> Result<()> {
268 for ((asset, commodity_id, time_slice), flow) in flow_map {
269 let row = CommodityFlowRow {
270 milestone_year,
271 asset_id: asset.id.unwrap(),
272 commodity_id: commodity_id.clone(),
273 time_slice: time_slice.clone(),
274 flow: *flow,
275 };
276 self.flows_writer.serialize(row)?;
277 }
278
279 Ok(())
280 }
281
282 pub fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
284 for (commodity_id, region_id, time_slice, price) in prices.iter() {
285 let row = CommodityPriceRow {
286 milestone_year,
287 commodity_id: commodity_id.clone(),
288 region_id: region_id.clone(),
289 time_slice: time_slice.clone(),
290 price,
291 };
292 self.prices_writer.serialize(row)?;
293 }
294
295 Ok(())
296 }
297
298 pub fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
300 if let Some(ref mut wtr) = &mut self.debug_writer {
301 wtr.write_debug_info(milestone_year, solution)?;
302 }
303
304 Ok(())
305 }
306
307 pub fn flush(&mut self) -> Result<()> {
309 self.assets_writer.flush()?;
310 self.flows_writer.flush()?;
311 self.prices_writer.flush()?;
312 if let Some(ref mut wtr) = &mut self.debug_writer {
313 wtr.flush()?;
314 }
315
316 Ok(())
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::asset::AssetPool;
324 use crate::fixture::{assets, commodity_id, region_id, time_slice};
325 use crate::time_slice::TimeSliceID;
326 use indexmap::indexmap;
327 use itertools::{assert_equal, Itertools};
328 use rstest::rstest;
329 use std::iter;
330 use tempfile::tempdir;
331
332 #[rstest]
333 fn test_write_assets(assets: AssetPool) {
334 let milestone_year = 2020;
335 let dir = tempdir().unwrap();
336
337 {
339 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
340 writer.write_assets(milestone_year, assets.iter()).unwrap();
341 writer.flush().unwrap();
342 }
343
344 let asset = assets.iter().next().unwrap();
346 let expected = AssetRow::new(milestone_year, asset);
347 let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
348 .unwrap()
349 .into_deserialize()
350 .try_collect()
351 .unwrap();
352 assert_equal(records, iter::once(expected));
353 }
354
355 #[rstest]
356 fn test_write_flows(assets: AssetPool, commodity_id: CommodityID, time_slice: TimeSliceID) {
357 let milestone_year = 2020;
358 let asset = assets.iter().next().unwrap();
359 let flow_map = indexmap! {
360 (asset.clone(), commodity_id.clone(), time_slice.clone()) => Flow(42.0)
361 };
362
363 let dir = tempdir().unwrap();
365 {
366 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
367 writer.write_flows(milestone_year, &flow_map).unwrap();
368 writer.flush().unwrap();
369 }
370
371 let expected = CommodityFlowRow {
373 milestone_year,
374 asset_id: asset.id.unwrap(),
375 commodity_id,
376 time_slice,
377 flow: Flow(42.0),
378 };
379 let records: Vec<CommodityFlowRow> =
380 csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
381 .unwrap()
382 .into_deserialize()
383 .try_collect()
384 .unwrap();
385 assert_equal(records, iter::once(expected));
386 }
387
388 #[rstest]
389 fn test_write_prices(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
390 let milestone_year = 2020;
391 let price = MoneyPerFlow(42.0);
392 let mut prices = CommodityPrices::default();
393 prices.insert(&commodity_id, ®ion_id, &time_slice, price);
394
395 let dir = tempdir().unwrap();
396
397 {
399 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
400 writer.write_prices(milestone_year, &prices).unwrap();
401 writer.flush().unwrap();
402 }
403
404 let expected = CommodityPriceRow {
406 milestone_year,
407 commodity_id,
408 region_id,
409 time_slice,
410 price,
411 };
412 let records: Vec<CommodityPriceRow> =
413 csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
414 .unwrap()
415 .into_deserialize()
416 .try_collect()
417 .unwrap();
418 assert_equal(records, iter::once(expected));
419 }
420
421 #[rstest]
422 fn test_write_commodity_balance_duals(
423 commodity_id: CommodityID,
424 region_id: RegionID,
425 time_slice: TimeSliceID,
426 ) {
427 let milestone_year = 2020;
428 let value = 0.5;
429 let dir = tempdir().unwrap();
430
431 {
433 let mut writer = DebugDataWriter::create(dir.path()).unwrap();
434 writer
435 .write_commodity_balance_duals(
436 milestone_year,
437 iter::once((&commodity_id, ®ion_id, &time_slice, value)),
438 )
439 .unwrap();
440 writer.flush().unwrap();
441 }
442
443 let expected = CommodityBalanceDualsRow {
445 milestone_year,
446 commodity_id,
447 region_id,
448 time_slice,
449 value,
450 };
451 let records: Vec<CommodityBalanceDualsRow> =
452 csv::Reader::from_path(dir.path().join(COMMODITY_BALANCE_DUALS_FILE_NAME))
453 .unwrap()
454 .into_deserialize()
455 .try_collect()
456 .unwrap();
457 assert_equal(records, iter::once(expected));
458 }
459
460 #[rstest]
461 fn test_write_activity_duals(assets: AssetPool, time_slice: TimeSliceID) {
462 let milestone_year = 2020;
463 let value = 0.5;
464 let dir = tempdir().unwrap();
465 let asset = assets.iter().next().unwrap();
466
467 {
469 let mut writer = DebugDataWriter::create(dir.path()).unwrap();
470 writer
471 .write_activity_duals(milestone_year, iter::once((asset, &time_slice, value)))
472 .unwrap();
473 writer.flush().unwrap();
474 }
475
476 let expected = ActivityDualsRow {
478 milestone_year,
479 asset_id: asset.id,
480 time_slice,
481 value,
482 };
483 let records: Vec<ActivityDualsRow> =
484 csv::Reader::from_path(dir.path().join(ACTIVITY_DUALS_FILE_NAME))
485 .unwrap()
486 .into_deserialize()
487 .try_collect()
488 .unwrap();
489 assert_equal(records, iter::once(expected));
490 }
491}