@@ -28,7 +28,7 @@ use llvm::Linkage::*;
2828use crate :: back:: write:: {
2929 self , CodegenDiagnosticsStage , DiagnosticHandlers , bitcode_section_name, save_temp_bitcode,
3030} ;
31- use crate :: builder:: SBuilder ;
31+ use crate :: builder:: { SBuilder , UNNAMED } ;
3232use crate :: errors:: {
3333 DynamicLinkingWithLTO , LlvmError , LtoBitcodeFromRlib , LtoDisallowed , LtoDylib , LtoProcMacro ,
3434} ;
@@ -806,6 +806,27 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
806806 o_types
807807}
808808
809+
810+ // For each kernel *call*, we now use some of our previous declared globals to move data to and from
811+ // the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
812+ // from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
813+ // If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
814+ // Since in our frontend users (by default) don't have to specify data transfer, this is something
815+ // we should optimize in the future! We also assume that everything should be copied back and forth,
816+ // but sometimes we can directly zero-allocate on the device and only move back, or if something is
817+ // immutable, we might only copy it to the device, but not back.
818+ //
819+ // Current steps:
820+ // 0. Alloca some variables for the following steps
821+ // 1. set insert point before kernel call.
822+ // 2. generate all the GEPS and stores, to be used in 3)
823+ // 3. generate __tgt_target_data_begin calls to move data to the GPU
824+ //
825+ // unchanged: keep kernel call. Later move the kernel to the GPU
826+ //
827+ // 4. set insert point after kernel call.
828+ // 5. generate all the GEPS and stores, to be used in 6)
829+ // 6. generate __tgt_target_data_end calls to move data from the GPU
809830fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernels : & [ & ' ll llvm:: Value ] , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
810831
811832 let main_fn = cx. get_function ( "main" ) ;
@@ -819,30 +840,39 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
819840 return ;
820841 } ;
821842 let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
843+ let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) } ;
822844 let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
823845
824- let types = cx. func_params_types ( cx. get_type_of_global ( kernels [ 0 ] ) ) ;
846+ let types = cx. func_params_types ( cx. get_type_of_global ( called ) ) ;
825847 dbg ! ( & types) ;
826848 let num_args = types. len ( ) as u64 ;
849+ let mut names: Vec < & llvm:: Value > = Vec :: with_capacity ( num_args) ;
827850
828- // First we generate a few variables used for the data mappers below.
851+ // Step 0)
829852 unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
830853 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
831-
832854 // Baseptr are just the input pointer to the kernel, stored in a local alloca
833855 let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
834-
835856 // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
836857 let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
837-
838858 // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
839859 let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
840860 let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
861+ // Now we allocate once per function param, a copy to be passed to one of our maps.
862+ for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
863+ // Todo:
864+ let p = llvm:: get_param ( called, index as u32 ) ;
865+ let name = llvm:: get_value_name ( p) ;
866+ let arg_name = format ! ( "{name}.addr" ) ;
867+ let alloca = unsafe { llvm:: LLVMBuildAlloca ( builder. llbuilder , in_ty, arg_name) } ;
868+ // get function arg, store it into the alloca, and read it.
869+ }
841870
842- // Now we generate the __tgt_target_data calls
871+
872+ // Step 1)
843873 unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
844- dbg ! ( "positioned builder, ready" ) ;
845874
875+ // Step 2)
846876 let i32_0 = cx. get_const_i32 ( 0 ) ;
847877 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
848878 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
@@ -853,8 +883,8 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
853883 let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
854884 builder. call ( fn_ty, begin, & args, None ) ;
855885
886+ // Step 4)
856887 unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
857- dbg ! ( "re-positioned builder, ready" ) ;
858888
859889 let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
860890 let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
@@ -865,15 +895,6 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
865895 let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
866896 builder. call ( fn_ty, end, & args, None ) ;
867897
868- // 1. set insert point before kernel call.
869- // 2. generate all the GEPS and stores.
870- // 3. generate __tgt_target_data calls.
871- //
872- // unchanged: keep kernel call.
873- //
874- // 4. generate all the GEPS and stores.
875- // 5. generate __tgt_target_data calls
876-
877898 // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
878899 // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
879900 // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
0 commit comments