@@ -799,7 +799,7 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
799799 o_types
800800}
801801
802- fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
802+ fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm :: Value , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
803803
804804 let main_fn = cx. get_function ( "main" ) ;
805805 if let Some ( main_fn) = main_fn {
@@ -814,32 +814,50 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, be
814814 let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
815815 let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
816816
817+ let types = cx. func_params_types ( cx. val_ty ( kernel) ) ;
818+ let num_args = types. len ( ) ;
819+
817820 // First we generate a few variables used for the data mappers below.
818821 // %.offload_baseptrs = alloca [3 x ptr], align 8
819822 // %.offload_ptrs = alloca [3 x ptr], align 8
820823 // %.offload_mappers = alloca [3 x ptr], align 8
821824 // %.offload_sizes = alloca [3 x i64], align 8
822825 unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
823- let ty = cx. type_array ( cx. type_ptr ( ) , 3 ) ;
824- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
825- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
826- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
827- let ty = cx. type_array ( cx. type_i64 ( ) , 3 ) ;
828- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_sizes" ) ;
829-
826+ let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
827+ let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
828+ let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
829+ let a3 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
830+ let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
831+ let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
830832
831833 // Now we generate the __tgt_target_data calls
832834 unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
833835 dbg ! ( "positioned builder, ready" ) ;
834836
837+ // %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
838+ // %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
839+ // %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
840+ let i32_0 = cx. get_const_i32 ( 0 ) ;
841+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
842+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
843+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
844+
835845 let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
836846 let o_type = o_types[ 0 ] ;
837- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , nullptr, nullptr, nullptr, o_type, nullptr, nullptr] ;
838- dbg ! ( & fn_ty) ;
839- dbg ! ( & begin) ;
840- dbg ! ( & args) ;
847+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
841848 builder. call ( fn_ty, begin, & args, None ) ;
842- dbg ! ( "called begin" ) ;
849+
850+ unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
851+ dbg ! ( "re-positioned builder, ready" ) ;
852+
853+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
854+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
855+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
856+
857+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
858+ let o_type = o_types[ 0 ] ;
859+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
860+ builder. call ( fn_ty, end, & args, None ) ;
843861
844862 // 1. set insert point before kernel call.
845863 // 2. generate all the GEPS and stores.
@@ -907,7 +925,7 @@ pub(crate) fn run_pass_manager(
907925 SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
908926 if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
909927
910- let ( offload_entry_ty, at_one, foo , bar , baz , fn_ty) = gen_globals ( & cx) ;
928+ let ( offload_entry_ty, at_one, begin , update , end , fn_ty) = gen_globals ( & cx) ;
911929
912930 dbg ! ( "created struct" ) ;
913931 let mut o_types = vec ! [ ] ;
@@ -918,7 +936,8 @@ pub(crate) fn run_pass_manager(
918936 // TODO: replace num by proper fn name
919937 o_types. push ( gen_define_handling ( & cx, offload_entry_ty, num) ) ;
920938 }
921- gen_call_handling ( & cx, at_one, foo, bar, baz, fn_ty, & o_types) ;
939+ let kernel = cx. get_function ( "kernel_1" ) . unwrap ( ) ;
940+ gen_call_handling ( & cx, kernel, at_one, begin, update, end, fn_ty, & o_types) ;
922941 } else {
923942 dbg ! ( "no marker found" ) ;
924943 }
0 commit comments