@@ -742,8 +742,20 @@ fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value
742742
743743fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) -> & ' ll llvm:: Value {
744744 let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
745- let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 8u64 , 0 , 16 , 0 ] ) ;
746- let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 800u64 , 544 , 547 , 544 ] ) ;
745+ // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
746+ // reference) types.
747+ let num_ptr_types = types. iter ( ) . map ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) ) . count ( ) ;
748+
749+ // We do not know their size anymore at this level, so hardcode a placeholder.
750+ // A follow-up pr will track these from the frontend, where we still have Rust types.
751+ // Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
752+ // I decided that 1024 bytes is a great placeholder value for now.
753+ let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 1024 ; num_ptr_types] ) ;
754+ // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
755+ // or both to and from the gpu (=3). Other values shouldn't affect us for now.
756+ // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
757+ // will be 2. For now, everything is 3, untill we have our frontend set up.
758+ let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 3 ; num_ptr_types] ) ;
747759 // Next: For each function, generate these three entries. A weak constant,
748760 // the llvm.rodata entry name, and the omp_offloading_entries value
749761
@@ -794,11 +806,11 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
794806 o_types
795807}
796808
797- fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
809+ fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernels : & [ & ' ll llvm:: Value ] , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
798810
799811 let main_fn = cx. get_function ( "main" ) ;
800812 if let Some ( main_fn) = main_fn {
801- let kernel_name = "kernel_1" ; //name.as_c_char_ptr(), name.len)
813+ let kernel_name = "kernel_1" ;
802814 let call = unsafe { llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) ) } ;
803815 let kernel_call = if call. is_some ( ) {
804816 dbg ! ( "found kernel call" ) ;
@@ -809,38 +821,36 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ide
809821 let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
810822 let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
811823
812- let types = cx. func_params_types ( cx. get_type_of_global ( kernel ) ) ;
824+ let types = cx. func_params_types ( cx. get_type_of_global ( kernels [ 0 ] ) ) ;
813825 dbg ! ( & types) ;
814826 let num_args = types. len ( ) as u64 ;
815827
816828 // First we generate a few variables used for the data mappers below.
817- // %.offload_baseptrs = alloca [3 x ptr], align 8
818- // %.offload_ptrs = alloca [3 x ptr], align 8
819- // %.offload_mappers = alloca [3 x ptr], align 8
820- // %.offload_sizes = alloca [3 x i64], align 8
821829 unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
822830 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
831+
832+ // Baseptr are just the input pointer to the kernel, stored in a local alloca
823833 let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
834+
835+ // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
824836 let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
825- let a3 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
837+
838+ // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
826839 let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
827840 let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
828841
829842 // Now we generate the __tgt_target_data calls
830843 unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
831844 dbg ! ( "positioned builder, ready" ) ;
832845
833- // %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
834- // %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
835- // %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
836846 let i32_0 = cx. get_const_i32 ( 0 ) ;
837847 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
838848 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
839849 let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
840850
841851 let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
842852 let o_type = o_types[ 0 ] ;
843- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
853+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
844854 builder. call ( fn_ty, begin, & args, None ) ;
845855
846856 unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
@@ -925,15 +935,16 @@ pub(crate) fn run_pass_manager(
925935
926936 dbg ! ( "created struct" ) ;
927937 let mut o_types = vec ! [ ] ;
938+ let mut kernels = vec ! [ ] ;
928939 for num in 0 ..9 {
929940 let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
930941 if let Some ( kernel) = kernel{
931942 o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
943+ kernels. push ( kernel) ;
932944 }
933945 }
934- let kernel = cx. get_function ( "kernel_1" ) . unwrap ( ) ;
935946 dbg ! ( "gen_call_handling" ) ;
936- gen_call_handling ( & cx, kernel , at_one, begin, update, end, fn_ty, & o_types) ;
947+ gen_call_handling ( & cx, & kernels , at_one, begin, update, end, fn_ty, & o_types) ;
937948 } else {
938949 dbg ! ( "no marker found" ) ;
939950 }
0 commit comments