1use anyhow::{Context, Result};
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::fmt::Display;
7use std::hash::Hash;
8
9pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
11impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
12
13macro_rules! define_id_type {
14 ($name:ident) => {
15 #[derive(
16 Clone,
17 derive_more::Display,
18 std::hash::Hash,
19 PartialOrd,
20 Ord,
21 PartialEq,
22 Eq,
23 Debug,
24 serde::Serialize,
25 )]
26 pub struct $name(pub std::rc::Rc<str>);
28
29 impl std::borrow::Borrow<str> for $name {
30 fn borrow(&self) -> &str {
31 &self.0
32 }
33 }
34
35 impl From<&str> for $name {
36 fn from(s: &str) -> Self {
37 $name(std::rc::Rc::from(s))
38 }
39 }
40
41 impl From<String> for $name {
42 fn from(s: String) -> Self {
43 $name(std::rc::Rc::from(s))
44 }
45 }
46
47 impl<'de> serde::Deserialize<'de> for $name {
48 fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
49 where
50 D: serde::Deserializer<'de>,
51 {
52 use serde::de::Error;
53 const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
54
55 let id: String = serde::Deserialize::deserialize(deserialiser)?;
56 let id = id.trim();
57 if id.is_empty() {
58 return Err(D::Error::custom("IDs cannot be empty"));
59 }
60
61 for forbidden in FORBIDDEN_IDS.iter() {
62 if id.eq_ignore_ascii_case(forbidden) {
63 return Err(D::Error::custom(format!(
64 "'{id}' is an invalid value for an ID"
65 )));
66 }
67 }
68
69 Ok(id.into())
70 }
71 }
72
73 impl $name {
74 pub fn new(id: &str) -> Self {
76 $name(std::rc::Rc::from(id))
77 }
78 }
79 };
80}
81pub(crate) use define_id_type;
82
83#[cfg(test)]
84define_id_type!(GenericID);
85
86pub trait HasID<ID: IDLike> {
88 fn get_id(&self) -> &ID;
90}
91
92macro_rules! define_id_getter {
94 ($t:ty, $id_ty:ty) => {
95 impl crate::id::HasID<$id_ty> for $t {
96 fn get_id(&self) -> &$id_ty {
97 &self.id
98 }
99 }
100 };
101}
102pub(crate) use define_id_getter;
103
104pub trait IDCollection<ID: IDLike> {
106 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
116}
117
118macro_rules! define_id_methods {
119 () => {
120 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
121 let found = self
122 .get(id.borrow())
123 .with_context(|| format!("Unknown ID {id} found"))?;
124 Ok(found)
125 }
126 };
127}
128
129impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
130 define_id_methods!();
131}
132
133impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
134 define_id_methods!();
135}
136
137impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
138 fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
139 let (found, _) = self
140 .get_key_value(id.borrow())
141 .with_context(|| format!("Unknown ID {id} found"))?;
142 Ok(found)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use rstest::rstest;
150
151 use serde::Deserialize;
152
153 #[derive(Debug, Deserialize)]
154 struct Record {
155 id: GenericID,
156 }
157
158 fn deserialise_id(id: &str) -> Result<Record> {
159 Ok(toml::from_str(&format!("id = \"{id}\""))?)
160 }
161
162 #[rstest]
163 #[case("commodity1")]
164 #[case("some commodity")]
165 #[case("PROCESS")]
166 #[case("café")] fn test_deserialise_id_valid(#[case] id: &str) {
168 assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
169 }
170
171 #[rstest]
172 #[case("")]
173 #[case("all")]
174 #[case("annual")]
175 #[case("ALL")]
176 #[case(" ALL ")]
177 fn test_deserialise_id_invalid(#[case] id: &str) {
178 assert!(deserialise_id(id).is_err());
179 }
180}