derive_more_impl/
from_str.rs

1use crate::utils::{DeriveType, HashMap};
2use crate::utils::{SingleFieldData, State};
3use proc_macro2::TokenStream;
4use quote::quote;
5use syn::{parse::Result, DeriveInput};
6
7/// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr`
8pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
9    let state = State::new(input, trait_name, trait_name.to_lowercase())?;
10
11    if state.derive_type == DeriveType::Enum {
12        Ok(enum_from(input, state, trait_name))
13    } else {
14        Ok(struct_from(&state, trait_name))
15    }
16}
17
18pub fn struct_from(state: &State, trait_name: &'static str) -> TokenStream {
19    // We cannot set defaults for fields, once we do we can remove this check
20    if state.fields.len() != 1 || state.enabled_fields().len() != 1 {
21        panic_one_field(trait_name);
22    }
23
24    let single_field_data = state.assert_single_enabled_field();
25    let SingleFieldData {
26        input_type,
27        field_type,
28        trait_path,
29        casted_trait,
30        impl_generics,
31        ty_generics,
32        where_clause,
33        ..
34    } = single_field_data.clone();
35
36    let initializers = [quote! { #casted_trait::from_str(src)? }];
37    let body = single_field_data.initializer(&initializers);
38    let error = quote! { <#field_type as #trait_path>::Err };
39
40    quote! {
41        #[automatically_derived]
42        impl #impl_generics #trait_path for #input_type #ty_generics #where_clause {
43            type Err = #error;
44
45            #[inline]
46            fn from_str(src: &str) -> derive_more::core::result::Result<Self, #error> {
47                derive_more::core::result::Result::Ok(#body)
48            }
49        }
50    }
51}
52
53fn enum_from(
54    input: &DeriveInput,
55    state: State,
56    trait_name: &'static str,
57) -> TokenStream {
58    let mut variants_caseinsensitive = HashMap::default();
59    for variant_state in state.enabled_variant_data().variant_states {
60        let variant = variant_state.variant.unwrap();
61        if !variant.fields.is_empty() {
62            panic!("Only enums with no fields can derive({trait_name})")
63        }
64
65        variants_caseinsensitive
66            .entry(variant.ident.to_string().to_lowercase())
67            .or_insert_with(Vec::new)
68            .push(variant.ident.clone());
69    }
70
71    let input_type = &input.ident;
72    let input_type_name = input_type.to_string();
73
74    let mut cases = vec![];
75
76    // if a case insensitive match is unique match do that
77    // otherwise do a case sensitive match
78    for (ref canonical, ref variants) in variants_caseinsensitive {
79        if variants.len() == 1 {
80            let variant = &variants[0];
81            cases.push(quote! {
82                #canonical => #input_type::#variant,
83            })
84        } else {
85            for variant in variants {
86                let variant_str = variant.to_string();
87                cases.push(quote! {
88                    #canonical if(src == #variant_str) => #input_type::#variant,
89                })
90            }
91        }
92    }
93
94    let trait_path = state.trait_path;
95
96    quote! {
97        impl #trait_path for #input_type {
98            type Err = derive_more::FromStrError;
99
100            #[inline]
101            fn from_str(src: &str) -> derive_more::core::result::Result<Self, derive_more::FromStrError> {
102                Ok(match src.to_lowercase().as_str() {
103                    #(#cases)*
104                    _ => return Err(derive_more::FromStrError::new(#input_type_name)),
105                })
106            }
107        }
108    }
109}
110
111fn panic_one_field(trait_name: &str) -> ! {
112    panic!("Only structs with one field can derive({trait_name})")
113}