muse2/
output.rs

1//! The module responsible for writing output data to disk.
2use crate::agent::AgentID;
3use crate::asset::{Asset, AssetID, AssetRef};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::optimisation::{FlowMap, Solution};
8use crate::simulation::CommodityPrices;
9use crate::time_slice::TimeSliceID;
10use crate::units::{Flow, MoneyPerFlow};
11use anyhow::{Context, Result};
12use csv;
13use serde::{Deserialize, Serialize};
14use std::fs;
15use std::fs::File;
16use std::path::{Path, PathBuf};
17
18pub mod metadata;
19use metadata::write_metadata;
20
21/// The root folder in which model-specific output folders will be created
22const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results";
23
24/// The output file name for commodity flows
25const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
26
27/// The output file name for commodity prices
28const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
29
30/// The output file name for assets
31const ASSETS_FILE_NAME: &str = "assets.csv";
32
33/// The output file name for commodity balance duals
34const COMMODITY_BALANCE_DUALS_FILE_NAME: &str = "debug_commodity_balance_duals.csv";
35
36/// The output file name for activity duals
37const ACTIVITY_DUALS_FILE_NAME: &str = "debug_activity_duals.csv";
38
39/// Get the model name from the specified directory path
40pub fn get_output_dir(model_dir: &Path) -> Result<PathBuf> {
41    // Get the model name from the dir path. This ends up being convoluted because we need to check
42    // for all possible errors. Ugh.
43    let model_dir = model_dir
44        .canonicalize() // canonicalise in case the user has specified "."
45        .context("Could not resolve path to model")?;
46
47    let model_name = model_dir
48        .file_name()
49        .context("Model cannot be in root folder")?
50        .to_str()
51        .context("Invalid chars in model dir name")?;
52
53    // Construct path
54    Ok([OUTPUT_DIRECTORY_ROOT, model_name].iter().collect())
55}
56
57/// Create a new output directory for the model specified at `model_dir`.
58pub fn create_output_directory(output_dir: &Path) -> Result<()> {
59    if output_dir.is_dir() {
60        // already exists
61        return Ok(());
62    }
63
64    // Try to create the directory, with parents
65    fs::create_dir_all(output_dir)?;
66
67    Ok(())
68}
69
70/// Represents a row in the assets output CSV file.
71#[derive(Serialize, Deserialize, Debug, PartialEq)]
72struct AssetRow {
73    milestone_year: u32,
74    asset_id: AssetID,
75    process_id: ProcessID,
76    region_id: RegionID,
77    agent_id: AgentID,
78    commission_year: u32,
79}
80
81impl AssetRow {
82    /// Create a new [`AssetRow`]
83    fn new(milestone_year: u32, asset: &Asset) -> Self {
84        Self {
85            milestone_year,
86            asset_id: asset.id.unwrap(),
87            process_id: asset.process.id.clone(),
88            region_id: asset.region_id.clone(),
89            agent_id: asset.agent_id.clone(),
90            commission_year: asset.commission_year,
91        }
92    }
93}
94
95/// Represents the flow-related data in a row of the commodity flows CSV file.
96#[derive(Serialize, Deserialize, Debug, PartialEq)]
97struct CommodityFlowRow {
98    milestone_year: u32,
99    asset_id: AssetID,
100    commodity_id: CommodityID,
101    time_slice: TimeSliceID,
102    flow: Flow,
103}
104
105/// Represents a row in the commodity prices CSV file
106#[derive(Serialize, Deserialize, Debug, PartialEq)]
107struct CommodityPriceRow {
108    milestone_year: u32,
109    commodity_id: CommodityID,
110    region_id: RegionID,
111    time_slice: TimeSliceID,
112    price: MoneyPerFlow,
113}
114
115/// Represents the activity duals data in a row of the activity duals CSV file
116#[derive(Serialize, Deserialize, Debug, PartialEq)]
117struct ActivityDualsRow {
118    milestone_year: u32,
119    asset_id: Option<AssetID>,
120    time_slice: TimeSliceID,
121    value: f64,
122}
123
124/// Represents the commodity balance duals data in a row of the commodity balance duals CSV file
125#[derive(Serialize, Deserialize, Debug, PartialEq)]
126struct CommodityBalanceDualsRow {
127    milestone_year: u32,
128    commodity_id: CommodityID,
129    region_id: RegionID,
130    time_slice: TimeSliceID,
131    value: f64,
132}
133
134/// For writing extra debug information about the model
135struct DebugDataWriter {
136    commodity_balance_duals_writer: csv::Writer<File>,
137    activity_duals_writer: csv::Writer<File>,
138}
139
140impl DebugDataWriter {
141    /// Open CSV files to write debug info to
142    ///
143    /// # Arguments
144    ///
145    /// * `output_path` - Folder where files will be saved
146    fn create(output_path: &Path) -> Result<Self> {
147        let new_writer = |file_name| {
148            let file_path = output_path.join(file_name);
149            csv::Writer::from_path(file_path)
150        };
151
152        Ok(Self {
153            commodity_balance_duals_writer: new_writer(COMMODITY_BALANCE_DUALS_FILE_NAME)?,
154            activity_duals_writer: new_writer(ACTIVITY_DUALS_FILE_NAME)?,
155        })
156    }
157
158    /// Write all debug info to output files
159    fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
160        self.write_activity_duals(milestone_year, solution.iter_activity_duals())?;
161        self.write_commodity_balance_duals(
162            milestone_year,
163            solution.iter_commodity_balance_duals(),
164        )?;
165        Ok(())
166    }
167
168    /// Write activity duals to file
169    fn write_activity_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
170    where
171        I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, f64)>,
172    {
173        for (asset, time_slice, value) in iter {
174            let row = ActivityDualsRow {
175                milestone_year,
176                asset_id: asset.id,
177                time_slice: time_slice.clone(),
178                value,
179            };
180            self.activity_duals_writer.serialize(row)?;
181        }
182
183        Ok(())
184    }
185
186    /// Write commodity balance duals to file
187    fn write_commodity_balance_duals<'a, I>(&mut self, milestone_year: u32, iter: I) -> Result<()>
188    where
189        I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, f64)>,
190    {
191        for (commodity_id, region_id, time_slice, value) in iter {
192            let row = CommodityBalanceDualsRow {
193                milestone_year,
194                commodity_id: commodity_id.clone(),
195                region_id: region_id.clone(),
196                time_slice: time_slice.clone(),
197                value,
198            };
199            self.commodity_balance_duals_writer.serialize(row)?;
200        }
201
202        Ok(())
203    }
204
205    /// Flush the underlying streams
206    fn flush(&mut self) -> Result<()> {
207        self.commodity_balance_duals_writer.flush()?;
208        self.activity_duals_writer.flush()?;
209
210        Ok(())
211    }
212}
213
214/// An object for writing commodity prices to file
215pub struct DataWriter {
216    assets_writer: csv::Writer<File>,
217    flows_writer: csv::Writer<File>,
218    prices_writer: csv::Writer<File>,
219    debug_writer: Option<DebugDataWriter>,
220}
221
222impl DataWriter {
223    /// Open CSV files to write output data to
224    ///
225    /// # Arguments
226    ///
227    /// * `output_path` - Folder where files will be saved
228    /// * `model_path` - Path to input model
229    /// * `save_debug_info` - Whether to include extra CSV files for debugging model
230    pub fn create(output_path: &Path, model_path: &Path, save_debug_info: bool) -> Result<Self> {
231        write_metadata(output_path, model_path).context("Failed to save metadata")?;
232
233        let new_writer = |file_name| {
234            let file_path = output_path.join(file_name);
235            csv::Writer::from_path(file_path)
236        };
237
238        let debug_writer = if save_debug_info {
239            // Create debug CSV files
240            Some(DebugDataWriter::create(output_path)?)
241        } else {
242            None
243        };
244
245        Ok(Self {
246            assets_writer: new_writer(ASSETS_FILE_NAME)?,
247            flows_writer: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
248            prices_writer: new_writer(COMMODITY_PRICES_FILE_NAME)?,
249            debug_writer,
250        })
251    }
252
253    /// Write assets to a CSV file
254    pub fn write_assets<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
255    where
256        I: Iterator<Item = &'a AssetRef>,
257    {
258        for asset in assets {
259            let row = AssetRow::new(milestone_year, asset);
260            self.assets_writer.serialize(row)?;
261        }
262
263        Ok(())
264    }
265
266    /// Write commodity flows to a CSV file
267    pub fn write_flows(&mut self, milestone_year: u32, flow_map: &FlowMap) -> Result<()> {
268        for ((asset, commodity_id, time_slice), flow) in flow_map {
269            let row = CommodityFlowRow {
270                milestone_year,
271                asset_id: asset.id.unwrap(),
272                commodity_id: commodity_id.clone(),
273                time_slice: time_slice.clone(),
274                flow: *flow,
275            };
276            self.flows_writer.serialize(row)?;
277        }
278
279        Ok(())
280    }
281
282    /// Write commodity prices to a CSV file
283    pub fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
284        for (commodity_id, region_id, time_slice, price) in prices.iter() {
285            let row = CommodityPriceRow {
286                milestone_year,
287                commodity_id: commodity_id.clone(),
288                region_id: region_id.clone(),
289                time_slice: time_slice.clone(),
290                price,
291            };
292            self.prices_writer.serialize(row)?;
293        }
294
295        Ok(())
296    }
297
298    /// Write debug information to CSV files
299    pub fn write_debug_info(&mut self, milestone_year: u32, solution: &Solution) -> Result<()> {
300        if let Some(ref mut wtr) = &mut self.debug_writer {
301            wtr.write_debug_info(milestone_year, solution)?;
302        }
303
304        Ok(())
305    }
306
307    /// Flush the underlying streams
308    pub fn flush(&mut self) -> Result<()> {
309        self.assets_writer.flush()?;
310        self.flows_writer.flush()?;
311        self.prices_writer.flush()?;
312        if let Some(ref mut wtr) = &mut self.debug_writer {
313            wtr.flush()?;
314        }
315
316        Ok(())
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::asset::AssetPool;
324    use crate::fixture::{assets, commodity_id, region_id, time_slice};
325    use crate::time_slice::TimeSliceID;
326    use indexmap::indexmap;
327    use itertools::{assert_equal, Itertools};
328    use rstest::rstest;
329    use std::iter;
330    use tempfile::tempdir;
331
332    #[rstest]
333    fn test_write_assets(assets: AssetPool) {
334        let milestone_year = 2020;
335        let dir = tempdir().unwrap();
336
337        // Write an asset
338        {
339            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
340            writer.write_assets(milestone_year, assets.iter()).unwrap();
341            writer.flush().unwrap();
342        }
343
344        // Read back and compare
345        let asset = assets.iter().next().unwrap();
346        let expected = AssetRow::new(milestone_year, asset);
347        let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
348            .unwrap()
349            .into_deserialize()
350            .try_collect()
351            .unwrap();
352        assert_equal(records, iter::once(expected));
353    }
354
355    #[rstest]
356    fn test_write_flows(assets: AssetPool, commodity_id: CommodityID, time_slice: TimeSliceID) {
357        let milestone_year = 2020;
358        let asset = assets.iter().next().unwrap();
359        let flow_map = indexmap! {
360            (asset.clone(), commodity_id.clone(), time_slice.clone()) => Flow(42.0)
361        };
362
363        // Write a flow
364        let dir = tempdir().unwrap();
365        {
366            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
367            writer.write_flows(milestone_year, &flow_map).unwrap();
368            writer.flush().unwrap();
369        }
370
371        // Read back and compare
372        let expected = CommodityFlowRow {
373            milestone_year,
374            asset_id: asset.id.unwrap(),
375            commodity_id,
376            time_slice,
377            flow: Flow(42.0),
378        };
379        let records: Vec<CommodityFlowRow> =
380            csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
381                .unwrap()
382                .into_deserialize()
383                .try_collect()
384                .unwrap();
385        assert_equal(records, iter::once(expected));
386    }
387
388    #[rstest]
389    fn test_write_prices(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
390        let milestone_year = 2020;
391        let price = MoneyPerFlow(42.0);
392        let mut prices = CommodityPrices::default();
393        prices.insert(&commodity_id, &region_id, &time_slice, price);
394
395        let dir = tempdir().unwrap();
396
397        // Write a price
398        {
399            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
400            writer.write_prices(milestone_year, &prices).unwrap();
401            writer.flush().unwrap();
402        }
403
404        // Read back and compare
405        let expected = CommodityPriceRow {
406            milestone_year,
407            commodity_id,
408            region_id,
409            time_slice,
410            price,
411        };
412        let records: Vec<CommodityPriceRow> =
413            csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
414                .unwrap()
415                .into_deserialize()
416                .try_collect()
417                .unwrap();
418        assert_equal(records, iter::once(expected));
419    }
420
421    #[rstest]
422    fn test_write_commodity_balance_duals(
423        commodity_id: CommodityID,
424        region_id: RegionID,
425        time_slice: TimeSliceID,
426    ) {
427        let milestone_year = 2020;
428        let value = 0.5;
429        let dir = tempdir().unwrap();
430
431        // Write commodity balance dual
432        {
433            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
434            writer
435                .write_commodity_balance_duals(
436                    milestone_year,
437                    iter::once((&commodity_id, &region_id, &time_slice, value)),
438                )
439                .unwrap();
440            writer.flush().unwrap();
441        }
442
443        // Read back and compare
444        let expected = CommodityBalanceDualsRow {
445            milestone_year,
446            commodity_id,
447            region_id,
448            time_slice,
449            value,
450        };
451        let records: Vec<CommodityBalanceDualsRow> =
452            csv::Reader::from_path(dir.path().join(COMMODITY_BALANCE_DUALS_FILE_NAME))
453                .unwrap()
454                .into_deserialize()
455                .try_collect()
456                .unwrap();
457        assert_equal(records, iter::once(expected));
458    }
459
460    #[rstest]
461    fn test_write_activity_duals(assets: AssetPool, time_slice: TimeSliceID) {
462        let milestone_year = 2020;
463        let value = 0.5;
464        let dir = tempdir().unwrap();
465        let asset = assets.iter().next().unwrap();
466
467        // Write activity dual
468        {
469            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
470            writer
471                .write_activity_duals(milestone_year, iter::once((asset, &time_slice, value)))
472                .unwrap();
473            writer.flush().unwrap();
474        }
475
476        // Read back and compare
477        let expected = ActivityDualsRow {
478            milestone_year,
479            asset_id: asset.id,
480            time_slice,
481            value,
482        };
483        let records: Vec<ActivityDualsRow> =
484            csv::Reader::from_path(dir.path().join(ACTIVITY_DUALS_FILE_NAME))
485                .unwrap()
486                .into_deserialize()
487                .try_collect()
488                .unwrap();
489        assert_equal(records, iter::once(expected));
490    }
491}