diff --git a/enum-display-macro/Cargo.toml b/enum-display-macro/Cargo.toml index a2bf804..6e7fe89 100644 --- a/enum-display-macro/Cargo.toml +++ b/enum-display-macro/Cargo.toml @@ -12,9 +12,11 @@ repository = "https://github.com/SeedyROM/enum-display/tree/main/enum-display-de # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +proc-macro2 = "1.0.95" convert_case = "0.6.0" quote = "1.0.21" syn = { version = "1.0.101", features = ["full"] } +regex = "1.11.1" [lib] proc-macro = true diff --git a/enum-display-macro/src/lib.rs b/enum-display-macro/src/lib.rs index 5dc0c7b..2c7b337 100644 --- a/enum-display-macro/src/lib.rs +++ b/enum-display-macro/src/lib.rs @@ -1,86 +1,275 @@ use convert_case::{Case, Casing}; use proc_macro::{self, TokenStream}; +use proc_macro2::Span; use quote::quote; -use syn::{parse_macro_input, DeriveInput}; - -fn parse_case_name(case_name: &str) -> Case { - match case_name { - "Upper" => Case::Upper, - "Lower" => Case::Lower, - "Title" => Case::Title, - "Toggle" => Case::Toggle, - "Camel" => Case::Camel, - "Pascal" => Case::Pascal, - "UpperCamel" => Case::UpperCamel, - "Snake" => Case::Snake, - "UpperSnake" => Case::UpperSnake, - "ScreamingSnake" => Case::ScreamingSnake, - "Kebab" => Case::Kebab, - "Cobol" => Case::Cobol, - "UpperKebab" => Case::UpperKebab, - "Train" => Case::Train, - "Flat" => Case::Flat, - "UpperFlat" => Case::UpperFlat, - "Alternating" => Case::Alternating, - _ => panic!("Unrecognized case name: {}", case_name), - } +use regex::Regex; +use syn::{parse_macro_input, Attribute, DeriveInput, FieldsNamed, FieldsUnnamed, Ident, Variant}; + +// Enum attributes +struct EnumAttrs { + case_transform: Option, } -#[proc_macro_derive(EnumDisplay, attributes(enum_display))] -pub fn derive(input: TokenStream) -> TokenStream { - // Parse the input tokens into a syntax tree - let DeriveInput { - ident, data, attrs, .. - } = parse_macro_input!(input); +impl EnumAttrs { + fn from_attrs(attrs: Vec) -> Self { + let mut case_transform: Option = None; - // Should we transform the case of the enum variants? - let mut case_transform: Option = None; - - // Find the enum_display attribute - for attr in attrs.into_iter() { - if attr.path.is_ident("enum_display") { - let meta = attr.parse_meta().unwrap(); - if let syn::Meta::List(list) = meta { - for nested in list.nested { - if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested { - if name_value.path.is_ident("case") { - if let syn::Lit::Str(lit_str) = name_value.lit { - // Set the case transform - case_transform = Some(parse_case_name(lit_str.value().as_str())); + for attr in attrs.into_iter() { + if attr.path.is_ident("enum_display") { + let meta = attr.parse_meta().unwrap(); + if let syn::Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("case") { + if let syn::Lit::Str(lit_str) = name_value.lit { + case_transform = + Some(Self::parse_case_name(lit_str.value().as_str())); + } } } } } } } + + Self { case_transform } } - // Build the match arms - let variants = match data { - syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, - _ => panic!("EnumDisplay can only be derived for enums"), + fn parse_case_name(case_name: &str) -> Case { + match case_name { + "Upper" => Case::Upper, + "Lower" => Case::Lower, + "Title" => Case::Title, + "Toggle" => Case::Toggle, + "Camel" => Case::Camel, + "Pascal" => Case::Pascal, + "UpperCamel" => Case::UpperCamel, + "Snake" => Case::Snake, + "UpperSnake" => Case::UpperSnake, + "ScreamingSnake" => Case::ScreamingSnake, + "Kebab" => Case::Kebab, + "Cobol" => Case::Cobol, + "UpperKebab" => Case::UpperKebab, + "Train" => Case::Train, + "Flat" => Case::Flat, + "UpperFlat" => Case::UpperFlat, + "Alternating" => Case::Alternating, + _ => panic!("Unrecognized case name: {}", case_name), + } } - .into_iter() - .map(|variant| { - let ident = variant.ident; - let ident_str = if case_transform.is_some() { - ident.to_string().to_case(case_transform.unwrap()) + + fn transform_case(&self, ident: String) -> String { + if let Some(case) = self.case_transform { + ident.to_case(case) } else { - ident.to_string() - }; + ident + } + } +} + +// Variant attributes +struct VariantAttrs { + format: Option, +} + +impl VariantAttrs { + fn from_attrs(attrs: Vec) -> Self { + let mut format = None; + + // Find the variant_display attribute + for attr in attrs.into_iter() { + if attr.path.is_ident("variant_display") { + let meta = attr.parse_meta().unwrap(); + if let syn::Meta::List(list) = meta { + for nested in list.nested { + if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested { + if name_value.path.is_ident("format") { + if let syn::Lit::Str(lit_str) = name_value.lit { + format = Some(Self::translate_numeric_placeholders(&lit_str.value())); + } + } + } + } + } + } + } + + Self { format } + } + + // Translates {123:?} to {_unnamed_123:?} for safer format arg usage + fn translate_numeric_placeholders(fmt: &str) -> String { + let re = Regex::new(r"\{\s*(\d+)\s*([^}]*)\}").unwrap(); + re.replace_all(fmt, |caps: ®ex::Captures| { + let idx = &caps[1]; + let fmt_spec = &caps[2]; + format!("{{_unnamed_{}{} }}", idx, fmt_spec) + }) + .to_string() + } +} + +// Shared intermediate variant info +struct VariantInfo { + ident: Ident, + ident_transformed: String, + attrs: VariantAttrs, +} + +// Intermediate Named variant info +struct NamedVariantIR { + info: VariantInfo, + fields: Vec, +} + +impl NamedVariantIR { + fn from_fields_named(fields_named: FieldsNamed, info: VariantInfo) -> Self { + let fields = fields_named + .named + .into_iter() + .filter_map(|field| field.ident) + .collect(); + Self { info, fields } + } + + fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { ident, ident_transformed, attrs } = self.info; + let fields = self.fields; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => quote! { #ident { #(#fields),* } => { let variant = #ident_transformed; format!(#fmt) } }, + (true, None) => quote! { #ident { .. } => String::from(#ident_transformed), }, + (false, None) => quote! { #ident { .. } => #ident_transformed, }, + _ => unreachable!("`any_has_format` should never be false when a variant has format string"), + } + } +} +// Intermediate Unnamed variant info +struct UnnamedVariantIR { + info: VariantInfo, + fields: Vec, +} + +impl UnnamedVariantIR { + fn from_fields_unnamed(fields_unnamed: FieldsUnnamed, info: VariantInfo) -> Self { + let fields: Vec = fields_unnamed + .unnamed + .into_iter() + .enumerate() + .map(|(i, _)| Ident::new(format!("_unnamed_{i}").as_str(), Span::call_site())) + .collect(); + Self { info, fields } + } + + fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { ident, ident_transformed, attrs } = self.info; + let fields = self.fields; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => quote! { #ident(#(#fields),*) => { let variant = #ident_transformed; format!(#fmt) } }, + (true, None) => quote! { #ident(..) => String::from(#ident_transformed), }, + (false, None) => quote! { #ident(..) => #ident_transformed, }, + _ => unreachable!("`any_has_format` should never be false when a variant has format string"), + } + } +} + +// Intermediate Unit variant info +struct UnitVariantIR { + info: VariantInfo, +} + +impl UnitVariantIR { + fn new(info: VariantInfo) -> Self { + Self { info } + } + + fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream { + let VariantInfo { ident, ident_transformed, attrs } = self.info; + match (any_has_format, attrs.format) { + (true, Some(fmt)) => quote! { #ident => { let variant = #ident_transformed; format!(#fmt) } }, + (true, None) => quote! { #ident => String::from(#ident_transformed), }, + (false, None) => quote! { #ident => #ident_transformed, }, + _ => unreachable!("`any_has_format` should never be false when a variant has format string"), + } + } +} + +// Intermediate version of Variant +enum VariantIR { + Named(NamedVariantIR), + Unnamed(UnnamedVariantIR), + Unit(UnitVariantIR), +} + +impl VariantIR { + fn from_variant(variant: Variant, enum_attrs: &EnumAttrs) -> Self { + let ident_str = variant.ident.to_string(); + let info = VariantInfo { + ident: variant.ident, + ident_transformed: enum_attrs.transform_case(ident_str), + attrs: VariantAttrs::from_attrs(variant.attrs), + }; match variant.fields { - syn::Fields::Named(_) => quote! { - #ident { .. } => #ident_str, + syn::Fields::Named(fields_named) => { + Self::Named(NamedVariantIR::from_fields_named(fields_named, info)) }, - syn::Fields::Unnamed(_) => quote! { - #ident(..) => #ident_str, - }, - syn::Fields::Unit => quote! { - #ident => #ident_str, + syn::Fields::Unnamed(fields_unnamed) => { + Self::Unnamed(UnnamedVariantIR::from_fields_unnamed(fields_unnamed, info)) }, + syn::Fields::Unit => Self::Unit(UnitVariantIR::new(info)), + } + } + + fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream { + match self { + VariantIR::Named(named_variant) => named_variant.gen(any_has_format), + VariantIR::Unnamed(unnamed_variant) => unnamed_variant.gen(any_has_format), + VariantIR::Unit(unit_variant) => unit_variant.gen(any_has_format), } - }); + } + + fn has_format(&self) -> bool { + match self { + VariantIR::Named(named_variant) => &named_variant.info, + VariantIR::Unnamed(unnamed_variant) => &unnamed_variant.info, + VariantIR::Unit(unit_variant) => &unit_variant.info, + }.attrs.format.is_some() + } +} + +#[proc_macro_derive(EnumDisplay, attributes(enum_display, variant_display))] +pub fn derive(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let DeriveInput { + ident, + data, + attrs, + generics, + .. + } = parse_macro_input!(input); + + // Copy generics and bounds + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // Read enum attrs + let enum_attrs = EnumAttrs::from_attrs(attrs); + + // Read variants and variant attrs into an intermediate format + let intermediate_variants: Vec = match data { + syn::Data::Enum(syn::DataEnum { variants, .. }) => variants, + _ => panic!("EnumDisplay can only be derived for enums"), + } + .into_iter() + .map(|variant| VariantIR::from_variant(variant, &enum_attrs)) + .collect(); + + // If any variants have a format string, the output of all match arms must be String instead of &str + // This is because we can't return a reference to the temporary output of format!() + let any_has_format = intermediate_variants.iter().any(|v| v.has_format()); + let post_fix = if any_has_format { quote!{ .as_str() } } else { quote! { } }; + + // Build the match arms + let variants = intermediate_variants.into_iter().map(|v| v.gen(any_has_format)); // #[allow(unused_qualifications)] is needed // due to https://github.com/SeedyROM/enum-display/issues/1 @@ -88,13 +277,13 @@ pub fn derive(input: TokenStream) -> TokenStream { let output = quote! { #[automatically_derived] #[allow(unused_qualifications)] - impl ::core::fmt::Display for #ident { + impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause { fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { ::core::fmt::Formatter::write_str( f, match self { - #(#ident::#variants)* - }, + #(Self::#variants)* + }#post_fix ) } } diff --git a/src/lib.rs b/src/lib.rs index 2014389..8340cdd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,13 +44,40 @@ mod tests { #[derive(EnumDisplay)] enum TestEnum { Name, + + #[variant_display(format = "Unit: {variant}")] + NameFullFormat, + Address { street: String, city: String, state: String, zip: String, }, + + #[variant_display(format = "Named: {variant} {{{street}, {zip}}}")] + AddressPartialFormat { + street: String, + city: String, + state: String, + zip: String, + }, + + #[variant_display(format = "Named: {variant} {{{street}, {city}, {state}, {zip}}}")] + AddressFullFormat { + street: String, + city: String, + state: String, + zip: String, + }, + DateOfBirth(u32, u32, u32), + + #[variant_display(format = "Unnamed: {variant}({2})")] + DateOfBirthPartialFormat(u32, u32, u32), + + #[variant_display(format = "Unnamed: {variant}({0}, {1}, {2})")] + DateOfBirthFullFormat(u32, u32, u32), } #[allow(dead_code)] @@ -67,9 +94,23 @@ mod tests { DateOfBirth(u32, u32, u32), } + #[allow(dead_code)] + #[derive(EnumDisplay)] + enum TestEnumWithGenerics<'a, T: Clone> where T: std::fmt::Display { + Name, + Address { + street: &'a T, + city: &'a T, + state: &'a T, + zip: &'a T, + }, + DateOfBirth(u32, u32, u32), + } + #[test] fn test_unit_field_variant() { assert_eq!(TestEnum::Name.to_string(), "Name"); + assert_eq!(TestEnum::NameFullFormat.to_string(), "Unit: NameFullFormat"); } #[test] @@ -84,11 +125,33 @@ mod tests { .to_string(), "Address" ); + assert_eq!( + TestEnum::AddressPartialFormat { + street: "123 Main St".to_string(), + city: "Any Town".to_string(), + state: "CA".to_string(), + zip: "12345".to_string() + } + .to_string(), + "Named: AddressPartialFormat {123 Main St, 12345}" + ); + assert_eq!( + TestEnum::AddressFullFormat { + street: "123 Main St".to_string(), + city: "Any Town".to_string(), + state: "CA".to_string(), + zip: "12345".to_string() + } + .to_string(), + "Named: AddressFullFormat {123 Main St, Any Town, CA, 12345}" + ); } #[test] fn test_unnamed_fields_variant() { - assert_eq!(TestEnum::DateOfBirth(1, 1, 2000).to_string(), "DateOfBirth"); + assert_eq!(TestEnum::DateOfBirth(1, 2, 1999).to_string(), "DateOfBirth"); + assert_eq!(TestEnum::DateOfBirthPartialFormat(1, 2, 1999).to_string(), "Unnamed: DateOfBirthPartialFormat(1999)"); + assert_eq!(TestEnum::DateOfBirthFullFormat(1, 2, 1999).to_string(), "Unnamed: DateOfBirthFullFormat(1, 2, 1999)"); } #[test] @@ -117,4 +180,28 @@ mod tests { "date-of-birth" ); } + + #[test] + fn test_unit_field_variant_with_generics() { + assert_eq!(TestEnumWithGenerics::<'_, String>::Name.to_string(), "Name"); + } + + #[test] + fn test_named_fields_variant_with_generics() { + assert_eq!( + TestEnumWithGenerics::Address { + street: &"123 Main St".to_string(), + city: &"Any Town".to_string(), + state: &"CA".to_string(), + zip: &"12345".to_string() + } + .to_string(), + "Address" + ); + } + + #[test] + fn test_unnamed_fields_variant_with_generics() { + assert_eq!(TestEnumWithGenerics::<'_, String>::DateOfBirth(1, 1, 2000).to_string(), "DateOfBirth"); + } }