Skip to content

Implement autodiff using intrinsics #142640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ mod llvm_enzyme {
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
.count() as u32;
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
let d_body = gen_enzyme_body(

// UNUSED
let _d_body = gen_enzyme_body(
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
&generics,
);
Expand All @@ -342,7 +344,7 @@ mod llvm_enzyme {
ident: first_ident(&meta_item_vec[0]),
generics,
contract: None,
body: Some(d_body),
body: None, // This leads to an error when the ad function is inside a traits
define_opaque: None,
});
let mut rustc_ad_attr =
Expand Down Expand Up @@ -429,12 +431,18 @@ mod llvm_enzyme {
tokens: ts,
});

let rustc_intrinsic_attr =
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic)));
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span);

let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
let d_annotatable = match &item {
Annotatable::AssocItem(_, _) => {
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
let d_fn = P(ast::AssocItem {
attrs: thin_vec![d_attr, inline_never],
attrs: thin_vec![d_attr, intrinsic_attr],
id: ast::DUMMY_NODE_ID,
span,
vis,
Expand All @@ -444,13 +452,15 @@ mod llvm_enzyme {
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
}
Annotatable::Item(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn =
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Item(d_fn)
}
Annotatable::Stmt(_) => {
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
let mut d_fn =
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
d_fn.vis = vis;

Annotatable::Stmt(P(ast::Stmt {
Expand Down
18 changes: 18 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,24 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
Some(instance),
)
}
_ if tcx.has_attr(def_id, sym::rustc_autodiff) => {
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
// so we can focus on the lowering of the intrinsic call

// `diff_items` is empty even when autodiff is enabled, and if we're here,
// it's because some function was marked as intrinsic and had the `rustc_autodiff` attr
let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items;

// this shouldn't happen?
if diff_items.is_empty() {
bug!("no autodiff items found for {def_id:?}");
}

// TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs`

// Just gen the fallback body for now
return Err(ty::Instance::new_raw(def_id, instance.args));
}
Comment on lines +200 to +217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the part you got stuck on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda. I tried to use the collect_and_partition_mono_items just as a hacky way of getting autodiff_items, so i can focus on the declaration. Once that was working, I would focus on how to get that information in the best way possible.

Copy link
Contributor Author

@Sa4dUs Sa4dUs Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for context, during yesterday's meeting, we came to the conclusion that, for the time being, it's acceptable to copy and paste the collector logic from within the provider's code, just to focus on the declaration, but still nice to have feedback since, in some moment, we'll need to get to AutoDiffItems in a decent way. collect_and_partition_mono_items(()).autodiff_items being [] in autodiff code seems weird, so it would be nice to know why it's happening, or if it is some kind of bug.

sym::is_val_statically_known => {
let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx);
let kind = self.type_kind(intrinsic_type);
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ pub(crate) fn check_intrinsic_type(
};
let name_str = intrinsic_name.as_str();

let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff);

let bound_vars = tcx.mk_bound_variable_kinds(&[
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon),
Expand Down Expand Up @@ -229,6 +231,17 @@ pub(crate) fn check_intrinsic_type(
//
// so: two type params, 0 lifetime param, 0 const params, two inputs, no return
(2, 0, 0, vec![param(0), param(1)], param(1), hir::Safety::Safe)
} else if has_autodiff {
let sig = tcx.fn_sig(intrinsic_id.to_def_id());
let sig = sig.skip_binder();
let n_tps = generics.own_counts().types;
let n_lts = generics.own_counts().lifetimes;
let n_cts = generics.own_counts().consts;

let inputs = sig.skip_binder().inputs().to_vec();
let output = sig.skip_binder().output();

(n_tps, n_lts, n_cts, inputs, output, hir::Safety::Safe)
} else {
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
Expand Down
136 changes: 33 additions & 103 deletions tests/pretty/autodiff/autodiff_forward.pp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//@ needs-enzyme

#![feature(autodiff)]
#![feature(intrinsics)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
Expand Down Expand Up @@ -36,163 +37,92 @@
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
#[inline(never)]
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f64, f64)>::default())
}
#[rustc_intrinsic]
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64);
#[rustc_autodiff]
#[inline(never)]
pub fn f2(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f2(x, y))
}
#[rustc_intrinsic]
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64;
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f3(x, y))
}
#[rustc_intrinsic]
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64;
#[rustc_autodiff]
#[inline(never)]
pub fn f4() {}
#[rustc_autodiff(Forward, 1, None)]
#[inline(never)]
pub fn df4() -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4());
::core::hint::black_box(());
}
#[rustc_intrinsic]
pub fn df4() -> ();
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
#[inline(never)]
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((by_0,));
::core::hint::black_box(f5(x, y))
}
#[rustc_intrinsic]
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64;
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
#[inline(never)]
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((bx_0,));
::core::hint::black_box(f5(x, y))
}
#[rustc_intrinsic]
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64;
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f5(x, y))
}
#[rustc_intrinsic]
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
struct DoesNotImplDefault;
#[rustc_autodiff]
#[inline(never)]
pub fn f6() -> DoesNotImplDefault {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Forward, 1, Const)]
#[inline(never)]
pub fn df6() -> DoesNotImplDefault {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f6());
::core::hint::black_box(());
::core::hint::black_box(f6())
}
#[rustc_intrinsic]
pub fn df6() -> DoesNotImplDefault;
#[rustc_autodiff]
#[inline(never)]
pub fn f7(x: f32) -> () {}
#[rustc_autodiff(Forward, 1, Const, None)]
#[inline(never)]
pub fn df7(x: f32) -> () {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f7(x));
::core::hint::black_box(());
}
#[rustc_intrinsic]
pub fn df7(x: f32) -> ();
#[no_mangle]
#[rustc_autodiff]
#[inline(never)]
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Forward, 4, Dual, Dual)]
#[inline(never)]
#[rustc_intrinsic]
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 5usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 5usize]>::default())
}
-> [f32; 5usize];
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
#[inline(never)]
#[rustc_intrinsic]
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
-> [f32; 4usize] {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
::core::hint::black_box(<[f32; 4usize]>::default())
}
-> [f32; 4usize];
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f8(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
}
#[rustc_intrinsic]
fn f8_1(x: &f32, bx_0: &f32) -> f32;
pub fn f9() {
#[rustc_autodiff]
#[inline(never)]
fn inner(x: f32) -> f32 { x * x }
#[rustc_autodiff(Forward, 1, Dual, Dual)]
#[inline(never)]
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<(f32, f32)>::default())
}
#[rustc_intrinsic]
fn d_inner_2(x: f32, bx_0: f32)
-> (f32, f32);
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
#[inline(never)]
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(inner(x));
::core::hint::black_box((bx_0,));
::core::hint::black_box(<f32>::default())
}
#[rustc_intrinsic]
fn d_inner_1(x: f32, bx_0: f32)
-> f32;
}
#[rustc_autodiff]
#[inline(never)]
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
#[rustc_intrinsic]
pub fn d_square<T: std::ops::Mul<Output = T> +
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f10::<T>(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f10::<T>(x))
}
Copy>(x: &T, dx_0: &mut T, dret: T) -> T;
fn main() {}
1 change: 1 addition & 0 deletions tests/pretty/autodiff/autodiff_forward.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//@ needs-enzyme

