diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index dc3bb8ab52a5d..ecd68e0015851 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -330,7 +330,9 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // UNUSED + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); @@ -342,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: Some(d_body), + body: None, // This leads to an error when the ad function is inside a traits define_opaque: None, }); let mut rustc_ad_attr = @@ -429,12 +431,18 @@ mod llvm_enzyme { tokens: ts, }); + let rustc_intrinsic_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr, intrinsic_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -444,13 +452,15 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index e8629aeebb95a..5813bdf8435e2 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -197,6 +197,24 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { Some(instance), ) } + _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + // NOTE(Sa4dUs): This is a hacky way to get the autodiff items + // so we can focus on the lowering of the intrinsic call + + // `diff_items` is empty even when autodiff is enabled, and if we're here, + // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr + let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + + // this shouldn't happen? + if diff_items.is_empty() { + bug!("no autodiff items found for {def_id:?}"); + } + + // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + + // Just gen the fallback body for now + return Err(ty::Instance::new_raw(def_id, instance.args)); + } sym::is_val_statically_known => { let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx); let kind = self.type_kind(intrinsic_type); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 54bb3ac411304..ea02ca7fec52f 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -174,6 +174,8 @@ pub(crate) fn check_intrinsic_type( }; let name_str = intrinsic_name.as_str(); + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -229,6 +231,17 @@ pub(crate) fn check_intrinsic_type( // // so: two type params, 0 lifetime param, 0 const params, two inputs, no return (2, 0, 0, vec![param(0), param(1)], param(1), hir::Safety::Safe) + } else if has_autodiff { + let sig = tcx.fn_sig(intrinsic_id.to_def_id()); + let sig = sig.skip_binder(); + let n_tps = generics.own_counts().types; + let n_lts = generics.own_counts().lifetimes; + let n_cts = generics.own_counts().consts; + + let inputs = sig.skip_binder().inputs().to_vec(); + let output = sig.skip_binder().output(); + + (n_tps, n_lts, n_cts, inputs, output, hir::Safety::Safe) } else { let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let (n_tps, n_cts, inputs, output) = match intrinsic_name { diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index a2525abc83207..787c2e517492c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -36,78 +37,44 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[inline(never)] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f64, f64)>::default()) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f2(x, y)) -} +#[rustc_intrinsic] +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[inline(never)] -pub fn df4() -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4() -> (); #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[inline(never)] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((by_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -115,84 +82,47 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[inline(never)] -pub fn df6() -> DoesNotImplDefault { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f6()); - ::core::hint::black_box(()); - ::core::hint::black_box(f6()) -} +#[rustc_intrinsic] +pub fn df6() -> DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[inline(never)] -pub fn df7(x: f32) -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f7(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df7(x: f32) -> (); #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[inline(never)] +#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 5usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 5usize]>::default()) -} +-> [f32; 5usize]; #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[inline(never)] +#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 4usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 4usize]>::default()) -} +-> [f32; 4usize]; #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[inline(never)] -fn f8_1(x: &f32, bx_0: &f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) -} +#[rustc_intrinsic] +fn f8_1(x: &f32, bx_0: &f32) -> f32; pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[inline(never)] - fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f32, f32)>::default()) - } + #[rustc_intrinsic] + fn d_inner_2(x: f32, bx_0: f32) + -> (f32, f32); #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[inline(never)] - fn d_inner_1(x: f32, bx_0: f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) - } + #[rustc_intrinsic] + fn d_inner_1(x: f32, bx_0: f32) + -> f32; } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[inline(never)] +#[rustc_intrinsic] pub fn d_square + - Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10::(x)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10::(x)) -} +Copy>(x: &T, dx_0: &mut T, dret: T) -> T; fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e23a1b3e241e9..b003d87dccfa7 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index e67c3443ddef1..6f368c74f1a26 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -29,58 +30,36 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f1(x, y)) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[inline(never)] -pub fn df2() { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df2(); #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[inline(never)] -pub fn df4(x: f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4(x: f32); #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[inline(never)] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dy_0)); -} +#[rustc_intrinsic] +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index d37e5e3eb4cec..fc95ba2e5a63e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp @@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 { unimplemented!() } -enum Foo { Reverse } +enum Foo { + Reverse, +} use Foo::Reverse; // What happens if we already have Reverse in type (enum variant decl) and value (enum variant // constructor) namespace? > It's expected to work normally. diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index d18061b2dbdef..4bc8dac0dc758 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -31,7 +32,7 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[inline(never)] + #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { unsafe { asm!("NOP", options(pure, nomem)); }; ::core::hint::black_box(self.f(x)); diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 11ff209f9d89e..9f00ff5eb02c1 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp