Skip to main content

muse2/
output.rs

1//! The module responsible for writing output data to disk.
2use crate::agent::AgentID;
3use crate::asset::{Asset, AssetGroupID, AssetID, AssetRef};
4use crate::commodity::CommodityID;
5use crate::process::ProcessID;
6use crate::region::RegionID;
7use crate::simulation::CommodityPrices;
8use crate::simulation::investment::appraisal::AppraisalOutput;
9use crate::simulation::optimisation::{FlowMap, Solution};
10use crate::time_slice::TimeSliceID;
11use crate::units::{
12    Activity, Capacity, Flow, Money, MoneyPerActivity, MoneyPerCapacity, MoneyPerFlow,
13};
14use anyhow::{Context, Result, ensure};
15use csv;
16use indexmap::IndexMap;
17use serde::{Deserialize, Serialize};
18use std::collections::HashSet;
19use std::fs;
20use std::fs::File;
21use std::path::{Path, PathBuf};
22
23pub mod metadata;
24use metadata::write_metadata;
25
26/// The output file name for commodity flows
27const COMMODITY_FLOWS_FILE_NAME: &str = "commodity_flows.csv";
28
29/// The output file name for commodity prices
30const COMMODITY_PRICES_FILE_NAME: &str = "commodity_prices.csv";
31
32/// The output file name for assets
33const ASSETS_FILE_NAME: &str = "assets.csv";
34
35/// The output file name for asset capacities
36const ASSET_CAPACITIES_FILE_NAME: &str = "asset_capacities.csv";
37
38/// Debug output file for asset dispatch
39const ACTIVITY_ASSET_DISPATCH: &str = "debug_dispatch_assets.csv";
40
41/// The output file name for commodity balance duals
42const COMMODITY_BALANCE_DUALS_FILE_NAME: &str = "debug_commodity_balance_duals.csv";
43
44/// The output file name for unmet demand values
45const UNMET_DEMAND_FILE_NAME: &str = "debug_unmet_demand.csv";
46
47/// The output file name for extra solver output values
48const SOLVER_VALUES_FILE_NAME: &str = "debug_solver.csv";
49
50/// The output file name for appraisal results
51const APPRAISAL_RESULTS_FILE_NAME: &str = "debug_appraisal_results.csv";
52
53/// The output file name for appraisal time slice results
54const APPRAISAL_RESULTS_TIME_SLICE_FILE_NAME: &str = "debug_appraisal_results_time_slices.csv";
55
56/// Get the default output directory for the model
57pub fn get_output_dir(model_dir: &Path, results_root: PathBuf) -> Result<PathBuf> {
58    // Get the model name from the dir path. This ends up being convoluted because we need to check
59    // for all possible errors. Ugh.
60    let model_dir = model_dir
61        .canonicalize() // canonicalise in case the user has specified "."
62        .context("Could not resolve path to model")?;
63
64    let model_name = model_dir
65        .file_name()
66        .context("Model cannot be in root folder")?
67        .to_str()
68        .context("Invalid chars in model dir name")?;
69
70    // Construct path
71    Ok([results_root, model_name.into()].iter().collect())
72}
73
74/// Get the default output directory for commodity flow graphs for the model
75pub fn get_graphs_dir(model_dir: &Path, graph_results_root: PathBuf) -> Result<PathBuf> {
76    let model_dir = model_dir
77        .canonicalize() // canonicalise in case the user has specified "."
78        .context("Could not resolve path to model")?;
79    let model_name = model_dir
80        .file_name()
81        .context("Model cannot be in root folder")?
82        .to_str()
83        .context("Invalid chars in model dir name")?;
84    Ok([graph_results_root, model_name.into()].iter().collect())
85}
86
87/// Create a new output directory for the model, optionally overwriting existing data
88///
89/// # Arguments
90///
91/// * `output_dir` - The output directory to create/overwrite
92/// * `allow_overwrite` - Whether to delete and recreate the folder if it is non-empty
93///
94/// # Returns
95///
96/// True if the output dir contained existing data that was deleted, false if not, or an error.
97pub fn create_output_directory(output_dir: &Path, allow_overwrite: bool) -> Result<bool> {
98    // If the folder already exists, then delete it
99    let overwrite = if let Ok(mut it) = fs::read_dir(output_dir) {
100        if it.next().is_none() {
101            // Folder exists and is empty: nothing to do
102            return Ok(false);
103        }
104
105        ensure!(
106            allow_overwrite,
107            "Output folder already exists and is not empty. \
108            Please delete the folder or pass the --overwrite command-line option."
109        );
110
111        fs::remove_dir_all(output_dir).context("Could not delete folder")?;
112        true
113    } else {
114        false
115    };
116
117    // Try to create the directory, with parents
118    fs::create_dir_all(output_dir)?;
119
120    Ok(overwrite)
121}
122
123/// Copy input files to output directory
124pub fn copy_input_files(model_dir: &Path, output_dir: &Path, model_name: &str) -> Result<()> {
125    // Get the model name from the dir path.
126    let mut input_copy_dir = output_dir.to_path_buf();
127    input_copy_dir.extend(["input", model_name]);
128
129    fs::create_dir_all(&input_copy_dir).context("Could not create input copy directory")?;
130
131    for entry in fs::read_dir(model_dir)? {
132        let entry = entry?;
133        let path = entry.path();
134        if path.is_file() {
135            let file_name = path.file_name().unwrap();
136            fs::copy(&path, input_copy_dir.join(file_name))?;
137        }
138    }
139    Ok(())
140}
141
142/// Represents a row in the assets output CSV file.
143#[derive(Serialize, Deserialize, Debug, PartialEq)]
144struct AssetRow {
145    asset_id: Option<AssetID>,
146    group_id: Option<AssetGroupID>,
147    process_id: ProcessID,
148    region_id: RegionID,
149    agent_id: AgentID,
150    commission_year: u32,
151}
152
153impl AssetRow {
154    /// Create a new [`AssetRow`] for a non-group asset
155    fn new(asset: &Asset) -> Self {
156        Self {
157            asset_id: asset.id(),
158            group_id: None,
159            process_id: asset.process_id().clone(),
160            region_id: asset.region_id().clone(),
161            agent_id: asset.agent_id().unwrap().clone(),
162            commission_year: asset.commission_year(),
163        }
164    }
165
166    /// Create a new [`AssetRow`] for a group, using the parent asset's metadata
167    fn from_parent(parent: &Asset) -> Self {
168        Self {
169            asset_id: None,
170            group_id: parent.group_id(),
171            process_id: parent.process_id().clone(),
172            region_id: parent.region_id().clone(),
173            agent_id: parent.agent_id().unwrap().clone(),
174            commission_year: parent.commission_year(),
175        }
176    }
177}
178
179/// Represents a row in the asset capacities output CSV file.
180#[derive(Serialize, Deserialize, Debug, PartialEq)]
181struct AssetCapacityRow {
182    milestone_year: u32,
183    asset_id: Option<AssetID>,
184    group_id: Option<AssetGroupID>,
185    capacity: Capacity,
186    num_units: Option<u32>,
187}
188
189/// Represents the flow-related data in a row of the commodity flows CSV file.
190#[derive(Serialize, Deserialize, Debug, PartialEq)]
191struct CommodityFlowRow {
192    milestone_year: u32,
193    asset_id: Option<AssetID>,
194    group_id: Option<AssetGroupID>,
195    commodity_id: CommodityID,
196    time_slice: TimeSliceID,
197    flow: Flow,
198}
199
200/// Represents a row in the commodity prices CSV file
201#[derive(Serialize, Deserialize, Debug, PartialEq)]
202struct CommodityPriceRow {
203    milestone_year: u32,
204    commodity_id: CommodityID,
205    region_id: RegionID,
206    time_slice: TimeSliceID,
207    price: MoneyPerFlow,
208}
209
210/// Represents the activity in a row of the dispatch CSV file
211#[derive(Serialize, Deserialize, Debug, PartialEq)]
212struct DispatchRow {
213    milestone_year: u32,
214    run_description: String,
215    asset_id: Option<AssetID>,
216    group_id: Option<AssetGroupID>,
217    process_id: ProcessID,
218    region_id: RegionID,
219    time_slice: TimeSliceID,
220    activity: Option<Activity>,
221    activity_dual: Option<MoneyPerActivity>,
222    column_dual: Option<MoneyPerActivity>,
223}
224
225/// Represents the commodity balance duals data in a row of the commodity balance duals CSV file
226#[derive(Serialize, Deserialize, Debug, PartialEq)]
227struct CommodityBalanceDualsRow {
228    milestone_year: u32,
229    run_description: String,
230    commodity_id: CommodityID,
231    region_id: RegionID,
232    time_slice: TimeSliceID,
233    value: MoneyPerFlow,
234}
235
236/// Represents the unmet demand data in a row of the unmet demand CSV file
237#[derive(Serialize, Deserialize, Debug, PartialEq)]
238struct UnmetDemandRow {
239    milestone_year: u32,
240    run_description: String,
241    commodity_id: CommodityID,
242    region_id: RegionID,
243    time_slice: TimeSliceID,
244    value: Flow,
245}
246
247/// Represents solver output values
248#[derive(Serialize, Deserialize, Debug, PartialEq)]
249struct SolverValuesRow {
250    milestone_year: u32,
251    run_description: String,
252    objective_value: Money,
253}
254
255/// Represents the appraisal results in a row of the appraisal results CSV file
256#[derive(Serialize, Deserialize, Debug, PartialEq)]
257struct AppraisalResultsRow {
258    milestone_year: u32,
259    run_description: String,
260    asset_id: Option<AssetID>,
261    process_id: ProcessID,
262    region_id: RegionID,
263    capacity: Capacity,
264    capacity_coefficient: MoneyPerCapacity,
265    metric: Option<f64>,
266}
267
268/// Represents the appraisal results in a row of the appraisal results CSV file
269#[derive(Serialize, Deserialize, Debug, PartialEq)]
270struct AppraisalResultsTimeSliceRow {
271    milestone_year: u32,
272    run_description: String,
273    asset_id: Option<AssetID>,
274    process_id: ProcessID,
275    region_id: RegionID,
276    time_slice: TimeSliceID,
277    activity: Activity,
278    activity_coefficient: MoneyPerActivity,
279    demand: Flow,
280    unmet_demand: Flow,
281}
282
283/// For writing extra debug information about the model
284struct DebugDataWriter {
285    context: Option<String>,
286    commodity_balance_duals_writer: csv::Writer<File>,
287    unmet_demand_writer: csv::Writer<File>,
288    solver_values_writer: csv::Writer<File>,
289    appraisal_results_writer: csv::Writer<File>,
290    appraisal_results_time_slice_writer: csv::Writer<File>,
291    dispatch_asset_writer: csv::Writer<File>,
292}
293
294impl DebugDataWriter {
295    /// Open CSV files to write debug info to
296    ///
297    /// # Arguments
298    ///
299    /// * `output_path` - Folder where files will be saved
300    fn create(output_path: &Path) -> Result<Self> {
301        let new_writer = |file_name| {
302            let file_path = output_path.join(file_name);
303            csv::Writer::from_path(file_path)
304        };
305
306        Ok(Self {
307            context: None,
308            commodity_balance_duals_writer: new_writer(COMMODITY_BALANCE_DUALS_FILE_NAME)?,
309            unmet_demand_writer: new_writer(UNMET_DEMAND_FILE_NAME)?,
310            solver_values_writer: new_writer(SOLVER_VALUES_FILE_NAME)?,
311            appraisal_results_writer: new_writer(APPRAISAL_RESULTS_FILE_NAME)?,
312            appraisal_results_time_slice_writer: new_writer(
313                APPRAISAL_RESULTS_TIME_SLICE_FILE_NAME,
314            )?,
315            dispatch_asset_writer: new_writer(ACTIVITY_ASSET_DISPATCH)?,
316        })
317    }
318
319    /// Prepend the current context to the run description
320    fn with_context(&self, run_description: &str) -> String {
321        if let Some(context) = &self.context {
322            format!("{context}; {run_description}")
323        } else {
324            run_description.to_string()
325        }
326    }
327
328    /// Write debug info about the dispatch optimisation
329    fn write_dispatch_debug_info(
330        &mut self,
331        milestone_year: u32,
332        run_description: &str,
333        solution: &Solution,
334    ) -> Result<()> {
335        self.write_dispatch(
336            milestone_year,
337            run_description,
338            solution.iter_activity(),
339            solution.iter_activity_duals(),
340            solution.iter_column_duals(),
341        )?;
342        self.write_commodity_balance_duals(
343            milestone_year,
344            run_description,
345            solution.iter_commodity_balance_duals(),
346        )?;
347        self.write_unmet_demand(
348            milestone_year,
349            run_description,
350            solution.iter_unmet_demand(),
351        )?;
352        self.write_solver_values(milestone_year, run_description, solution.objective_value)?;
353        Ok(())
354    }
355
356    // Write activity to file
357    fn write_dispatch<'a, I, J, K>(
358        &mut self,
359        milestone_year: u32,
360        run_description: &str,
361        iter_activity: I,
362        iter_activity_duals: J,
363        iter_column_duals: K,
364    ) -> Result<()>
365    where
366        I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, Activity)>,
367        J: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
368        K: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
369    {
370        // To account for different order of entries or missing ones, we first compile data in hash map
371        type CompiledActivityData = (
372            Option<Activity>,
373            Option<MoneyPerActivity>,
374            Option<MoneyPerActivity>,
375        );
376        let mut map: IndexMap<(&AssetRef, &TimeSliceID), CompiledActivityData> = IndexMap::new();
377
378        // For the activities
379        for (asset, time_slice, activity) in iter_activity {
380            map.entry((asset, time_slice)).or_default().0 = Some(activity);
381        }
382        // The activity duals
383        for (asset, time_slice, activity_dual) in iter_activity_duals {
384            map.entry((asset, time_slice)).or_default().1 = Some(activity_dual);
385        }
386        // And the column duals
387        for (asset, time_slice, column_dual) in iter_column_duals {
388            map.entry((asset, time_slice)).or_default().2 = Some(column_dual);
389        }
390
391        for (asset, time_slice, activity, activity_dual, column_dual) in
392            map.iter()
393                .map(|(&(agent, ts), &(activity, activity_dual, column_dual))| {
394                    (agent, ts, activity, activity_dual, column_dual)
395                })
396        {
397            let row = DispatchRow {
398                milestone_year,
399                run_description: self.with_context(run_description),
400                asset_id: asset.id(),
401                group_id: asset.group_id(),
402                process_id: asset.process_id().clone(),
403                region_id: asset.region_id().clone(),
404                time_slice: time_slice.clone(),
405                activity,
406                activity_dual,
407                column_dual,
408            };
409            self.dispatch_asset_writer.serialize(row)?;
410        }
411
412        Ok(())
413    }
414
415    /// Write commodity balance duals to file
416    fn write_commodity_balance_duals<'a, I>(
417        &mut self,
418        milestone_year: u32,
419        run_description: &str,
420        iter: I,
421    ) -> Result<()>
422    where
423        I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, MoneyPerFlow)>,
424    {
425        for (commodity_id, region_id, time_slice, value) in iter {
426            let row = CommodityBalanceDualsRow {
427                milestone_year,
428                run_description: self.with_context(run_description),
429                commodity_id: commodity_id.clone(),
430                region_id: region_id.clone(),
431                time_slice: time_slice.clone(),
432                value,
433            };
434            self.commodity_balance_duals_writer.serialize(row)?;
435        }
436
437        Ok(())
438    }
439
440    /// Write unmet demand values to file
441    fn write_unmet_demand<'a, I>(
442        &mut self,
443        milestone_year: u32,
444        run_description: &str,
445        iter: I,
446    ) -> Result<()>
447    where
448        I: Iterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, Flow)>,
449    {
450        for (commodity_id, region_id, time_slice, value) in iter {
451            let row = UnmetDemandRow {
452                milestone_year,
453                run_description: self.with_context(run_description),
454                commodity_id: commodity_id.clone(),
455                region_id: region_id.clone(),
456                time_slice: time_slice.clone(),
457                value,
458            };
459            self.unmet_demand_writer.serialize(row)?;
460        }
461
462        Ok(())
463    }
464
465    /// Write additional solver output values to file
466    fn write_solver_values(
467        &mut self,
468        milestone_year: u32,
469        run_description: &str,
470        objective_value: Money,
471    ) -> Result<()> {
472        let row = SolverValuesRow {
473            milestone_year,
474            run_description: self.with_context(run_description),
475            objective_value,
476        };
477        self.solver_values_writer.serialize(row)?;
478        self.solver_values_writer.flush()?;
479
480        Ok(())
481    }
482
483    /// Write appraisal results to file
484    fn write_appraisal_results(
485        &mut self,
486        milestone_year: u32,
487        run_description: &str,
488        appraisal_results: &[AppraisalOutput],
489    ) -> Result<()> {
490        for result in appraisal_results {
491            let row = AppraisalResultsRow {
492                milestone_year,
493                run_description: self.with_context(run_description),
494                asset_id: result.asset.id(),
495                process_id: result.asset.process_id().clone(),
496                region_id: result.asset.region_id().clone(),
497                capacity: result.capacity.total_capacity(),
498                capacity_coefficient: result.coefficients.capacity_coefficient,
499                metric: result.metric.as_ref().map(|m| m.value()),
500            };
501            self.appraisal_results_writer.serialize(row)?;
502        }
503
504        Ok(())
505    }
506
507    /// Write appraisal results to file
508    fn write_appraisal_time_slice_results(
509        &mut self,
510        milestone_year: u32,
511        run_description: &str,
512        appraisal_results: &[AppraisalOutput],
513        demand: &IndexMap<TimeSliceID, Flow>,
514    ) -> Result<()> {
515        for result in appraisal_results {
516            for (time_slice, activity) in &result.activity {
517                let activity_coefficient = result.coefficients.activity_coefficients[time_slice];
518                let demand = demand[time_slice];
519                let unmet_demand = result.unmet_demand[time_slice];
520                let row = AppraisalResultsTimeSliceRow {
521                    milestone_year,
522                    run_description: self.with_context(run_description),
523                    asset_id: result.asset.id(),
524                    process_id: result.asset.process_id().clone(),
525                    region_id: result.asset.region_id().clone(),
526                    time_slice: time_slice.clone(),
527                    activity: *activity,
528                    activity_coefficient,
529                    demand,
530                    unmet_demand,
531                };
532                self.appraisal_results_time_slice_writer.serialize(row)?;
533            }
534        }
535
536        Ok(())
537    }
538
539    /// Flush the underlying streams
540    fn flush(&mut self) -> Result<()> {
541        self.commodity_balance_duals_writer.flush()?;
542        self.unmet_demand_writer.flush()?;
543        self.solver_values_writer.flush()?;
544        self.appraisal_results_writer.flush()?;
545        self.appraisal_results_time_slice_writer.flush()?;
546        self.dispatch_asset_writer.flush()?;
547
548        Ok(())
549    }
550}
551
552/// An object for writing output data to file
553pub struct DataWriter {
554    assets: csv::Writer<File>,
555    asset_capacities: csv::Writer<File>,
556    flows: csv::Writer<File>,
557    prices: csv::Writer<File>,
558    debug: Option<DebugDataWriter>,
559}
560
561impl DataWriter {
562    /// Open CSV files to write output data to
563    ///
564    /// # Arguments
565    ///
566    /// * `output_path` - Folder where files will be saved
567    /// * `model_path` - Path to input model
568    /// * `save_debug_info` - Whether to include extra CSV files for debugging model
569    pub fn create(output_path: &Path, model_path: &Path, save_debug_info: bool) -> Result<Self> {
570        write_metadata(output_path, model_path).context("Failed to save metadata")?;
571
572        let new_writer = |file_name| {
573            let file_path = output_path.join(file_name);
574            csv::Writer::from_path(file_path)
575        };
576
577        let debug_writer = if save_debug_info {
578            // Create debug CSV files
579            Some(DebugDataWriter::create(output_path)?)
580        } else {
581            None
582        };
583
584        Ok(Self {
585            assets: new_writer(ASSETS_FILE_NAME)?,
586            asset_capacities: new_writer(ASSET_CAPACITIES_FILE_NAME)?,
587            flows: new_writer(COMMODITY_FLOWS_FILE_NAME)?,
588            prices: new_writer(COMMODITY_PRICES_FILE_NAME)?,
589            debug: debug_writer,
590        })
591    }
592
593    /// Write debug info about the dispatch optimisation
594    pub fn write_dispatch_debug_info(
595        &mut self,
596        milestone_year: u32,
597        run_description: &str,
598        solution: &Solution,
599    ) -> Result<()> {
600        if let Some(wtr) = &mut self.debug {
601            wtr.write_dispatch_debug_info(milestone_year, run_description, solution)?;
602        }
603
604        Ok(())
605    }
606
607    /// Write debug info about the investment appraisal
608    pub fn write_appraisal_debug_info(
609        &mut self,
610        milestone_year: u32,
611        run_description: &str,
612        appraisal_results: &[AppraisalOutput],
613        demand: &IndexMap<TimeSliceID, Flow>,
614    ) -> Result<()> {
615        if let Some(wtr) = &mut self.debug {
616            wtr.write_appraisal_results(milestone_year, run_description, appraisal_results)?;
617            wtr.write_appraisal_time_slice_results(
618                milestone_year,
619                run_description,
620                appraisal_results,
621                demand,
622            )?;
623        }
624
625        Ok(())
626    }
627
628    /// Append newly commissioned asset definitions to the assets CSV file.
629    ///
630    /// For divisible asset groups, a single row is emitted per group (using the parent asset's
631    /// metadata).
632    pub fn write_assets<'a, I>(&mut self, assets: I) -> Result<()>
633    where
634        I: Iterator<Item = &'a AssetRef>,
635    {
636        let mut seen_group_ids: HashSet<AssetGroupID> = HashSet::new();
637        for asset in assets {
638            if let Some(parent) = asset.parent() {
639                // Active child of a group: emit one row for the group (first child wins)
640                let group_id = asset.group_id().unwrap();
641                if seen_group_ids.insert(group_id) {
642                    self.assets.serialize(AssetRow::from_parent(parent))?;
643                }
644            } else {
645                self.assets.serialize(AssetRow::new(asset))?;
646            }
647        }
648
649        Ok(())
650    }
651
652    /// Write asset capacities for the current milestone year to a CSV file.
653    ///
654    /// This file is appended to on each invocation. For divisible asset groups, a single row is
655    /// emitted per group with the total capacity.
656    pub fn write_asset_capacities<'a, I>(&mut self, milestone_year: u32, assets: I) -> Result<()>
657    where
658        I: Iterator<Item = &'a AssetRef>,
659    {
660        let mut seen_group_ids: HashSet<AssetGroupID> = HashSet::new();
661        for asset in assets {
662            if let Some(parent) = asset.parent() {
663                let group_id = asset.group_id().unwrap();
664                if seen_group_ids.insert(group_id) {
665                    let row = AssetCapacityRow {
666                        milestone_year,
667                        asset_id: None,
668                        group_id: Some(group_id),
669                        capacity: parent.total_capacity(),
670                        num_units: parent.capacity().n_units(),
671                    };
672                    self.asset_capacities.serialize(row)?;
673                }
674            } else {
675                let row = AssetCapacityRow {
676                    milestone_year,
677                    asset_id: asset.id(),
678                    group_id: None,
679                    capacity: asset.total_capacity(),
680                    num_units: None,
681                };
682                self.asset_capacities.serialize(row)?;
683            }
684        }
685
686        Ok(())
687    }
688
689    /// Write commodity flows to a CSV file
690    pub fn write_flows(&mut self, milestone_year: u32, flow_map: &FlowMap) -> Result<()> {
691        for ((asset, commodity_id, time_slice), flow) in flow_map {
692            if asset.parent().is_some() {
693                // Skip child assets, as their flows are included in the parent asset's flow
694                continue;
695            }
696
697            let row = CommodityFlowRow {
698                milestone_year,
699                asset_id: asset.id(),
700                group_id: asset.group_id(),
701                commodity_id: commodity_id.clone(),
702                time_slice: time_slice.clone(),
703                flow: *flow,
704            };
705            self.flows.serialize(row)?;
706        }
707
708        Ok(())
709    }
710
711    /// Write commodity prices to a CSV file
712    pub fn write_prices(&mut self, milestone_year: u32, prices: &CommodityPrices) -> Result<()> {
713        for (commodity_id, region_id, time_slice, price) in prices.iter() {
714            let row = CommodityPriceRow {
715                milestone_year,
716                commodity_id: commodity_id.clone(),
717                region_id: region_id.clone(),
718                time_slice: time_slice.clone(),
719                price,
720            };
721            self.prices.serialize(row)?;
722        }
723
724        Ok(())
725    }
726
727    /// Flush the underlying streams
728    pub fn flush(&mut self) -> Result<()> {
729        self.assets.flush()?;
730        self.asset_capacities.flush()?;
731        self.flows.flush()?;
732        self.prices.flush()?;
733        if let Some(wtr) = &mut self.debug {
734            wtr.flush()?;
735        }
736
737        Ok(())
738    }
739
740    /// Add context to the debug writer
741    pub fn set_debug_context(&mut self, context: String) {
742        if let Some(wtr) = &mut self.debug {
743            wtr.context = Some(context);
744        }
745    }
746
747    /// Clear context from the debug writer
748    pub fn clear_debug_context(&mut self) {
749        if let Some(wtr) = &mut self.debug {
750            wtr.context = None;
751        }
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758    use crate::asset::AssetPool;
759    use crate::fixture::{
760        appraisal_output, asset, asset_divisible, assets, commodity_id, region_id, time_slice,
761    };
762    use crate::simulation::investment::appraisal::AppraisalOutput;
763    use crate::time_slice::TimeSliceID;
764    use indexmap::indexmap;
765    use itertools::{Itertools, assert_equal};
766    use rstest::rstest;
767    use std::iter;
768    use tempfile::tempdir;
769
770    #[rstest]
771    fn write_assets(assets: AssetPool) {
772        let dir = tempdir().unwrap();
773
774        // Write an asset
775        {
776            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
777            writer.write_assets(assets.iter()).unwrap();
778            writer.flush().unwrap();
779        }
780
781        // Read back and compare
782        let asset = assets.iter().next().unwrap();
783        let expected = AssetRow::new(asset);
784        let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
785            .unwrap()
786            .into_deserialize()
787            .try_collect()
788            .unwrap();
789        assert_equal(records, iter::once(expected));
790    }
791
792    #[rstest]
793    fn write_asset_capacities(assets: AssetPool) {
794        let milestone_year = 2020;
795        let dir = tempdir().unwrap();
796
797        // Write asset capacities
798        {
799            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
800            writer
801                .write_asset_capacities(milestone_year, assets.iter())
802                .unwrap();
803            writer.flush().unwrap();
804        }
805
806        // Read back and compare
807        let asset = assets.iter().next().unwrap();
808        let expected = AssetCapacityRow {
809            milestone_year,
810            asset_id: asset.id(),
811            group_id: None,
812            capacity: asset.total_capacity(),
813            num_units: None,
814        };
815        let records: Vec<AssetCapacityRow> =
816            csv::Reader::from_path(dir.path().join(ASSET_CAPACITIES_FILE_NAME))
817                .unwrap()
818                .into_deserialize()
819                .try_collect()
820                .unwrap();
821        assert_equal(records, iter::once(expected));
822    }
823
824    #[rstest]
825    fn write_assets_divisible_group_deduplicated(asset_divisible: Asset) {
826        let milestone_year = asset_divisible.commission_year();
827        let mut pool = AssetPool::new();
828        let mut user_assets = vec![asset_divisible.into()];
829
830        // Commission a divisible asset so the active pool contains multiple children in one group
831        let commissioned = pool
832            .commission_new(milestone_year, &mut user_assets)
833            .to_vec();
834        assert!(commissioned.len() > 1);
835
836        let dir = tempdir().unwrap();
837
838        // Write all active assets: divisible children should collapse to one group row
839        {
840            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
841            writer.write_assets(pool.iter()).unwrap();
842            writer.flush().unwrap();
843        }
844
845        // Read back and compare: we expect a single group row with parent-derived metadata
846        let records: Vec<AssetRow> = csv::Reader::from_path(dir.path().join(ASSETS_FILE_NAME))
847            .unwrap()
848            .into_deserialize()
849            .try_collect()
850            .unwrap();
851        assert_eq!(records.len(), 1);
852
853        let first_child = commissioned.first().unwrap();
854        let parent = first_child.parent().unwrap();
855        let expected = AssetRow::from_parent(parent);
856        assert_eq!(records[0], expected);
857        assert_eq!(records[0].asset_id, None);
858        assert_eq!(records[0].group_id, parent.group_id());
859    }
860
861    #[rstest]
862    fn write_asset_capacities_divisible_group_deduplicated(asset_divisible: Asset) {
863        let milestone_year = asset_divisible.commission_year();
864        let mut pool = AssetPool::new();
865        let mut user_assets = vec![asset_divisible.into()];
866
867        // Commission a divisible asset so we get several children under one parent/group
868        let commissioned = pool
869            .commission_new(milestone_year, &mut user_assets)
870            .to_vec();
871        assert!(commissioned.len() > 1);
872
873        let dir = tempdir().unwrap();
874
875        // Write capacities: divisible children should be deduplicated to one group entry
876        {
877            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
878            writer
879                .write_asset_capacities(milestone_year, pool.iter())
880                .unwrap();
881            writer.flush().unwrap();
882        }
883
884        // Read back and compare: group capacity and unit count must match the parent
885        let records: Vec<AssetCapacityRow> =
886            csv::Reader::from_path(dir.path().join(ASSET_CAPACITIES_FILE_NAME))
887                .unwrap()
888                .into_deserialize()
889                .try_collect()
890                .unwrap();
891        assert_eq!(records.len(), 1);
892
893        let first_child = commissioned.first().unwrap();
894        let parent = first_child.parent().unwrap();
895        let expected = AssetCapacityRow {
896            milestone_year,
897            asset_id: None,
898            group_id: parent.group_id(),
899            capacity: parent.total_capacity(),
900            num_units: parent.capacity().n_units(),
901        };
902        assert_eq!(records[0], expected);
903    }
904
905    #[rstest]
906    fn write_flows(assets: AssetPool, commodity_id: CommodityID, time_slice: TimeSliceID) {
907        let milestone_year = 2020;
908        let asset = assets.iter().next().unwrap();
909        let flow_map = indexmap! {
910            (asset.clone(), commodity_id.clone(), time_slice.clone()) => Flow(42.0)
911        };
912
913        // Write a flow
914        let dir = tempdir().unwrap();
915        {
916            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
917            writer.write_flows(milestone_year, &flow_map).unwrap();
918            writer.flush().unwrap();
919        }
920
921        // Read back and compare
922        let expected = CommodityFlowRow {
923            milestone_year,
924            asset_id: asset.id(),
925            group_id: None,
926            commodity_id,
927            time_slice,
928            flow: Flow(42.0),
929        };
930        let records: Vec<CommodityFlowRow> =
931            csv::Reader::from_path(dir.path().join(COMMODITY_FLOWS_FILE_NAME))
932                .unwrap()
933                .into_deserialize()
934                .try_collect()
935                .unwrap();
936        assert_equal(records, iter::once(expected));
937    }
938
939    #[rstest]
940    fn write_prices(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
941        let milestone_year = 2020;
942        let price = MoneyPerFlow(42.0);
943        let mut prices = CommodityPrices::default();
944        prices.insert(&commodity_id, &region_id, &time_slice, price);
945
946        let dir = tempdir().unwrap();
947
948        // Write a price
949        {
950            let mut writer = DataWriter::create(dir.path(), dir.path(), false).unwrap();
951            writer.write_prices(milestone_year, &prices).unwrap();
952            writer.flush().unwrap();
953        }
954
955        // Read back and compare
956        let expected = CommodityPriceRow {
957            milestone_year,
958            commodity_id,
959            region_id,
960            time_slice,
961            price,
962        };
963        let records: Vec<CommodityPriceRow> =
964            csv::Reader::from_path(dir.path().join(COMMODITY_PRICES_FILE_NAME))
965                .unwrap()
966                .into_deserialize()
967                .try_collect()
968                .unwrap();
969        assert_equal(records, iter::once(expected));
970    }
971
972    #[rstest]
973    fn write_commodity_balance_duals(
974        commodity_id: CommodityID,
975        region_id: RegionID,
976        time_slice: TimeSliceID,
977    ) {
978        let milestone_year = 2020;
979        let run_description = "test_run".to_string();
980        let value = MoneyPerFlow(0.5);
981        let dir = tempdir().unwrap();
982
983        // Write commodity balance dual
984        {
985            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
986            writer
987                .write_commodity_balance_duals(
988                    milestone_year,
989                    &run_description,
990                    iter::once((&commodity_id, &region_id, &time_slice, value)),
991                )
992                .unwrap();
993            writer.flush().unwrap();
994        }
995
996        // Read back and compare
997        let expected = CommodityBalanceDualsRow {
998            milestone_year,
999            run_description,
1000            commodity_id,
1001            region_id,
1002            time_slice,
1003            value,
1004        };
1005        let records: Vec<CommodityBalanceDualsRow> =
1006            csv::Reader::from_path(dir.path().join(COMMODITY_BALANCE_DUALS_FILE_NAME))
1007                .unwrap()
1008                .into_deserialize()
1009                .try_collect()
1010                .unwrap();
1011        assert_equal(records, iter::once(expected));
1012    }
1013
1014    #[rstest]
1015    fn write_unmet_demand(commodity_id: CommodityID, region_id: RegionID, time_slice: TimeSliceID) {
1016        let milestone_year = 2020;
1017        let run_description = "test_run".to_string();
1018        let value = Flow(0.5);
1019        let dir = tempdir().unwrap();
1020
1021        // Write unmet demand
1022        {
1023            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1024            writer
1025                .write_unmet_demand(
1026                    milestone_year,
1027                    &run_description,
1028                    iter::once((&commodity_id, &region_id, &time_slice, value)),
1029                )
1030                .unwrap();
1031            writer.flush().unwrap();
1032        }
1033
1034        // Read back and compare
1035        let expected = UnmetDemandRow {
1036            milestone_year,
1037            run_description,
1038            commodity_id,
1039            region_id,
1040            time_slice,
1041            value,
1042        };
1043        let records: Vec<UnmetDemandRow> =
1044            csv::Reader::from_path(dir.path().join(UNMET_DEMAND_FILE_NAME))
1045                .unwrap()
1046                .into_deserialize()
1047                .try_collect()
1048                .unwrap();
1049        assert_equal(records, iter::once(expected));
1050    }
1051
1052    #[rstest]
1053    fn write_dispatch(assets: AssetPool, time_slice: TimeSliceID) {
1054        let milestone_year = 2020;
1055        let run_description = "test_run".to_string();
1056        let activity = Activity(100.5);
1057        let activity_dual = MoneyPerActivity(-1.5);
1058        let column_dual = MoneyPerActivity(5.0);
1059        let dir = tempdir().unwrap();
1060        let asset = assets.iter().next().unwrap();
1061
1062        // Write activity
1063        {
1064            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1065            writer
1066                .write_dispatch(
1067                    milestone_year,
1068                    &run_description,
1069                    iter::once((asset, &time_slice, activity)),
1070                    iter::once((asset, &time_slice, activity_dual)),
1071                    iter::once((asset, &time_slice, column_dual)),
1072                )
1073                .unwrap();
1074            writer.flush().unwrap();
1075        }
1076
1077        // Read back and compare
1078        let expected = DispatchRow {
1079            milestone_year,
1080            run_description,
1081            asset_id: asset.id(),
1082            group_id: asset.group_id(),
1083            process_id: asset.process_id().clone(),
1084            region_id: asset.region_id().clone(),
1085            time_slice,
1086            activity: Some(activity),
1087            activity_dual: Some(activity_dual),
1088            column_dual: Some(column_dual),
1089        };
1090        let records: Vec<DispatchRow> =
1091            csv::Reader::from_path(dir.path().join(ACTIVITY_ASSET_DISPATCH))
1092                .unwrap()
1093                .into_deserialize()
1094                .try_collect()
1095                .unwrap();
1096        assert_equal(records, iter::once(expected));
1097    }
1098
1099    #[rstest]
1100    fn write_dispatch_with_missing_keys(assets: AssetPool, time_slice: TimeSliceID) {
1101        let milestone_year = 2020;
1102        let run_description = "test_run".to_string();
1103        let activity = Activity(100.5);
1104        let dir = tempdir().unwrap();
1105        let asset = assets.iter().next().unwrap();
1106
1107        // Write activity
1108        {
1109            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1110            writer
1111                .write_dispatch(
1112                    milestone_year,
1113                    &run_description,
1114                    iter::once((asset, &time_slice, activity)),
1115                    iter::empty::<(&AssetRef, &TimeSliceID, MoneyPerActivity)>(),
1116                    iter::empty::<(&AssetRef, &TimeSliceID, MoneyPerActivity)>(),
1117                )
1118                .unwrap();
1119            writer.flush().unwrap();
1120        }
1121
1122        // Read back and compare
1123        let expected = DispatchRow {
1124            milestone_year,
1125            run_description,
1126            asset_id: asset.id(),
1127            group_id: asset.group_id(),
1128            process_id: asset.process_id().clone(),
1129            region_id: asset.region_id().clone(),
1130            time_slice,
1131            activity: Some(activity),
1132            activity_dual: None,
1133            column_dual: None,
1134        };
1135        let records: Vec<DispatchRow> =
1136            csv::Reader::from_path(dir.path().join(ACTIVITY_ASSET_DISPATCH))
1137                .unwrap()
1138                .into_deserialize()
1139                .try_collect()
1140                .unwrap();
1141        assert_equal(records, iter::once(expected));
1142    }
1143
1144    #[rstest]
1145    fn write_solver_values() {
1146        let milestone_year = 2020;
1147        let run_description = "test_run".to_string();
1148        let objective_value = Money(1234.56);
1149        let dir = tempdir().unwrap();
1150
1151        // Write solver values
1152        {
1153            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1154            writer
1155                .write_solver_values(milestone_year, &run_description, objective_value)
1156                .unwrap();
1157            writer.flush().unwrap();
1158        }
1159
1160        // Read back and compare
1161        let expected = SolverValuesRow {
1162            milestone_year,
1163            run_description,
1164            objective_value,
1165        };
1166        let records: Vec<SolverValuesRow> =
1167            csv::Reader::from_path(dir.path().join(SOLVER_VALUES_FILE_NAME))
1168                .unwrap()
1169                .into_deserialize()
1170                .try_collect()
1171                .unwrap();
1172        assert_equal(records, iter::once(expected));
1173    }
1174
1175    #[rstest]
1176    fn write_appraisal_results(asset: Asset, appraisal_output: AppraisalOutput) {
1177        let milestone_year = 2020;
1178        let run_description = "test_run".to_string();
1179        let dir = tempdir().unwrap();
1180
1181        // Write appraisal results
1182        {
1183            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1184            writer
1185                .write_appraisal_results(milestone_year, &run_description, &[appraisal_output])
1186                .unwrap();
1187            writer.flush().unwrap();
1188        }
1189
1190        // Read back and compare
1191        let expected = AppraisalResultsRow {
1192            milestone_year,
1193            run_description,
1194            asset_id: None,
1195            process_id: asset.process_id().clone(),
1196            region_id: asset.region_id().clone(),
1197            capacity: Capacity(42.0),
1198            capacity_coefficient: MoneyPerCapacity(2.14),
1199            metric: Some(4.14),
1200        };
1201        let records: Vec<AppraisalResultsRow> =
1202            csv::Reader::from_path(dir.path().join(APPRAISAL_RESULTS_FILE_NAME))
1203                .unwrap()
1204                .into_deserialize()
1205                .try_collect()
1206                .unwrap();
1207        assert_equal(records, iter::once(expected));
1208    }
1209
1210    #[rstest]
1211    fn write_appraisal_time_slice_results(
1212        asset: Asset,
1213        appraisal_output: AppraisalOutput,
1214        time_slice: TimeSliceID,
1215    ) {
1216        let milestone_year = 2020;
1217        let run_description = "test_run".to_string();
1218        let dir = tempdir().unwrap();
1219        let demand = indexmap! {time_slice.clone() => Flow(100.0) };
1220
1221        // Write appraisal time slice results
1222        {
1223            let mut writer = DebugDataWriter::create(dir.path()).unwrap();
1224            writer
1225                .write_appraisal_time_slice_results(
1226                    milestone_year,
1227                    &run_description,
1228                    &[appraisal_output],
1229                    &demand,
1230                )
1231                .unwrap();
1232            writer.flush().unwrap();
1233        }
1234
1235        // Read back and compare
1236        let expected = AppraisalResultsTimeSliceRow {
1237            milestone_year,
1238            run_description,
1239            asset_id: None,
1240            process_id: asset.process_id().clone(),
1241            region_id: asset.region_id().clone(),
1242            time_slice: time_slice.clone(),
1243            activity: Activity(10.0),
1244            activity_coefficient: MoneyPerActivity(0.5),
1245            demand: Flow(100.0),
1246            unmet_demand: Flow(5.0),
1247        };
1248        let records: Vec<AppraisalResultsTimeSliceRow> =
1249            csv::Reader::from_path(dir.path().join(APPRAISAL_RESULTS_TIME_SLICE_FILE_NAME))
1250                .unwrap()
1251                .into_deserialize()
1252                .try_collect()
1253                .unwrap();
1254        assert_equal(records, iter::once(expected));
1255    }
1256
1257    #[test]
1258    fn create_output_directory_new_directory() {
1259        let temp_dir = tempdir().unwrap();
1260        let output_dir = temp_dir.path().join("new_output");
1261
1262        // Create a new directory should succeed and return false (no overwrite)
1263        let result = create_output_directory(&output_dir, false).unwrap();
1264        assert!(!result);
1265        assert!(output_dir.exists());
1266        assert!(output_dir.is_dir());
1267    }
1268
1269    #[test]
1270    fn create_output_directory_existing_empty_directory() {
1271        let temp_dir = tempdir().unwrap();
1272        let output_dir = temp_dir.path().join("empty_output");
1273
1274        // Create the directory first
1275        fs::create_dir(&output_dir).unwrap();
1276
1277        // Creating again should succeed and return false (no overwrite needed)
1278        let result = create_output_directory(&output_dir, false).unwrap();
1279        assert!(!result);
1280        assert!(output_dir.exists());
1281        assert!(output_dir.is_dir());
1282    }
1283
1284    #[test]
1285    fn create_output_directory_existing_with_files_no_overwrite() {
1286        let temp_dir = tempdir().unwrap();
1287        let output_dir = temp_dir.path().join("output_with_files");
1288
1289        // Create directory with a file
1290        fs::create_dir(&output_dir).unwrap();
1291        fs::write(output_dir.join("existing_file.txt"), "some content").unwrap();
1292
1293        // Should fail when allow_overwrite is false
1294        let result = create_output_directory(&output_dir, false);
1295        assert!(result.is_err());
1296        assert!(
1297            result
1298                .unwrap_err()
1299                .to_string()
1300                .contains("Output folder already exists")
1301        );
1302    }
1303
1304    #[test]
1305    fn create_output_directory_existing_with_files_allow_overwrite() {
1306        let temp_dir = tempdir().unwrap();
1307        let output_dir = temp_dir.path().join("output_with_files");
1308
1309        // Create directory with a file
1310        fs::create_dir(&output_dir).unwrap();
1311        let file_path = output_dir.join("existing_file.txt");
1312        fs::write(&file_path, "some content").unwrap();
1313
1314        // Should succeed when allow_overwrite is true and return true (overwrite occurred)
1315        let result = create_output_directory(&output_dir, true).unwrap();
1316        assert!(result);
1317        assert!(output_dir.exists());
1318        assert!(output_dir.is_dir());
1319        assert!(!file_path.exists()); // File should be gone
1320    }
1321
1322    #[test]
1323    fn create_output_directory_nested_path() {
1324        let temp_dir = tempdir().unwrap();
1325        let output_dir = temp_dir.path().join("nested").join("path").join("output");
1326
1327        // Should create nested directories and return false (no overwrite)
1328        let result = create_output_directory(&output_dir, false).unwrap();
1329        assert!(!result);
1330        assert!(output_dir.exists());
1331        assert!(output_dir.is_dir());
1332    }
1333
1334    #[test]
1335    fn create_output_directory_existing_subdirs_with_files_allow_overwrite() {
1336        let temp_dir = tempdir().unwrap();
1337        let output_dir = temp_dir.path().join("output_with_subdirs");
1338
1339        // Create directory structure with files
1340        fs::create_dir_all(output_dir.join("subdir")).unwrap();
1341        fs::write(output_dir.join("file1.txt"), "content1").unwrap();
1342        fs::write(output_dir.join("subdir").join("file2.txt"), "content2").unwrap();
1343
1344        // Should succeed when allow_overwrite is true and return true (overwrite occurred)
1345        let result = create_output_directory(&output_dir, true).unwrap();
1346        assert!(result);
1347        assert!(output_dir.exists());
1348        assert!(output_dir.is_dir());
1349        // All previous content should be gone
1350        assert!(!output_dir.join("file1.txt").exists());
1351        assert!(!output_dir.join("subdir").exists());
1352    }
1353}