ruint/support/
serde.rs

1//! Support for the [`serde`](https://crates.io/crates/serde) crate.
2
3#![cfg(feature = "serde")]
4#![cfg_attr(docsrs, doc(cfg(feature = "serde")))]
5
6use crate::{nbytes, Bits, Uint};
7use core::{
8    fmt::{Formatter, Result as FmtResult, Write},
9    str,
10};
11use serde::{
12    de::{Error, Unexpected, Visitor},
13    Deserialize, Deserializer, Serialize, Serializer,
14};
15
16#[allow(unused_imports)]
17use alloc::string::String;
18
19/// Canonical serialization for all human-readable instances of `Uint<0, 0>`,
20/// and minimal human-readable `Uint<BITS, LIMBS>::ZERO` for any bit size.
21const ZERO_STR: &str = "0x0";
22
23impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
24    fn serialize_human_full<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
25        if BITS == 0 {
26            return s.serialize_str(ZERO_STR);
27        }
28
29        let mut result = String::with_capacity(2 + nbytes(BITS) * 2);
30        result.push_str("0x");
31
32        self.as_le_bytes()
33            .iter()
34            .rev()
35            .try_for_each(|byte| write!(result, "{byte:02x}"))
36            .unwrap();
37
38        s.serialize_str(&result)
39    }
40
41    fn serialize_human_minimal<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
42        if BITS == 0 {
43            return s.serialize_str(ZERO_STR);
44        }
45
46        let le_bytes = self.as_le_bytes();
47        let mut bytes = le_bytes.iter().rev().skip_while(|b| **b == 0);
48
49        // We avoid String allocation if there is no non-0 byte
50        // If there is a first byte, we allocate a string, and write the prefix
51        // and first byte to it
52        let mut result = match bytes.next() {
53            Some(b) => {
54                let mut result = String::with_capacity(2 + nbytes(BITS) * 2);
55                write!(result, "0x{b:x}").unwrap();
56                result
57            }
58            None => return s.serialize_str(ZERO_STR),
59        };
60        bytes
61            .try_for_each(|byte| write!(result, "{byte:02x}"))
62            .unwrap();
63
64        s.serialize_str(&result)
65    }
66
67    fn serialize_binary<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
68        s.serialize_bytes(&self.to_be_bytes_vec())
69    }
70}
71
72/// Serialize a [`Uint`] value.
73///
74/// For human readable formats a `0x` prefixed lower case hex string is used.
75/// For binary formats a byte array is used. Leading zeros are included.
76impl<const BITS: usize, const LIMBS: usize> Serialize for Uint<BITS, LIMBS> {
77    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
78        if serializer.is_human_readable() {
79            self.serialize_human_minimal(serializer)
80        } else {
81            self.serialize_binary(serializer)
82        }
83    }
84}
85
86/// Deserialize human readable hex strings or byte arrays into hashes.
87/// Hex strings can be upper/lower/mixed case, have an optional `0x` prefix, and
88/// can be any length. They are interpreted big-endian.
89impl<'de, const BITS: usize, const LIMBS: usize> Deserialize<'de> for Uint<BITS, LIMBS> {
90    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
91        if deserializer.is_human_readable() {
92            deserializer.deserialize_any(HrVisitor)
93        } else {
94            deserializer.deserialize_bytes(ByteVisitor)
95        }
96    }
97}
98
99impl<const BITS: usize, const LIMBS: usize> Serialize for Bits<BITS, LIMBS> {
100    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
101        if serializer.is_human_readable() {
102            self.as_uint().serialize_human_full(serializer)
103        } else {
104            self.as_uint().serialize_binary(serializer)
105        }
106    }
107}
108
109impl<'de, const BITS: usize, const LIMBS: usize> Deserialize<'de> for Bits<BITS, LIMBS> {
110    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
111        Uint::deserialize(deserializer).map(Self::from)
112    }
113}
114
115/// Serde Visitor for human readable formats.
116///
117/// Accepts either a primitive number, a decimal or a hexadecimal string.
118struct HrVisitor<const BITS: usize, const LIMBS: usize>;
119
120impl<'de, const BITS: usize, const LIMBS: usize> Visitor<'de> for HrVisitor<BITS, LIMBS> {
121    type Value = Uint<BITS, LIMBS>;
122
123    fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
124        write!(formatter, "a {} byte hex string", nbytes(BITS))
125    }
126
127    fn visit_u64<E: Error>(self, v: u64) -> Result<Self::Value, E> {
128        Uint::try_from(v).map_err(|_| Error::invalid_value(Unexpected::Unsigned(v), &self))
129    }
130
131    fn visit_u128<E: Error>(self, v: u128) -> Result<Self::Value, E> {
132        // `Unexpected::Unsigned` cannot contain a `u128`
133        Uint::try_from(v).map_err(Error::custom)
134    }
135
136    fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
137        // Shortcut for common case
138        if value == ZERO_STR {
139            return Ok(Uint::<BITS, LIMBS>::ZERO);
140        }
141        // `ZERO_STR` is the only valid serialization of `Uint<0, 0>`, so if we
142        // have not shortcut, we are in an error case
143        if BITS == 0 {
144            return Err(Error::invalid_value(Unexpected::Str(value), &self));
145        }
146
147        value
148            .parse()
149            .map_err(|_| Error::invalid_value(Unexpected::Str(value), &self))
150    }
151}
152
153/// Serde Visitor for non-human readable formats
154struct ByteVisitor<const BITS: usize, const LIMBS: usize>;
155
156impl<'de, const BITS: usize, const LIMBS: usize> Visitor<'de> for ByteVisitor<BITS, LIMBS> {
157    type Value = Uint<BITS, LIMBS>;
158
159    fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
160        write!(formatter, "{BITS} bits of binary data in big endian order")
161    }
162
163    fn visit_bytes<E: Error>(self, value: &[u8]) -> Result<Self::Value, E> {
164        if value.len() != nbytes(BITS) {
165            return Err(Error::invalid_length(value.len(), &self));
166        }
167        Uint::try_from_be_slice(value).ok_or_else(|| {
168            Error::invalid_value(
169                Unexpected::Other(&format!("too large for Uint<{BITS}>")),
170                &self,
171            )
172        })
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::{const_for, nlimbs};
180    use proptest::proptest;
181
182    #[allow(unused_imports)]
183    use alloc::vec::Vec;
184
185    #[test]
186    fn test_serde_human_readable() {
187        const_for!(BITS in SIZES {
188            const LIMBS: usize = nlimbs(BITS);
189            proptest!(|(value: Uint<BITS, LIMBS>)| {
190                let serialized = serde_json::to_string(&value).unwrap();
191                let deserialized = serde_json::from_str(&serialized).unwrap();
192                assert_eq!(value, deserialized);
193            });
194            proptest!(|(value: Bits<BITS, LIMBS>)| {
195                let serialized = serde_json::to_string(&value).unwrap();
196                let deserialized = serde_json::from_str(&serialized).unwrap();
197                assert_eq!(value, deserialized);
198            });
199        });
200    }
201
202    #[test]
203    fn test_human_readable_de() {
204        let jason = r#"[
205            1,
206            "0x1",
207            "0o1",
208            "0b1"
209        ]"#;
210        let numbers: Vec<Uint<1, 1>> = serde_json::from_str(jason).unwrap();
211        uint! {
212            assert_eq!(numbers, vec![1_U1, 1_U1, 1_U1, 1_U1]);
213        }
214
215        let jason = r#"[
216            "",
217            "0x",
218            "0o",
219            "0b"
220        ]"#;
221        let numbers: Vec<Uint<1, 1>> = serde_json::from_str(jason).unwrap();
222        uint! {
223            assert_eq!(numbers, vec![0_U1, 0_U1, 0_U1, 0_U1]);
224        }
225    }
226
227    #[test]
228    fn test_serde_machine_readable() {
229        const_for!(BITS in SIZES {
230            const LIMBS: usize = nlimbs(BITS);
231            proptest!(|(value: Uint<BITS, LIMBS>)| {
232                let serialized = bincode::serialize(&value).unwrap();
233                let deserialized = bincode::deserialize(&serialized[..]).unwrap();
234                assert_eq!(value, deserialized);
235            });
236            proptest!(|(value: Bits<BITS, LIMBS>)| {
237                let serialized = bincode::serialize(&value).unwrap();
238                let deserialized = bincode::deserialize(&serialized[..]).unwrap();
239                assert_eq!(value, deserialized);
240            });
241        });
242    }
243
244    #[test]
245    fn test_serde_invalid_size_error() {
246        // Test that if we add a character to a value that is already the max length for
247        // the given number of bits, we get an error.
248        const_for!(BITS in SIZES {
249            const LIMBS: usize = nlimbs(BITS);
250            let value = Uint::<BITS, LIMBS>::MAX;
251            let mut serialized = serde_json::to_string(&value).unwrap();
252
253            // ensure format of serialized value is correct ("0x...")
254            assert_eq!(&serialized[..3], "\"0x");
255            // last character should be a quote
256            assert_eq!(&serialized[serialized.len() - 1..], "\"");
257
258            // strip the last character, add a zero, and finish with a quote
259            serialized.pop();
260            serialized.push('0');
261            serialized.push('"');
262            let deserialized = serde_json::from_str::<Uint<BITS, LIMBS>>(&serialized);
263            assert!(deserialized.is_err(), "{BITS} {serialized}");
264        });
265    }
266}