|
| 1 | +use proc_macro::TokenStream; |
| 2 | +use proc_macro2::TokenStream as TokenStream2; |
| 3 | +use quote::quote; |
| 4 | +use syn::{ |
| 5 | + parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Error, FnArg, Ident, |
| 6 | + ItemFn, Result, ReturnType, Type, Visibility, |
| 7 | +}; |
| 8 | + |
| 9 | +/// Enum representing the supported runtime function attributes |
| 10 | +pub enum Fn { |
| 11 | + PostInit, |
| 12 | +} |
| 13 | + |
| 14 | +impl Fn { |
| 15 | + /// Convenience method to generate the token stream for the `post_init` attribute |
| 16 | + pub fn post_init(args: TokenStream, input: TokenStream) -> TokenStream { |
| 17 | + match Self::PostInit.check_args_empty(args) { |
| 18 | + Ok(_) => Self::PostInit.quote_fn(input), |
| 19 | + Err(e) => e.to_compile_error().into(), |
| 20 | + } |
| 21 | + } |
| 22 | + |
| 23 | + /// Generate the token stream for the function with the given attribute |
| 24 | + fn quote_fn(&self, item: TokenStream) -> TokenStream { |
| 25 | + let mut func = parse_macro_input!(item as ItemFn); |
| 26 | + |
| 27 | + if let Err(e) = self.check_fn(&func) { |
| 28 | + return e.to_compile_error().into(); |
| 29 | + } |
| 30 | + |
| 31 | + let export_name = self.export_name(&func); |
| 32 | + let link_section = self.link_section(&func); |
| 33 | + |
| 34 | + // Append to function name the prefix __riscv_rt_ (to prevent users from calling it directly) |
| 35 | + // Note that we do not change the export name, only the internal function name in the Rust code. |
| 36 | + func.sig.ident = Ident::new( |
| 37 | + &format!("__riscv_rt_{}", func.sig.ident), |
| 38 | + func.sig.ident.span(), |
| 39 | + ); |
| 40 | + |
| 41 | + quote! { |
| 42 | + #export_name |
| 43 | + #link_section |
| 44 | + #func |
| 45 | + } |
| 46 | + .into() |
| 47 | + } |
| 48 | + |
| 49 | + /// Check if the function signature is valid for the given attribute |
| 50 | + fn check_fn(&self, f: &ItemFn) -> Result<()> { |
| 51 | + // First, check that the function is private |
| 52 | + if f.vis != Visibility::Inherited { |
| 53 | + let attr = self.attr_name(); |
| 54 | + return Err(Error::new( |
| 55 | + f.vis.span(), |
| 56 | + format!("`#[{attr}]` function must be private"), |
| 57 | + )); |
| 58 | + } |
| 59 | + let sig = &f.sig; |
| 60 | + |
| 61 | + // Next, check common aspects of the signature (constness, asyncness, generics, etc.) |
| 62 | + let valid_signature = sig.constness.is_none() |
| 63 | + && sig.asyncness.is_none() |
| 64 | + && sig.abi.is_none() |
| 65 | + && sig.generics.params.is_empty() |
| 66 | + && sig.generics.where_clause.is_none() |
| 67 | + && sig.variadic.is_none(); |
| 68 | + if !valid_signature { |
| 69 | + let attr = self.attr_name(); |
| 70 | + let expected = self.expected_signature(); |
| 71 | + return Err(Error::new( |
| 72 | + sig.span(), |
| 73 | + format!("`#[{attr}]` function signature must be `{expected}`"), |
| 74 | + )); |
| 75 | + } |
| 76 | + |
| 77 | + // Finally, check that input arguments and output type are valid |
| 78 | + self.check_inputs(&sig.inputs)?; |
| 79 | + self.check_output(&sig.output) |
| 80 | + } |
| 81 | + |
| 82 | + /// Utility method for printing attribute name in error messages |
| 83 | + const fn attr_name(&self) -> &'static str { |
| 84 | + // Use this match to specify attribute names for different functions in the future |
| 85 | + match self { |
| 86 | + Self::PostInit => "post_init", |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + /// Utility method for printing expected function signature in error messages |
| 91 | + const fn expected_signature(&self) -> &'static str { |
| 92 | + // Use this match to specify expected signatures for different functions in the future |
| 93 | + match self { |
| 94 | + Self::PostInit => "[unsafe] fn([usize])", |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + /// Check if the function has valid input arguments for the given attribute |
| 99 | + fn check_inputs(&self, inputs: &Punctuated<FnArg, Comma>) -> Result<()> { |
| 100 | + // Use this match to specify expected input arguments for different functions in the future |
| 101 | + match self { |
| 102 | + Self::PostInit => self.check_fn_args(inputs, &["usize"]), |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + /// Check if the function has a valid output type for the given attribute |
| 107 | + fn check_output(&self, output: &ReturnType) -> Result<()> { |
| 108 | + // Use this match to specify expected output types for different functions in the future |
| 109 | + match self { |
| 110 | + Self::PostInit => match output { |
| 111 | + // post_init return type is () |
| 112 | + ReturnType::Default => Ok(()), |
| 113 | + ReturnType::Type(_, ty) => match **ty { |
| 114 | + Type::Tuple(ref tuple) => { |
| 115 | + if tuple.elems.is_empty() { |
| 116 | + Ok(()) |
| 117 | + } else { |
| 118 | + Err(Error::new(tuple.span(), "return type must be ()")) |
| 119 | + } |
| 120 | + } |
| 121 | + _ => Err(Error::new(ty.span(), "return type must be ()")), |
| 122 | + }, |
| 123 | + }, |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + /// The export name for the given attribute |
| 128 | + fn export_name(&self, _f: &ItemFn) -> Option<TokenStream2> { |
| 129 | + // Use this match to specify export names for different functions in the future |
| 130 | + let export_name = match self { |
| 131 | + Self::PostInit => Some("__post_init".to_string()), |
| 132 | + }; |
| 133 | + |
| 134 | + export_name.map(|name| quote! { #[export_name = #name] }) |
| 135 | + } |
| 136 | + |
| 137 | + /// The link section attribute for the given attribute (if any) |
| 138 | + fn link_section(&self, _f: &ItemFn) -> Option<TokenStream2> { |
| 139 | + // Use this match to specify section names for different functions in the future |
| 140 | + let section_name: Option<String> = match self { |
| 141 | + Self::PostInit => None, |
| 142 | + }; |
| 143 | + |
| 144 | + section_name.map(|section| quote! { |
| 145 | + #[cfg_attr(any(target_arch = "riscv32", target_arch = "riscv64"), link_section = #section)] |
| 146 | + }) |
| 147 | + } |
| 148 | + |
| 149 | + /// Check that no arguments were provided to the macro attribute |
| 150 | + fn check_args_empty(&self, args: TokenStream) -> Result<()> { |
| 151 | + if args.is_empty() { |
| 152 | + Ok(()) |
| 153 | + } else { |
| 154 | + let args: TokenStream2 = args.into(); |
| 155 | + let attr = self.attr_name(); |
| 156 | + Err(Error::new( |
| 157 | + args.span(), |
| 158 | + format!("`#[{attr}]` function does not accept any arguments"), |
| 159 | + )) |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + /// Iterates through the input arguments and checks that their types match the expected types |
| 164 | + fn check_fn_args( |
| 165 | + &self, |
| 166 | + inputs: &Punctuated<FnArg, Comma>, |
| 167 | + expected_types: &[&str], |
| 168 | + ) -> Result<()> { |
| 169 | + let mut expected_iter = expected_types.iter(); |
| 170 | + for arg in inputs.iter() { |
| 171 | + match expected_iter.next() { |
| 172 | + Some(expected) => check_arg_type(arg, expected)?, |
| 173 | + None => { |
| 174 | + let attr = self.attr_name(); |
| 175 | + return Err(Error::new( |
| 176 | + arg.span(), |
| 177 | + format!("`#[{attr}]` function has too many input arguments"), |
| 178 | + )); |
| 179 | + } |
| 180 | + } |
| 181 | + } |
| 182 | + Ok(()) |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +/// Check if a function argument matches the expected type |
| 187 | +fn check_arg_type(arg: &FnArg, expected: &str) -> Result<()> { |
| 188 | + match arg { |
| 189 | + FnArg::Typed(argument) => { |
| 190 | + if !is_correct_type(&argument.ty, expected) { |
| 191 | + Err(Error::new( |
| 192 | + argument.ty.span(), |
| 193 | + format!("argument type must be {expected}"), |
| 194 | + )) |
| 195 | + } else { |
| 196 | + Ok(()) |
| 197 | + } |
| 198 | + } |
| 199 | + FnArg::Receiver(_) => Err(Error::new(arg.span(), "invalid argument")), |
| 200 | + } |
| 201 | +} |
| 202 | + |
| 203 | +/// Check if a type matches the expected type name |
| 204 | +fn is_correct_type(ty: &Type, expected: &str) -> bool { |
| 205 | + let correct: Type = syn::parse_str(expected).unwrap(); |
| 206 | + if let Some(ty) = strip_type_path(ty) { |
| 207 | + ty == correct |
| 208 | + } else { |
| 209 | + false |
| 210 | + } |
| 211 | +} |
| 212 | + |
| 213 | +/// Strip the path of a type, returning only the last segment (e.g., `core::usize` -> `usize`) |
| 214 | +fn strip_type_path(ty: &Type) -> Option<Type> { |
| 215 | + match ty { |
| 216 | + Type::Ptr(ty) => { |
| 217 | + let mut ty = ty.clone(); |
| 218 | + *ty.elem = strip_type_path(&ty.elem)?; |
| 219 | + Some(Type::Ptr(ty)) |
| 220 | + } |
| 221 | + Type::Path(ty) => { |
| 222 | + let mut ty = ty.clone(); |
| 223 | + let last_segment = ty.path.segments.last().unwrap().clone(); |
| 224 | + ty.path.segments = Punctuated::new(); |
| 225 | + ty.path.segments.push_value(last_segment); |
| 226 | + Some(Type::Path(ty)) |
| 227 | + } |
| 228 | + _ => None, |
| 229 | + } |
| 230 | +} |
0 commit comments