-
Notifications
You must be signed in to change notification settings - Fork 13.6k
TypeTree support in autodiff #144197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
TypeTree support in autodiff #144197
Changes from all commits
bc955e6
c84ea22
9d410e3
9e9c72a
cb5cbd7
dd73495
620304c
f52e34d
b045bbd
eb518c9
3c978eb
b376c7b
9dc3e78
d28cfef
c9e130b
9f438c2
461f7c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -599,13 +599,11 @@ pub(crate) fn run_pass_manager( | |
for function in cx.get_functions() { | ||
let enzyme_marker = "enzyme_marker"; | ||
if attributes::has_string_attr(function, enzyme_marker) { | ||
// Sanity check: Ensure 'noinline' is present before replacing it. | ||
assert!( | ||
attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), | ||
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" | ||
); | ||
// Remove 'noinline' if present (it should be there in most cases) | ||
if attributes::has_attr(function, Function, llvm::AttributeKind::NoInline) { | ||
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd probably merge the frontend rework first, then you can drop these lines. |
||
} | ||
|
||
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); | ||
attributes::remove_string_attr_from_llfn(function, enzyme_marker); | ||
|
||
assert!( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
use std::os::raw::{c_char, c_uint}; | ||
use std::ptr; | ||
|
||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; | ||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; | ||
use rustc_codegen_ssa::ModuleCodegen; | ||
use rustc_codegen_ssa::common::TypeKind; | ||
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; | ||
|
@@ -14,7 +16,7 @@ use crate::context::SimpleCx; | |
use crate::declare::declare_simple_fn; | ||
use crate::errors::{AutoDiffWithoutEnable, LlvmError}; | ||
use crate::llvm::AttributePlace::Function; | ||
use crate::llvm::{Metadata, True}; | ||
use crate::llvm::{Metadata, True, TypeTree}; | ||
use crate::value::Value; | ||
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; | ||
|
||
|
@@ -512,3 +514,141 @@ pub(crate) fn differentiate<'ll>( | |
|
||
Ok(()) | ||
} | ||
|
||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format | ||
/// | ||
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) | ||
/// and converts it to Enzyme's internal C++ TypeTree representation that | ||
/// Enzyme can understand during differentiation analysis. | ||
#[cfg(llvm_enzyme)] | ||
fn to_enzyme_typetree( | ||
rust_typetree: RustTypeTree, | ||
data_layout: &str, | ||
llcx: &llvm::Context, | ||
) -> TypeTree { | ||
// Start with an empty TypeTree | ||
let mut enzyme_tt = TypeTree::new(); | ||
|
||
// Convert each Type in the Rust TypeTree to Enzyme format | ||
for rust_type in rust_typetree.0 { | ||
let concrete_type = match rust_type.kind { | ||
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, | ||
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, | ||
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, | ||
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, | ||
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, | ||
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, | ||
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, | ||
}; | ||
|
||
// Create a TypeTree for this specific type | ||
let type_tt = TypeTree::from_type(concrete_type, llcx); | ||
|
||
// Apply offset if specified | ||
let type_tt = if rust_type.offset == -1 { | ||
type_tt // -1 means everywhere/no specific offset | ||
} else { | ||
// Apply specific offset positioning | ||
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) | ||
}; | ||
|
||
// Merge this type into the main TypeTree | ||
enzyme_tt = enzyme_tt.merge(type_tt); | ||
} | ||
|
||
enzyme_tt | ||
} | ||
|
||
#[cfg(not(llvm_enzyme))] | ||
#[allow(dead_code)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need a dead_code? Are you also gating the caller side? Then you shouldn't need this dummy. |
||
fn to_enzyme_typetree( | ||
_rust_typetree: RustTypeTree, | ||
_data_layout: &str, | ||
_llcx: &llvm::Context, | ||
) -> ! { | ||
unimplemented!("TypeTree conversion not available without llvm_enzyme support") | ||
} | ||
|
||
// Attaches TypeTree information to LLVM function as enzyme_type attributes. | ||
#[cfg(llvm_enzyme)] | ||
pub(crate) fn add_tt<'ll>( | ||
llmod: &'ll llvm::Module, | ||
llcx: &'ll llvm::Context, | ||
fn_def: &'ll Value, | ||
tt: FncTree, | ||
) { | ||
let inputs = tt.args; | ||
let ret_tt: RustTypeTree = tt.ret; | ||
|
||
// Get LLVM data layout string for TypeTree conversion | ||
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; | ||
let llvm_data_layout = | ||
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) | ||
.expect("got a non-UTF8 data-layout from LLVM"); | ||
|
||
// Attribute name that Enzyme recognizes for TypeTree information | ||
let attr_name = "enzyme_type"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KMJ-007 Here you mark your attribute as enzyme_type, but in the test below I can't spot any metadata with that name. Are you sure it is getting lowered and not just dropped? Maybe add a few dbg statements to see till where exists. |
||
let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); | ||
|
||
// Attach TypeTree attributes to each input parameter | ||
// Enzyme uses these to understand parameter memory layouts during differentiation | ||
for (i, input) in inputs.iter().enumerate() { | ||
unsafe { | ||
// Convert Rust TypeTree to Enzyme's internal format | ||
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); | ||
|
||
// Serialize TypeTree to string format that Enzyme can parse | ||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); | ||
let c_str = std::ffi::CStr::from_ptr(c_str); | ||
|
||
// Create LLVM string attribute with TypeTree information | ||
let attr = llvm::LLVMCreateStringAttribute( | ||
llcx, | ||
c_attr_name.as_ptr(), | ||
c_attr_name.as_bytes().len() as c_uint, | ||
c_str.as_ptr(), | ||
c_str.to_bytes().len() as c_uint, | ||
); | ||
|
||
// Attach attribute to the specific function parameter | ||
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments | ||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); | ||
|
||
// Free the C string to prevent memory leaks | ||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); | ||
} | ||
} | ||
|
||
// Attach TypeTree attribute to the return type | ||
// Enzyme needs this to understand how to handle return value derivatives | ||
unsafe { | ||
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); | ||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); | ||
let c_str = std::ffi::CStr::from_ptr(c_str); | ||
|
||
let ret_attr = llvm::LLVMCreateStringAttribute( | ||
llcx, | ||
c_attr_name.as_ptr(), | ||
c_attr_name.as_bytes().len() as c_uint, | ||
c_str.as_ptr(), | ||
c_str.to_bytes().len() as c_uint, | ||
); | ||
|
||
// Attach to function return type | ||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); | ||
|
||
// Free the C string | ||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); | ||
} | ||
} | ||
|
||
// Fallback implementation when Enzyme is not available | ||
#[cfg(not(llvm_enzyme))] | ||
pub(crate) fn add_tt<'ll>( | ||
_llmod: &'ll llvm::Module, | ||
_llcx: &'ll llvm::Context, | ||
_fn_def: &'ll Value, | ||
_tt: FncTree, | ||
) { | ||
unimplemented!() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gcc people would probably appreciate a commend clarifying, that these tt are an LLVM only feature and thus shouldn't be set for gcc