1use anyhow::{Context, Result};
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::fmt::Display;
7use std::hash::Hash;
8
9pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
11impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
12
13macro_rules! define_id_type {
14 ($name:ident) => {
15 #[derive(
16 Clone, std::hash::Hash, PartialOrd, Ord, PartialEq, Eq, Debug, serde::Serialize,
17 )]
18 pub struct $name(pub std::rc::Rc<str>);
20
21 impl std::borrow::Borrow<str> for $name {
22 fn borrow(&self) -> &str {
23 &self.0
24 }
25 }
26
27 impl std::fmt::Display for $name {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{}", self.0)
30 }
31 }
32
33 impl From<&str> for $name {
34 fn from(s: &str) -> Self {
35 $name(std::rc::Rc::from(s))
36 }
37 }
38
39 impl From<String> for $name {
40 fn from(s: String) -> Self {
41 $name(std::rc::Rc::from(s))
42 }
43 }
44
45 impl<'de> serde::Deserialize<'de> for $name {
46 fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
47 where
48 D: serde::Deserializer<'de>,
49 {
50 use serde::de::Error;
51
52 let id: String = serde::Deserialize::deserialize(deserialiser)?;
53 let id = id.trim();
54 if id.is_empty() {
55 return Err(D::Error::custom("IDs cannot be empty"));
56 }
57
58 const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
59 for forbidden in FORBIDDEN_IDS.iter() {
60 if id.eq_ignore_ascii_case(forbidden) {
61 return Err(D::Error::custom(format!(
62 "'{id}' is an invalid value for an ID"
63 )));
64 }
65 }
66
67 Ok(id.into())
68 }
69 }
70
71 impl $name {
72 pub fn new(id: &str) -> Self {
74 $name(std::rc::Rc::from(id))
75 }
76 }
77 };
78}
79pub(crate) use define_id_type;
80
81#[cfg(test)]
82define_id_type!(GenericID);
83
84pub trait HasID<ID: IDLike> {
86 fn get_id(&self) -> &ID;
88}
89
90macro_rules! define_id_getter {
92 ($t:ty, $id_ty:ty) => {
93 impl crate::id::HasID<$id_ty> for $t {
94 fn get_id(&self) -> &$id_ty {
95 &self.id
96 }
97 }
98 };
99}
100pub(crate) use define_id_getter;
101
102pub trait IDCollection<ID: IDLike> {
104 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
114}
115
116macro_rules! define_id_methods {
117 () => {
118 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
119 let found = self
120 .get(id.borrow())
121 .with_context(|| format!("Unknown ID {id} found"))?;
122 Ok(found)
123 }
124 };
125}
126
127impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
128 define_id_methods!();
129}
130
131impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
132 define_id_methods!();
133}
134
135impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
136 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
137 let (found, _) = self
138 .get_key_value(id.borrow())
139 .with_context(|| format!("Unknown ID {id} found"))?;
140 Ok(found)
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use rstest::rstest;
148
149 use serde::Deserialize;
150
151 #[derive(Debug, Deserialize)]
152 struct Record {
153 id: GenericID,
154 }
155
156 fn deserialise_id(id: &str) -> Result<Record> {
157 Ok(toml::from_str(&format!("id = \"{id}\""))?)
158 }
159
160 #[rstest]
161 #[case("commodity1")]
162 #[case("some commodity")]
163 #[case("PROCESS")]
164 #[case("café")] fn test_deserialise_id_valid(#[case] id: &str) {
166 assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
167 }
168
169 #[rstest]
170 #[case("")]
171 #[case("all")]
172 #[case("annual")]
173 #[case("ALL")]
174 #[case(" ALL ")]
175 fn test_deserialise_id_invalid(#[case] id: &str) {
176 assert!(deserialise_id(id).is_err());
177 }
178}