derive_more_impl/
error.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6    self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7    State,
8};
9
10pub fn expand(
11    input: &syn::DeriveInput,
12    trait_name: &'static str,
13) -> Result<TokenStream> {
14    let syn::DeriveInput {
15        ident, generics, ..
16    } = input;
17
18    let state = State::with_attr_params(
19        input,
20        trait_name,
21        trait_name.to_lowercase(),
22        allowed_attr_params(),
23    )?;
24
25    let type_params: HashSet<_> = generics
26        .params
27        .iter()
28        .filter_map(|generic| match generic {
29            syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
30            _ => None,
31        })
32        .collect();
33
34    let (bounds, source, provide) = match state.derive_type {
35        DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
36        DeriveType::Enum => render_enum(&type_params, &state)?,
37    };
38
39    let source = source.map(|source| {
40        // Not using `#[inline]` here on purpose, since this is almost never part
41        // of a hot codepath.
42        quote! {
43            // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is
44            //       stabilized.
45            fn source(&self) -> Option<&(dyn derive_more::with_trait::Error + 'static)> {
46                use derive_more::__private::AsDynError;
47                #source
48            }
49        }
50    });
51
52    let provide = provide.map(|provide| {
53        // Not using `#[inline]` here on purpose, since this is almost never part
54        // of a hot codepath.
55        quote! {
56            fn provide<'_request>(
57                &'_request self,
58                request: &mut derive_more::core::error::Request<'_request>,
59            ) {
60                #provide
61            }
62        }
63    });
64
65    let mut generics = generics.clone();
66
67    if !type_params.is_empty() {
68        let (_, ty_generics, _) = generics.split_for_impl();
69        generics = utils::add_extra_where_clauses(
70            &generics,
71            quote! {
72                where
73                    #ident #ty_generics: derive_more::core::fmt::Debug
74                                         + derive_more::core::fmt::Display
75            },
76        );
77    }
78
79    if !bounds.is_empty() {
80        let bounds = bounds.iter();
81        generics = utils::add_extra_where_clauses(
82            &generics,
83            quote! {
84                where #(
85                    #bounds: derive_more::core::fmt::Debug
86                             + derive_more::core::fmt::Display
87                             // TODO: Use `derive_more::core::error::Error` once `error_in_core`
88                             //       Rust feature is stabilized.
89                             + derive_more::with_trait::Error
90                             + 'static
91                ),*
92            },
93        );
94    }
95
96    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
97
98    let render = quote! {
99        #[automatically_derived]
100        // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is
101        //       stabilized.
102        impl #impl_generics derive_more::with_trait::Error for #ident #ty_generics #where_clause {
103            #source
104            #provide
105        }
106    };
107
108    Ok(render)
109}
110
111fn render_struct(
112    type_params: &HashSet<syn::Ident>,
113    state: &State,
114) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
115    let parsed_fields = parse_fields(type_params, state)?;
116
117    let source = parsed_fields.render_source_as_struct();
118    let provide = parsed_fields.render_provide_as_struct();
119
120    Ok((parsed_fields.bounds, source, provide))
121}
122
123fn render_enum(
124    type_params: &HashSet<syn::Ident>,
125    state: &State,
126) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
127    let mut bounds = HashSet::default();
128    let mut source_match_arms = Vec::new();
129    let mut provide_match_arms = Vec::new();
130
131    for variant in state.enabled_variant_data().variants {
132        let default_info = FullMetaInfo {
133            enabled: true,
134            ..FullMetaInfo::default()
135        };
136
137        let state = State::from_variant(
138            state.input,
139            state.trait_name,
140            state.trait_attr.clone(),
141            allowed_attr_params(),
142            variant,
143            default_info,
144        )?;
145
146        let parsed_fields = parse_fields(type_params, &state)?;
147
148        if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
149            source_match_arms.push(expr);
150        }
151
152        if let Some(expr) = parsed_fields.render_provide_as_enum_variant_match_arm() {
153            provide_match_arms.push(expr);
154        }
155
156        bounds.extend(parsed_fields.bounds.into_iter());
157    }
158
159    let render = |match_arms: &mut Vec<TokenStream>, unmatched| {
160        if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
161            match_arms.push(quote! { _ => #unmatched });
162        }
163
164        (!match_arms.is_empty()).then(|| {
165            quote! {
166                match self {
167                    #(#match_arms),*
168                }
169            }
170        })
171    };
172
173    let source = render(&mut source_match_arms, quote! { None });
174    let provide = render(&mut provide_match_arms, quote! { () });
175
176    Ok((bounds, source, provide))
177}
178
179fn allowed_attr_params() -> AttrParams {
180    AttrParams {
181        enum_: vec!["ignore"],
182        struct_: vec!["ignore"],
183        variant: vec!["ignore"],
184        field: vec!["ignore", "source", "backtrace"],
185    }
186}
187
188struct ParsedFields<'input, 'state> {
189    data: MultiFieldData<'input, 'state>,
190    source: Option<usize>,
191    backtrace: Option<usize>,
192    bounds: HashSet<syn::Type>,
193}
194
195impl<'input, 'state> ParsedFields<'input, 'state> {
196    fn new(data: MultiFieldData<'input, 'state>) -> Self {
197        Self {
198            data,
199            source: None,
200            backtrace: None,
201            bounds: HashSet::default(),
202        }
203    }
204}
205
206impl ParsedFields<'_, '_> {
207    fn render_source_as_struct(&self) -> Option<TokenStream> {
208        let source = self.source?;
209        let ident = &self.data.members[source];
210        Some(render_some(quote! { #ident }))
211    }
212
213    fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
214        let source = self.source?;
215        let pattern = self.data.matcher(&[source], &[quote! { source }]);
216        let expr = render_some(quote! { source });
217        Some(quote! { #pattern => #expr })
218    }
219
220    fn render_provide_as_struct(&self) -> Option<TokenStream> {
221        let backtrace = self.backtrace?;
222
223        let source_provider = self.source.map(|source| {
224            let source_expr = &self.data.members[source];
225            quote! {
226                // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust feature is
227                //       stabilized.
228                derive_more::with_trait::Error::provide(&#source_expr, request);
229            }
230        });
231        let backtrace_provider = self
232            .source
233            .filter(|source| *source == backtrace)
234            .is_none()
235            .then(|| {
236                let backtrace_expr = &self.data.members[backtrace];
237                quote! {
238                    request.provide_ref::<::std::backtrace::Backtrace>(&#backtrace_expr);
239                }
240            });
241
242        (source_provider.is_some() || backtrace_provider.is_some()).then(|| {
243            quote! {
244                #backtrace_provider
245                #source_provider
246            }
247        })
248    }
249
250    fn render_provide_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
251        let backtrace = self.backtrace?;
252
253        match self.source {
254            Some(source) if source == backtrace => {
255                let pattern = self.data.matcher(&[source], &[quote! { source }]);
256                Some(quote! {
257                    #pattern => {
258                        // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust
259                        //       feature is stabilized.
260                        derive_more::with_trait::Error::provide(source, request);
261                    }
262                })
263            }
264            Some(source) => {
265                let pattern = self.data.matcher(
266                    &[source, backtrace],
267                    &[quote! { source }, quote! { backtrace }],
268                );
269                Some(quote! {
270                    #pattern => {
271                        request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
272                        // TODO: Use `derive_more::core::error::Error` once `error_in_core` Rust
273                        //       feature is stabilized.
274                        derive_more::with_trait::Error::provide(source, request);
275                    }
276                })
277            }
278            None => {
279                let pattern = self.data.matcher(&[backtrace], &[quote! { backtrace }]);
280                Some(quote! {
281                    #pattern => {
282                        request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
283                    }
284                })
285            }
286        }
287    }
288}
289
290fn render_some<T>(expr: T) -> TokenStream
291where
292    T: quote::ToTokens,
293{
294    quote! { Some(#expr.as_dyn_error()) }
295}
296
297fn parse_fields<'input, 'state>(
298    type_params: &HashSet<syn::Ident>,
299    state: &'state State<'input>,
300) -> Result<ParsedFields<'input, 'state>> {
301    let mut parsed_fields = match state.derive_type {
302        DeriveType::Named => {
303            parse_fields_impl(state, |attr, field, _| {
304                // Unwrapping is safe, cause fields in named struct
305                // always have an ident
306                let ident = field.ident.as_ref().unwrap();
307
308                match attr {
309                    "source" => ident == "source",
310                    "backtrace" => {
311                        ident == "backtrace"
312                            || is_type_path_ends_with_segment(&field.ty, "Backtrace")
313                    }
314                    _ => unreachable!(),
315                }
316            })
317        }
318
319        DeriveType::Unnamed => {
320            let mut parsed_fields =
321                parse_fields_impl(state, |attr, field, len| match attr {
322                    "source" => {
323                        len == 1
324                            && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
325                    }
326                    "backtrace" => {
327                        is_type_path_ends_with_segment(&field.ty, "Backtrace")
328                    }
329                    _ => unreachable!(),
330                })?;
331
332            parsed_fields.source = parsed_fields
333                .source
334                .or_else(|| infer_source_field(&state.fields, &parsed_fields));
335
336            Ok(parsed_fields)
337        }
338
339        _ => unreachable!(),
340    }?;
341
342    if let Some(source) = parsed_fields.source {
343        add_bound_if_type_parameter_used_in_type(
344            &mut parsed_fields.bounds,
345            type_params,
346            &state.fields[source].ty,
347        );
348    }
349
350    Ok(parsed_fields)
351}
352
353/// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
354/// and doesn't contain any generic parameters.
355fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
356    let syn::Type::Path(ty) = ty else {
357        return false;
358    };
359
360    // Unwrapping is safe, cause 'syn::TypePath.path.segments'
361    // have to have at least one segment
362    let segment = ty.path.segments.last().unwrap();
363
364    if !matches!(segment.arguments, syn::PathArguments::None) {
365        return false;
366    }
367
368    segment.ident == tail
369}
370
371fn infer_source_field(
372    fields: &[&syn::Field],
373    parsed_fields: &ParsedFields,
374) -> Option<usize> {
375    // if we have exactly two fields
376    if fields.len() != 2 {
377        return None;
378    }
379
380    // no source field was specified/inferred
381    if parsed_fields.source.is_some() {
382        return None;
383    }
384
385    // but one of the fields was specified/inferred as backtrace field
386    if let Some(backtrace) = parsed_fields.backtrace {
387        // then infer *other field* as source field
388        let source = (backtrace + 1) % 2;
389        // unless it was explicitly marked as non-source
390        if parsed_fields.data.infos[source].info.source != Some(false) {
391            return Some(source);
392        }
393    }
394
395    None
396}
397
398fn parse_fields_impl<'input, 'state, P>(
399    state: &'state State<'input>,
400    is_valid_default_field_for_attr: P,
401) -> Result<ParsedFields<'input, 'state>>
402where
403    P: Fn(&str, &syn::Field, usize) -> bool,
404{
405    let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
406
407    let iter = fields
408        .iter()
409        .zip(infos.iter().map(|info| &info.info))
410        .enumerate()
411        .map(|(index, (field, info))| (index, *field, info));
412
413    let source = parse_field_impl(
414        &is_valid_default_field_for_attr,
415        state.fields.len(),
416        iter.clone(),
417        "source",
418        |info| info.source,
419    )?;
420
421    let backtrace = parse_field_impl(
422        &is_valid_default_field_for_attr,
423        state.fields.len(),
424        iter.clone(),
425        "backtrace",
426        |info| info.backtrace,
427    )?;
428
429    let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
430
431    if let Some((index, _, _)) = source {
432        parsed_fields.source = Some(index);
433    }
434
435    if let Some((index, _, _)) = backtrace {
436        parsed_fields.backtrace = Some(index);
437    }
438
439    Ok(parsed_fields)
440}
441
442fn parse_field_impl<'a, P, V>(
443    is_valid_default_field_for_attr: &P,
444    len: usize,
445    iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
446    attr: &str,
447    value: V,
448) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
449where
450    P: Fn(&str, &syn::Field, usize) -> bool,
451    V: Fn(&MetaInfo) -> Option<bool>,
452{
453    let explicit_fields = iter
454        .clone()
455        .filter(|(_, _, info)| matches!(value(info), Some(true)));
456
457    let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
458        None => is_valid_default_field_for_attr(attr, field, len),
459        _ => false,
460    });
461
462    let field = assert_iter_contains_zero_or_one_item(
463        explicit_fields,
464        &format!(
465            "Multiple `{attr}` attributes specified. \
466             Single attribute per struct/enum variant allowed.",
467        ),
468    )?;
469
470    let field = match field {
471        field @ Some(_) => field,
472        None => assert_iter_contains_zero_or_one_item(
473            inferred_fields,
474            "Conflicting fields found. Consider specifying some \
475             `#[error(...)]` attributes to resolve conflict.",
476        )?,
477    };
478
479    Ok(field)
480}
481
482fn assert_iter_contains_zero_or_one_item<'a>(
483    mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
484    error_msg: &str,
485) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
486    let Some(item) = iter.next() else {
487        return Ok(None);
488    };
489
490    if let Some((_, field, _)) = iter.next() {
491        return Err(Error::new(field.span(), error_msg));
492    }
493
494    Ok(Some(item))
495}
496
497fn add_bound_if_type_parameter_used_in_type(
498    bounds: &mut HashSet<syn::Type>,
499    type_params: &HashSet<syn::Ident>,
500    ty: &syn::Type,
501) {
502    if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
503        bounds.insert(ty);
504    }
505}