1use crate::{Error, Header, Result};
2use bytes::{Bytes, BytesMut};
3use core::marker::{PhantomData, PhantomPinned};
4
5pub trait Decodable: Sized {
7 fn decode(buf: &mut &[u8]) -> Result<Self>;
10}
11
12pub struct Rlp<'a> {
14 payload_view: &'a [u8],
15}
16
17impl<'a> Rlp<'a> {
18 pub fn new(mut payload: &'a [u8]) -> Result<Self> {
20 let payload_view = Header::decode_bytes(&mut payload, true)?;
21 Ok(Self { payload_view })
22 }
23
24 #[inline]
26 pub fn get_next<T: Decodable>(&mut self) -> Result<Option<T>> {
27 if self.payload_view.is_empty() {
28 Ok(None)
29 } else {
30 T::decode(&mut self.payload_view).map(Some)
31 }
32 }
33}
34
35impl<T: ?Sized> Decodable for PhantomData<T> {
36 fn decode(_buf: &mut &[u8]) -> Result<Self> {
37 Ok(Self)
38 }
39}
40
41impl Decodable for PhantomPinned {
42 fn decode(_buf: &mut &[u8]) -> Result<Self> {
43 Ok(Self)
44 }
45}
46
47impl Decodable for bool {
48 #[inline]
49 fn decode(buf: &mut &[u8]) -> Result<Self> {
50 Ok(match u8::decode(buf)? {
51 0 => false,
52 1 => true,
53 _ => return Err(Error::Custom("invalid bool value, must be 0 or 1")),
54 })
55 }
56}
57
58impl<const N: usize> Decodable for [u8; N] {
59 #[inline]
60 fn decode(from: &mut &[u8]) -> Result<Self> {
61 let bytes = Header::decode_bytes(from, false)?;
62 Self::try_from(bytes).map_err(|_| Error::UnexpectedLength)
63 }
64}
65
66macro_rules! decode_integer {
67 ($($t:ty),+ $(,)?) => {$(
68 impl Decodable for $t {
69 #[inline]
70 fn decode(buf: &mut &[u8]) -> Result<Self> {
71 let bytes = Header::decode_bytes(buf, false)?;
72 static_left_pad(bytes).map(<$t>::from_be_bytes)
73 }
74 }
75 )+};
76}
77
78decode_integer!(u8, u16, u32, u64, usize, u128);
79
80impl Decodable for Bytes {
81 #[inline]
82 fn decode(buf: &mut &[u8]) -> Result<Self> {
83 Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec()))
84 }
85}
86
87impl Decodable for BytesMut {
88 #[inline]
89 fn decode(buf: &mut &[u8]) -> Result<Self> {
90 Header::decode_bytes(buf, false).map(Self::from)
91 }
92}
93
94impl Decodable for alloc::string::String {
95 #[inline]
96 fn decode(buf: &mut &[u8]) -> Result<Self> {
97 Header::decode_str(buf).map(Into::into)
98 }
99}
100
101impl<T: Decodable> Decodable for alloc::vec::Vec<T> {
102 #[inline]
103 fn decode(buf: &mut &[u8]) -> Result<Self> {
104 let mut bytes = Header::decode_bytes(buf, true)?;
105 let mut vec = Self::new();
106 let payload_view = &mut bytes;
107 while !payload_view.is_empty() {
108 vec.push(T::decode(payload_view)?);
109 }
110 Ok(vec)
111 }
112}
113
114macro_rules! wrap_impl {
115 ($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$(
116 $(#[$attr])*
117 impl<$($gen)*> Decodable for $t {
118 #[inline]
119 fn decode(buf: &mut &[u8]) -> Result<Self> {
120 <$t2 as Decodable>::decode(buf).map(<$t>::$new)
121 }
122 }
123 )+};
124}
125
126wrap_impl! {
127 #[cfg(feature = "arrayvec")]
128 [const N: usize] <arrayvec::ArrayVec<u8, N>>::from([u8; N]),
129 [T: Decodable] <alloc::boxed::Box<T>>::new(T),
130 [T: Decodable] <alloc::rc::Rc<T>>::new(T),
131 [T: Decodable] <alloc::sync::Arc<T>>::new(T),
132}
133
134impl<T: ?Sized + alloc::borrow::ToOwned> Decodable for alloc::borrow::Cow<'_, T>
135where
136 T::Owned: Decodable,
137{
138 #[inline]
139 fn decode(buf: &mut &[u8]) -> Result<Self> {
140 T::Owned::decode(buf).map(Self::Owned)
141 }
142}
143
144#[cfg(feature = "std")]
145mod std_impl {
146 use super::*;
147 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
148
149 impl Decodable for IpAddr {
150 fn decode(buf: &mut &[u8]) -> Result<Self> {
151 let bytes = Header::decode_bytes(buf, false)?;
152 match bytes.len() {
153 4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))),
154 16 => {
155 Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible"))))
156 }
157 _ => Err(Error::UnexpectedLength),
158 }
159 }
160 }
161
162 impl Decodable for Ipv4Addr {
163 #[inline]
164 fn decode(buf: &mut &[u8]) -> Result<Self> {
165 let bytes = Header::decode_bytes(buf, false)?;
166 slice_to_array::<4>(bytes).map(Self::from)
167 }
168 }
169
170 impl Decodable for Ipv6Addr {
171 #[inline]
172 fn decode(buf: &mut &[u8]) -> Result<Self> {
173 let bytes = Header::decode_bytes(buf, false)?;
174 slice_to_array::<16>(bytes).map(Self::from)
175 }
176 }
177}
178
179#[inline]
185pub fn decode_exact<T: Decodable>(bytes: impl AsRef<[u8]>) -> Result<T> {
186 let mut buf = bytes.as_ref();
187 let out = T::decode(&mut buf)?;
188
189 if !buf.is_empty() {
191 return Err(Error::UnexpectedLength);
193 }
194
195 Ok(out)
196}
197
198#[inline]
204pub(crate) fn static_left_pad<const N: usize>(data: &[u8]) -> Result<[u8; N]> {
205 if data.len() > N {
206 return Err(Error::Overflow);
207 }
208
209 let mut v = [0; N];
210
211 if data.is_empty() {
213 return Ok(v);
214 }
215
216 if data[0] == 0 {
217 return Err(Error::LeadingZero);
218 }
219
220 unsafe { v.get_unchecked_mut(N - data.len()..) }.copy_from_slice(data);
222 Ok(v)
223}
224
225#[cfg(feature = "std")]
226#[inline]
227fn slice_to_array<const N: usize>(slice: &[u8]) -> Result<[u8; N]> {
228 slice.try_into().map_err(|_| Error::UnexpectedLength)
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::{encode, Encodable};
235 use core::fmt::Debug;
236 use hex_literal::hex;
237
238 #[allow(unused_imports)]
239 use alloc::{string::String, vec::Vec};
240
241 fn check_decode<'a, T, IT>(fixtures: IT)
242 where
243 T: Encodable + Decodable + PartialEq + Debug,
244 IT: IntoIterator<Item = (Result<T>, &'a [u8])>,
245 {
246 for (expected, mut input) in fixtures {
247 if let Ok(expected) = &expected {
248 assert_eq!(crate::encode(expected), input, "{expected:?}");
249 }
250
251 let orig = input;
252 assert_eq!(
253 T::decode(&mut input),
254 expected,
255 "input: {}{}",
256 hex::encode(orig),
257 if let Ok(expected) = &expected {
258 format!("; expected: {}", hex::encode(crate::encode(expected)))
259 } else {
260 String::new()
261 }
262 );
263
264 if expected.is_ok() {
265 assert_eq!(input, &[]);
266 }
267 }
268 }
269
270 #[test]
271 fn rlp_bool() {
272 let out = [0x80];
273 let val = bool::decode(&mut &out[..]);
274 assert_eq!(Ok(false), val);
275
276 let out = [0x01];
277 let val = bool::decode(&mut &out[..]);
278 assert_eq!(Ok(true), val);
279 }
280
281 #[test]
282 fn rlp_strings() {
283 check_decode::<Bytes, _>([
284 (Ok(hex!("00")[..].to_vec().into()), &hex!("00")[..]),
285 (
286 Ok(hex!("6f62636465666768696a6b6c6d")[..].to_vec().into()),
287 &hex!("8D6F62636465666768696A6B6C6D")[..],
288 ),
289 (Err(Error::UnexpectedList), &hex!("C0")[..]),
290 ])
291 }
292
293 #[test]
294 fn rlp_fixed_length() {
295 check_decode([
296 (Ok(hex!("6f62636465666768696a6b6c6d")), &hex!("8D6F62636465666768696A6B6C6D")[..]),
297 (Err(Error::UnexpectedLength), &hex!("8C6F62636465666768696A6B6C")[..]),
298 (Err(Error::UnexpectedLength), &hex!("8E6F62636465666768696A6B6C6D6E")[..]),
299 ])
300 }
301
302 #[test]
303 fn rlp_u64() {
304 check_decode([
305 (Ok(9_u64), &hex!("09")[..]),
306 (Ok(0_u64), &hex!("80")[..]),
307 (Ok(0x0505_u64), &hex!("820505")[..]),
308 (Ok(0xCE05050505_u64), &hex!("85CE05050505")[..]),
309 (Err(Error::Overflow), &hex!("8AFFFFFFFFFFFFFFFFFF7C")[..]),
310 (Err(Error::InputTooShort), &hex!("8BFFFFFFFFFFFFFFFFFF7C")[..]),
311 (Err(Error::UnexpectedList), &hex!("C0")[..]),
312 (Err(Error::LeadingZero), &hex!("00")[..]),
313 (Err(Error::NonCanonicalSingleByte), &hex!("8105")[..]),
314 (Err(Error::LeadingZero), &hex!("8200F4")[..]),
315 (Err(Error::NonCanonicalSize), &hex!("B8020004")[..]),
316 (
317 Err(Error::Overflow),
318 &hex!("A101000000000000000000000000000000000000008B000000000000000000000000")[..],
319 ),
320 ])
321 }
322
323 #[test]
324 fn rlp_vectors() {
325 check_decode::<Vec<u64>, _>([
326 (Ok(vec![]), &hex!("C0")[..]),
327 (Ok(vec![0xBBCCB5_u64, 0xFFC0B5_u64]), &hex!("C883BBCCB583FFC0B5")[..]),
328 ])
329 }
330
331 #[cfg(feature = "std")]
332 #[test]
333 fn rlp_ip() {
334 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
335
336 let localhost4 = Ipv4Addr::new(127, 0, 0, 1);
337 let localhost6 = localhost4.to_ipv6_mapped();
338 let expected4 = &hex!("847F000001")[..];
339 let expected6 = &hex!("9000000000000000000000ffff7f000001")[..];
340 check_decode::<Ipv4Addr, _>([(Ok(localhost4), expected4)]);
341 check_decode::<Ipv6Addr, _>([(Ok(localhost6), expected6)]);
342 check_decode::<IpAddr, _>([
343 (Ok(IpAddr::V4(localhost4)), expected4),
344 (Ok(IpAddr::V6(localhost6)), expected6),
345 ]);
346 }
347
348 #[test]
349 fn malformed_rlp() {
350 check_decode::<Bytes, _>([
351 (Err(Error::InputTooShort), &hex!("C1")[..]),
352 (Err(Error::InputTooShort), &hex!("D7")[..]),
353 ]);
354 check_decode::<[u8; 5], _>([
355 (Err(Error::InputTooShort), &hex!("C1")[..]),
356 (Err(Error::InputTooShort), &hex!("D7")[..]),
357 ]);
358 #[cfg(feature = "std")]
359 check_decode::<std::net::IpAddr, _>([
360 (Err(Error::InputTooShort), &hex!("C1")[..]),
361 (Err(Error::InputTooShort), &hex!("D7")[..]),
362 ]);
363 check_decode::<Vec<u8>, _>([
364 (Err(Error::InputTooShort), &hex!("C1")[..]),
365 (Err(Error::InputTooShort), &hex!("D7")[..]),
366 ]);
367 check_decode::<String, _>([
368 (Err(Error::InputTooShort), &hex!("C1")[..]),
369 (Err(Error::InputTooShort), &hex!("D7")[..]),
370 ]);
371 check_decode::<String, _>([
372 (Err(Error::InputTooShort), &hex!("C1")[..]),
373 (Err(Error::InputTooShort), &hex!("D7")[..]),
374 ]);
375 check_decode::<u8, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
376 check_decode::<u64, _>([(Err(Error::InputTooShort), &hex!("82")[..])]);
377 }
378
379 #[test]
380 fn rlp_full() {
381 fn check_decode_exact<T: Decodable + Encodable + PartialEq + Debug>(input: T) {
382 let encoded = encode(&input);
383 assert_eq!(decode_exact::<T>(&encoded), Ok(input));
384 assert_eq!(
385 decode_exact::<T>([encoded, vec![0x00]].concat()),
386 Err(Error::UnexpectedLength)
387 );
388 }
389
390 check_decode_exact::<String>("".into());
391 check_decode_exact::<String>("test1234".into());
392 check_decode_exact::<Vec<u64>>(vec![]);
393 check_decode_exact::<Vec<u64>>(vec![0; 4]);
394 }
395}