Skip to main content

muse2/
id.rs

1//! Code for handling IDs
2use anyhow::Result;
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::error::Error;
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9use std::marker::PhantomData;
10
11/// A trait for ID types
12pub trait ID: Eq + Hash + Borrow<str> + Clone + Display + Debug + From<String> {
13    /// Get the name of this type of ID (e.g. "commodity ID", "region ID" etc.)
14    fn get_type_name() -> &'static str;
15}
16
17macro_rules! define_id_type {
18    ($name:ident, $type_name:expr) => {
19        #[derive(
20            Clone,
21            derive_more::Display,
22            std::hash::Hash,
23            PartialOrd,
24            Ord,
25            PartialEq,
26            Eq,
27            Debug,
28            serde::Serialize,
29        )]
30        /// An ID type (e.g. `AgentID`, `CommodityID`, etc.)
31        pub struct $name(pub std::rc::Rc<str>);
32
33        impl std::borrow::Borrow<str> for $name {
34            fn borrow(&self) -> &str {
35                &self.0
36            }
37        }
38
39        impl From<&str> for $name {
40            fn from(s: &str) -> Self {
41                $name(std::rc::Rc::from(s))
42            }
43        }
44
45        impl From<String> for $name {
46            fn from(s: String) -> Self {
47                $name(std::rc::Rc::from(s))
48            }
49        }
50
51        impl<'de> serde::Deserialize<'de> for $name {
52            fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
53            where
54                D: serde::Deserializer<'de>,
55            {
56                use serde::de::Error;
57                const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
58
59                let id: String = serde::Deserialize::deserialize(deserialiser)?;
60                let id = id.trim();
61                if id.is_empty() {
62                    return Err(D::Error::custom("IDs cannot be empty"));
63                }
64
65                for forbidden in FORBIDDEN_IDS.iter() {
66                    if id.eq_ignore_ascii_case(forbidden) {
67                        return Err(D::Error::custom(format!(
68                            "'{id}' is an invalid value for an ID"
69                        )));
70                    }
71                }
72
73                Ok(id.into())
74            }
75        }
76
77        impl crate::id::ID for $name {
78            fn get_type_name() -> &'static str {
79                $type_name
80            }
81        }
82
83        impl $name {
84            /// Create a new ID from a string slice
85            pub fn new(id: &str) -> Self {
86                $name(std::rc::Rc::from(id))
87            }
88        }
89    };
90}
91pub(crate) use define_id_type;
92
93#[cfg(test)]
94define_id_type!(GenericID, "generic ID");
95
96/// Indicates that the struct has an ID field
97pub trait HasID<T: ID> {
98    /// Get the struct's ID
99    fn get_id(&self) -> &T;
100}
101
102/// Implement the `HasID` trait for the given type, assuming it has a field called `id`
103macro_rules! define_id_getter {
104    ($t:ty, $id_ty:ty) => {
105        impl crate::id::HasID<$id_ty> for $t {
106            fn get_id(&self) -> &$id_ty {
107                &self.id
108            }
109        }
110    };
111}
112pub(crate) use define_id_getter;
113
114/// Indicates that the specified ID was not found in a given collection
115#[derive(Debug)]
116pub struct MissingIDError<T: ID> {
117    missing_id: String,
118    _phantom: PhantomData<fn() -> T>,
119}
120
121impl<T: ID> MissingIDError<T> {
122    /// Create a new `MissingIDError`
123    pub fn new(missing_id: &str) -> MissingIDError<T> {
124        MissingIDError::<T> {
125            missing_id: missing_id.to_string(),
126            _phantom: std::marker::PhantomData,
127        }
128    }
129}
130
131impl<T: ID> Display for MissingIDError<T> {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "Unknown {} '{}'", T::get_type_name(), self.missing_id)
134    }
135}
136
137impl<T: ID> Error for MissingIDError<T> {}
138
139/// A data structure containing a set of IDs
140pub trait IDCollection<T: ID> {
141    /// Check if the ID is in the collection, returning a reference to it if found.
142    ///
143    /// # Arguments
144    ///
145    /// * `id` - The ID to check (can be string or ID type)
146    ///
147    /// # Returns
148    ///
149    /// A reference to the ID in `self`, or an error if not found.
150    fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>>;
151}
152
153macro_rules! define_id_methods {
154    () => {
155        fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>> {
156            self.get(id.borrow())
157                .ok_or_else(|| MissingIDError::new(id.borrow()))
158        }
159    };
160}
161
162impl<T: ID> IDCollection<T> for HashSet<T> {
163    define_id_methods!();
164}
165
166impl<T: ID> IDCollection<T> for IndexSet<T> {
167    define_id_methods!();
168}
169
170impl<T: ID, V> IDCollection<T> for IndexMap<T, V> {
171    fn get_id<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<&T, MissingIDError<T>> {
172        let (found, _) = self
173            .get_key_value(id.borrow())
174            .ok_or_else(|| MissingIDError::new(id.borrow()))?;
175        Ok(found)
176    }
177}
178
179/// A trait for getting an ID and a value from a map
180pub trait GetIDValue<K: ID, V> {
181    /// Get the ID and value, if any, for the given collection
182    fn get_id_value<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<(&K, &V), MissingIDError<K>>;
183}
184
185impl<K: ID + Borrow<str>, V> GetIDValue<K, V> for IndexMap<K, V> {
186    fn get_id_value<S: Borrow<str> + ?Sized>(&self, id: &S) -> Result<(&K, &V), MissingIDError<K>> {
187        self.get_key_value(id.borrow())
188            .ok_or_else(|| MissingIDError::new(id.borrow()))
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use rstest::rstest;
196
197    use serde::Deserialize;
198
199    #[derive(Debug, Deserialize)]
200    struct Record {
201        id: GenericID,
202    }
203
204    fn deserialise_id(id: &str) -> Result<Record> {
205        Ok(toml::from_str(&format!("id = \"{id}\""))?)
206    }
207
208    #[rstest]
209    #[case("commodity1")]
210    #[case("some commodity")]
211    #[case("PROCESS")]
212    #[case("café")] // unicode supported
213    fn deserialise_id_valid(#[case] id: &str) {
214        assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
215    }
216
217    #[rstest]
218    #[case("")]
219    #[case("all")]
220    #[case("annual")]
221    #[case("ALL")]
222    #[case(" ALL ")]
223    fn deserialise_id_invalid(#[case] id: &str) {
224        deserialise_id(id).unwrap_err();
225    }
226}