Skip to content

Commit d651974

Browse files
committed
Macro expansion cleanup
1 parent b728326 commit d651974

File tree

4 files changed

+56
-151
lines changed

4 files changed

+56
-151
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 27 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ mod llvm_enzyme {
2121
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
2222
};
2323
use rustc_expand::base::{Annotatable, ExtCtxt};
24-
use rustc_span::{Ident, Span, Symbol, kw, sym};
24+
use rustc_span::{Ident, Span, Symbol, sym};
2525
use thin_vec::{ThinVec, thin_vec};
2626
use tracing::{debug, trace};
2727

@@ -183,11 +183,8 @@ mod llvm_enzyme {
183183
}
184184

185185
/// We expand the autodiff macro to generate a new placeholder function which passes
186-
/// type-checking and can be called by users. The function body of the placeholder function will
187-
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
188-
/// should just prevent early inlining and optimizations which alter the function signature.
189-
/// The exact signature of the generated function depends on the configuration provided by the
190-
/// user, but here is an example:
186+
/// type-checking and can be called by users. The exact signature of the generated function
187+
/// depends on the configuration provided by the user, but here is an example:
191188
///
192189
/// ```
193190
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -203,14 +200,8 @@ mod llvm_enzyme {
203200
/// f32::sin(**x)
204201
/// }
205202
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
206-
/// #[inline(never)]
207203
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
208-
/// unsafe {
209-
/// asm!("NOP");
210-
/// };
211-
/// ::core::hint::black_box(sin(x));
212-
/// ::core::hint::black_box((dx, dret));
213-
/// ::core::hint::black_box(sin(x))
204+
/// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
214205
/// }
215206
/// ```
216207
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -330,22 +321,20 @@ mod llvm_enzyme {
330321
}
331322
let span = ecx.with_def_site_ctxt(expand_span);
332323

333-
let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
324+
let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
334325

335326
let d_body = gen_enzyme_body(
336327
ecx,
337328
&d_sig,
338329
primal,
339330
span,
340-
idents,
341-
errored,
342331
first_ident(&meta_item_vec[0]),
343332
&generics,
344333
impl_of_trait,
345334
);
346335

347336
// The first element of it is the name of the function to be generated
348-
let asdf = Box::new(ast::Fn {
337+
let d_fn = Box::new(ast::Fn {
349338
defaultness: ast::Defaultness::Final,
350339
sig: d_sig,
351340
ident: first_ident(&meta_item_vec[0]),
@@ -442,7 +431,7 @@ mod llvm_enzyme {
442431
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
443432
let d_annotatable = match &item {
444433
Annotatable::AssocItem(_, _) => {
445-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
434+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
446435
let d_fn = P(ast::AssocItem {
447436
attrs: thin_vec![d_attr],
448437
id: ast::DUMMY_NODE_ID,
@@ -454,13 +443,13 @@ mod llvm_enzyme {
454443
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
455444
}
456445
Annotatable::Item(_) => {
457-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
446+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
458447
d_fn.vis = vis;
459448

460449
Annotatable::Item(d_fn)
461450
}
462451
Annotatable::Stmt(_) => {
463-
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf));
452+
let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn));
464453
d_fn.vis = vis;
465454

466455
Annotatable::Stmt(P(ast::Stmt {
@@ -525,14 +514,8 @@ mod llvm_enzyme {
525514
.into(),
526515
);
527516

528-
let enzyme_path = ecx.path(
529-
span,
530-
vec![
531-
Ident::from_str("std"),
532-
Ident::from_str("intrinsics"),
533-
Ident::with_dummy_span(sym::enzyme_autodiff),
534-
],
535-
);
517+
let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::enzyme_autodiff]);
518+
let enzyme_path = ecx.path(span, enzyme_path_idents);
536519
let call_expr = ecx.expr_call(
537520
span,
538521
ecx.expr_path(enzyme_path),
@@ -591,25 +574,6 @@ mod llvm_enzyme {
591574
ecx.expr_path(path)
592575
}
593576

594-
// Will generate a body of the type:
595-
// ```
596-
// primal(args);
597-
// std::intrinsics::enzyme_autodiff(primal, diff, (args))
598-
// }
599-
// ```
600-
fn init_body_helper(
601-
ecx: &ExtCtxt<'_>,
602-
span: Span,
603-
primal: Ident,
604-
idents: &[Ident],
605-
_errored: bool,
606-
generics: &Generics,
607-
) -> P<ast::Block> {
608-
let _primal_call = gen_primal_call(ecx, span, primal, idents, generics);
609-
let body = ecx.block(span, ThinVec::new());
610-
body
611-
}
612-
613577
/// We only want this function to type-check, since we will replace the body
614578
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
615579
/// so instead we manually build something that should pass the type checker.
@@ -623,8 +587,6 @@ mod llvm_enzyme {
623587
d_sig: &ast::FnSig,
624588
primal: Ident,
625589
span: Span,
626-
idents: Vec<Ident>,
627-
errored: bool,
628590
diff_ident: Ident,
629591
generics: &Generics,
630592
is_impl: bool,
@@ -633,87 +595,22 @@ mod llvm_enzyme {
633595

634596
// Add a call to the primal function to prevent it from being inlined
635597
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
636-
let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics);
637-
638-
body.stmts.push(call_enzyme_autodiff(
639-
ecx,
640-
primal,
641-
diff_ident,
642-
new_decl_span,
643-
d_sig,
644-
generics,
645-
is_impl,
646-
));
598+
let body = ecx.block(
599+
span,
600+
thin_vec![call_enzyme_autodiff(
601+
ecx,
602+
primal,
603+
diff_ident,
604+
new_decl_span,
605+
d_sig,
606+
generics,
607+
is_impl,
608+
)],
609+
);
647610

648611
body
649612
}
650613

651-
fn gen_primal_call(
652-
ecx: &ExtCtxt<'_>,
653-
span: Span,
654-
primal: Ident,
655-
idents: &[Ident],
656-
generics: &Generics,
657-
) -> P<ast::Expr> {
658-
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
659-
660-
if has_self {
661-
let args: ThinVec<_> =
662-
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
663-
let self_expr = ecx.expr_self(span);
664-
ecx.expr_method_call(span, self_expr, primal, args)
665-
} else {
666-
let args: ThinVec<_> =
667-
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
668-
let mut primal_path = ecx.path_ident(span, primal);
669-
670-
let is_generic = !generics.params.is_empty();
671-
672-
match (is_generic, primal_path.segments.last_mut()) {
673-
(true, Some(function_path)) => {
674-
let primal_generic_types = generics
675-
.params
676-
.iter()
677-
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
678-
679-
let generated_generic_types = primal_generic_types
680-
.map(|type_param| {
681-
let generic_param = TyKind::Path(
682-
None,
683-
ast::Path {
684-
span,
685-
segments: thin_vec![ast::PathSegment {
686-
ident: type_param.ident,
687-
args: None,
688-
id: ast::DUMMY_NODE_ID,
689-
}],
690-
tokens: None,
691-
},
692-
);
693-
694-
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
695-
id: type_param.id,
696-
span,
697-
kind: generic_param,
698-
tokens: None,
699-
})))
700-
})
701-
.collect();
702-
703-
function_path.args =
704-
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
705-
span,
706-
args: generated_generic_types,
707-
})));
708-
}
709-
_ => {}
710-
}
711-
712-
let primal_call_expr = ecx.expr_path(primal_path);
713-
ecx.expr_call(span, primal_call_expr, args)
714-
}
715-
}
716-
717614
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
718615
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
719616
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -730,7 +627,7 @@ mod llvm_enzyme {
730627
sig: &ast::FnSig,
731628
x: &AutoDiffAttrs,
732629
span: Span,
733-
) -> (ast::FnSig, Vec<Ident>, bool) {
630+
) -> ast::FnSig {
734631
let dcx = ecx.sess.dcx();
735632
let has_ret = has_ret(&sig.decl.output);
736633
let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
@@ -742,7 +639,7 @@ mod llvm_enzyme {
742639
found: num_activities,
743640
});
744641
// This is not the right signature, but we can continue parsing.
745-
return (sig.clone(), vec![], true);
642+
return sig.clone();
746643
}
747644
assert!(sig.decl.inputs.len() == x.input_activity.len());
748645
assert!(has_ret == x.has_ret_activity());
@@ -785,7 +682,7 @@ mod llvm_enzyme {
785682

786683
if errors {
787684
// This is not the right signature, but we can continue parsing.
788-
return (sig.clone(), idents, true);
685+
return sig.clone();
789686
}
790687

791688
let unsafe_activities = x
@@ -993,7 +890,7 @@ mod llvm_enzyme {
993890
}
994891
let d_sig = FnSig { header: d_header, decl: d_decl, span };
995892
trace!("Generated signature: {:?}", d_sig);
996-
(d_sig, idents, false)
893+
d_sig
997894
}
998895
}
999896

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
}
3939
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
4040
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
41-
std::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y))
41+
::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y))
4242
}
4343
#[rustc_autodiff]
4444
#[inline(never)]
@@ -47,7 +47,7 @@
4747
}
4848
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
4949
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
50-
std::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y))
50+
::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y))
5151
}
5252
#[rustc_autodiff]
5353
#[inline(never)]
@@ -56,29 +56,32 @@
5656
}
5757
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
5858
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
59-
std::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y))
59+
::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y))
6060
}
6161
#[rustc_autodiff]
6262
#[inline(never)]
6363
pub fn f4() {}
6464
#[rustc_autodiff(Forward, 1, None)]
65-
pub fn df4() -> () { std::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ()) }
65+
pub fn df4() -> () {
66+
::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ())
67+
}
6668
#[rustc_autodiff]
6769
#[inline(never)]
6870
pub fn f5(x: &[f64], y: f64) -> f64 {
6971
::core::panicking::panic("not implemented")
7072
}
7173
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
7274
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
73-
std::intrinsics::enzyme_autodiff(f5::<>, df5_y::<>, (x, y, by_0))
75+
::core::intrinsics::enzyme_autodiff(f5::<>, df5_y::<>, (x, y, by_0))
7476
}
7577
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
7678
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
77-
std::intrinsics::enzyme_autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
79+
::core::intrinsics::enzyme_autodiff(f5::<>, df5_x::<>, (x, bx_0, y))
7880
}
7981
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
8082
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
81-
std::intrinsics::enzyme_autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret))
83+
::core::intrinsics::enzyme_autodiff(f5::<>, df5_rev::<>,
84+
(x, dx_0, y, dret))
8285
}
8386
struct DoesNotImplDefault;
8487
#[rustc_autodiff]
@@ -88,14 +91,14 @@
8891
}
8992
#[rustc_autodiff(Forward, 1, Const)]
9093
pub fn df6() -> DoesNotImplDefault {
91-
std::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ())
94+
::core::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ())
9295
}
9396
#[rustc_autodiff]
9497
#[inline(never)]
9598
pub fn f7(x: f32) -> () {}
9699
#[rustc_autodiff(Forward, 1, Const, None)]
97100
pub fn df7(x: f32) -> () {
98-
std::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,))
101+
::core::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,))
99102
}
100103
#[no_mangle]
101104
#[rustc_autodiff]
@@ -104,30 +107,32 @@
104107
#[rustc_autodiff(Forward, 4, Dual, Dual)]
105108
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
106109
-> [f32; 5usize] {
107-
std::intrinsics::enzyme_autodiff(f8::<>, f8_3::<>,
110+
::core::intrinsics::enzyme_autodiff(f8::<>, f8_3::<>,
108111
(x, bx_0, bx_1, bx_2, bx_3))
109112
}
110113
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
111114
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
112115
-> [f32; 4usize] {
113-
std::intrinsics::enzyme_autodiff(f8::<>, f8_2::<>,
116+
::core::intrinsics::enzyme_autodiff(f8::<>, f8_2::<>,
114117
(x, bx_0, bx_1, bx_2, bx_3))
115118
}
116119
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
117120
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
118-
std::intrinsics::enzyme_autodiff(f8::<>, f8_1::<>, (x, bx_0))
121+
::core::intrinsics::enzyme_autodiff(f8::<>, f8_1::<>, (x, bx_0))
119122
}
120123
pub fn f9() {
121124
#[rustc_autodiff]
122125
#[inline(never)]
123126
fn inner(x: f32) -> f32 { x * x }
124127
#[rustc_autodiff(Forward, 1, Dual, Dual)]
125128
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
126-
std::intrinsics::enzyme_autodiff(inner::<>, d_inner_2::<>, (x, bx_0))
129+
::core::intrinsics::enzyme_autodiff(inner::<>, d_inner_2::<>,
130+
(x, bx_0))
127131
}
128132
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
129133
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
130-
std::intrinsics::enzyme_autodiff(inner::<>, d_inner_1::<>, (x, bx_0))
134+
::core::intrinsics::enzyme_autodiff(inner::<>, d_inner_1::<>,
135+
(x, bx_0))
131136
}
132137
}
133138
#[rustc_autodiff]
@@ -136,6 +141,7 @@
136141
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
137142
pub fn d_square<T: std::ops::Mul<Output = T> +
138143
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
139-
std::intrinsics::enzyme_autodiff(f10::<T>, d_square::<T>, (x, dx_0, dret))
144+
::core::intrinsics::enzyme_autodiff(f10::<T>, d_square::<T>,
145+
(x, dx_0, dret))
140146
}
141147
fn main() {}

0 commit comments

Comments
 (0)