extension_trait/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use std::borrow::Cow;
6use syn::ext::IdentExt;
7use syn::parse::{Nothing, Parse, ParseStream};
8use syn::spanned::Spanned;
9use syn::{
10    parse_macro_input, Attribute, FnArg, Generics, Ident, ImplItem, ImplItemConst, ImplItemFn,
11    ImplItemType, ItemImpl, Pat, PatIdent, PatReference, PatTuple, PatType, Receiver, Signature,
12    Visibility,
13};
14
15struct ItemImplWithVisibility {
16    attrs: Vec<Attribute>,
17    visibility: Visibility,
18    impl_item: ItemImpl,
19}
20
21impl Parse for ItemImplWithVisibility {
22    fn parse(input: ParseStream) -> syn::Result<Self> {
23        let attrs = input.call(Attribute::parse_outer)?;
24        let visibility = input.parse()?;
25        let impl_item = input.parse()?;
26        Ok(Self {
27            attrs,
28            visibility,
29            impl_item,
30        })
31    }
32}
33
34/// Declares an extension trait
35///
36/// # Example
37///
38/// ```
39/// #[macro_use]
40/// extern crate extension_trait;
41///
42/// #[extension_trait]
43/// pub impl DoubleExt for str {
44///    fn double(&self) -> String {
45///        self.repeat(2)
46///    }
47/// }
48///
49/// fn main() {
50///     assert_eq!("Hello".double(), "HelloHello");
51/// }
52/// ```
53#[proc_macro_attribute]
54pub fn extension_trait(args: TokenStream, input: TokenStream) -> TokenStream {
55    parse_macro_input!(args as Nothing);
56    let ItemImplWithVisibility {
57        attrs,
58        visibility,
59        impl_item,
60    } = parse_macro_input!(input as ItemImplWithVisibility);
61    let ItemImpl {
62        impl_token,
63        unsafety,
64        trait_,
65        items,
66        ..
67    } = &impl_item;
68    let items = items.iter().map(|item| match item {
69        ImplItem::Const(ImplItemConst {
70            attrs,
71            vis: _,
72            defaultness: None,
73            const_token,
74            ident,
75            generics: Generics {
76                lt_token: None,
77                params: _,
78                gt_token: None,
79                where_clause: None,
80            },
81            colon_token,
82            ty,
83            eq_token: _,
84            expr: _,
85            semi_token,
86        }) => quote! { #(#attrs)* #const_token #ident #colon_token #ty #semi_token },
87        ImplItem::Fn(ImplItemFn {
88            attrs,
89            vis: _,
90            defaultness: None,
91            sig: Signature {
92                constness,
93                asyncness,
94                unsafety,
95                abi,
96                fn_token,
97                ident,
98                generics,
99                paren_token: _,
100                inputs,
101                variadic,
102                output
103            },
104            block: _,
105        }) => {
106            let inputs = inputs.into_iter().map(|arg| {
107                let span = arg.span();
108                match arg {
109                    FnArg::Typed(PatType { attrs, pat, colon_token, ty }) => {
110                        let ident = extract_ident(pat).unwrap_or_else(|| Cow::Owned(Ident::new("_", span)));
111                        quote! { #(#attrs)* #ident #colon_token #ty }
112                    },
113                    FnArg::Receiver(Receiver {
114                        attrs,
115                        reference: None,
116                        mutability: _,
117                        self_token,
118                        colon_token: Some(colon),
119                        ty,
120                    }) => quote! { #(#attrs)* #self_token #colon #ty },
121                    FnArg::Receiver(receiver) => receiver.into_token_stream(),
122                }
123            });
124            let where_clause = &generics.where_clause;
125            quote! {
126                #(#attrs)*
127                #constness #asyncness #unsafety #abi #fn_token #ident #generics (#(#inputs,)* #variadic) #output #where_clause;
128            }
129        },
130        ImplItem::Type(ImplItemType {
131            attrs,
132            type_token,
133            ident,
134            generics,
135            semi_token,
136            ..
137        }) => quote! { #(#attrs)* #type_token #ident #generics #semi_token },
138        _ => syn::Error::new(item.span(), "unsupported item type").to_compile_error(),
139    });
140    if let Some((None, path, _)) = trait_ {
141        (quote! {
142            #(#attrs)*
143            #visibility #unsafety trait #path {
144                #(#items)*
145            }
146            #impl_item
147        })
148        .into()
149    } else {
150        syn::Error::new(impl_token.span(), "extension trait name was not provided")
151            .to_compile_error()
152            .into()
153    }
154}
155
156fn extract_ident(pat: &Pat) -> Option<Cow<'_, Ident>> {
157    match pat {
158        Pat::Reference(PatReference { pat, .. }) => extract_ident(pat),
159        Pat::Ident(PatIdent { ident, .. }) => Some(Cow::Borrowed(ident)),
160        Pat::Tuple(PatTuple { elems, .. }) => {
161            if elems.len() <= 1 {
162                extract_ident(elems.into_iter().next()?)
163            } else {
164                let span = elems.span();
165                let elems = elems
166                    .into_iter()
167                    .map(extract_ident)
168                    .map(|o| o.map(|ident| ident.unraw().to_string()))
169                    .collect::<Option<Vec<String>>>()?;
170                let joined = elems.join("_");
171                Some(Cow::Owned(Ident::new(&joined, span)))
172            }
173        }
174        _ => None,
175    }
176}