1use anyhow::Result;
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::error::Error;
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11pub trait ID: Eq + Hash + Borrow<str> + Clone + Display + Debug + From<String> {
13 fn get_type_name() -> &'static str;
15}
16
17macro_rules! define_id_type {
18 ($name:ident, $type_name:expr) => {
19 #[derive(
20 Clone,
21 derive_more::Display,
22 std::hash::Hash,
23 PartialOrd,
24 Ord,
25 PartialEq,
26 Eq,
27 Debug,
28 serde::Serialize,
29 )]
30 pub struct $name(pub std::rc::Rc<str>);
32
33 impl std::borrow::Borrow<str> for $name {
34 fn borrow(&self) -> &str {
35 &self.0
36 }
37 }
38
39 impl From<&str> for $name {
40 fn from(s: &str) -> Self {
41 $name(std::rc::Rc::from(s))
42 }
43 }
44
45 impl From<String> for $name {
46 fn from(s: String) -> Self {
47 $name(std::rc::Rc::from(s))
48 }
49 }
50
51 impl<'de> serde::Deserialize<'de> for $name {
52 fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
53 where
54 D: serde::Deserializer<'de>,
55 {
56 use serde::de::Error;
57 const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
58
59 let id: String = serde::Deserialize::deserialize(deserialiser)?;
60 let id = id.trim();
61 if id.is_empty() {
62 return Err(D::Error::custom("IDs cannot be empty"));
63 }
64
65 for forbidden in FORBIDDEN_IDS.iter() {
66 if id.eq_ignore_ascii_case(forbidden) {
67 return Err(D::Error::custom(format!(
68 "'{id}' is an invalid value for an ID"
69 )));
70 }
71 }
72
73 Ok(id.into())
74 }
75 }
76
77 impl crate::id::ID for $name {
78 fn get_type_name() -> &'static str {
79 $type_name
80 }
81 }
82
83 impl $name {
84 pub fn new(id: &str) -> Self {
86 $name(std::rc::Rc::from(id))
87 }
88 }
89 };
90}
91pub(crate) use define_id_type;
92
93#[cfg(test)]
94define_id_type!(GenericID, "generic ID");
95
96pub trait HasID<T: ID> {
98 fn get_id(&self) -> &T;
100}
101
102macro_rules! define_id_getter {
104 ($t:ty, $id_ty:ty) => {
105 impl crate::id::HasID<$id_ty> for $t {
106 fn get_id(&self) -> &$id_ty {
107 &self.id
108 }
109 }
110 };
111}
112pub(crate) use define_id_getter;
113
114#[derive(Debug)]
116pub struct MissingIDError<T: ID> {
117 missing_id: String,
118 _phantom: PhantomData<fn() -> T>,
119}
120
121impl<T: ID> MissingIDError<T> {
122 pub fn new(missing_id: &str) -> MissingIDError<T> {
124 MissingIDError::<T> {
125 missing_id: missing_id.to_string(),
126 _phantom: std::marker::PhantomData,
127 }
128 }
129}
130
131impl<T: ID> Display for MissingIDError<T> {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "Unknown {} '{}'", T::get_type_name(), self.missing_id)
134 }
135}
136
137impl<T: ID> Error for MissingIDError<T> {}
138
139pub trait IDCollection<T: ID> {
141 fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>>;
151}
152
153macro_rules! define_id_methods {
154 () => {
155 fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>> {
156 self.get(id.borrow())
157 .ok_or_else(|| MissingIDError::new(id.borrow()))
158 }
159 };
160}
161
162impl<T: ID> IDCollection<T> for HashSet<T> {
163 define_id_methods!();
164}
165
166impl<T: ID> IDCollection<T> for IndexSet<T> {
167 define_id_methods!();
168}
169
170impl<T: ID, V> IDCollection<T> for IndexMap<T, V> {
171 fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>> {
172 let (found, _) = self
173 .get_key_value(id.borrow())
174 .ok_or_else(|| MissingIDError::new(id.borrow()))?;
175 Ok(found)
176 }
177}
178
179pub trait GetIDValue<K: ID, V> {
181 fn get_id_value<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<(&K, &V), MissingIDError<K>>;
183}
184
185impl<K: ID + Borrow<str>, V> GetIDValue<K, V> for IndexMap<K, V> {
186 fn get_id_value<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<(&K, &V), MissingIDError<K>> {
187 self.get_key_value(id.borrow())
188 .ok_or_else(|| MissingIDError::new(id.borrow()))
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use rstest::rstest;
196
197 use serde::Deserialize;
198
199 #[derive(Debug, Deserialize)]
200 struct Record {
201 id: GenericID,
202 }
203
204 fn deserialise_id(id: &str) -> Result<Record> {
205 Ok(toml::from_str(&format!("id = \"{id}\""))?)
206 }
207
208 #[rstest]
209 #[case("commodity1")]
210 #[case("some commodity")]
211 #[case("PROCESS")]
212 #[case("café")] fn deserialise_id_valid(#[case] id: &str) {
214 assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
215 }
216
217 #[rstest]
218 #[case("")]
219 #[case("all")]
220 #[case("annual")]
221 #[case("ALL")]
222 #[case(" ALL ")]
223 fn deserialise_id_invalid(#[case] id: &str) {
224 deserialise_id(id).unwrap_err();
225 }
226}