alloy_sol_macro_input/
attr.rs

1use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{punctuated::Punctuated, Attribute, Error, LitBool, LitStr, Path, Result, Token};
5
6const DUPLICATE_ERROR: &str = "duplicate attribute";
7const UNKNOWN_ERROR: &str = "unknown `sol` attribute";
8
9/// Wraps the argument in a doc attribute.
10pub fn mk_doc(s: impl quote::ToTokens) -> TokenStream {
11    quote!(#[doc = #s])
12}
13
14/// Returns `true` if the attribute is `#[doc = "..."]`.
15pub fn is_doc(attr: &Attribute) -> bool {
16    attr.path().is_ident("doc")
17}
18
19/// Returns `true` if the attribute is `#[derive(...)]`.
20pub fn is_derive(attr: &Attribute) -> bool {
21    attr.path().is_ident("derive")
22}
23
24/// Returns an iterator over all the `#[doc = "..."]` attributes.
25pub fn docs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
26    attrs.iter().filter(|a| is_doc(a))
27}
28
29/// Flattens all the `#[doc = "..."]` attributes into a single string.
30pub fn docs_str(attrs: &[Attribute]) -> String {
31    let mut doc = String::new();
32    for attr in docs(attrs) {
33        let syn::Meta::NameValue(syn::MetaNameValue {
34            value: syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(s), .. }),
35            ..
36        }) = &attr.meta
37        else {
38            continue;
39        };
40
41        let value = s.value();
42        if !value.is_empty() {
43            if !doc.is_empty() {
44                doc.push('\n');
45            }
46            doc.push_str(&value);
47        }
48    }
49    doc
50}
51
52/// Returns an iterator over all the `#[derive(...)]` attributes.
53pub fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
54    attrs.iter().filter(|a| is_derive(a))
55}
56
57/// Returns an iterator over all the rust `::` paths in the `#[derive(...)]`
58/// attributes.
59pub fn derives_mapped(attrs: &[Attribute]) -> impl Iterator<Item = Path> + '_ {
60    derives(attrs).flat_map(parse_derives)
61}
62
63/// Parses the `#[derive(...)]` attributes into a list of paths.
64pub fn parse_derives(attr: &Attribute) -> Punctuated<Path, Token![,]> {
65    attr.parse_args_with(Punctuated::<Path, Token![,]>::parse_terminated).unwrap_or_default()
66}
67
68// When adding a new attribute:
69// 1. add a field to this struct,
70// 2. add a match arm in the `parse` function below,
71// 3. add test cases in the `tests` module at the bottom of this file,
72// 4. implement the attribute in your `SolInputExpander` implementation,
73// 5. document the attribute in the [`sol!`] macro docs.
74
75/// `#[sol(...)]` attributes.
76#[derive(Debug, Default, PartialEq, Eq)]
77pub struct SolAttrs {
78    /// `#[sol(rpc)]`
79    pub rpc: Option<bool>,
80    /// `#[sol(abi)]`
81    pub abi: Option<bool>,
82    /// `#[sol(all_derives)]`
83    pub all_derives: Option<bool>,
84    /// `#[sol(extra_methods)]`
85    pub extra_methods: Option<bool>,
86    /// `#[sol(docs)]`
87    pub docs: Option<bool>,
88
89    /// `#[sol(alloy_sol_types = alloy_core::sol_types)]`
90    pub alloy_sol_types: Option<Path>,
91    /// `#[sol(alloy_contract = alloy_contract)]`
92    pub alloy_contract: Option<Path>,
93
94    // TODO: Implement
95    /// UNIMPLEMENTED: `#[sol(rename = "new_name")]`
96    pub rename: Option<LitStr>,
97    // TODO: Implement
98    /// UNIMPLMENTED: `#[sol(rename_all = "camelCase")]`
99    pub rename_all: Option<CasingStyle>,
100
101    /// `#[sol(bytecode = "0x1234")]`
102    pub bytecode: Option<LitStr>,
103    /// `#[sol(deployed_bytecode = "0x1234")]`
104    pub deployed_bytecode: Option<LitStr>,
105
106    /// UDVT only `#[sol(type_check = "my_function")]`
107    pub type_check: Option<LitStr>,
108}
109
110impl SolAttrs {
111    /// Parse the `#[sol(...)]` attributes from a list of attributes.
112    pub fn parse(attrs: &[Attribute]) -> Result<(Self, Vec<Attribute>)> {
113        let mut this = Self::default();
114        let mut others = Vec::with_capacity(attrs.len());
115        for attr in attrs {
116            if !attr.path().is_ident("sol") {
117                others.push(attr.clone());
118                continue;
119            }
120
121            attr.meta.require_list()?.parse_nested_meta(|meta| {
122                let path = meta.path.get_ident().ok_or_else(|| meta.error("expected ident"))?;
123                let s = path.to_string();
124
125                macro_rules! match_ {
126                    ($($l:ident => $e:expr),* $(,)?) => {
127                        match s.as_str() {
128                            $(
129                                stringify!($l) => if this.$l.is_some() {
130                                    return Err(meta.error(DUPLICATE_ERROR))
131                                } else {
132                                    this.$l = Some($e);
133                                },
134                            )*
135                            _ => return Err(meta.error(UNKNOWN_ERROR)),
136                        }
137                    };
138                }
139
140                // `path` => true, `path = <bool>` => <bool>
141                let bool = || {
142                    if let Ok(input) = meta.value() {
143                        input.parse::<LitBool>().map(|lit| lit.value)
144                    } else {
145                        Ok(true)
146                    }
147                };
148
149                // `path = <path>`
150                let path = || meta.value()?.parse::<Path>();
151
152                // `path = "<str>"`
153                let lit = || meta.value()?.parse::<LitStr>();
154
155                // `path = "0x<hex>"`
156                let bytes = || {
157                    let lit = lit()?;
158                    if let Err(e) = hex::check(lit.value()) {
159                        let msg = format!("invalid hex value: {e}");
160                        return Err(Error::new(lit.span(), msg));
161                    }
162                    Ok(lit)
163                };
164
165                match_! {
166                    rpc => bool()?,
167                    abi => bool()?,
168                    all_derives => bool()?,
169                    extra_methods => bool()?,
170                    docs => bool()?,
171
172                    alloy_sol_types => path()?,
173                    alloy_contract => path()?,
174
175                    rename => lit()?,
176                    rename_all => CasingStyle::from_lit(&lit()?)?,
177
178                    bytecode => bytes()?,
179                    deployed_bytecode => bytes()?,
180
181                    type_check => lit()?,
182                };
183                Ok(())
184            })?;
185        }
186        Ok((this, others))
187    }
188}
189
190/// Trait for items that contain `#[sol(...)]` attributes among other
191/// attributes. This is usually a shortcut  for [`SolAttrs::parse`].
192pub trait ContainsSolAttrs {
193    /// Get the list of attributes.
194    fn attrs(&self) -> &[Attribute];
195
196    /// Parse the `#[sol(...)]` attributes from the list of attributes.
197    fn split_attrs(&self) -> syn::Result<(SolAttrs, Vec<Attribute>)> {
198        SolAttrs::parse(self.attrs())
199    }
200}
201
202impl ContainsSolAttrs for syn_solidity::File {
203    fn attrs(&self) -> &[Attribute] {
204        &self.attrs
205    }
206}
207
208impl ContainsSolAttrs for syn_solidity::ItemContract {
209    fn attrs(&self) -> &[Attribute] {
210        &self.attrs
211    }
212}
213
214impl ContainsSolAttrs for syn_solidity::ItemEnum {
215    fn attrs(&self) -> &[Attribute] {
216        &self.attrs
217    }
218}
219
220impl ContainsSolAttrs for syn_solidity::ItemError {
221    fn attrs(&self) -> &[Attribute] {
222        &self.attrs
223    }
224}
225
226impl ContainsSolAttrs for syn_solidity::ItemEvent {
227    fn attrs(&self) -> &[Attribute] {
228        &self.attrs
229    }
230}
231
232impl ContainsSolAttrs for syn_solidity::ItemFunction {
233    fn attrs(&self) -> &[Attribute] {
234        &self.attrs
235    }
236}
237
238impl ContainsSolAttrs for syn_solidity::ItemStruct {
239    fn attrs(&self) -> &[Attribute] {
240        &self.attrs
241    }
242}
243
244impl ContainsSolAttrs for syn_solidity::ItemUdt {
245    fn attrs(&self) -> &[Attribute] {
246        &self.attrs
247    }
248}
249
250/// Defines the casing for the attributes long representation.
251#[derive(Clone, Copy, Debug, PartialEq, Eq)]
252pub enum CasingStyle {
253    /// Indicate word boundaries with uppercase letter, excluding the first
254    /// word.
255    Camel,
256    /// Keep all letters lowercase and indicate word boundaries with hyphens.
257    Kebab,
258    /// Indicate word boundaries with uppercase letter, including the first
259    /// word.
260    Pascal,
261    /// Keep all letters uppercase and indicate word boundaries with
262    /// underscores.
263    ScreamingSnake,
264    /// Keep all letters lowercase and indicate word boundaries with
265    /// underscores.
266    Snake,
267    /// Keep all letters lowercase and remove word boundaries.
268    Lower,
269    /// Keep all letters uppercase and remove word boundaries.
270    Upper,
271    /// Use the original attribute name defined in the code.
272    Verbatim,
273}
274
275impl CasingStyle {
276    fn from_lit(name: &LitStr) -> Result<Self> {
277        let normalized = name.value().to_upper_camel_case().to_lowercase();
278        let s = match normalized.as_ref() {
279            "camel" | "camelcase" => Self::Camel,
280            "kebab" | "kebabcase" => Self::Kebab,
281            "pascal" | "pascalcase" => Self::Pascal,
282            "screamingsnake" | "screamingsnakecase" => Self::ScreamingSnake,
283            "snake" | "snakecase" => Self::Snake,
284            "lower" | "lowercase" => Self::Lower,
285            "upper" | "uppercase" => Self::Upper,
286            "verbatim" | "verbatimcase" => Self::Verbatim,
287            s => return Err(Error::new(name.span(), format!("unsupported casing: {s}"))),
288        };
289        Ok(s)
290    }
291
292    /// Apply the casing style to the given string.
293    #[allow(dead_code)]
294    pub fn apply(self, s: &str) -> String {
295        match self {
296            Self::Pascal => s.to_upper_camel_case(),
297            Self::Kebab => s.to_kebab_case(),
298            Self::Camel => s.to_lower_camel_case(),
299            Self::ScreamingSnake => s.to_shouty_snake_case(),
300            Self::Snake => s.to_snake_case(),
301            Self::Lower => s.to_snake_case().replace('_', ""),
302            Self::Upper => s.to_shouty_snake_case().replace('_', ""),
303            Self::Verbatim => s.to_owned(),
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use syn::parse_quote;
312
313    macro_rules! test_sol_attrs {
314        ($($group:ident { $($t:tt)* })+) => {$(
315            #[test]
316            fn $group() {
317                test_sol_attrs! { $($t)* }
318            }
319        )+};
320
321        ($( $(#[$attr:meta])* => $expected:expr ),+ $(,)?) => {$(
322            run_test(
323                &[$(stringify!(#[$attr])),*],
324                $expected
325            );
326        )+};
327    }
328
329    macro_rules! sol_attrs {
330        ($($id:ident : $e:expr),* $(,)?) => {
331            SolAttrs {
332                $($id: Some($e),)*
333                ..Default::default()
334            }
335        };
336    }
337
338    struct OuterAttribute(Vec<Attribute>);
339
340    impl syn::parse::Parse for OuterAttribute {
341        fn parse(input: syn::parse::ParseStream<'_>) -> Result<Self> {
342            input.call(Attribute::parse_outer).map(Self)
343        }
344    }
345
346    fn run_test(
347        attrs_s: &'static [&'static str],
348        expected: std::result::Result<SolAttrs, &'static str>,
349    ) {
350        let attrs: Vec<Attribute> =
351            attrs_s.iter().flat_map(|s| syn::parse_str::<OuterAttribute>(s).unwrap().0).collect();
352        match (SolAttrs::parse(&attrs), expected) {
353            (Ok((actual, _)), Ok(expected)) => assert_eq!(actual, expected, "{attrs_s:?}"),
354            (Err(actual), Err(expected)) => {
355                let actual = actual.to_string();
356                if !actual.contains(expected) {
357                    assert_eq!(actual, expected, "{attrs_s:?}")
358                }
359            }
360            (a, b) => panic!("assertion failed: `{a:?} != {b:?}`: {attrs_s:?}"),
361        }
362    }
363
364    test_sol_attrs! {
365        top_level {
366            #[cfg] => Ok(SolAttrs::default()),
367            #[cfg()] => Ok(SolAttrs::default()),
368            #[cfg = ""] => Ok(SolAttrs::default()),
369            #[derive()] #[sol()] => Ok(SolAttrs::default()),
370            #[sol()] => Ok(SolAttrs::default()),
371            #[sol()] #[sol()] => Ok(SolAttrs::default()),
372            #[sol = ""] => Err("expected `(`"),
373            #[sol] => Err("expected attribute arguments in parentheses: `sol(...)`"),
374
375            #[sol(() = "")] => Err("unexpected token in nested attribute, expected ident"),
376            #[sol(? = "")] => Err("unexpected token in nested attribute, expected ident"),
377            #[sol(::a)] => Err("expected ident"),
378            #[sol(::a = "")] => Err("expected ident"),
379            #[sol(a::b = "")] => Err("expected ident"),
380        }
381
382        extra {
383            #[sol(all_derives)] => Ok(sol_attrs! { all_derives: true }),
384            #[sol(all_derives = true)] => Ok(sol_attrs! { all_derives: true }),
385            #[sol(all_derives = false)] => Ok(sol_attrs! { all_derives: false }),
386            #[sol(all_derives = "false")] => Err("expected boolean literal"),
387            #[sol(all_derives)] #[sol(all_derives)] => Err(DUPLICATE_ERROR),
388
389            #[sol(extra_methods)] => Ok(sol_attrs! { extra_methods: true }),
390            #[sol(extra_methods = true)] => Ok(sol_attrs! { extra_methods: true }),
391            #[sol(extra_methods = false)] => Ok(sol_attrs! { extra_methods: false }),
392
393            #[sol(docs)] => Ok(sol_attrs! { docs: true }),
394            #[sol(docs = true)] => Ok(sol_attrs! { docs: true }),
395            #[sol(docs = false)] => Ok(sol_attrs! { docs: false }),
396
397            #[sol(abi)] => Ok(sol_attrs! { abi: true }),
398            #[sol(abi = true)] => Ok(sol_attrs! { abi: true }),
399            #[sol(abi = false)] => Ok(sol_attrs! { abi: false }),
400
401            #[sol(rpc)] => Ok(sol_attrs! { rpc: true }),
402            #[sol(rpc = true)] => Ok(sol_attrs! { rpc: true }),
403            #[sol(rpc = false)] => Ok(sol_attrs! { rpc: false }),
404
405            #[sol(alloy_sol_types)] => Err("expected `=`"),
406            #[sol(alloy_sol_types = alloy_core::sol_types)] => Ok(sol_attrs! { alloy_sol_types: parse_quote!(alloy_core::sol_types) }),
407            #[sol(alloy_sol_types = ::alloy_core::sol_types)] => Ok(sol_attrs! { alloy_sol_types: parse_quote!(::alloy_core::sol_types) }),
408            #[sol(alloy_sol_types = alloy::sol_types)] => Ok(sol_attrs! { alloy_sol_types: parse_quote!(alloy::sol_types) }),
409            #[sol(alloy_sol_types = ::alloy::sol_types)] => Ok(sol_attrs! { alloy_sol_types: parse_quote!(::alloy::sol_types) }),
410
411            #[sol(alloy_contract)] => Err("expected `=`"),
412            #[sol(alloy_contract = alloy::contract)] => Ok(sol_attrs! { alloy_contract: parse_quote!(alloy::contract) }),
413            #[sol(alloy_contract = ::alloy::contract)] => Ok(sol_attrs! { alloy_contract: parse_quote!(::alloy::contract) }),
414        }
415
416        rename {
417            #[sol(rename = "foo")] => Ok(sol_attrs! { rename: parse_quote!("foo") }),
418
419            #[sol(rename_all = "foo")] => Err("unsupported casing: foo"),
420            #[sol(rename_all = "camelcase")] => Ok(sol_attrs! { rename_all: CasingStyle::Camel }),
421            #[sol(rename_all = "camelCase")] #[sol(rename_all = "PascalCase")] => Err(DUPLICATE_ERROR),
422        }
423
424        bytecode {
425            #[sol(deployed_bytecode = "0x1234")] => Ok(sol_attrs! { deployed_bytecode: parse_quote!("0x1234") }),
426            #[sol(bytecode = "0x1234")] => Ok(sol_attrs! { bytecode: parse_quote!("0x1234") }),
427            #[sol(bytecode = "1234")] => Ok(sol_attrs! { bytecode: parse_quote!("1234") }),
428            #[sol(bytecode = "0x123xyz")] => Err("invalid hex value: "),
429            #[sol(bytecode = "12 34")] => Err("invalid hex value: "),
430            #[sol(bytecode = "xyz")] => Err("invalid hex value: "),
431            #[sol(bytecode = "123")] => Err("invalid hex value: "),
432        }
433
434        type_check {
435            #[sol(type_check = "my_function")] => Ok(sol_attrs! { type_check: parse_quote!("my_function") }),
436            #[sol(type_check = "my_function1")] #[sol(type_check = "my_function2")] => Err(DUPLICATE_ERROR),
437        }
438    }
439}