@@ -21,7 +21,7 @@ mod llvm_enzyme {
21
21
MetaItemInner , PatKind , Path , PathSegment , TyKind , Visibility ,
22
22
} ;
23
23
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
24
- use rustc_span:: { Ident , Span , Symbol , kw , sym} ;
24
+ use rustc_span:: { Ident , Span , Symbol , sym} ;
25
25
use thin_vec:: { ThinVec , thin_vec} ;
26
26
use tracing:: { debug, trace} ;
27
27
@@ -183,11 +183,8 @@ mod llvm_enzyme {
183
183
}
184
184
185
185
/// We expand the autodiff macro to generate a new placeholder function which passes
186
- /// type-checking and can be called by users. The function body of the placeholder function will
187
- /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
188
- /// should just prevent early inlining and optimizations which alter the function signature.
189
- /// The exact signature of the generated function depends on the configuration provided by the
190
- /// user, but here is an example:
186
+ /// type-checking and can be called by users. The exact signature of the generated function
187
+ /// depends on the configuration provided by the user, but here is an example:
191
188
///
192
189
/// ```
193
190
/// #[autodiff(cos_box, Reverse, Duplicated, Active)]
@@ -203,14 +200,8 @@ mod llvm_enzyme {
203
200
/// f32::sin(**x)
204
201
/// }
205
202
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
206
- /// #[inline(never)]
207
203
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
208
- /// unsafe {
209
- /// asm!("NOP");
210
- /// };
211
- /// ::core::hint::black_box(sin(x));
212
- /// ::core::hint::black_box((dx, dret));
213
- /// ::core::hint::black_box(sin(x))
204
+ /// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret))
214
205
/// }
215
206
/// ```
216
207
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -330,22 +321,20 @@ mod llvm_enzyme {
330
321
}
331
322
let span = ecx. with_def_site_ctxt ( expand_span) ;
332
323
333
- let ( d_sig, idents , errored ) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
324
+ let d_sig = gen_enzyme_decl ( ecx, & sig, & x, span) ;
334
325
335
326
let d_body = gen_enzyme_body (
336
327
ecx,
337
328
& d_sig,
338
329
primal,
339
330
span,
340
- idents,
341
- errored,
342
331
first_ident ( & meta_item_vec[ 0 ] ) ,
343
332
& generics,
344
333
impl_of_trait,
345
334
) ;
346
335
347
336
// The first element of it is the name of the function to be generated
348
- let asdf = Box :: new ( ast:: Fn {
337
+ let d_fn = Box :: new ( ast:: Fn {
349
338
defaultness : ast:: Defaultness :: Final ,
350
339
sig : d_sig,
351
340
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
@@ -442,7 +431,7 @@ mod llvm_enzyme {
442
431
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
443
432
let d_annotatable = match & item {
444
433
Annotatable :: AssocItem ( _, _) => {
445
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf ) ;
434
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( d_fn ) ;
446
435
let d_fn = P ( ast:: AssocItem {
447
436
attrs : thin_vec ! [ d_attr] ,
448
437
id : ast:: DUMMY_NODE_ID ,
@@ -454,13 +443,13 @@ mod llvm_enzyme {
454
443
Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
455
444
}
456
445
Annotatable :: Item ( _) => {
457
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
446
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
458
447
d_fn. vis = vis;
459
448
460
449
Annotatable :: Item ( d_fn)
461
450
}
462
451
Annotatable :: Stmt ( _) => {
463
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( asdf ) ) ;
452
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr] , ItemKind :: Fn ( d_fn ) ) ;
464
453
d_fn. vis = vis;
465
454
466
455
Annotatable :: Stmt ( P ( ast:: Stmt {
@@ -525,14 +514,8 @@ mod llvm_enzyme {
525
514
. into ( ) ,
526
515
) ;
527
516
528
- let enzyme_path = ecx. path (
529
- span,
530
- vec ! [
531
- Ident :: from_str( "std" ) ,
532
- Ident :: from_str( "intrinsics" ) ,
533
- Ident :: with_dummy_span( sym:: enzyme_autodiff) ,
534
- ] ,
535
- ) ;
517
+ let enzyme_path_idents = ecx. std_path ( & [ sym:: intrinsics, sym:: enzyme_autodiff] ) ;
518
+ let enzyme_path = ecx. path ( span, enzyme_path_idents) ;
536
519
let call_expr = ecx. expr_call (
537
520
span,
538
521
ecx. expr_path ( enzyme_path) ,
@@ -591,25 +574,6 @@ mod llvm_enzyme {
591
574
ecx. expr_path ( path)
592
575
}
593
576
594
- // Will generate a body of the type:
595
- // ```
596
- // primal(args);
597
- // std::intrinsics::enzyme_autodiff(primal, diff, (args))
598
- // }
599
- // ```
600
- fn init_body_helper (
601
- ecx : & ExtCtxt < ' _ > ,
602
- span : Span ,
603
- primal : Ident ,
604
- idents : & [ Ident ] ,
605
- _errored : bool ,
606
- generics : & Generics ,
607
- ) -> P < ast:: Block > {
608
- let _primal_call = gen_primal_call ( ecx, span, primal, idents, generics) ;
609
- let body = ecx. block ( span, ThinVec :: new ( ) ) ;
610
- body
611
- }
612
-
613
577
/// We only want this function to type-check, since we will replace the body
614
578
/// later on llvm level. Using `loop {}` does not cover all return types anymore,
615
579
/// so instead we manually build something that should pass the type checker.
@@ -623,8 +587,6 @@ mod llvm_enzyme {
623
587
d_sig : & ast:: FnSig ,
624
588
primal : Ident ,
625
589
span : Span ,
626
- idents : Vec < Ident > ,
627
- errored : bool ,
628
590
diff_ident : Ident ,
629
591
generics : & Generics ,
630
592
is_impl : bool ,
@@ -633,87 +595,22 @@ mod llvm_enzyme {
633
595
634
596
// Add a call to the primal function to prevent it from being inlined
635
597
// and call `enzyme_autodiff` intrinsic (this also covers the return type)
636
- let mut body = init_body_helper ( ecx, span, primal, & idents, errored, generics) ;
637
-
638
- body. stmts . push ( call_enzyme_autodiff (
639
- ecx,
640
- primal,
641
- diff_ident,
642
- new_decl_span,
643
- d_sig,
644
- generics,
645
- is_impl,
646
- ) ) ;
598
+ let body = ecx. block (
599
+ span,
600
+ thin_vec ! [ call_enzyme_autodiff(
601
+ ecx,
602
+ primal,
603
+ diff_ident,
604
+ new_decl_span,
605
+ d_sig,
606
+ generics,
607
+ is_impl,
608
+ ) ] ,
609
+ ) ;
647
610
648
611
body
649
612
}
650
613
651
- fn gen_primal_call (
652
- ecx : & ExtCtxt < ' _ > ,
653
- span : Span ,
654
- primal : Ident ,
655
- idents : & [ Ident ] ,
656
- generics : & Generics ,
657
- ) -> P < ast:: Expr > {
658
- let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
659
-
660
- if has_self {
661
- let args: ThinVec < _ > =
662
- idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
663
- let self_expr = ecx. expr_self ( span) ;
664
- ecx. expr_method_call ( span, self_expr, primal, args)
665
- } else {
666
- let args: ThinVec < _ > =
667
- idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
668
- let mut primal_path = ecx. path_ident ( span, primal) ;
669
-
670
- let is_generic = !generics. params . is_empty ( ) ;
671
-
672
- match ( is_generic, primal_path. segments . last_mut ( ) ) {
673
- ( true , Some ( function_path) ) => {
674
- let primal_generic_types = generics
675
- . params
676
- . iter ( )
677
- . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
678
-
679
- let generated_generic_types = primal_generic_types
680
- . map ( |type_param| {
681
- let generic_param = TyKind :: Path (
682
- None ,
683
- ast:: Path {
684
- span,
685
- segments : thin_vec ! [ ast:: PathSegment {
686
- ident: type_param. ident,
687
- args: None ,
688
- id: ast:: DUMMY_NODE_ID ,
689
- } ] ,
690
- tokens : None ,
691
- } ,
692
- ) ;
693
-
694
- ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
695
- id : type_param. id ,
696
- span,
697
- kind : generic_param,
698
- tokens : None ,
699
- } ) ) )
700
- } )
701
- . collect ( ) ;
702
-
703
- function_path. args =
704
- Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
705
- span,
706
- args : generated_generic_types,
707
- } ) ) ) ;
708
- }
709
- _ => { }
710
- }
711
-
712
- let primal_call_expr = ecx. expr_path ( primal_path) ;
713
- ecx. expr_call ( span, primal_call_expr, args)
714
- }
715
- }
716
-
717
614
// Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
718
615
// be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
719
616
// Active arguments must be scalars. Their shadow argument is added to the return type (and will be
@@ -730,7 +627,7 @@ mod llvm_enzyme {
730
627
sig : & ast:: FnSig ,
731
628
x : & AutoDiffAttrs ,
732
629
span : Span ,
733
- ) -> ( ast:: FnSig , Vec < Ident > , bool ) {
630
+ ) -> ast:: FnSig {
734
631
let dcx = ecx. sess . dcx ( ) ;
735
632
let has_ret = has_ret ( & sig. decl . output ) ;
736
633
let sig_args = sig. decl . inputs . len ( ) + if has_ret { 1 } else { 0 } ;
@@ -742,7 +639,7 @@ mod llvm_enzyme {
742
639
found : num_activities,
743
640
} ) ;
744
641
// This is not the right signature, but we can continue parsing.
745
- return ( sig. clone ( ) , vec ! [ ] , true ) ;
642
+ return sig. clone ( ) ;
746
643
}
747
644
assert ! ( sig. decl. inputs. len( ) == x. input_activity. len( ) ) ;
748
645
assert ! ( has_ret == x. has_ret_activity( ) ) ;
@@ -785,7 +682,7 @@ mod llvm_enzyme {
785
682
786
683
if errors {
787
684
// This is not the right signature, but we can continue parsing.
788
- return ( sig. clone ( ) , idents , true ) ;
685
+ return sig. clone ( ) ;
789
686
}
790
687
791
688
let unsafe_activities = x
@@ -993,7 +890,7 @@ mod llvm_enzyme {
993
890
}
994
891
let d_sig = FnSig { header : d_header, decl : d_decl, span } ;
995
892
trace ! ( "Generated signature: {:?}" , d_sig) ;
996
- ( d_sig, idents , false )
893
+ d_sig
997
894
}
998
895
}
999
896
0 commit comments