muse2/
id.rs

1//! Code for handling IDs
2use anyhow::{Context, Result};
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::fmt::Display;
7use std::hash::Hash;
8
9/// A trait alias for ID types
10pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
11impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
12
13macro_rules! define_id_type {
14    ($name:ident) => {
15        #[derive(
16            Clone, std::hash::Hash, PartialOrd, Ord, PartialEq, Eq, Debug, serde::Serialize,
17        )]
18        /// An ID type (e.g. `AgentID`, `CommodityID`, etc.)
19        pub struct $name(pub std::rc::Rc<str>);
20
21        impl std::borrow::Borrow<str> for $name {
22            fn borrow(&self) -> &str {
23                &self.0
24            }
25        }
26
27        impl std::fmt::Display for $name {
28            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29                write!(f, "{}", self.0)
30            }
31        }
32
33        impl From<&str> for $name {
34            fn from(s: &str) -> Self {
35                $name(std::rc::Rc::from(s))
36            }
37        }
38
39        impl From<String> for $name {
40            fn from(s: String) -> Self {
41                $name(std::rc::Rc::from(s))
42            }
43        }
44
45        impl<'de> serde::Deserialize<'de> for $name {
46            fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
47            where
48                D: serde::Deserializer<'de>,
49            {
50                use serde::de::Error;
51
52                let id: String = serde::Deserialize::deserialize(deserialiser)?;
53                let id = id.trim();
54                if id.is_empty() {
55                    return Err(D::Error::custom("IDs cannot be empty"));
56                }
57
58                const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
59                for forbidden in FORBIDDEN_IDS.iter() {
60                    if id.eq_ignore_ascii_case(forbidden) {
61                        return Err(D::Error::custom(format!(
62                            "'{id}' is an invalid value for an ID"
63                        )));
64                    }
65                }
66
67                Ok(id.into())
68            }
69        }
70
71        impl $name {
72            /// Create a new ID from a string slice
73            pub fn new(id: &str) -> Self {
74                $name(std::rc::Rc::from(id))
75            }
76        }
77    };
78}
79pub(crate) use define_id_type;
80
81#[cfg(test)]
82define_id_type!(GenericID);
83
84/// Indicates that the struct has an ID field
85pub trait HasID<ID: IDLike> {
86    /// Get the struct's ID
87    fn get_id(&self) -> &ID;
88}
89
90/// Implement the `HasID` trait for the given type, assuming it has a field called `id`
91macro_rules! define_id_getter {
92    ($t:ty, $id_ty:ty) => {
93        impl crate::id::HasID<$id_ty> for $t {
94            fn get_id(&self) -> &$id_ty {
95                &self.id
96            }
97        }
98    };
99}
100pub(crate) use define_id_getter;
101
102/// A data structure containing a set of IDs
103pub trait IDCollection<ID: IDLike> {
104    /// Check if the ID is in the collection, returning a copy of it if found.
105    ///
106    /// # Arguments
107    ///
108    /// * `id` - The ID to check (can be string or ID type)
109    ///
110    /// # Returns
111    ///
112    /// A copy of the ID in `self`, or an error if not found.
113    fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
114}
115
116macro_rules! define_id_methods {
117    () => {
118        fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
119            let found = self
120                .get(id.borrow())
121                .with_context(|| format!("Unknown ID {id} found"))?;
122            Ok(found)
123        }
124    };
125}
126
127impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
128    define_id_methods!();
129}
130
131impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
132    define_id_methods!();
133}
134
135impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
136    fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
137        let (found, _) = self
138            .get_key_value(id.borrow())
139            .with_context(|| format!("Unknown ID {id} found"))?;
140        Ok(found)
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use rstest::rstest;
148
149    use serde::Deserialize;
150
151    #[derive(Debug, Deserialize)]
152    struct Record {
153        id: GenericID,
154    }
155
156    fn deserialise_id(id: &str) -> Result<Record> {
157        Ok(toml::from_str(&format!("id = \"{id}\""))?)
158    }
159
160    #[rstest]
161    #[case("commodity1")]
162    #[case("some commodity")]
163    #[case("PROCESS")]
164    #[case("café")] // unicode supported
165    fn test_deserialise_id_valid(#[case] id: &str) {
166        assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
167    }
168
169    #[rstest]
170    #[case("")]
171    #[case("all")]
172    #[case("annual")]
173    #[case("ALL")]
174    #[case(" ALL ")]
175    fn test_deserialise_id_invalid(#[case] id: &str) {
176        assert!(deserialise_id(id).is_err());
177    }
178}