derive_more_impl/
from.rs

1//! Implementation of a [`From`] derive macro.
2
3use std::{
4    any::{Any, TypeId},
5    iter,
6};
7
8use proc_macro2::{Span, TokenStream};
9use quote::{format_ident, quote, ToTokens as _, TokenStreamExt as _};
10use syn::{
11    parse::{Parse, ParseStream},
12    parse_quote,
13    spanned::Spanned as _,
14    token,
15};
16
17use crate::utils::{
18    attr::{self, ParseMultiple as _},
19    polyfill, Either, Spanning,
20};
21
22/// Expands a [`From`] derive macro.
23pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStream> {
24    let attr_name = format_ident!("from");
25
26    match &input.data {
27        syn::Data::Struct(data) => Expansion {
28            attrs: StructAttribute::parse_attrs_with(
29                &input.attrs,
30                &attr_name,
31                &ConsiderLegacySyntax {
32                    fields: &data.fields,
33                },
34            )?
35            .map(|attr| attr.into_inner().into())
36            .as_ref(),
37            ident: &input.ident,
38            variant: None,
39            fields: &data.fields,
40            generics: &input.generics,
41            has_explicit_from: false,
42        }
43        .expand(),
44        syn::Data::Enum(data) => {
45            let mut has_explicit_from = false;
46            let attrs = data
47                .variants
48                .iter()
49                .map(|variant| {
50                    let attr = VariantAttribute::parse_attrs_with(
51                        &variant.attrs,
52                        &attr_name,
53                        &ConsiderLegacySyntax {
54                            fields: &variant.fields,
55                        },
56                    )?
57                    .map(Spanning::into_inner);
58                    if matches!(
59                        attr,
60                        Some(
61                            VariantAttribute::Empty(_)
62                                | VariantAttribute::Types(_)
63                                | VariantAttribute::Forward(_)
64                        ),
65                    ) {
66                        has_explicit_from = true;
67                    }
68                    Ok(attr)
69                })
70                .collect::<syn::Result<Vec<_>>>()?;
71
72            data.variants
73                .iter()
74                .zip(&attrs)
75                .map(|(variant, attrs)| {
76                    Expansion {
77                        attrs: attrs.as_ref(),
78                        ident: &input.ident,
79                        variant: Some(&variant.ident),
80                        fields: &variant.fields,
81                        generics: &input.generics,
82                        has_explicit_from,
83                    }
84                    .expand()
85                })
86                .collect()
87        }
88        syn::Data::Union(data) => Err(syn::Error::new(
89            data.union_token.span(),
90            "`From` cannot be derived for unions",
91        )),
92    }
93}
94
95/// Representation of a [`From`] derive macro struct container attribute.
96///
97/// ```rust,ignore
98/// #[from(forward)]
99/// #[from(<types>)]
100/// ```
101type StructAttribute = attr::Conversion;
102
103/// Representation of a [`From`] derive macro enum variant attribute.
104///
105/// ```rust,ignore
106/// #[from]
107/// #[from(skip)] #[from(ignore)]
108/// #[from(forward)]
109/// #[from(<types>)]
110/// ```
111type VariantAttribute = attr::FieldConversion;
112
113/// Expansion of a macro for generating [`From`] implementation of a struct or
114/// enum.
115struct Expansion<'a> {
116    /// [`From`] attributes.
117    ///
118    /// As a [`VariantAttribute`] is superset of a [`StructAttribute`], we use
119    /// it for both derives.
120    attrs: Option<&'a VariantAttribute>,
121
122    /// Struct or enum [`syn::Ident`].
123    ///
124    /// [`syn::Ident`]: struct@syn::Ident
125    ident: &'a syn::Ident,
126
127    /// Variant [`syn::Ident`] in case of enum expansion.
128    ///
129    /// [`syn::Ident`]: struct@syn::Ident
130    variant: Option<&'a syn::Ident>,
131
132    /// Struct or variant [`syn::Fields`].
133    fields: &'a syn::Fields,
134
135    /// Struct or enum [`syn::Generics`].
136    generics: &'a syn::Generics,
137
138    /// Indicator whether one of the enum variants has
139    /// [`VariantAttribute::Empty`], [`VariantAttribute::Types`] or
140    /// [`VariantAttribute::Forward`].
141    ///
142    /// Always [`false`] for structs.
143    has_explicit_from: bool,
144}
145
146impl Expansion<'_> {
147    /// Expands [`From`] implementations for a struct or an enum variant.
148    fn expand(&self) -> syn::Result<TokenStream> {
149        use crate::utils::FieldsExt as _;
150
151        let ident = self.ident;
152        let field_tys = self.fields.iter().map(|f| &f.ty).collect::<Vec<_>>();
153        let (impl_gens, ty_gens, where_clause) = self.generics.split_for_impl();
154
155        let skip_variant = self.has_explicit_from
156            || (self.variant.is_some() && self.fields.is_empty());
157        match (self.attrs, skip_variant) {
158            (Some(VariantAttribute::Types(tys)), _) => {
159                tys.0.iter().map(|ty| {
160                    let variant = self.variant.iter();
161
162                    let mut from_tys = self.fields.validate_type(ty)?;
163                    let init = self.expand_fields(|ident, ty, index| {
164                        let ident = ident.into_iter();
165                        let index = index.into_iter();
166                        let from_ty = from_tys.next().unwrap_or_else(|| unreachable!());
167                        quote! {
168                            #( #ident: )* <#ty as derive_more::core::convert::From<#from_ty>>::from(
169                                value #( .#index )*
170                            ),
171                        }
172                    });
173
174                    Ok(quote! {
175                        #[allow(unreachable_code)] // omit warnings for `!` and unreachable types
176                        #[automatically_derived]
177                        impl #impl_gens derive_more::core::convert::From<#ty>
178                         for #ident #ty_gens #where_clause {
179                            #[inline]
180                            fn from(value: #ty) -> Self {
181                                #ident #( :: #variant )* #init
182                            }
183                        }
184                    })
185                })
186                .collect()
187            }
188            (Some(VariantAttribute::Empty(_)), _) | (None, false) => {
189                let variant = self.variant.iter();
190                let init = self.expand_fields(|ident, _, index| {
191                    let ident = ident.into_iter();
192                    let index = index.into_iter();
193                    quote! { #( #ident: )* value #( . #index )*, }
194                });
195
196                Ok(quote! {
197                    #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
198                    #[automatically_derived]
199                    impl #impl_gens derive_more::core::convert::From<(#( #field_tys ),*)>
200                     for #ident #ty_gens #where_clause {
201                        #[inline]
202                        fn from(value: (#( #field_tys ),*)) -> Self {
203                            #ident #( :: #variant )* #init
204                        }
205                    }
206                })
207            }
208            (Some(VariantAttribute::Forward(_)), _) => {
209                let mut i = 0;
210                let mut gen_idents = Vec::with_capacity(self.fields.len());
211                let init = self.expand_fields(|ident, ty, index| {
212                    let ident = ident.into_iter();
213                    let index = index.into_iter();
214                    let gen_ident = format_ident!("__FromT{i}");
215                    let out = quote! {
216                        #( #ident: )* <#ty as derive_more::core::convert::From<#gen_ident>>::from(
217                            value #( .#index )*
218                        ),
219                    };
220                    gen_idents.push(gen_ident);
221                    i += 1;
222                    out
223                });
224
225                let variant = self.variant.iter();
226                let generics = {
227                    let mut generics = self.generics.clone();
228                    for (ty, ident) in field_tys.iter().zip(&gen_idents) {
229                        generics
230                            .make_where_clause()
231                            .predicates
232                            .push(parse_quote! { #ty: derive_more::core::convert::From<#ident> });
233                        generics
234                            .params
235                            .push(syn::TypeParam::from(ident.clone()).into());
236                    }
237                    generics
238                };
239                let (impl_gens, _, where_clause) = generics.split_for_impl();
240
241                Ok(quote! {
242                    #[allow(unreachable_code)] // omit warnings for `!` and other unreachable types
243                    #[automatically_derived]
244                    impl #impl_gens derive_more::core::convert::From<(#( #gen_idents ),*)>
245                     for #ident #ty_gens #where_clause {
246                        #[inline]
247                        fn from(value: (#( #gen_idents ),*)) -> Self {
248                            #ident #(:: #variant)* #init
249                        }
250                    }
251                })
252            }
253            (Some(VariantAttribute::Skip(_)), _) | (None, true) => {
254                Ok(TokenStream::new())
255            }
256        }
257    }
258
259    /// Expands fields initialization wrapped into [`token::Brace`]s in case of
260    /// [`syn::FieldsNamed`], or [`token::Paren`] in case of
261    /// [`syn::FieldsUnnamed`].
262    ///
263    /// [`token::Brace`]: struct@token::Brace
264    /// [`token::Paren`]: struct@token::Paren
265    fn expand_fields(
266        &self,
267        mut wrap: impl FnMut(
268            Option<&syn::Ident>,
269            &syn::Type,
270            Option<syn::Index>,
271        ) -> TokenStream,
272    ) -> TokenStream {
273        let surround = match self.fields {
274            syn::Fields::Named(_) | syn::Fields::Unnamed(_) => {
275                Some(|tokens| match self.fields {
276                    syn::Fields::Named(named) => {
277                        let mut out = TokenStream::new();
278                        named
279                            .brace_token
280                            .surround(&mut out, |out| out.append_all(tokens));
281                        out
282                    }
283                    syn::Fields::Unnamed(unnamed) => {
284                        let mut out = TokenStream::new();
285                        unnamed
286                            .paren_token
287                            .surround(&mut out, |out| out.append_all(tokens));
288                        out
289                    }
290                    syn::Fields::Unit => unreachable!(),
291                })
292            }
293            syn::Fields::Unit => None,
294        };
295
296        surround
297            .map(|surround| {
298                surround(if self.fields.len() == 1 {
299                    let field = self
300                        .fields
301                        .iter()
302                        .next()
303                        .unwrap_or_else(|| unreachable!("self.fields.len() == 1"));
304                    wrap(field.ident.as_ref(), &field.ty, None)
305                } else {
306                    self.fields
307                        .iter()
308                        .enumerate()
309                        .map(|(i, field)| {
310                            wrap(field.ident.as_ref(), &field.ty, Some(i.into()))
311                        })
312                        .collect()
313                })
314            })
315            .unwrap_or_default()
316    }
317}
318
319/// [`attr::Parser`] considering legacy syntax for [`attr::Types`] and emitting [`legacy_error`], if
320/// any occurs.
321struct ConsiderLegacySyntax<'a> {
322    /// [`syn::Fields`] of a struct or enum variant, the attribute is parsed for.
323    fields: &'a syn::Fields,
324}
325
326impl attr::Parser for ConsiderLegacySyntax<'_> {
327    fn parse<T: Parse + Any>(&self, input: ParseStream<'_>) -> syn::Result<T> {
328        if TypeId::of::<T>() == TypeId::of::<attr::Types>() {
329            let ahead = input.fork();
330            if let Ok(p) = ahead.parse::<syn::Path>() {
331                if p.is_ident("types") {
332                    return legacy_error(&ahead, input.span(), self.fields);
333                }
334            }
335        }
336        T::parse(input)
337    }
338}
339
340/// Constructs a [`syn::Error`] for legacy syntax: `#[from(types(i32, "&str"))]`.
341fn legacy_error<T>(
342    tokens: ParseStream<'_>,
343    span: Span,
344    fields: &syn::Fields,
345) -> syn::Result<T> {
346    let content;
347    syn::parenthesized!(content in tokens);
348
349    let types = content
350        .parse_terminated(polyfill::NestedMeta::parse, token::Comma)?
351        .into_iter()
352        .map(|meta| {
353            let value = match meta {
354                polyfill::NestedMeta::Meta(meta) => {
355                    meta.into_token_stream().to_string()
356                }
357                polyfill::NestedMeta::Lit(syn::Lit::Str(str)) => str.value(),
358                polyfill::NestedMeta::Lit(_) => unreachable!(),
359            };
360            if fields.len() > 1 {
361                format!(
362                    "({})",
363                    fields
364                        .iter()
365                        .map(|_| value.clone())
366                        .collect::<Vec<_>>()
367                        .join(", "),
368                )
369            } else {
370                value
371            }
372        })
373        .chain(match fields.len() {
374            0 => Either::Left(iter::empty()),
375            1 => Either::Right(iter::once(
376                fields
377                    .iter()
378                    .next()
379                    .unwrap_or_else(|| unreachable!("fields.len() == 1"))
380                    .ty
381                    .to_token_stream()
382                    .to_string(),
383            )),
384            _ => Either::Right(iter::once(format!(
385                "({})",
386                fields
387                    .iter()
388                    .map(|f| f.ty.to_token_stream().to_string())
389                    .collect::<Vec<_>>()
390                    .join(", ")
391            ))),
392        })
393        .collect::<Vec<_>>()
394        .join(", ");
395
396    Err(syn::Error::new(
397        span,
398        format!("legacy syntax, remove `types` and use `{types}` instead"),
399    ))
400}