1use crate::asset::AssetRef;
3use crate::commodity::CommodityID;
4use crate::model::{Model, PricingStrategy};
5use crate::region::RegionID;
6use crate::simulation::optimisation::Solution;
7use crate::time_slice::{TimeSliceID, TimeSliceInfo};
8use crate::units::{Dimensionless, MoneyPerActivity, MoneyPerFlow, Year};
9use std::collections::{BTreeMap, HashMap, btree_map};
10
11pub fn calculate_prices(model: &Model, solution: &Solution) -> CommodityPrices {
21 let shadow_prices = CommodityPrices::from_iter(solution.iter_commodity_balance_duals());
22 match model.parameters.pricing_strategy {
23 PricingStrategy::ShadowPrices => shadow_prices,
25 PricingStrategy::ScarcityAdjusted => shadow_prices
27 .clone()
28 .with_scarcity_adjustment(solution.iter_activity_duals()),
29 }
30}
31
32#[derive(Default, Clone)]
34pub struct CommodityPrices(BTreeMap<(CommodityID, RegionID, TimeSliceID), MoneyPerFlow>);
35
36impl CommodityPrices {
37 fn with_scarcity_adjustment<'a, I>(mut self, activity_duals: I) -> Self
43 where
44 I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
45 {
46 let highest_duals = get_highest_activity_duals(activity_duals);
47
48 for (key, highest) in &highest_duals {
51 if let Some(price) = self.0.get_mut(key) {
52 *price += MoneyPerFlow(highest.value());
54 }
55 }
56
57 self
58 }
59
60 pub fn extend<T>(&mut self, iter: T)
62 where
63 T: IntoIterator<Item = ((CommodityID, RegionID, TimeSliceID), MoneyPerFlow)>,
64 {
65 self.0.extend(iter);
66 }
67
68 pub fn insert(
70 &mut self,
71 commodity_id: &CommodityID,
72 region_id: &RegionID,
73 time_slice: &TimeSliceID,
74 price: MoneyPerFlow,
75 ) {
76 let key = (commodity_id.clone(), region_id.clone(), time_slice.clone());
77 self.0.insert(key, price);
78 }
79
80 pub fn iter(
86 &self,
87 ) -> impl Iterator<Item = (&CommodityID, &RegionID, &TimeSliceID, MoneyPerFlow)> {
88 self.0
89 .iter()
90 .map(|((commodity_id, region_id, ts), price)| (commodity_id, region_id, ts, *price))
91 }
92
93 pub fn get(
95 &self,
96 commodity_id: &CommodityID,
97 region_id: &RegionID,
98 time_slice: &TimeSliceID,
99 ) -> Option<MoneyPerFlow> {
100 self.0
101 .get(&(commodity_id.clone(), region_id.clone(), time_slice.clone()))
102 .copied()
103 }
104
105 pub fn keys(&self) -> btree_map::Keys<'_, (CommodityID, RegionID, TimeSliceID), MoneyPerFlow> {
107 self.0.keys()
108 }
109
110 pub fn remove(
112 &mut self,
113 commodity_id: &CommodityID,
114 region_id: &RegionID,
115 time_slice: &TimeSliceID,
116 ) -> Option<MoneyPerFlow> {
117 self.0
118 .remove(&(commodity_id.clone(), region_id.clone(), time_slice.clone()))
119 }
120
121 fn time_slice_weighted_averages(
129 &self,
130 time_slice_info: &TimeSliceInfo,
131 ) -> HashMap<(CommodityID, RegionID), MoneyPerFlow> {
132 let mut weighted_prices = HashMap::new();
133
134 for ((commodity_id, region_id, time_slice_id), price) in &self.0 {
135 let weight = time_slice_info.time_slices[time_slice_id] / Year(1.0);
137 let key = (commodity_id.clone(), region_id.clone());
138 *weighted_prices.entry(key).or_default() += *price * weight;
139 }
140
141 weighted_prices
142 }
143
144 pub fn within_tolerance_weighted(
154 &self,
155 other: &Self,
156 tolerance: Dimensionless,
157 time_slice_info: &TimeSliceInfo,
158 ) -> bool {
159 let self_averages = self.time_slice_weighted_averages(time_slice_info);
160 let other_averages = other.time_slice_weighted_averages(time_slice_info);
161
162 for (key, &price) in &self_averages {
163 let other_price = other_averages[key];
164 let abs_diff = (price - other_price).abs();
165
166 if price == MoneyPerFlow(0.0) {
168 if other_price != MoneyPerFlow(0.0) {
170 return false;
171 }
172 } else if abs_diff / price.abs() > tolerance {
174 return false;
175 }
176 }
177 true
178 }
179}
180
181impl<'a> FromIterator<(&'a CommodityID, &'a RegionID, &'a TimeSliceID, MoneyPerFlow)>
182 for CommodityPrices
183{
184 fn from_iter<I>(iter: I) -> Self
185 where
186 I: IntoIterator<Item = (&'a CommodityID, &'a RegionID, &'a TimeSliceID, MoneyPerFlow)>,
187 {
188 let map = iter
189 .into_iter()
190 .map(|(commodity_id, region_id, time_slice, price)| {
191 (
192 (commodity_id.clone(), region_id.clone(), time_slice.clone()),
193 price,
194 )
195 })
196 .collect();
197 CommodityPrices(map)
198 }
199}
200
201impl IntoIterator for CommodityPrices {
202 type Item = ((CommodityID, RegionID, TimeSliceID), MoneyPerFlow);
203 type IntoIter =
204 std::collections::btree_map::IntoIter<(CommodityID, RegionID, TimeSliceID), MoneyPerFlow>;
205
206 fn into_iter(self) -> Self::IntoIter {
207 self.0.into_iter()
208 }
209}
210
211fn get_highest_activity_duals<'a, I>(
212 activity_duals: I,
213) -> HashMap<(CommodityID, RegionID, TimeSliceID), MoneyPerActivity>
214where
215 I: Iterator<Item = (&'a AssetRef, &'a TimeSliceID, MoneyPerActivity)>,
216{
217 let mut highest_duals = HashMap::new();
219 for (asset, time_slice, dual) in activity_duals {
220 for flow in asset.iter_flows().filter(|flow| flow.is_output()) {
222 highest_duals
224 .entry((
225 flow.commodity.id.clone(),
226 asset.region_id().clone(),
227 time_slice.clone(),
228 ))
229 .and_modify(|current_dual| {
230 if dual > *current_dual {
231 *current_dual = dual;
232 }
233 })
234 .or_insert(dual);
235 }
236 }
237
238 highest_duals
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::commodity::CommodityID;
245 use crate::fixture::{commodity_id, region_id, time_slice, time_slice_info};
246 use crate::region::RegionID;
247 use crate::time_slice::TimeSliceID;
248 use rstest::rstest;
249
250 #[rstest]
251 #[case(MoneyPerFlow(100.0), MoneyPerFlow(100.0), Dimensionless(0.0), true)] #[case(MoneyPerFlow(100.0), MoneyPerFlow(105.0), Dimensionless(0.1), true)] #[case(MoneyPerFlow(-100.0), MoneyPerFlow(-105.0), Dimensionless(0.1), true)] #[case(MoneyPerFlow(0.0), MoneyPerFlow(0.0), Dimensionless(0.1), true)] #[case(MoneyPerFlow(100.0), MoneyPerFlow(105.0), Dimensionless(0.01), false)] #[case(MoneyPerFlow(100.0), MoneyPerFlow(-105.0), Dimensionless(0.1), false)] #[case(MoneyPerFlow(0.0), MoneyPerFlow(10.0), Dimensionless(0.1), false)] #[case(MoneyPerFlow(0.0), MoneyPerFlow(-10.0), Dimensionless(0.1), false)] #[case(MoneyPerFlow(10.0), MoneyPerFlow(0.0), Dimensionless(0.1), false)] #[case(MoneyPerFlow(-10.0), MoneyPerFlow(0.0), Dimensionless(0.1), false)] fn test_within_tolerance_scenarios(
262 #[case] price1: MoneyPerFlow,
263 #[case] price2: MoneyPerFlow,
264 #[case] tolerance: Dimensionless,
265 #[case] expected: bool,
266 time_slice_info: TimeSliceInfo,
267 time_slice: TimeSliceID,
268 ) {
269 let mut prices1 = CommodityPrices::default();
270 let mut prices2 = CommodityPrices::default();
271
272 let commodity = CommodityID::new("test_commodity");
274 let region = RegionID::new("test_region");
275 prices1.insert(&commodity, ®ion, &time_slice, price1);
276 prices2.insert(&commodity, ®ion, &time_slice, price2);
277
278 assert_eq!(
279 prices1.within_tolerance_weighted(&prices2, tolerance, &time_slice_info),
280 expected
281 );
282 }
283
284 #[rstest]
285 fn test_time_slice_weighted_averages(
286 commodity_id: CommodityID,
287 region_id: RegionID,
288 time_slice_info: TimeSliceInfo,
289 time_slice: TimeSliceID,
290 ) {
291 let mut prices = CommodityPrices::default();
292
293 prices.insert(&commodity_id, ®ion_id, &time_slice, MoneyPerFlow(100.0));
295
296 let averages = prices.time_slice_weighted_averages(&time_slice_info);
297
298 assert_eq!(averages[&(commodity_id, region_id)], MoneyPerFlow(100.0));
300 }
301}