@@ -58,19 +58,25 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
58
58
at_one
59
59
}
60
60
61
- // The meaning of the __tgt_offload_entry (as per llvm docs) is
62
- // Type, Identifier, Description
63
- // void*, addr, Address of global symbol within device image (function or global)
64
- // char*, name, Name of the symbol
65
- // size_t, size, Size of the entry info (0 if it is a function)
66
- // int32_t, flags, Flags associated with the entry (see Target Region Entry Flags)
67
- // int32_t, reserved, Reserved, to be used by the runtime library.
68
61
pub ( crate ) fn add_tgt_offload_entry < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
69
62
let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
70
63
let tptr = cx. type_ptr ( ) ;
71
64
let ti64 = cx. type_i64 ( ) ;
72
65
let ti32 = cx. type_i32 ( ) ;
73
66
let ti16 = cx. type_i16 ( ) ;
67
+ // For each kernel to run on the gpu, we will later generate one entry of this type.
68
+ // coppied from LLVM
69
+ // typedef struct {
70
+ // uint64_t Reserved;
71
+ // uint16_t Version;
72
+ // uint16_t Kind;
73
+ // uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
74
+ // void *Address; Address of global symbol within device image (function or global)
75
+ // char *SymbolName;
76
+ // uint64_t Size; Size of the entry info (0 if it is a function)
77
+ // uint64_t Data;
78
+ // void *AuxAddr;
79
+ // } __tgt_offload_entry;
74
80
let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
75
81
cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
76
82
offload_entry_ty
@@ -83,19 +89,30 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
83
89
let ti32 = cx. type_i32 ( ) ;
84
90
let tarr = cx. type_array ( ti32, 3 ) ;
85
91
86
- // For each kernel to run on the gpu, we will later generate one entry of this type.
87
- // coppied from LLVM
88
- // typedef struct {
89
- // uint64_t Reserved;
90
- // uint16_t Version;
91
- // uint16_t Kind;
92
- // uint32_t Flags;
93
- // void *Address;
94
- // char *SymbolName;
95
- // uint64_t Size;
96
- // uint64_t Data;
97
- // void *AuxAddr;
98
- // } __tgt_offload_entry;
92
+ // Taken from the LLVM APITypes.h declaration:
93
+ //struct KernelArgsTy {
94
+ // uint32_t Version = 0; // Version of this struct for ABI compatibility.
95
+ // uint32_t NumArgs = 0; // Number of arguments in each input pointer.
96
+ // void **ArgBasePtrs =
97
+ // nullptr; // Base pointer of each argument (e.g. a struct).
98
+ // void **ArgPtrs = nullptr; // Pointer to the argument data.
99
+ // int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
100
+ // int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
101
+ // void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
102
+ // void **ArgMappers = nullptr; // User-defined mappers, possibly null.
103
+ // uint64_t Tripcount =
104
+ // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
105
+ // struct {
106
+ // uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
107
+ // uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
108
+ // uint64_t Unused : 62;
109
+ // } Flags = {0, 0, 0};
110
+ // // The number of teams (for x,y,z dimension).
111
+ // uint32_t NumTeams[3] = {0, 0, 0};
112
+ // // The number of threads (for x,y,z dimension).
113
+ // uint32_t ThreadLimit[3] = {0, 0, 0};
114
+ // uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
115
+ //};
99
116
let kernel_elements =
100
117
vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
101
118
@@ -180,7 +197,7 @@ fn gen_define_handling<'ll>(
180
197
181
198
// We do not know their size anymore at this level, so hardcode a placeholder.
182
199
// A follow-up pr will track these from the frontend, where we still have Rust types.
183
- // Then, we will be able to figure out that e.g. `&[f32;1024 ]` will result in 32*1024 bytes.
200
+ // Then, we will be able to figure out that e.g. `&[f32;256 ]` will result in 4*256 bytes.
184
201
// I decided that 1024 bytes is a great placeholder value for now.
185
202
add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 1024 ; num_ptr_types] ) ;
186
203
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
@@ -285,135 +302,139 @@ fn gen_call_handling<'ll>(
285
302
let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
286
303
287
304
let main_fn = cx. get_function ( "main" ) ;
288
- if let Some ( main_fn) = main_fn {
289
- let kernel_name = "kernel_1" ;
290
- let call = unsafe {
291
- llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
292
- } ;
293
- let kernel_call = if call. is_some ( ) {
294
- call. unwrap ( )
295
- } else {
296
- return ;
297
- } ;
298
- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
299
- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
300
- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
301
-
302
- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
303
- let num_args = types. len ( ) as u64 ;
304
-
305
- // Step 0)
306
- // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
307
- // %6 = alloca %struct.__tgt_bin_desc, align 8
308
- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
309
-
310
- let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
311
-
312
- let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
313
- // Baseptr are just the input pointer to the kernel, stored in a local alloca
314
- let a1 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
315
- // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
316
- let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
317
- // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
318
- let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
319
- let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
320
- // Now we allocate once per function param, a copy to be passed to one of our maps.
321
- let mut vals = vec ! [ ] ;
322
- let mut geps = vec ! [ ] ;
323
- let i32_0 = cx. get_const_i32 ( 0 ) ;
324
- for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
325
- // get function arg, store it into the alloca, and read it.
326
- let p = llvm:: get_param ( called, index as u32 ) ;
327
- let name = llvm:: get_value_name ( p) ;
328
- let name = str:: from_utf8 ( name) . unwrap ( ) ;
329
- let arg_name = CString :: new ( format ! ( "{name}.addr" ) ) . unwrap ( ) ;
330
- let alloca =
331
- unsafe { llvm:: LLVMBuildAlloca ( builder. llbuilder , in_ty, arg_name. as_ptr ( ) ) } ;
332
- builder. store ( p, alloca, Align :: EIGHT ) ;
333
- let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
334
- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
335
- vals. push ( val) ;
336
- geps. push ( gep) ;
337
- }
338
-
339
- // Step 1)
340
- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
341
- builder. memset (
342
- tgt_bin_desc_alloca,
343
- cx. get_const_i8 ( 0 ) ,
344
- cx. get_const_i64 ( 32 ) ,
345
- Align :: from_bytes ( 8 ) . unwrap ( ) ,
346
- ) ;
347
-
348
- let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
349
- let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
350
- let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
351
- let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
352
- let init_rtls_decl = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
353
-
354
- // call void @__tgt_register_lib(ptr noundef %6)
355
- builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
356
- // call void @__tgt_init_all_rtls()
357
- builder. call ( init_ty, init_rtls_decl, & [ ] , None ) ;
358
-
359
- for i in 0 ..num_args {
360
- let idx = cx. get_const_i32 ( i) ;
361
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
362
- builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
363
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
364
- builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
365
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
366
- builder. store ( cx. get_const_i64 ( 1024 ) , gep3, Align :: EIGHT ) ;
367
- }
305
+ let Some ( main_fn) = main_fn else { return } ;
306
+ let kernel_name = "kernel_1" ;
307
+ let call = unsafe {
308
+ llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) )
309
+ } ;
310
+ let Some ( kernel_call) = call else {
311
+ return ;
312
+ } ;
313
+ let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
314
+ let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
315
+ let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
316
+
317
+ let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
318
+ let num_args = types. len ( ) as u64 ;
319
+
320
+ // Step 0)
321
+ // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
322
+ // %6 = alloca %struct.__tgt_bin_desc, align 8
323
+ unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
324
+
325
+ let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
326
+
327
+ let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
328
+ // Baseptr are just the input pointer to the kernel, stored in a local alloca
329
+ let a1 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
330
+ // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
331
+ let a2 = builder. direct_alloca ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
332
+ // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
333
+ let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
334
+ let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
335
+ // Now we allocate once per function param, a copy to be passed to one of our maps.
336
+ let mut vals = vec ! [ ] ;
337
+ let mut geps = vec ! [ ] ;
338
+ let i32_0 = cx. get_const_i32 ( 0 ) ;
339
+ for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
340
+ // get function arg, store it into the alloca, and read it.
341
+ let p = llvm:: get_param ( called, index as u32 ) ;
342
+ let name = llvm:: get_value_name ( p) ;
343
+ let name = str:: from_utf8 ( name) . unwrap ( ) ;
344
+ let arg_name = format ! ( "{name}.addr" ) ;
345
+ let alloca = builder. direct_alloca ( in_ty, Align :: EIGHT , & arg_name) ;
346
+
347
+ builder. store ( p, alloca, Align :: EIGHT ) ;
348
+ let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
349
+ let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
350
+ vals. push ( val) ;
351
+ geps. push ( gep) ;
352
+ }
368
353
369
- // Step 2)
370
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
371
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
372
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
373
-
374
- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
375
- let o_type = o_types[ 0 ] ;
376
- let s_ident_t = generate_at_one ( & cx) ;
377
- let args = vec ! [
378
- s_ident_t,
379
- cx. get_const_i64( u64 :: MAX ) ,
380
- cx. get_const_i32( num_args) ,
381
- gep1,
382
- gep2,
383
- gep3,
384
- o_type,
385
- nullptr,
386
- nullptr,
387
- ] ;
388
- builder. call ( fn_ty, begin_mapper_decl, & args, None ) ;
389
-
390
- // Step 4)
391
- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
392
-
393
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
394
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
395
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
396
-
397
- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
398
- let o_type = o_types[ 0 ] ;
399
- let args = vec ! [
400
- s_ident_t,
401
- cx. get_const_i64( u64 :: MAX ) ,
402
- cx. get_const_i32( num_args) ,
403
- gep1,
404
- gep2,
405
- gep3,
406
- o_type,
407
- nullptr,
408
- nullptr,
409
- ] ;
410
- builder. call ( fn_ty, end_mapper_decl, & args, None ) ;
411
- builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
412
-
413
- // With this we generated the following begin and end mappers. We could easily generate the
414
- // update mapper in an update.
415
- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
416
- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
417
- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
354
+ // Step 1)
355
+ unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
356
+ builder. memset (
357
+ tgt_bin_desc_alloca,
358
+ cx. get_const_i8 ( 0 ) ,
359
+ cx. get_const_i64 ( 32 ) ,
360
+ Align :: from_bytes ( 8 ) . unwrap ( ) ,
361
+ ) ;
362
+
363
+ let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
364
+ let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
365
+ let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
366
+ let init_ty = cx. type_func ( & [ ] , cx. type_void ( ) ) ;
367
+ let init_rtls_decl = declare_offload_fn ( cx, "__tgt_init_all_rtls" , init_ty) ;
368
+
369
+ // call void @__tgt_register_lib(ptr noundef %6)
370
+ builder. call ( mapper_fn_ty, register_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
371
+ // call void @__tgt_init_all_rtls()
372
+ builder. call ( init_ty, init_rtls_decl, & [ ] , None ) ;
373
+
374
+ for i in 0 ..num_args {
375
+ let idx = cx. get_const_i32 ( i) ;
376
+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
377
+ builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
378
+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
379
+ builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
380
+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
381
+ // As mentioned above, we don't use Rust type informatino yet. So for now we will just
382
+ // assume that we have 1024 bytes, 256 f32 values.
383
+ // FIXME(offload): write an offload frontend and handle arbitrary types.
384
+ builder. store ( cx. get_const_i64 ( 1024 ) , gep3, Align :: EIGHT ) ;
418
385
}
386
+
387
+ // Step 2)
388
+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
389
+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
390
+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
391
+
392
+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
393
+ let o_type = o_types[ 0 ] ;
394
+ let s_ident_t = generate_at_one ( & cx) ;
395
+ let args = vec ! [
396
+ s_ident_t,
397
+ cx. get_const_i64( u64 :: MAX ) ,
398
+ cx. get_const_i32( num_args) ,
399
+ gep1,
400
+ gep2,
401
+ gep3,
402
+ o_type,
403
+ nullptr,
404
+ nullptr,
405
+ ] ;
406
+ builder. call ( fn_ty, begin_mapper_decl, & args, None ) ;
407
+
408
+ // Step 3)
409
+ // Here we will add code for the actual kernel launches in a follow-up PR.
410
+ // FIXME(offload): launch kernels
411
+
412
+ // Step 4)
413
+ unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
414
+
415
+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
416
+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
417
+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
418
+
419
+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
420
+ let o_type = o_types[ 0 ] ;
421
+ let args = vec ! [
422
+ s_ident_t,
423
+ cx. get_const_i64( u64 :: MAX ) ,
424
+ cx. get_const_i32( num_args) ,
425
+ gep1,
426
+ gep2,
427
+ gep3,
428
+ o_type,
429
+ nullptr,
430
+ nullptr,
431
+ ] ;
432
+ builder. call ( fn_ty, end_mapper_decl, & args, None ) ;
433
+ builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
434
+
435
+ // With this we generated the following begin and end mappers. We could easily generate the
436
+ // update mapper in an update.
437
+ // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
438
+ // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
439
+ // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
419
440
}
0 commit comments