muse2/
output.rs

1//! The module responsible for writing output data to disk.
2use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetPool, AssetRef};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::optimisation::{FlowMap, Solution};
8use crate::simulation::CommodityPrices;
9use crate::time_slice::TimeSliceID;
10use crate::units::{Flow, MoneyPerActivity, MoneyPerFlow};
11use anyhow::{Context, Result};
12use csv;
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::fs::File;
16use std::path::{Path, PathBuf};
17
18pub mod metadata;
19use metadata::write_metadata;
20
21/// The root folder in which model-specific output folders will be created
22const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
23
24/// The output file name for commodity flows
25const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
26
27/// The output file name for commodity prices
28const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
29
30/// The output file name for assets
31const ASSETS_FILE_NAME: &str = "assets.csv";
32
33/// The output file name for commodity balance duals
34const COMMODITY_BALANCE_DUALS_FILE_NAME: &str = "debug_commodity_balance_duals.csv";
35
36/// The output file name for activity duals
37const ACTIVITY_DUALS_FILE_NAME: &str = "debug_activity_duals.csv";
38
39/// Get the model name from the specified directory path
40pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
41    // Get the model name from the dir path. This ends up being convoluted because we need to check
42    // for all possible errors. Ugh.
43    let model_dir = model_dir
44        .canonicalize() // canonicalise in case the user has specified "."
45        .context("Could not resolve path to model")?;
46
47    let model_name = model_dir
48        .file_name()
49        .context("Model cannot be in root folder")?
50        .to_str()
51        .context("Invalid chars in model dir name")?;
52
53    // Construct path
54    Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
55}
56
57/// Create a new output directory for the model specified at `model_dir`.
58pub fn create_output_directory(output_dir: &Path) -> Result<()> {
59    if output_dir.is_dir() {
60        // already exists
61        return Ok(());
62    }
63
64    // Try to create the directory, with parents
65    fs::create_dir_all(output_dir)?;
66
67    Ok(())
68}
69
70/// Represents a row in the assets output CSV file.
71#[derive(Serialize, Deserialize, Debug, PartialEq)]
72struct AssetRow {
73    milestone_year: u32,
74    asset_id: AssetID,
75    process_id: ProcessID,
76    region_id: RegionID,
77    agent_id: AgentID,
78    commission_year: u32,
79}
80
81impl AssetRow {
82    /// Create a new [`AssetRow`]
83    fn new(milestone_year: u32, asset: &Asset) -> Self {
84        Self {
85            milestone_year,
86            asset_id: asset.id.unwrap(),
87            process_id: asset.process.id.clone(),
88            region_id: asset.region_id.clone(),
89            agent_id: asset.agent_id.clone().unwrap(),
90            commission_year: asset.commission_year,
91        }
92    }
93}
94
95/// Represents the flow-related data in a row of the commodity flows CSV file.
96#[derive(Serialize, Deserialize, Debug, PartialEq)]
97struct CommodityFlowRow {
98    milestone_year: u32,
99    asset_id: AssetID,
100    commodity_id: CommodityID,
101    time_slice: TimeSliceID,
102    flow: Flow,
103}
104
105/// Represents a row in the commodity prices CSV file
106#[derive(Serialize, Deserialize, Debug, PartialEq)]
107struct CommodityPriceRow {
108    milestone_year: u32,
109    commodity_id: CommodityID,
110    region_id: RegionID,
111    time_slice: TimeSliceID,
112    price: MoneyPerFlow,
113}
114
115/// Represents the activity duals data in a row of the activity duals CSV file
116#[derive(Serialize, Deserialize, Debug, PartialEq)]
117struct ActivityDualsRow {
118    milestone_year: u32,
119    asset_id: Option<AssetID>,
120    time_slice: TimeSliceID,
121    value: MoneyPerActivity,
122}
123
124/// Represents the commodity balance duals data in a row of the commodity balance duals CSV file
125#[derive(Serialize, Deserialize, Debug, PartialEq)]
126struct CommodityBalanceDualsRow {
127    milestone_year: u32,
128    commodity_id: CommodityID,
129    region_id: RegionID,
130    time_slice: TimeSliceID,
131    value: MoneyPerFlow,
132}
133
134/// For writing extra debug information about the model
135struct DebugDataWriter {
136    commodity_balance_duals_writer: csv::Writer<File>,
137    activity_duals_writer: csv::Writer<File>,
138}
139
140impl DebugDataWriter {
141    /// Open CSV files to write debug info to
142    ///
143    /// # Arguments
144    ///
145    /// * `output_path` - Folder where files will be saved
146    fn create(output_path: &Path) -> Result<Self> {
147        let new_writer = |file_name| {
148            let file_path = output_path.join(file_name);
149            csv::Writer::from_path(file_path)
150        };
151
152        Ok(Self {
153            commodity_balance_duals_writer: new_writer(COMMODITY_BALANCE_DUALS_FILE_NAME)?,
154            activity_duals_writer: new_writer(ACTIVITY_DUALS_FILE_NAME)?,
155        })
156    }
157
158    /// Write all debug info to output files
159    fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
160        self.write_activity_duals(milestone_year, solution.iter_activity_duals())?;
161        self.write_commodity_balance_duals(
162            milestone_year,
163            solution.iter_commodity_balance_duals(),
164        )?;
165        Ok(())
166    }
167
168    /// Write activity duals to file
169    fn write_activity_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
170    where
171        I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
172    {
173        for (asset, time_slice, value) in iter {
174            let row = ActivityDualsRow {
175                milestone_year,
176                asset_id: asset.id,
177                time_slice: time_slice.clone(),
178                value,
179            };
180            self.activity_duals_writer.serialize(row)?;
181        }
182
183        Ok(())
184    }
185
186    /// Write commodity balance duals to file
187    fn write_commodity_balance_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
188    where
189        I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, MoneyPerFlow)>,
190    {
191        for (commodity_id, region_id, time_slice, value) in iter {
192            let row = CommodityBalanceDualsRow {
193                milestone_year,
194                commodity_id: commodity_id.clone(),
195                region_id: region_id.clone(),
196                time_slice: time_slice.clone(),
197                value,
198            };
199            self.commodity_balance_duals_writer.serialize(row)?;
200        }
201
202        Ok(())
203    }
204
205    /// Flush the underlying streams
206    fn flush(&mut self) -> Result<()> {
207        self.commodity_balance_duals_writer.flush()?;
208        self.activity_duals_writer.flush()?;
209
210        Ok(())
211    }
212}
213
214/// An object for writing commodity prices to file
215pub struct DataWriter {
216    assets_writer: csv::Writer<File>,
217    flows_writer: csv::Writer<File>,
218    prices_writer: csv::Writer<File>,
219    debug_writer: Option<DebugDataWriter>,
220}
221
222impl DataWriter {
223    /// Open CSV files to write output data to
224    ///
225    /// # Arguments
226    ///
227    /// * `output_path` - Folder where files will be saved
228    /// * `model_path` - Path to input model
229    /// * `save_debug_info` - Whether to include extra CSV files for debugging model
230    pub fn create(output_path: &Path, model_path: &Path, save_debug_info: bool) -> Result<Self> {
231        write_metadata(output_path, model_path).context("Failed to save metadata")?;
232
233        let new_writer = |file_name| {
234            let file_path = output_path.join(file_name);
235            csv::Writer::from_path(file_path)
236        };
237
238        let debug_writer = if save_debug_info {
239            // Create debug CSV files
240            Some(DebugDataWriter::create(output_path)?)
241        } else {
242            None
243        };
244
245        Ok(Self {
246            assets_writer: new_writer(ASSETS_FILE_NAME)?,
247            flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
248            prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
249            debug_writer,
250        })
251    }
252
253    /// Write information to various output CSV files
254    pub fn write(
255        &mut self,
256        milestone_year: u32,
257        solution: &Solution,
258        assets: &AssetPool,
259        flow_map: &FlowMap,
260        prices: &CommodityPrices,
261    ) -> Result<()> {
262        if let Some(ref mut wtr) = &mut self.debug_writer {
263            wtr.write_debug_info(milestone_year, solution)?;
264        }
265
266        self.write_assets(milestone_year, assets.iter())?;
267        self.write_flows(milestone_year, flow_map)?;
268        self.write_prices(milestone_year, prices)?;
269
270        Ok(())
271    }
272
273    /// Write assets to a CSV file
274    fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
275    where
276        I: Iterator<Item = &'a AssetRef>,
277    {
278        for asset in assets {
279            let row = AssetRow::new(milestone_year, asset);
280            self.assets_writer.serialize(row)?;
281        }
282
283        Ok(())
284    }
285
286    /// Write commodity flows to a CSV file
287    fn write_flows(&mut self, milestone_year: u32, flow_map: &FlowMap) -> Result<()> {
288        for ((asset, commodity_id, time_slice), flow) in flow_map {
289            let row = CommodityFlowRow {
290                milestone_year,
291                asset_id: asset.id.unwrap(),
292                commodity_id: commodity_id.clone(),
293                time_slice: time_slice.clone(),
294                flow: *flow,
295            };
296            self.flows_writer.serialize(row)?;
297        }
298
299        Ok(())
300    }
301
302    /// Write commodity prices to a CSV file
303    fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
304        for (commodity_id, region_id, time_slice, price) in prices.iter() {
305            let row = CommodityPriceRow {
306                milestone_year,
307                commodity_id: commodity_id.clone(),
308                region_id: region_id.clone(),
309                time_slice: time_slice.clone(),
310                price,
311            };
312            self.prices_writer.serialize(row)?;
313        }
314
315        Ok(())
316    }
317
318    /// Flush the underlying streams
319    pub fn flush(&mut self) -> Result<()> {
320        self.assets_writer.flush()?;
321        self.flows_writer.flush()?;
322        self.prices_writer.flush()?;
323        if let Some(ref mut wtr) = &mut self.debug_writer {
324            wtr.flush()?;
325        }
326
327        Ok(())
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::asset::AssetPool;
335    use crate::fixture::{assets, commodity_id, region_id, time_slice};
336    use crate::time_slice::TimeSliceID;
337    use indexmap::indexmap;
338    use itertools::{assert_equal, Itertools};
339    use rstest::rstest;
340    use std::iter;
341    use tempfile::tempdir;
342
343    #[rstest]
344    fn test_write_assets(assets: AssetPool) {
345        let milestone_year = 2020;
346        let dir = tempdir().unwrap();
347
348        // Write an asset
349        {
350            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
351            writer.write_assets(milestone_year, assets.iter()).unwrap();
352            writer.flush().unwrap();
353        }
354
355        // Read back and compare
356        let asset = assets.iter().next().unwrap();
357        let expected = AssetRow::new(milestone_year, asset);
358        let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
359            .unwrap()
360            .into_deserialize()
361            .try_collect()
362            .unwrap();
363        assert_equal(records, iter::once(expected));
364    }
365
366    #[rstest]
367    fn test_write_flows(assets: AssetPool, commodity_id: CommodityID, time_slice: TimeSliceID) {
368        let milestone_year = 2020;
369        let asset = assets.iter().next().unwrap();
370        let flow_map = indexmap! {
371            (asset.clone(), commodity_id.clone(), time_slice.clone()) => Flow(42.0)
372        };
373
374        // Write a flow
375        let dir = tempdir().unwrap();
376        {
377            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
378            writer.write_flows(milestone_year, &flow_map).unwrap();
379            writer.flush().unwrap();
380        }
381
382        // Read back and compare
383        let expected = CommodityFlowRow {
384            milestone_year,
385            asset_id: asset.id.unwrap(),
386            commodity_id,
387            time_slice,
388            flow: Flow(42.0),
389        };
390        let records: Vec<CommodityFlowRow> =
391            csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
392                .unwrap()
393                .into_deserialize()
394                .try_collect()
395                .unwrap();
396        assert_equal(records, iter::once(expected));
397    }
398
399    #[rstest]
400    fn test_write_prices(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
401        let milestone_year = 2020;
402        let price = MoneyPerFlow(42.0);
403        let mut prices = CommodityPrices::default();
404        prices.insert(&commodity_id, &region_id, &time_slice, price);
405
406        let dir = tempdir().unwrap();
407
408        // Write a price
409        {
410            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
411            writer.write_prices(milestone_year, &prices).unwrap();
412            writer.flush().unwrap();
413        }
414
415        // Read back and compare
416        let expected = CommodityPriceRow {
417            milestone_year,
418            commodity_id,
419            region_id,
420            time_slice,
421            price,
422        };
423        let records: Vec<CommodityPriceRow> =
424            csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
425                .unwrap()
426                .into_deserialize()
427                .try_collect()
428                .unwrap();
429        assert_equal(records, iter::once(expected));
430    }
431
432    #[rstest]
433    fn test_write_commodity_balance_duals(
434        commodity_id: CommodityID,
435        region_id: RegionID,
436        time_slice: TimeSliceID,
437    ) {
438        let milestone_year = 2020;
439        let value = MoneyPerFlow(0.5);
440        let dir = tempdir().unwrap();
441
442        // Write commodity balance dual
443        {
444            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
445            writer
446                .write_commodity_balance_duals(
447                    milestone_year,
448                    iter::once((&commodity_id, &region_id, &time_slice, value)),
449                )
450                .unwrap();
451            writer.flush().unwrap();
452        }
453
454        // Read back and compare
455        let expected = CommodityBalanceDualsRow {
456            milestone_year,
457            commodity_id,
458            region_id,
459            time_slice,
460            value,
461        };
462        let records: Vec<CommodityBalanceDualsRow> =
463            csv::Reader::from_path(dir.path().join(COMMODITY_BALANCE_DUALS_FILE_NAME))
464                .unwrap()
465                .into_deserialize()
466                .try_collect()
467                .unwrap();
468        assert_equal(records, iter::once(expected));
469    }
470
471    #[rstest]
472    fn test_write_activity_duals(assets: AssetPool, time_slice: TimeSliceID) {
473        let milestone_year = 2020;
474        let value = MoneyPerActivity(0.5);
475        let dir = tempdir().unwrap();
476        let asset = assets.iter().next().unwrap();
477
478        // Write activity dual
479        {
480            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
481            writer
482                .write_activity_duals(milestone_year, iter::once((asset, &time_slice, value)))
483                .unwrap();
484            writer.flush().unwrap();
485        }
486
487        // Read back and compare
488        let expected = ActivityDualsRow {
489            milestone_year,
490            asset_id: asset.id,
491            time_slice,
492            value,
493        };
494        let records: Vec<ActivityDualsRow> =
495            csv::Reader::from_path(dir.path().join(ACTIVITY_DUALS_FILE_NAME))
496                .unwrap()
497                .into_deserialize()
498                .try_collect()
499                .unwrap();
500        assert_equal(records, iter::once(expected));
501    }
502}