Skip to content

Formatting #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions enum-display-macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ repository = "https://github.com/SeedyROM/enum-display/tree/main/enum-display-de
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
proc-macro2 = "1.0.95"
convert_case = "0.6.0"
quote = "1.0.21"
syn = { version = "1.0.101", features = ["full"] }
regex = "1.11.1"

[lib]
proc-macro = true
319 changes: 254 additions & 65 deletions enum-display-macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,100 +1,289 @@
use convert_case::{Case, Casing};
use proc_macro::{self, TokenStream};
use proc_macro2::Span;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

fn parse_case_name(case_name: &str) -> Case {
match case_name {
"Upper" => Case::Upper,
"Lower" => Case::Lower,
"Title" => Case::Title,
"Toggle" => Case::Toggle,
"Camel" => Case::Camel,
"Pascal" => Case::Pascal,
"UpperCamel" => Case::UpperCamel,
"Snake" => Case::Snake,
"UpperSnake" => Case::UpperSnake,
"ScreamingSnake" => Case::ScreamingSnake,
"Kebab" => Case::Kebab,
"Cobol" => Case::Cobol,
"UpperKebab" => Case::UpperKebab,
"Train" => Case::Train,
"Flat" => Case::Flat,
"UpperFlat" => Case::UpperFlat,
"Alternating" => Case::Alternating,
_ => panic!("Unrecognized case name: {}", case_name),
}
use regex::Regex;
use syn::{parse_macro_input, Attribute, DeriveInput, FieldsNamed, FieldsUnnamed, Ident, Variant};

// Enum attributes
struct EnumAttrs {
case_transform: Option<Case>,
}

#[proc_macro_derive(EnumDisplay, attributes(enum_display))]
pub fn derive(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree
let DeriveInput {
ident, data, attrs, ..
} = parse_macro_input!(input);
impl EnumAttrs {
fn from_attrs(attrs: Vec<Attribute>) -> Self {
let mut case_transform: Option<Case> = None;

// Should we transform the case of the enum variants?
let mut case_transform: Option<Case> = None;

// Find the enum_display attribute
for attr in attrs.into_iter() {
if attr.path.is_ident("enum_display") {
let meta = attr.parse_meta().unwrap();
if let syn::Meta::List(list) = meta {
for nested in list.nested {
if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
if name_value.path.is_ident("case") {
if let syn::Lit::Str(lit_str) = name_value.lit {
// Set the case transform
case_transform = Some(parse_case_name(lit_str.value().as_str()));
for attr in attrs.into_iter() {
if attr.path.is_ident("enum_display") {
let meta = attr.parse_meta().unwrap();
if let syn::Meta::List(list) = meta {
for nested in list.nested {
if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
if name_value.path.is_ident("case") {
if let syn::Lit::Str(lit_str) = name_value.lit {
case_transform =
Some(Self::parse_case_name(lit_str.value().as_str()));
}
}
}
}
}
}
}

Self { case_transform }
}

// Build the match arms
let variants = match data {
syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
_ => panic!("EnumDisplay can only be derived for enums"),
fn parse_case_name(case_name: &str) -> Case {
match case_name {
"Upper" => Case::Upper,
"Lower" => Case::Lower,
"Title" => Case::Title,
"Toggle" => Case::Toggle,
"Camel" => Case::Camel,
"Pascal" => Case::Pascal,
"UpperCamel" => Case::UpperCamel,
"Snake" => Case::Snake,
"UpperSnake" => Case::UpperSnake,
"ScreamingSnake" => Case::ScreamingSnake,
"Kebab" => Case::Kebab,
"Cobol" => Case::Cobol,
"UpperKebab" => Case::UpperKebab,
"Train" => Case::Train,
"Flat" => Case::Flat,
"UpperFlat" => Case::UpperFlat,
"Alternating" => Case::Alternating,
_ => panic!("Unrecognized case name: {}", case_name),
}
}
.into_iter()
.map(|variant| {
let ident = variant.ident;
let ident_str = if case_transform.is_some() {
ident.to_string().to_case(case_transform.unwrap())

fn transform_case(&self, ident: String) -> String {
if let Some(case) = self.case_transform {
ident.to_case(case)
} else {
ident.to_string()
};
ident
}
}
}

// Variant attributes
struct VariantAttrs {
format: Option<String>,
}

impl VariantAttrs {
fn from_attrs(attrs: Vec<Attribute>) -> Self {
let mut format = None;

// Find the variant_display attribute
for attr in attrs.into_iter() {
if attr.path.is_ident("variant_display") {
let meta = attr.parse_meta().unwrap();
if let syn::Meta::List(list) = meta {
for nested in list.nested {
if let syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested {
if name_value.path.is_ident("format") {
if let syn::Lit::Str(lit_str) = name_value.lit {
format = Some(Self::translate_numeric_placeholders(&lit_str.value()));
}
}
}
}
}
}
}

Self { format }
}

// Translates {123:?} to {_unnamed_123:?} for safer format arg usage
fn translate_numeric_placeholders(fmt: &str) -> String {
let re = Regex::new(r"\{\s*(\d+)\s*([^}]*)\}").unwrap();
re.replace_all(fmt, |caps: &regex::Captures| {
let idx = &caps[1];
let fmt_spec = &caps[2];
format!("{{_unnamed_{}{} }}", idx, fmt_spec)
})
.to_string()
}
}

// Shared intermediate variant info
struct VariantInfo {
ident: Ident,
ident_transformed: String,
attrs: VariantAttrs,
}

// Intermediate Named variant info
struct NamedVariantIR {
info: VariantInfo,
fields: Vec<Ident>,
}

impl NamedVariantIR {
fn from_fields_named(fields_named: FieldsNamed, info: VariantInfo) -> Self {
let fields = fields_named
.named
.into_iter()
.filter_map(|field| field.ident)
.collect();
Self { info, fields }
}

fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream {
let VariantInfo { ident, ident_transformed, attrs } = self.info;
let fields = self.fields;
match (any_has_format, attrs.format) {
(true, Some(fmt)) => quote! { #ident { #(#fields),* } => { let variant = #ident_transformed; format!(#fmt) } },
(true, None) => quote! { #ident { .. } => String::from(#ident_transformed), },
(false, None) => quote! { #ident { .. } => #ident_transformed, },
_ => unreachable!("`any_has_format` should never be false when a variant has format string"),
}
}
}

// Intermediate Unnamed variant info
struct UnnamedVariantIR {
info: VariantInfo,
fields: Vec<Ident>,
}

impl UnnamedVariantIR {
fn from_fields_unnamed(fields_unnamed: FieldsUnnamed, info: VariantInfo) -> Self {
let fields: Vec<Ident> = fields_unnamed
.unnamed
.into_iter()
.enumerate()
.map(|(i, _)| Ident::new(format!("_unnamed_{i}").as_str(), Span::call_site()))
.collect();
Self { info, fields }
}

fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream {
let VariantInfo { ident, ident_transformed, attrs } = self.info;
let fields = self.fields;
match (any_has_format, attrs.format) {
(true, Some(fmt)) => quote! { #ident(#(#fields),*) => { let variant = #ident_transformed; format!(#fmt) } },
(true, None) => quote! { #ident(..) => String::from(#ident_transformed), },
(false, None) => quote! { #ident(..) => #ident_transformed, },
_ => unreachable!("`any_has_format` should never be false when a variant has format string"),
}
}
}

// Intermediate Unit variant info
struct UnitVariantIR {
info: VariantInfo,
}

impl UnitVariantIR {
fn new(info: VariantInfo) -> Self {
Self { info }
}

fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream {
let VariantInfo { ident, ident_transformed, attrs } = self.info;
match (any_has_format, attrs.format) {
(true, Some(fmt)) => quote! { #ident => { let variant = #ident_transformed; format!(#fmt) } },
(true, None) => quote! { #ident => String::from(#ident_transformed), },
(false, None) => quote! { #ident => #ident_transformed, },
_ => unreachable!("`any_has_format` should never be false when a variant has format string"),
}
}
}

// Intermediate version of Variant
enum VariantIR {
Named(NamedVariantIR),
Unnamed(UnnamedVariantIR),
Unit(UnitVariantIR),
}

impl VariantIR {
fn from_variant(variant: Variant, enum_attrs: &EnumAttrs) -> Self {
let ident_str = variant.ident.to_string();
let info = VariantInfo {
ident: variant.ident,
ident_transformed: enum_attrs.transform_case(ident_str),
attrs: VariantAttrs::from_attrs(variant.attrs),
};
match variant.fields {
syn::Fields::Named(_) => quote! {
#ident { .. } => #ident_str,
syn::Fields::Named(fields_named) => {
Self::Named(NamedVariantIR::from_fields_named(fields_named, info))
},
syn::Fields::Unnamed(_) => quote! {
#ident(..) => #ident_str,
},
syn::Fields::Unit => quote! {
#ident => #ident_str,
syn::Fields::Unnamed(fields_unnamed) => {
Self::Unnamed(UnnamedVariantIR::from_fields_unnamed(fields_unnamed, info))
},
syn::Fields::Unit => Self::Unit(UnitVariantIR::new(info)),
}
}

fn gen(self, any_has_format: bool) -> proc_macro2::TokenStream {
match self {
VariantIR::Named(named_variant) => named_variant.gen(any_has_format),
VariantIR::Unnamed(unnamed_variant) => unnamed_variant.gen(any_has_format),
VariantIR::Unit(unit_variant) => unit_variant.gen(any_has_format),
}
});
}

fn has_format(&self) -> bool {
match self {
VariantIR::Named(named_variant) => &named_variant.info,
VariantIR::Unnamed(unnamed_variant) => &unnamed_variant.info,
VariantIR::Unit(unit_variant) => &unit_variant.info,
}.attrs.format.is_some()
}
}

#[proc_macro_derive(EnumDisplay, attributes(enum_display, variant_display))]
pub fn derive(input: TokenStream) -> TokenStream {
// Parse the input tokens into a syntax tree
let DeriveInput {
ident,
data,
attrs,
generics,
..
} = parse_macro_input!(input);

// Copy generics and bounds
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

// Read enum attrs
let enum_attrs = EnumAttrs::from_attrs(attrs);

// Read variants and variant attrs into an intermediate format
let intermediate_variants: Vec<VariantIR> = match data {
syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,
_ => panic!("EnumDisplay can only be derived for enums"),
}
.into_iter()
.map(|variant| VariantIR::from_variant(variant, &enum_attrs))
.collect();

// If any variants have a format string, the output of all match arms must be String instead of &str
// This is because we can't return a reference to the temporary output of format!()
let any_has_format = intermediate_variants.iter().any(|v| v.has_format());
let post_fix = if any_has_format { quote!{ .as_str() } } else { quote! { } };

// Build the match arms
let variants = intermediate_variants.into_iter().map(|v| v.gen(any_has_format));

// #[allow(unused_qualifications)] is needed
// due to https://github.com/SeedyROM/enum-display/issues/1
// Possibly related to https://github.com/rust-lang/rust/issues/96698
let output = quote! {
#[automatically_derived]
#[allow(unused_qualifications)]
impl ::core::fmt::Display for #ident {
impl #impl_generics ::core::fmt::Display for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::write_str(
f,
match self {
#(#ident::#variants)*
},
#(Self::#variants)*
}#post_fix
)
}
}
Expand Down
Loading