#![feature(autodiff)]
#![feature(intrinsics)]
//@ pretty-mode:expanded
//@ pretty-compare-only
//@ pp-exact:autodiff_forward.pp
Expand Down
43 changes: 11 additions & 32 deletions tests/pretty/autodiff/autodiff_reverse.pp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//@ needs-enzyme

#![feature(autodiff)]
#![feature(intrinsics)]
#[prelude_import]
use ::std::prelude::rust_2015::*;
#[macro_use]
Expand All @@ -29,58 +30,36 @@
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f1(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f1(x, y))
}
#[rustc_intrinsic]
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
#[rustc_autodiff]
#[inline(never)]
pub fn f2() {}
#[rustc_autodiff(Reverse, 1, None)]
#[inline(never)]
pub fn df2() {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f2());
::core::hint::black_box(());
}
#[rustc_intrinsic]
pub fn df2();
#[rustc_autodiff]
#[inline(never)]
pub fn f3(x: &[f64], y: f64) -> f64 {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
#[inline(never)]
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f3(x, y));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(f3(x, y))
}
#[rustc_intrinsic]
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
enum Foo { Reverse, }
use Foo::Reverse;
#[rustc_autodiff]
#[inline(never)]
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
#[rustc_autodiff(Reverse, 1, Const, None)]
#[inline(never)]
pub fn df4(x: f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f4(x));
::core::hint::black_box(());
}
#[rustc_intrinsic]
pub fn df4(x: f32);
#[rustc_autodiff]
#[inline(never)]
pub fn f5(x: *const f32, y: &f32) {
::core::panicking::panic("not implemented")
}
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
#[inline(never)]
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
unsafe { asm!("NOP", options(pure, nomem)); };
::core::hint::black_box(f5(x, y));
::core::hint::black_box((dx_0, dy_0));
}
#[rustc_intrinsic]
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32);
fn main() {}
Loading