1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6 self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7 State,
8};
9
10pub fn expand(
11 input: &syn::DeriveInput,
12 trait_name: &'static str,
13) -> Result<TokenStream> {
14 let syn::DeriveInput {
15 ident, generics, ..
16 } = input;
17
18 let state = State::with_attr_params(
19 input,
20 trait_name,
21 trait_name.to_lowercase(),
22 allowed_attr_params(),
23 )?;
24
25 let type_params: HashSet<_> = generics
26 .params
27 .iter()
28 .filter_map(|generic| match generic {
29 syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
30 _ => None,
31 })
32 .collect();
33
34 let (bounds, source, provide) = match state.derive_type {
35 DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
36 DeriveType::Enum => render_enum(&type_params, &state)?,
37 };
38
39 let source = source.map(|source| {
40 quote! {
43 fn source(&self) -> Option<&(dyn derive_more::with_trait::Error + 'static)> {
46 use derive_more::__private::AsDynError;
47 #source
48 }
49 }
50 });
51
52 let provide = provide.map(|provide| {
53 quote! {
56 fn provide<'_request>(
57 &'_request self,
58 request: &mut derive_more::core::error::Request<'_request>,
59 ) {
60 #provide
61 }
62 }
63 });
64
65 let mut generics = generics.clone();
66
67 if !type_params.is_empty() {
68 let (_, ty_generics, _) = generics.split_for_impl();
69 generics = utils::add_extra_where_clauses(
70 &generics,
71 quote! {
72 where
73 #ident #ty_generics: derive_more::core::fmt::Debug
74 + derive_more::core::fmt::Display
75 },
76 );
77 }
78
79 if !bounds.is_empty() {
80 let bounds = bounds.iter();
81 generics = utils::add_extra_where_clauses(
82 &generics,
83 quote! {
84 where #(
85 #bounds: derive_more::core::fmt::Debug
86 + derive_more::core::fmt::Display
87 + derive_more::with_trait::Error
90 + 'static
91 ),*
92 },
93 );
94 }
95
96 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
97
98 let render = quote! {
99 #[automatically_derived]
100 impl #impl_generics derive_more::with_trait::Error for #ident #ty_generics #where_clause {
103 #source
104 #provide
105 }
106 };
107
108 Ok(render)
109}
110
111fn render_struct(
112 type_params: &HashSet<syn::Ident>,
113 state: &State,
114) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
115 let parsed_fields = parse_fields(type_params, state)?;
116
117 let source = parsed_fields.render_source_as_struct();
118 let provide = parsed_fields.render_provide_as_struct();
119
120 Ok((parsed_fields.bounds, source, provide))
121}
122
123fn render_enum(
124 type_params: &HashSet<syn::Ident>,
125 state: &State,
126) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
127 let mut bounds = HashSet::default();
128 let mut source_match_arms = Vec::new();
129 let mut provide_match_arms = Vec::new();
130
131 for variant in state.enabled_variant_data().variants {
132 let default_info = FullMetaInfo {
133 enabled: true,
134 ..FullMetaInfo::default()
135 };
136
137 let state = State::from_variant(
138 state.input,
139 state.trait_name,
140 state.trait_attr.clone(),
141 allowed_attr_params(),
142 variant,
143 default_info,
144 )?;
145
146 let parsed_fields = parse_fields(type_params, &state)?;
147
148 if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
149 source_match_arms.push(expr);
150 }
151
152 if let Some(expr) = parsed_fields.render_provide_as_enum_variant_match_arm() {
153 provide_match_arms.push(expr);
154 }
155
156 bounds.extend(parsed_fields.bounds.into_iter());
157 }
158
159 let render = |match_arms: &mut Vec<TokenStream>, unmatched| {
160 if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
161 match_arms.push(quote! { _ => #unmatched });
162 }
163
164 (!match_arms.is_empty()).then(|| {
165 quote! {
166 match self {
167 #(#match_arms),*
168 }
169 }
170 })
171 };
172
173 let source = render(&mut source_match_arms, quote! { None });
174 let provide = render(&mut provide_match_arms, quote! { () });
175
176 Ok((bounds, source, provide))
177}
178
179fn allowed_attr_params() -> AttrParams {
180 AttrParams {
181 enum_: vec!["ignore"],
182 struct_: vec!["ignore"],
183 variant: vec!["ignore"],
184 field: vec!["ignore", "source", "backtrace"],
185 }
186}
187
188struct ParsedFields<'input, 'state> {
189 data: MultiFieldData<'input, 'state>,
190 source: Option<usize>,
191 backtrace: Option<usize>,
192 bounds: HashSet<syn::Type>,
193}
194
195impl<'input, 'state> ParsedFields<'input, 'state> {
196 fn new(data: MultiFieldData<'input, 'state>) -> Self {
197 Self {
198 data,
199 source: None,
200 backtrace: None,
201 bounds: HashSet::default(),
202 }
203 }
204}
205
206impl ParsedFields<'_, '_> {
207 fn render_source_as_struct(&self) -> Option<TokenStream> {
208 let source = self.source?;
209 let ident = &self.data.members[source];
210 Some(render_some(quote! { #ident }))
211 }
212
213 fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
214 let source = self.source?;
215 let pattern = self.data.matcher(&[source], &[quote! { source }]);
216 let expr = render_some(quote! { source });
217 Some(quote! { #pattern => #expr })
218 }
219
220 fn render_provide_as_struct(&self) -> Option<TokenStream> {
221 let backtrace = self.backtrace?;
222
223 let source_provider = self.source.map(|source| {
224 let source_expr = &self.data.members[source];
225 quote! {
226 derive_more::with_trait::Error::provide(&#source_expr, request);
229 }
230 });
231 let backtrace_provider = self
232 .source
233 .filter(|source| *source == backtrace)
234 .is_none()
235 .then(|| {
236 let backtrace_expr = &self.data.members[backtrace];
237 quote! {
238 request.provide_ref::<::std::backtrace::Backtrace>(&#backtrace_expr);
239 }
240 });
241
242 (source_provider.is_some() || backtrace_provider.is_some()).then(|| {
243 quote! {
244 #backtrace_provider
245 #source_provider
246 }
247 })
248 }
249
250 fn render_provide_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
251 let backtrace = self.backtrace?;
252
253 match self.source {
254 Some(source) if source == backtrace => {
255 let pattern = self.data.matcher(&[source], &[quote! { source }]);
256 Some(quote! {
257 #pattern => {
258 derive_more::with_trait::Error::provide(source, request);
261 }
262 })
263 }
264 Some(source) => {
265 let pattern = self.data.matcher(
266 &[source, backtrace],
267 &[quote! { source }, quote! { backtrace }],
268 );
269 Some(quote! {
270 #pattern => {
271 request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
272 derive_more::with_trait::Error::provide(source, request);
275 }
276 })
277 }
278 None => {
279 let pattern = self.data.matcher(&[backtrace], &[quote! { backtrace }]);
280 Some(quote! {
281 #pattern => {
282 request.provide_ref::<::std::backtrace::Backtrace>(backtrace);
283 }
284 })
285 }
286 }
287 }
288}
289
290fn render_some<T>(expr: T) -> TokenStream
291where
292 T: quote::ToTokens,
293{
294 quote! { Some(#expr.as_dyn_error()) }
295}
296
297fn parse_fields<'input, 'state>(
298 type_params: &HashSet<syn::Ident>,
299 state: &'state State<'input>,
300) -> Result<ParsedFields<'input, 'state>> {
301 let mut parsed_fields = match state.derive_type {
302 DeriveType::Named => {
303 parse_fields_impl(state, |attr, field, _| {
304 let ident = field.ident.as_ref().unwrap();
307
308 match attr {
309 "source" => ident == "source",
310 "backtrace" => {
311 ident == "backtrace"
312 || is_type_path_ends_with_segment(&field.ty, "Backtrace")
313 }
314 _ => unreachable!(),
315 }
316 })
317 }
318
319 DeriveType::Unnamed => {
320 let mut parsed_fields =
321 parse_fields_impl(state, |attr, field, len| match attr {
322 "source" => {
323 len == 1
324 && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
325 }
326 "backtrace" => {
327 is_type_path_ends_with_segment(&field.ty, "Backtrace")
328 }
329 _ => unreachable!(),
330 })?;
331
332 parsed_fields.source = parsed_fields
333 .source
334 .or_else(|| infer_source_field(&state.fields, &parsed_fields));
335
336 Ok(parsed_fields)
337 }
338
339 _ => unreachable!(),
340 }?;
341
342 if let Some(source) = parsed_fields.source {
343 add_bound_if_type_parameter_used_in_type(
344 &mut parsed_fields.bounds,
345 type_params,
346 &state.fields[source].ty,
347 );
348 }
349
350 Ok(parsed_fields)
351}
352
353fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
356 let syn::Type::Path(ty) = ty else {
357 return false;
358 };
359
360 let segment = ty.path.segments.last().unwrap();
363
364 if !matches!(segment.arguments, syn::PathArguments::None) {
365 return false;
366 }
367
368 segment.ident == tail
369}
370
371fn infer_source_field(
372 fields: &[&syn::Field],
373 parsed_fields: &ParsedFields,
374) -> Option<usize> {
375 if fields.len() != 2 {
377 return None;
378 }
379
380 if parsed_fields.source.is_some() {
382 return None;
383 }
384
385 if let Some(backtrace) = parsed_fields.backtrace {
387 let source = (backtrace + 1) % 2;
389 if parsed_fields.data.infos[source].info.source != Some(false) {
391 return Some(source);
392 }
393 }
394
395 None
396}
397
398fn parse_fields_impl<'input, 'state, P>(
399 state: &'state State<'input>,
400 is_valid_default_field_for_attr: P,
401) -> Result<ParsedFields<'input, 'state>>
402where
403 P: Fn(&str, &syn::Field, usize) -> bool,
404{
405 let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
406
407 let iter = fields
408 .iter()
409 .zip(infos.iter().map(|info| &info.info))
410 .enumerate()
411 .map(|(index, (field, info))| (index, *field, info));
412
413 let source = parse_field_impl(
414 &is_valid_default_field_for_attr,
415 state.fields.len(),
416 iter.clone(),
417 "source",
418 |info| info.source,
419 )?;
420
421 let backtrace = parse_field_impl(
422 &is_valid_default_field_for_attr,
423 state.fields.len(),
424 iter.clone(),
425 "backtrace",
426 |info| info.backtrace,
427 )?;
428
429 let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
430
431 if let Some((index, _, _)) = source {
432 parsed_fields.source = Some(index);
433 }
434
435 if let Some((index, _, _)) = backtrace {
436 parsed_fields.backtrace = Some(index);
437 }
438
439 Ok(parsed_fields)
440}
441
442fn parse_field_impl<'a, P, V>(
443 is_valid_default_field_for_attr: &P,
444 len: usize,
445 iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
446 attr: &str,
447 value: V,
448) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
449where
450 P: Fn(&str, &syn::Field, usize) -> bool,
451 V: Fn(&MetaInfo) -> Option<bool>,
452{
453 let explicit_fields = iter
454 .clone()
455 .filter(|(_, _, info)| matches!(value(info), Some(true)));
456
457 let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
458 None => is_valid_default_field_for_attr(attr, field, len),
459 _ => false,
460 });
461
462 let field = assert_iter_contains_zero_or_one_item(
463 explicit_fields,
464 &format!(
465 "Multiple `{attr}` attributes specified. \
466 Single attribute per struct/enum variant allowed.",
467 ),
468 )?;
469
470 let field = match field {
471 field @ Some(_) => field,
472 None => assert_iter_contains_zero_or_one_item(
473 inferred_fields,
474 "Conflicting fields found. Consider specifying some \
475 `#[error(...)]` attributes to resolve conflict.",
476 )?,
477 };
478
479 Ok(field)
480}
481
482fn assert_iter_contains_zero_or_one_item<'a>(
483 mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
484 error_msg: &str,
485) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
486 let Some(item) = iter.next() else {
487 return Ok(None);
488 };
489
490 if let Some((_, field, _)) = iter.next() {
491 return Err(Error::new(field.span(), error_msg));
492 }
493
494 Ok(Some(item))
495}
496
497fn add_bound_if_type_parameter_used_in_type(
498 bounds: &mut HashSet<syn::Type>,
499 type_params: &HashSet<syn::Ident>,
500 ty: &syn::Type,
501) {
502 if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
503 bounds.insert(ty);
504 }
505}