1use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetPool, AssetRef};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::optimisation::{FlowMap, Solution};
8use crate::simulation::CommodityPrices;
9use crate::time_slice::TimeSliceID;
10use crate::units::{Flow, MoneyPerActivity, MoneyPerFlow};
11use anyhow::{Context, Result};
12use csv;
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::fs::File;
16use std::path::{Path, PathBuf};
17
18pub mod metadata;
19use metadata::write_metadata;
20
21const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
23
24const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
26
27const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
29
30const ASSETS_FILE_NAME: &str = "assets.csv";
32
33const COMMODITY_BALANCE_DUALS_FILE_NAME: &str = "debug_commodity_balance_duals.csv";
35
36const ACTIVITY_DUALS_FILE_NAME: &str = "debug_activity_duals.csv";
38
39pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
41 let model_dir = model_dir
44 .canonicalize() .context("Could not resolve path to model")?;
46
47 let model_name = model_dir
48 .file_name()
49 .context("Model cannot be in root folder")?
50 .to_str()
51 .context("Invalid chars in model dir name")?;
52
53 Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
55}
56
57pub fn create_output_directory(output_dir: &Path) -> Result<()> {
59 if output_dir.is_dir() {
60 return Ok(());
62 }
63
64 fs::create_dir_all(output_dir)?;
66
67 Ok(())
68}
69
70#[derive(Serialize, Deserialize, Debug, PartialEq)]
72struct AssetRow {
73 milestone_year: u32,
74 asset_id: AssetID,
75 process_id: ProcessID,
76 region_id: RegionID,
77 agent_id: AgentID,
78 commission_year: u32,
79}
80
81impl AssetRow {
82 fn new(milestone_year: u32, asset: &Asset) -> Self {
84 Self {
85 milestone_year,
86 asset_id: asset.id.unwrap(),
87 process_id: asset.process.id.clone(),
88 region_id: asset.region_id.clone(),
89 agent_id: asset.agent_id.clone().unwrap(),
90 commission_year: asset.commission_year,
91 }
92 }
93}
94
95#[derive(Serialize, Deserialize, Debug, PartialEq)]
97struct CommodityFlowRow {
98 milestone_year: u32,
99 asset_id: AssetID,
100 commodity_id: CommodityID,
101 time_slice: TimeSliceID,
102 flow: Flow,
103}
104
105#[derive(Serialize, Deserialize, Debug, PartialEq)]
107struct CommodityPriceRow {
108 milestone_year: u32,
109 commodity_id: CommodityID,
110 region_id: RegionID,
111 time_slice: TimeSliceID,
112 price: MoneyPerFlow,
113}
114
115#[derive(Serialize, Deserialize, Debug, PartialEq)]
117struct ActivityDualsRow {
118 milestone_year: u32,
119 asset_id: Option<AssetID>,
120 time_slice: TimeSliceID,
121 value: MoneyPerActivity,
122}
123
124#[derive(Serialize, Deserialize, Debug, PartialEq)]
126struct CommodityBalanceDualsRow {
127 milestone_year: u32,
128 commodity_id: CommodityID,
129 region_id: RegionID,
130 time_slice: TimeSliceID,
131 value: MoneyPerFlow,
132}
133
134struct DebugDataWriter {
136 commodity_balance_duals_writer: csv::Writer<File>,
137 activity_duals_writer: csv::Writer<File>,
138}
139
140impl DebugDataWriter {
141 fn create(output_path: &Path) -> Result<Self> {
147 let new_writer = |file_name| {
148 let file_path = output_path.join(file_name);
149 csv::Writer::from_path(file_path)
150 };
151
152 Ok(Self {
153 commodity_balance_duals_writer: new_writer(COMMODITY_BALANCE_DUALS_FILE_NAME)?,
154 activity_duals_writer: new_writer(ACTIVITY_DUALS_FILE_NAME)?,
155 })
156 }
157
158 fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
160 self.write_activity_duals(milestone_year, solution.iter_activity_duals())?;
161 self.write_commodity_balance_duals(
162 milestone_year,
163 solution.iter_commodity_balance_duals(),
164 )?;
165 Ok(())
166 }
167
168 fn write_activity_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
170 where
171 I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
172 {
173 for (asset, time_slice, value) in iter {
174 let row = ActivityDualsRow {
175 milestone_year,
176 asset_id: asset.id,
177 time_slice: time_slice.clone(),
178 value,
179 };
180 self.activity_duals_writer.serialize(row)?;
181 }
182
183 Ok(())
184 }
185
186 fn write_commodity_balance_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
188 where
189 I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, MoneyPerFlow)>,
190 {
191 for (commodity_id, region_id, time_slice, value) in iter {
192 let row = CommodityBalanceDualsRow {
193 milestone_year,
194 commodity_id: commodity_id.clone(),
195 region_id: region_id.clone(),
196 time_slice: time_slice.clone(),
197 value,
198 };
199 self.commodity_balance_duals_writer.serialize(row)?;
200 }
201
202 Ok(())
203 }
204
205 fn flush(&mut self) -> Result<()> {
207 self.commodity_balance_duals_writer.flush()?;
208 self.activity_duals_writer.flush()?;
209
210 Ok(())
211 }
212}
213
214pub struct DataWriter {
216 assets_writer: csv::Writer<File>,
217 flows_writer: csv::Writer<File>,
218 prices_writer: csv::Writer<File>,
219 debug_writer: Option<DebugDataWriter>,
220}
221
222impl DataWriter {
223 pub fn create(output_path: &Path, model_path: &Path, save_debug_info: bool) -> Result<Self> {
231 write_metadata(output_path, model_path).context("Failed to save metadata")?;
232
233 let new_writer = |file_name| {
234 let file_path = output_path.join(file_name);
235 csv::Writer::from_path(file_path)
236 };
237
238 let debug_writer = if save_debug_info {
239 Some(DebugDataWriter::create(output_path)?)
241 } else {
242 None
243 };
244
245 Ok(Self {
246 assets_writer: new_writer(ASSETS_FILE_NAME)?,
247 flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
248 prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
249 debug_writer,
250 })
251 }
252
253 pub fn write(
255 &mut self,
256 milestone_year: u32,
257 solution: &Solution,
258 assets: &AssetPool,
259 flow_map: &FlowMap,
260 prices: &CommodityPrices,
261 ) -> Result<()> {
262 if let Some(ref mut wtr) = &mut self.debug_writer {
263 wtr.write_debug_info(milestone_year, solution)?;
264 }
265
266 self.write_assets(milestone_year, assets.iter())?;
267 self.write_flows(milestone_year, flow_map)?;
268 self.write_prices(milestone_year, prices)?;
269
270 Ok(())
271 }
272
273 fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
275 where
276 I: Iterator<Item = &'a AssetRef>,
277 {
278 for asset in assets {
279 let row = AssetRow::new(milestone_year, asset);
280 self.assets_writer.serialize(row)?;
281 }
282
283 Ok(())
284 }
285
286 fn write_flows(&mut self, milestone_year: u32, flow_map: &FlowMap) -> Result<()> {
288 for ((asset, commodity_id, time_slice), flow) in flow_map {
289 let row = CommodityFlowRow {
290 milestone_year,
291 asset_id: asset.id.unwrap(),
292 commodity_id: commodity_id.clone(),
293 time_slice: time_slice.clone(),
294 flow: *flow,
295 };
296 self.flows_writer.serialize(row)?;
297 }
298
299 Ok(())
300 }
301
302 fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
304 for (commodity_id, region_id, time_slice, price) in prices.iter() {
305 let row = CommodityPriceRow {
306 milestone_year,
307 commodity_id: commodity_id.clone(),
308 region_id: region_id.clone(),
309 time_slice: time_slice.clone(),
310 price,
311 };
312 self.prices_writer.serialize(row)?;
313 }
314
315 Ok(())
316 }
317
318 pub fn flush(&mut self) -> Result<()> {
320 self.assets_writer.flush()?;
321 self.flows_writer.flush()?;
322 self.prices_writer.flush()?;
323 if let Some(ref mut wtr) = &mut self.debug_writer {
324 wtr.flush()?;
325 }
326
327 Ok(())
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::asset::AssetPool;
335 use crate::fixture::{assets, commodity_id, region_id, time_slice};
336 use crate::time_slice::TimeSliceID;
337 use indexmap::indexmap;
338 use itertools::{assert_equal, Itertools};
339 use rstest::rstest;
340 use std::iter;
341 use tempfile::tempdir;
342
343 #[rstest]
344 fn test_write_assets(assets: AssetPool) {
345 let milestone_year = 2020;
346 let dir = tempdir().unwrap();
347
348 {
350 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
351 writer.write_assets(milestone_year, assets.iter()).unwrap();
352 writer.flush().unwrap();
353 }
354
355 let asset = assets.iter().next().unwrap();
357 let expected = AssetRow::new(milestone_year, asset);
358 let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
359 .unwrap()
360 .into_deserialize()
361 .try_collect()
362 .unwrap();
363 assert_equal(records, iter::once(expected));
364 }
365
366 #[rstest]
367 fn test_write_flows(assets: AssetPool, commodity_id: CommodityID, time_slice: TimeSliceID) {
368 let milestone_year = 2020;
369 let asset = assets.iter().next().unwrap();
370 let flow_map = indexmap! {
371 (asset.clone(), commodity_id.clone(), time_slice.clone()) => Flow(42.0)
372 };
373
374 let dir = tempdir().unwrap();
376 {
377 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
378 writer.write_flows(milestone_year, &flow_map).unwrap();
379 writer.flush().unwrap();
380 }
381
382 let expected = CommodityFlowRow {
384 milestone_year,
385 asset_id: asset.id.unwrap(),
386 commodity_id,
387 time_slice,
388 flow: Flow(42.0),
389 };
390 let records: Vec<CommodityFlowRow> =
391 csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
392 .unwrap()
393 .into_deserialize()
394 .try_collect()
395 .unwrap();
396 assert_equal(records, iter::once(expected));
397 }
398
399 #[rstest]
400 fn test_write_prices(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
401 let milestone_year = 2020;
402 let price = MoneyPerFlow(42.0);
403 let mut prices = CommodityPrices::default();
404 prices.insert(&commodity_id, ®ion_id, &time_slice, price);
405
406 let dir = tempdir().unwrap();
407
408 {
410 let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
411 writer.write_prices(milestone_year, &prices).unwrap();
412 writer.flush().unwrap();
413 }
414
415 let expected = CommodityPriceRow {
417 milestone_year,
418 commodity_id,
419 region_id,
420 time_slice,
421 price,
422 };
423 let records: Vec<CommodityPriceRow> =
424 csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
425 .unwrap()
426 .into_deserialize()
427 .try_collect()
428 .unwrap();
429 assert_equal(records, iter::once(expected));
430 }
431
432 #[rstest]
433 fn test_write_commodity_balance_duals(
434 commodity_id: CommodityID,
435 region_id: RegionID,
436 time_slice: TimeSliceID,
437 ) {
438 let milestone_year = 2020;
439 let value = MoneyPerFlow(0.5);
440 let dir = tempdir().unwrap();
441
442 {
444 let mut writer = DebugDataWriter::create(dir.path()).unwrap();
445 writer
446 .write_commodity_balance_duals(
447 milestone_year,
448 iter::once((&commodity_id, ®ion_id, &time_slice, value)),
449 )
450 .unwrap();
451 writer.flush().unwrap();
452 }
453
454 let expected = CommodityBalanceDualsRow {
456 milestone_year,
457 commodity_id,
458 region_id,
459 time_slice,
460 value,
461 };
462 let records: Vec<CommodityBalanceDualsRow> =
463 csv::Reader::from_path(dir.path().join(COMMODITY_BALANCE_DUALS_FILE_NAME))
464 .unwrap()
465 .into_deserialize()
466 .try_collect()
467 .unwrap();
468 assert_equal(records, iter::once(expected));
469 }
470
471 #[rstest]
472 fn test_write_activity_duals(assets: AssetPool, time_slice: TimeSliceID) {
473 let milestone_year = 2020;
474 let value = MoneyPerActivity(0.5);
475 let dir = tempdir().unwrap();
476 let asset = assets.iter().next().unwrap();
477
478 {
480 let mut writer = DebugDataWriter::create(dir.path()).unwrap();
481 writer
482 .write_activity_duals(milestone_year, iter::once((asset, &time_slice, value)))
483 .unwrap();
484 writer.flush().unwrap();
485 }
486
487 let expected = ActivityDualsRow {
489 milestone_year,
490 asset_id: asset.id,
491 time_slice,
492 value,
493 };
494 let records: Vec<ActivityDualsRow> =
495 csv::Reader::from_path(dir.path().join(ACTIVITY_DUALS_FILE_NAME))
496 .unwrap()
497 .into_deserialize()
498 .try_collect()
499 .unwrap();
500 assert_equal(records, iter::once(expected));
501 }
502}