Skip to content

Commit 40055d6

Browse files
committed
addressing more feedback
1 parent 9ec8abe commit 40055d6

File tree

1 file changed

+171
-150
lines changed

1 file changed

+171
-150
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 171 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,25 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
5858
at_one
5959
}
6060

61-
// The meaning of the __tgt_offload_entry (as per llvm docs) is
62-
// Type, Identifier, Description
63-
// void*, addr, Address of global symbol within device image (function or global)
64-
// char*, name, Name of the symbol
65-
// size_t, size, Size of the entry info (0 if it is a function)
66-
// int32_t, flags, Flags associated with the entry (see Target Region Entry Flags)
67-
// int32_t, reserved, Reserved, to be used by the runtime library.
6861
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
6962
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
7063
let tptr = cx.type_ptr();
7164
let ti64 = cx.type_i64();
7265
let ti32 = cx.type_i32();
7366
let ti16 = cx.type_i16();
67+
// For each kernel to run on the gpu, we will later generate one entry of this type.
68+
// coppied from LLVM
69+
// typedef struct {
70+
// uint64_t Reserved;
71+
// uint16_t Version;
72+
// uint16_t Kind;
73+
// uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
74+
// void *Address; Address of global symbol within device image (function or global)
75+
// char *SymbolName;
76+
// uint64_t Size; Size of the entry info (0 if it is a function)
77+
// uint64_t Data;
78+
// void *AuxAddr;
79+
// } __tgt_offload_entry;
7480
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
7581
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
7682
offload_entry_ty
@@ -83,19 +89,30 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
8389
let ti32 = cx.type_i32();
8490
let tarr = cx.type_array(ti32, 3);
8591

86-
// For each kernel to run on the gpu, we will later generate one entry of this type.
87-
// coppied from LLVM
88-
// typedef struct {
89-
// uint64_t Reserved;
90-
// uint16_t Version;
91-
// uint16_t Kind;
92-
// uint32_t Flags;
93-
// void *Address;
94-
// char *SymbolName;
95-
// uint64_t Size;
96-
// uint64_t Data;
97-
// void *AuxAddr;
98-
// } __tgt_offload_entry;
92+
// Taken from the LLVM APITypes.h declaration:
93+
//struct KernelArgsTy {
94+
// uint32_t Version = 0; // Version of this struct for ABI compatibility.
95+
// uint32_t NumArgs = 0; // Number of arguments in each input pointer.
96+
// void **ArgBasePtrs =
97+
// nullptr; // Base pointer of each argument (e.g. a struct).
98+
// void **ArgPtrs = nullptr; // Pointer to the argument data.
99+
// int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
100+
// int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
101+
// void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
102+
// void **ArgMappers = nullptr; // User-defined mappers, possibly null.
103+
// uint64_t Tripcount =
104+
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
105+
// struct {
106+
// uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
107+
// uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
108+
// uint64_t Unused : 62;
109+
// } Flags = {0, 0, 0};
110+
// // The number of teams (for x,y,z dimension).
111+
// uint32_t NumTeams[3] = {0, 0, 0};
112+
// // The number of threads (for x,y,z dimension).
113+
// uint32_t ThreadLimit[3] = {0, 0, 0};
114+
// uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
115+
//};
99116
let kernel_elements =
100117
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
101118

@@ -180,7 +197,7 @@ fn gen_define_handling<'ll>(
180197

181198
// We do not know their size anymore at this level, so hardcode a placeholder.
182199
// A follow-up pr will track these from the frontend, where we still have Rust types.
183-
// Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
200+
// Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
184201
// I decided that 1024 bytes is a great placeholder value for now.
185202
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
186203
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
@@ -285,135 +302,139 @@ fn gen_call_handling<'ll>(
285302
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
286303

