@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
12411241}
12421242#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
12431243
1244+ // Find the min val of div2 that doesn't increase N/(div1*div2)
12441245int mindiv (int N, int div1, int div2) {
12451246 int nPrRnd = div1 * div2;
1246- int rnds0 = N / nPrRnd;
1247- nPrRnd -= div1 * 3 ;
1248- int rnds3 = N / nPrRnd;
1249- nPrRnd -= div1;
1250- int rnds4 = N / nPrRnd;
1251- nPrRnd -= div1;
1252- int rnds5 = N / nPrRnd;
1253- nPrRnd -= div1;
1254- int rnds6 = N / nPrRnd;
1255- nPrRnd -= div1;
1256- int rnds7 = N / nPrRnd;
1257- nPrRnd -= div1;
1258- int rnds8 = N / nPrRnd;
1259- nPrRnd -= div1;
1260- int rnds9 = N / nPrRnd;
1261- nPrRnd -= div1;
1262- int rtn = div2;
1263- if (rnds0 == rnds3) rtn = div2 - 3 ;
1264- if (rnds0 == rnds4) rtn = div2 - 4 ;
1265- if (rnds0 == rnds5) rtn = div2 - 5 ;
1266- if (rnds0 == rnds6) rtn = div2 - 6 ;
1267- if (rnds0 == rnds7) rtn = div2 - 7 ;
1268- if (rnds0 == rnds8) rtn = div2 - 8 ;
1269- if (rnds0 == rnds9) rtn = div2 - 9 ;
1270- return rtn;
1247+ int rnds[13 ];
1248+ for (int i = 0 ; i < 13 ; i++) {
1249+ rnds[i] = (N + nPrRnd - 1 ) / nPrRnd;
1250+ nPrRnd -= div1;
1251+ }
1252+ for (int i = 12 ; i >= 0 ; i--)
1253+ if (rnds[0 ] == rnds[i]) return (div2 - i);
12711254}
12721255
12731256torch::Tensor wvSplitK (const at::Tensor& in_a, const at::Tensor& in_b,
@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
13001283 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
13011284 const int max_lds_len = get_lds_size () / 2 ;
13021285
1303- #define WVSPLITK (_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
1304- _N) \
1305- { \
1306- dim3 block (64 , _WvPrGrp); \
1307- if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0 )) { \
1308- int __wvPrGrp = mindiv (M_in, CuCount * _YTILEs, _WvPrGrp); \
1309- wvSplitK_hf_sml_<fptype, 64 , _YTILEs, _WvPrGrp, 8 , _UNRLs, _N> \
1310- <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1311- biasf4, c, __wvPrGrp, CuCount); \
1312- } else if (K_in * N_in <= max_lds_len * 1.2 ) { \
1313- int __wvPrGrp = mindiv (M_in, CuCount * _YTILEm, _WvPrGrp); \
1314- wvSplitK_hf_<fptype, 64 , _YTILEm, _WvPrGrp, 8 , _UNRLm, _N> \
1315- <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1316- biasf4, c, __wvPrGrp, CuCount); \
1317- } else { \
1318- int __wvPrGrp = mindiv (M_in, CuCount * _YTILEb, _WvPrGrp); \
1319- wvSplitK_hf_big_<fptype, 64 , _YTILEb, _WvPrGrp, 8 , _UNRLb, _N> \
1320- <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1321- biasf4, c, __wvPrGrp, CuCount); \
1322- } \
1286+ #define WVSPLITK (_YTILE, _UNRL, _N ) \
1287+ { \
1288+ dim3 block (64 , 16 ); \
1289+ int __wvPrGrp = mindiv (M_in, CuCount * _YTILE, 16 ); \
1290+ if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0 )) \
1291+ wvSplitK_hf_sml_<fptype, 64 , _YTILE, 16 , 8 , _UNRL, _N> \
1292+ <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1293+ biasf4, c, __wvPrGrp, CuCount); \
1294+ else if (K_in * N_in <= max_lds_len * 1.2 ) \
1295+ wvSplitK_hf_<fptype, 64 , _YTILE, 16 , 8 , _UNRL, _N> \
1296+ <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1297+ biasf4, c, __wvPrGrp, CuCount); \
1298+ else \
1299+ wvSplitK_hf_big_<fptype, 64 , _YTILE, 16 , 8 , _UNRL, _N> \
1300+ <<<grid, block, 0 , stream>>> (K_in, M_in, Bx_in, By_in, af4, bf4, \
1301+ biasf4, c, __wvPrGrp, CuCount); \
1302+ }
1303+
1304+ #define WVSPLIT_TILE (_sYT, __N ) \
1305+ { \
1306+ bool fit_lds = (K_in * N_in <= max_lds_len); \
1307+ if (_sYT <= 1 ) \
1308+ WVSPLITK (1 , 4 , __N) \
1309+ else if ((__N == 1 ) || (!fit_lds) || (_sYT <= 4 * 2 )) \
1310+ WVSPLITK (2 , 2 , __N) \
1311+ else if (_sYT <= 4 * 3 ) \
1312+ WVSPLITK (3 , 2 , __N) \
1313+ else if (__N == 4 ) \
1314+ WVSPLITK (4 , 1 , __N) \
1315+ else \
1316+ WVSPLITK (4 , 2 , __N) \
13231317 }
13241318
13251319 AT_DISPATCH_REDUCED_FLOATING_TYPES (in_b.scalar_type (), " wvSplitK" , [&] {
@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
13311325 ? reinterpret_cast <const fptype*>(in_bias->data_ptr ())
13321326 : nullptr ;
13331327 fptype* c = reinterpret_cast <fptype*>(out_c.data_ptr ());
1328+
1329+ // first shoot for biggest tile-size that keeps all simd busy,
1330+ // then cut the active waves to balance their distribution...
1331+ int sYT = (M_in + CuCount * 4 - 1 ) / (CuCount * 4 );
1332+
13341333 switch (N_in) {
13351334 case 1 :
1336- WVSPLITK ( 16 , 2 , 2 , 2 , 2 , 2 , 2 , 1 )
1335+ WVSPLIT_TILE ( sYT , 1 )
13371336 break ;
13381337 case 2 :
1339- WVSPLITK ( 16 , 2 , 2 , 2 , 2 , 2 , 2 , 2 )
1338+ WVSPLIT_TILE ( sYT , 2 )
13401339 break ;
13411340 case 3 :
1342- WVSPLITK ( 16 , 4 , 7 , 7 , 1 , 1 , 1 , 3 )
1341+ WVSPLIT_TILE ( sYT , 3 )
13431342 break ;
13441343 case 4 :
1345- WVSPLITK ( 16 , 4 , 7 , 7 , 1 , 1 , 1 , 4 )
1344+ WVSPLIT_TILE ( sYT , 4 )
13461345 break ;
13471346 default :
13481347 throw std::runtime_error (
0 commit comments