muse2/
id.rs

1//! Code for handling IDs
2use anyhow::{Context, Result};
3use indexmap::{IndexMap, IndexSet};
4use std::borrow::Borrow;
5use std::collections::HashSet;
6use std::fmt::Display;
7use std::hash::Hash;
8
9/// A trait alias for ID types
10pub trait IDLike: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
11impl<T> IDLike for T where T: Eq + Hash + Borrow<str> + Clone + Display + From<String> {}
12
13macro_rules! define_id_type {
14    ($name:ident) => {
15        #[derive(
16            Clone,
17            derive_more::Display,
18            std::hash::Hash,
19            PartialOrd,
20            Ord,
21            PartialEq,
22            Eq,
23            Debug,
24            serde::Serialize,
25        )]
26        /// An ID type (e.g. `AgentID`, `CommodityID`, etc.)
27        pub struct $name(pub std::rc::Rc<str>);
28
29        impl std::borrow::Borrow<str> for $name {
30            fn borrow(&self) -> &str {
31                &self.0
32            }
33        }
34
35        impl From<&str> for $name {
36            fn from(s: &str) -> Self {
37                $name(std::rc::Rc::from(s))
38            }
39        }
40
41        impl From<String> for $name {
42            fn from(s: String) -> Self {
43                $name(std::rc::Rc::from(s))
44            }
45        }
46
47        impl<'de> serde::Deserialize<'de> for $name {
48            fn deserialize<D>(deserialiser: D) -> std::result::Result<Self, D::Error>
49            where
50                D: serde::Deserializer<'de>,
51            {
52                use serde::de::Error;
53                const FORBIDDEN_IDS: [&str; 2] = ["all", "annual"];
54
55                let id: String = serde::Deserialize::deserialize(deserialiser)?;
56                let id = id.trim();
57                if id.is_empty() {
58                    return Err(D::Error::custom("IDs cannot be empty"));
59                }
60
61                for forbidden in FORBIDDEN_IDS.iter() {
62                    if id.eq_ignore_ascii_case(forbidden) {
63                        return Err(D::Error::custom(format!(
64                            "'{id}' is an invalid value for an ID"
65                        )));
66                    }
67                }
68
69                Ok(id.into())
70            }
71        }
72
73        impl $name {
74            /// Create a new ID from a string slice
75            pub fn new(id: &str) -> Self {
76                $name(std::rc::Rc::from(id))
77            }
78        }
79    };
80}
81pub(crate) use define_id_type;
82
83#[cfg(test)]
84define_id_type!(GenericID);
85
86/// Indicates that the struct has an ID field
87pub trait HasID<ID: IDLike> {
88    /// Get the struct's ID
89    fn get_id(&self) -> &ID;
90}
91
92/// Implement the `HasID` trait for the given type, assuming it has a field called `id`
93macro_rules! define_id_getter {
94    ($t:ty, $id_ty:ty) => {
95        impl crate::id::HasID<$id_ty> for $t {
96            fn get_id(&self) -> &$id_ty {
97                &self.id
98            }
99        }
100    };
101}
102pub(crate) use define_id_getter;
103
104/// A data structure containing a set of IDs
105pub trait IDCollection<ID: IDLike> {
106    /// Check if the ID is in the collection, returning a copy of it if found.
107    ///
108    /// # Arguments
109    ///
110    /// * `id` - The ID to check (can be string or ID type)
111    ///
112    /// # Returns
113    ///
114    /// A copy of the ID in `self`, or an error if not found.
115    fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID>;
116}
117
118macro_rules! define_id_methods {
119    () => {
120        fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
121            let found = self
122                .get(id.borrow())
123                .with_context(|| format!("Unknown ID {id} found"))?;
124            Ok(found)
125        }
126    };
127}
128
129impl<ID: IDLike> IDCollection<ID> for HashSet<ID> {
130    define_id_methods!();
131}
132
133impl<ID: IDLike> IDCollection<ID> for IndexSet<ID> {
134    define_id_methods!();
135}
136
137impl<ID: IDLike, V> IDCollection<ID> for IndexMap<ID, V> {
138    fn get_id<T: Borrow<str> + Display + ?Sized>(&self, id: &T) -> Result<&ID> {
139        let (found, _) = self
140            .get_key_value(id.borrow())
141            .with_context(|| format!("Unknown ID {id} found"))?;
142        Ok(found)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use rstest::rstest;
150
151    use serde::Deserialize;
152
153    #[derive(Debug, Deserialize)]
154    struct Record {
155        id: GenericID,
156    }
157
158    fn deserialise_id(id: &str) -> Result<Record> {
159        Ok(toml::from_str(&format!("id = \"{id}\""))?)
160    }
161
162    #[rstest]
163    #[case("commodity1")]
164    #[case("some commodity")]
165    #[case("PROCESS")]
166    #[case("café")] // unicode supported
167    fn test_deserialise_id_valid(#[case] id: &str) {
168        assert_eq!(deserialise_id(id).unwrap().id.to_string(), id);
169    }
170
171    #[rstest]
172    #[case("")]
173    #[case("all")]
174    #[case("annual")]
175    #[case("ALL")]
176    #[case(" ALL ")]
177    fn test_deserialise_id_invalid(#[case] id: &str) {
178        assert!(deserialise_id(id).is_err());
179    }
180}