muse2/model/
parameters.rs

1//! Defines the `ModelParameters` struct, which represents the contents of `model.toml`.
2use crate::ISSUES_URL;
3use crate::asset::check_capacity_valid_for_asset;
4use crate::input::{
5    deserialise_proportion_nonzero, input_err_msg, is_sorted_and_unique, read_toml,
6};
7use crate::units::{Capacity, Dimensionless, MoneyPerFlow};
8use anyhow::{Context, Result, ensure};
9use log::warn;
10use serde::Deserialize;
11use serde_string_enum::DeserializeLabeledStringEnum;
12use std::path::Path;
13use std::sync::OnceLock;
14
15const MODEL_PARAMETERS_FILE_NAME: &str = "model.toml";
16
17/// The name of the option used to gate other, broken options.
18pub const ALLOW_BROKEN_OPTION_NAME: &str = "please_give_me_broken_results";
19
20/// Whether broken options have been enabled by an option in the model config file
21static BROKEN_OPTIONS_ALLOWED: OnceLock<bool> = OnceLock::new();
22
23/// Whether broken model options have been enabled in the config file or not
24pub fn broken_model_options_allowed() -> bool {
25    *BROKEN_OPTIONS_ALLOWED
26        .get()
27        .expect("Broken options flag not set")
28}
29
30macro_rules! define_unit_param_default {
31    ($name:ident, $type: ty, $value: expr) => {
32        fn $name() -> $type {
33            <$type>::new($value)
34        }
35    };
36}
37
38macro_rules! define_param_default {
39    ($name:ident, $type: ty, $value: expr) => {
40        fn $name() -> $type {
41            $value
42        }
43    };
44}
45
46define_unit_param_default!(default_candidate_asset_capacity, Capacity, 0.0001);
47define_unit_param_default!(default_capacity_limit_factor, Dimensionless, 0.1);
48define_unit_param_default!(default_value_of_lost_load, MoneyPerFlow, 1e9);
49define_unit_param_default!(default_price_tolerance, Dimensionless, 1e-6);
50define_param_default!(default_max_ironing_out_iterations, u32, 10);
51
52/// Model parameters as defined in the `model.toml` file.
53///
54/// NOTE: If you add or change a field in this struct, you must also update the schema in
55/// `schemas/input/model.yaml`.
56#[derive(Debug, Deserialize, PartialEq)]
57pub struct ModelParameters {
58    /// Milestone years
59    pub milestone_years: Vec<u32>,
60    /// Allow known-broken options to be enabled.
61    #[serde(default, rename = "please_give_me_broken_results")] // Can't use constant here :-(
62    pub allow_broken_options: bool,
63    /// The (small) value of capacity given to candidate assets.
64    ///
65    /// Don't change unless you know what you're doing.
66    #[serde(default = "default_candidate_asset_capacity")]
67    pub candidate_asset_capacity: Capacity,
68    /// Defines the strategy used for calculating commodity prices
69    #[serde(default)]
70    pub pricing_strategy: PricingStrategy,
71    /// Affects the maximum capacity that can be given to a newly created asset.
72    ///
73    /// It is the proportion of maximum capacity that could be required across time slices.
74    #[serde(default = "default_capacity_limit_factor")]
75    #[serde(deserialize_with = "deserialise_proportion_nonzero")]
76    pub capacity_limit_factor: Dimensionless,
77    /// The cost applied to unmet demand.
78    ///
79    /// Currently this only applies to the LCOX appraisal.
80    #[serde(default = "default_value_of_lost_load")]
81    pub value_of_lost_load: MoneyPerFlow,
82    /// The maximum number of iterations to run the "ironing out" step of agent investment for
83    #[serde(default = "default_max_ironing_out_iterations")]
84    pub max_ironing_out_iterations: u32,
85    /// The relative tolerance for price convergence in the ironing out loop
86    #[serde(default = "default_price_tolerance")]
87    pub price_tolerance: Dimensionless,
88}
89
90/// The strategy used for calculating commodity prices
91#[derive(DeserializeLabeledStringEnum, Debug, PartialEq, Default)]
92pub enum PricingStrategy {
93    /// Take commodity prices directly from the shadow prices
94    #[default]
95    #[string = "shadow_prices"]
96    ShadowPrices,
97    /// Adjust shadow prices for scarcity
98    #[string = "scarcity_adjusted"]
99    ScarcityAdjusted,
100}
101
102/// Check that the `milestone_years` parameter is valid
103fn check_milestone_years(years: &[u32]) -> Result<()> {
104    ensure!(!years.is_empty(), "`milestone_years` is empty");
105
106    ensure!(
107        is_sorted_and_unique(years),
108        "`milestone_years` must be composed of unique values in order"
109    );
110
111    Ok(())
112}
113
114/// Check that the `value_of_lost_load` parameter is valid
115fn check_value_of_lost_load(value: MoneyPerFlow) -> Result<()> {
116    ensure!(
117        value.is_finite() && value > MoneyPerFlow(0.0),
118        "value_of_lost_load must be a finite number greater than zero"
119    );
120
121    Ok(())
122}
123
124/// Check that the `max_ironing_out_iterations` parameter is valid
125fn check_max_ironing_out_iterations(value: u32) -> Result<()> {
126    ensure!(value > 0, "max_ironing_out_iterations cannot be zero");
127
128    Ok(())
129}
130
131/// Check the `price_tolerance` parameter is valid
132fn check_price_tolerance(value: Dimensionless) -> Result<()> {
133    ensure!(
134        value.is_finite() && value >= Dimensionless(0.0),
135        "price_tolerance must be a finite number greater than or equal to zero"
136    );
137
138    Ok(())
139}
140
141impl ModelParameters {
142    /// Read a model file from the specified directory.
143    ///
144    /// # Arguments
145    ///
146    /// * `model_dir` - Folder containing model configuration files
147    ///
148    /// # Returns
149    ///
150    /// The model file contents as a [`ModelParameters`] struct or an error if the file is invalid
151    pub fn from_path<P: AsRef<Path>>(model_dir: P) -> Result<ModelParameters> {
152        let file_path = model_dir.as_ref().join(MODEL_PARAMETERS_FILE_NAME);
153        let model_params: ModelParameters = read_toml(&file_path)?;
154
155        // Set flag signalling whether broken model options are allowed or not
156        BROKEN_OPTIONS_ALLOWED
157            .set(model_params.allow_broken_options)
158            .unwrap(); // Will only fail if there is a race condition, which shouldn't happen
159
160        model_params
161            .validate()
162            .with_context(|| input_err_msg(file_path))?;
163
164        Ok(model_params)
165    }
166
167    /// Validate parameters after reading in file
168    fn validate(&self) -> Result<()> {
169        if self.allow_broken_options {
170            warn!(
171                "!!! You've enabled the {ALLOW_BROKEN_OPTION_NAME} option. !!!\n\
172                I see you like to live dangerously 😈. This option should ONLY be used by \
173                developers as it can cause peculiar behaviour that breaks things. NEVER enable it \
174                for results you actually care about or want to publish. You have been warned!"
175            );
176        }
177
178        // milestone_years
179        check_milestone_years(&self.milestone_years)?;
180
181        // pricing_strategy
182        if self.pricing_strategy == PricingStrategy::ScarcityAdjusted {
183            ensure!(
184                self.allow_broken_options,
185                "The pricing strategy is set to 'scarcity_adjusted', which is known to be broken. \
186                If you are sure that you want to enable it anyway, you need to set the \
187                {ALLOW_BROKEN_OPTION_NAME} option to true."
188            );
189
190            warn!(
191                "The pricing strategy is set to 'scarcity_adjusted'. Commodity prices may be \
192                incorrect if assets have more than one output commodity. See: {ISSUES_URL}/677"
193            );
194        }
195
196        // capacity_limit_factor already validated with deserialise_proportion_nonzero
197
198        // candidate_asset_capacity
199        check_capacity_valid_for_asset(self.candidate_asset_capacity)
200            .context("Invalid value for candidate_asset_capacity")?;
201
202        // value_of_lost_load
203        check_value_of_lost_load(self.value_of_lost_load)?;
204
205        // max_ironing_out_iterations
206        check_max_ironing_out_iterations(self.max_ironing_out_iterations)?;
207
208        // price_tolerance
209        check_price_tolerance(self.price_tolerance)?;
210
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use rstest::rstest;
219    use std::fmt::Display;
220    use std::fs::File;
221    use std::io::Write;
222    use tempfile::tempdir;
223
224    /// Helper function to assert validation result based on expected validity
225    fn assert_validation_result<T, U: Display>(
226        result: Result<T>,
227        expected_valid: bool,
228        value: U,
229        expected_error_fragment: &str,
230    ) {
231        if expected_valid {
232            assert!(
233                result.is_ok(),
234                "Expected value {} to be valid, but got error: {:?}",
235                value,
236                result.err()
237            );
238        } else {
239            assert!(
240                result.is_err(),
241                "Expected value {value} to be invalid, but it was accepted",
242            );
243            let error_message = result.err().unwrap().to_string();
244            assert!(
245                error_message.contains(expected_error_fragment),
246                "Error message should mention the validation constraint, got: {error_message}",
247            );
248        }
249    }
250
251    #[test]
252    fn test_check_milestone_years() {
253        // Valid
254        assert!(check_milestone_years(&[1]).is_ok());
255        assert!(check_milestone_years(&[1, 2]).is_ok());
256
257        // Invalid
258        assert!(check_milestone_years(&[]).is_err());
259        assert!(check_milestone_years(&[1, 1]).is_err());
260        assert!(check_milestone_years(&[2, 1]).is_err());
261    }
262
263    #[test]
264    fn test_model_params_from_path() {
265        let dir = tempdir().unwrap();
266        {
267            let mut file = File::create(dir.path().join(MODEL_PARAMETERS_FILE_NAME)).unwrap();
268            writeln!(file, "milestone_years = [2020, 2100]").unwrap();
269        }
270
271        let model_params = ModelParameters::from_path(dir.path()).unwrap();
272        assert_eq!(model_params.milestone_years, [2020, 2100]);
273    }
274
275    #[rstest]
276    #[case(1.0, true)] // Valid positive value
277    #[case(1e-10, true)] // Valid very small positive value
278    #[case(1e9, true)] // Valid large value (default)
279    #[case(f64::MAX, true)] // Valid maximum finite value
280    #[case(0.0, false)] // Invalid: exactly zero
281    #[case(-1.0, false)] // Invalid: negative value
282    #[case(-1e-10, false)] // Invalid: very small negative value
283    #[case(f64::INFINITY, false)] // Invalid: infinite value
284    #[case(f64::NEG_INFINITY, false)] // Invalid: negative infinite value
285    #[case(f64::NAN, false)] // Invalid: NaN value
286    fn test_check_value_of_lost_load(#[case] value: f64, #[case] expected_valid: bool) {
287        let money_per_flow = MoneyPerFlow::new(value);
288        let result = check_value_of_lost_load(money_per_flow);
289
290        assert_validation_result(
291            result,
292            expected_valid,
293            value,
294            "value_of_lost_load must be a finite number greater than zero",
295        );
296    }
297
298    #[rstest]
299    #[case(1, true)] // Valid minimum value
300    #[case(10, true)] // Valid default value
301    #[case(100, true)] // Valid large value
302    #[case(u32::MAX, true)] // Valid maximum value
303    #[case(0, false)] // Invalid: zero
304    fn test_check_max_ironing_out_iterations(#[case] value: u32, #[case] expected_valid: bool) {
305        let result = check_max_ironing_out_iterations(value);
306
307        assert_validation_result(
308            result,
309            expected_valid,
310            value,
311            "max_ironing_out_iterations cannot be zero",
312        );
313    }
314
315    #[rstest]
316    #[case(0.0, true)] // Valid minimum value (exactly zero)
317    #[case(1e-10, true)] // Valid very small positive value
318    #[case(1e-6, true)] // Valid default value
319    #[case(1.0, true)] // Valid larger value
320    #[case(f64::MAX, true)] // Valid maximum finite value
321    #[case(-1e-10, false)] // Invalid: negative value
322    #[case(-1.0, false)] // Invalid: negative value
323    #[case(f64::INFINITY, false)] // Invalid: infinite value
324    #[case(f64::NEG_INFINITY, false)] // Invalid: negative infinite value
325    #[case(f64::NAN, false)] // Invalid: NaN value
326    fn test_check_price_tolerance(#[case] value: f64, #[case] expected_valid: bool) {
327        let dimensionless = Dimensionless::new(value);
328        let result = check_price_tolerance(dimensionless);
329
330        assert_validation_result(
331            result,
332            expected_valid,
333            value,
334            "price_tolerance must be a finite number greater than or equal to zero",
335        );
336    }
337}