287304
let main_fn = cx.get_function("main");
288-
if let Some(main_fn) = main_fn {
289-
let kernel_name = "kernel_1";
290-
let call = unsafe {
291-
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
292-
};
293-
let kernel_call = if call.is_some() {
294-
call.unwrap()
295-
} else {
296-
return;
297-
};
298-
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
299-
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
300-
let mut builder = SBuilder::build(cx, kernel_call_bb);
301-
302-
let types = cx.func_params_types(cx.get_type_of_global(called));
303-
let num_args = types.len() as u64;
304-
305-
// Step 0)
306-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
307-
// %6 = alloca %struct.__tgt_bin_desc, align 8
308-
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
309-
310-
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
311-
312-
let ty = cx.type_array(cx.type_ptr(), num_args);
313-
// Baseptr are just the input pointer to the kernel, stored in a local alloca
314-
let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
315-
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
316-
let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
317-
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
318-
let ty2 = cx.type_array(cx.type_i64(), num_args);
319-
let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
320-
// Now we allocate once per function param, a copy to be passed to one of our maps.
321-
let mut vals = vec![];
322-
let mut geps = vec![];
323-
let i32_0 = cx.get_const_i32(0);
324-
for (index, in_ty) in types.iter().enumerate() {
325-
// get function arg, store it into the alloca, and read it.
326-
let p = llvm::get_param(called, index as u32);
327-
let name = llvm::get_value_name(p);
328-
let name = str::from_utf8(name).unwrap();
329-
let arg_name = CString::new(format!("{name}.addr")).unwrap();
330-
let alloca =
331-
unsafe { llvm::LLVMBuildAlloca(builder.llbuilder, in_ty, arg_name.as_ptr()) };
332-
builder.store(p, alloca, Align::EIGHT);
333-
let val = builder.load(in_ty, alloca, Align::EIGHT);
334-
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
335-
vals.push(val);
336-
geps.push(gep);
337-
}
338-
339-
// Step 1)
340-
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
341-
builder.memset(
342-
tgt_bin_desc_alloca,
343-
cx.get_const_i8(0),
344-
cx.get_const_i64(32),
345-
Align::from_bytes(8).unwrap(),
346-
);
347-
348-
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
349-
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
350-
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
351-
let init_ty = cx.type_func(&[], cx.type_void());
352-
let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
353-
354-
// call void @__tgt_register_lib(ptr noundef %6)
355-
builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
356-
// call void @__tgt_init_all_rtls()
357-
builder.call(init_ty, init_rtls_decl, &[], None);
358-
359-
for i in 0..num_args {
360-
let idx = cx.get_const_i32(i);
361-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
362-
builder.store(vals[i as usize], gep1, Align::EIGHT);
363-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
364-
builder.store(geps[i as usize], gep2, Align::EIGHT);
365-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
366-
builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
367-
}
305+
let Some(main_fn) = main_fn else { return };
306+
let kernel_name = "kernel_1";
307+
let call = unsafe {
308+
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
309+
};
310+
let Some(kernel_call) = call else {
311+
return;
312+
};
313+
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
314+
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
315+
let mut builder = SBuilder::build(cx, kernel_call_bb);
316+
317+
let types = cx.func_params_types(cx.get_type_of_global(called));
318+
let num_args = types.len() as u64;
319+
320+
// Step 0)
321+
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
322+
// %6 = alloca %struct.__tgt_bin_desc, align 8
323+
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
324+
325+
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
326+
327+
let ty = cx.type_array(cx.type_ptr(), num_args);
328+
// Baseptr are just the input pointer to the kernel, stored in a local alloca
329+
let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
330+
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
331+
let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
332+
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
333+
let ty2 = cx.type_array(cx.type_i64(), num_args);
334+
let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
335+
// Now we allocate once per function param, a copy to be passed to one of our maps.
336+
let mut vals = vec![];
337+
let mut geps = vec![];
338+
let i32_0 = cx.get_const_i32(0);
339+
for (index, in_ty) in types.iter().enumerate() {
340+
// get function arg, store it into the alloca, and read it.
341+
let p = llvm::get_param(called, index as u32);
342+
let name = llvm::get_value_name(p);
343+
let name = str::from_utf8(name).unwrap();
344+
let arg_name = format!("{name}.addr");
345+
let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
346+
347+
builder.store(p, alloca, Align::EIGHT);
348+
let val = builder.load(in_ty, alloca, Align::EIGHT);
349+
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
350+
vals.push(val);
351+
geps.push(gep);
352+
}
368353

369-
// Step 2)
370-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
371-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
372-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
373-
374-
let nullptr = cx.const_null(cx.type_ptr());
375-
let o_type = o_types[0];
376-
let s_ident_t = generate_at_one(&cx);
377-
let args = vec![
378-
s_ident_t,
379-
cx.get_const_i64(u64::MAX),
380-
cx.get_const_i32(num_args),
381-
gep1,
382-
gep2,
383-
gep3,
384-
o_type,
385-
nullptr,
386-
nullptr,
387-
];
388-
builder.call(fn_ty, begin_mapper_decl, &args, None);
389-
390-
// Step 4)
391-
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
392-
393-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
394-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
395-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
396-
397-
let nullptr = cx.const_null(cx.type_ptr());
398-
let o_type = o_types[0];
399-
let args = vec![
400-
s_ident_t,
401-
cx.get_const_i64(u64::MAX),
402-
cx.get_const_i32(num_args),
403-
gep1,
404-
gep2,
405-
gep3,
406-
o_type,
407-
nullptr,
408-
nullptr,
409-
];
410-
builder.call(fn_ty, end_mapper_decl, &args, None);
411-
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
412-
413-
// With this we generated the following begin and end mappers. We could easily generate the
414-
// update mapper in an update.
415-
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
416-
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
417-
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
354+
// Step 1)
355+
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
356+
builder.memset(
357+
tgt_bin_desc_alloca,
358+
cx.get_const_i8(0),
359+
cx.get_const_i64(32),
360+
Align::from_bytes(8).unwrap(),
361+
);
362+
363+
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
364+
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
365+
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
366+
let init_ty = cx.type_func(&[], cx.type_void());
367+
let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
368+
369+
// call void @__tgt_register_lib(ptr noundef %6)
370+
builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
371+
// call void @__tgt_init_all_rtls()
372+
builder.call(init_ty, init_rtls_decl, &[], None);
373+
374+
for i in 0..num_args {
375+
let idx = cx.get_const_i32(i);
376+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
377+
builder.store(vals[i as usize], gep1, Align::EIGHT);
378+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
379+
builder.store(geps[i as usize], gep2, Align::EIGHT);
380+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
381+
// As mentioned above, we don't use Rust type informatino yet. So for now we will just
382+
// assume that we have 1024 bytes, 256 f32 values.
383+
// FIXME(offload): write an offload frontend and handle arbitrary types.
384+
builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
418385
}
386+
387+
// Step 2)
388+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
389+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
390+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
391+
392+
let nullptr = cx.const_null(cx.type_ptr());
393+
let o_type = o_types[0];
394+
let s_ident_t = generate_at_one(&cx);
395+
let args = vec![
396+
s_ident_t,
397+
cx.get_const_i64(u64::MAX),
398+
cx.get_const_i32(num_args),
399+
gep1,
400+
gep2,
401+
gep3,
402+
o_type,
403+
nullptr,
404+
nullptr,
405+
];
406+
builder.call(fn_ty, begin_mapper_decl, &args, None);
407+
408+
// Step 3)
409+
// Here we will add code for the actual kernel launches in a follow-up PR.
410+
// FIXME(offload): launch kernels
411+
412+
// Step 4)
413+
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
414+
415+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
416+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
417+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
418+
419+
let nullptr = cx.const_null(cx.type_ptr());
420+
let o_type = o_types[0];
421+
let args = vec![
422+
s_ident_t,
423+
cx.get_const_i64(u64::MAX),
424+
cx.get_const_i32(num_args),
425+
gep1,
426+
gep2,
427+
gep3,
428+
o_type,
429+
nullptr,
430+
nullptr,
431+
];
432+
builder.call(fn_ty, end_mapper_decl, &args, None);
433+
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
434+
435+
// With this we generated the following begin and end mappers. We could easily generate the
436+
// update mapper in an update.
437+
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
438+
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
439+
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
419440
}

0 commit comments

Comments
 (0)