Skip to content

Commit 2e7054d

Browse files
authored
Improve wvsplitK tile and balance heristics. (#29937)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
1 parent 3c680f4 commit 2e7054d

File tree

1 file changed

+48
-49
lines changed

1 file changed

+48
-49
lines changed

csrc/rocm/skinny_gemms.cu

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
12411241
}
12421242
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
12431243

1244+
// Find the min val of div2 that doesn't increase N/(div1*div2)
12441245
int mindiv(int N, int div1, int div2) {
12451246
int nPrRnd = div1 * div2;
1246-
int rnds0 = N / nPrRnd;
1247-
nPrRnd -= div1 * 3;
1248-
int rnds3 = N / nPrRnd;
1249-
nPrRnd -= div1;
1250-
int rnds4 = N / nPrRnd;
1251-
nPrRnd -= div1;
1252-
int rnds5 = N / nPrRnd;
1253-
nPrRnd -= div1;
1254-
int rnds6 = N / nPrRnd;
1255-
nPrRnd -= div1;
1256-
int rnds7 = N / nPrRnd;
1257-
nPrRnd -= div1;
1258-
int rnds8 = N / nPrRnd;
1259-
nPrRnd -= div1;
1260-
int rnds9 = N / nPrRnd;
1261-
nPrRnd -= div1;
1262-
int rtn = div2;
1263-
if (rnds0 == rnds3) rtn = div2 - 3;
1264-
if (rnds0 == rnds4) rtn = div2 - 4;
1265-
if (rnds0 == rnds5) rtn = div2 - 5;
1266-
if (rnds0 == rnds6) rtn = div2 - 6;
1267-
if (rnds0 == rnds7) rtn = div2 - 7;
1268-
if (rnds0 == rnds8) rtn = div2 - 8;
1269-
if (rnds0 == rnds9) rtn = div2 - 9;
1270-
return rtn;
1247+
int rnds[13];
1248+
for (int i = 0; i < 13; i++) {
1249+
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
1250+
nPrRnd -= div1;
1251+
}
1252+
for (int i = 12; i >= 0; i--)
1253+
if (rnds[0] == rnds[i]) return (div2 - i);
12711254
}
12721255

12731256
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
13001283
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
13011284
const int max_lds_len = get_lds_size() / 2;
13021285

1303-
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
1304-
_N) \
1305-
{ \
1306-
dim3 block(64, _WvPrGrp); \
1307-
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
1308-
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
1309-
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
1310-
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1311-
biasf4, c, __wvPrGrp, CuCount); \
1312-
} else if (K_in * N_in <= max_lds_len * 1.2) { \
1313-
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
1314-
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
1315-
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1316-
biasf4, c, __wvPrGrp, CuCount); \
1317-
} else { \
1318-
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
1319-
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
1320-
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1321-
biasf4, c, __wvPrGrp, CuCount); \
1322-
} \
1286+
#define WVSPLITK(_YTILE, _UNRL, _N) \
1287+
{ \
1288+
dim3 block(64, 16); \
1289+
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
1290+
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
1291+
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
1292+
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1293+
biasf4, c, __wvPrGrp, CuCount); \
1294+
else if (K_in * N_in <= max_lds_len * 1.2) \
1295+
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
1296+
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1297+
biasf4, c, __wvPrGrp, CuCount); \
1298+
else \
1299+
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
1300+
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
1301+
biasf4, c, __wvPrGrp, CuCount); \
1302+
}
1303+
1304+
#define WVSPLIT_TILE(_sYT, __N) \
1305+
{ \
1306+
bool fit_lds = (K_in * N_in <= max_lds_len); \
1307+
if (_sYT <= 1) \
1308+
WVSPLITK(1, 4, __N) \
1309+
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
1310+
WVSPLITK(2, 2, __N) \
1311+
else if (_sYT <= 4 * 3) \
1312+
WVSPLITK(3, 2, __N) \
1313+
else if (__N == 4) \
1314+
WVSPLITK(4, 1, __N) \
1315+
else \
1316+
WVSPLITK(4, 2, __N) \
13231317
}
13241318

13251319
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
13311325
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
13321326
: nullptr;
13331327
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
1328+
1329+
// first shoot for biggest tile-size that keeps all simd busy,
1330+
// then cut the active waves to balance their distribution...
1331+
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
1332+
13341333
switch (N_in) {
13351334
case 1:
1336-
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
1335+
WVSPLIT_TILE(sYT, 1)
13371336
break;
13381337
case 2:
1339-
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
1338+
WVSPLIT_TILE(sYT, 2)
13401339
break;
13411340
case 3:
1342-
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
1341+
WVSPLIT_TILE(sYT, 3)
13431342
break;
13441343
case 4:
1345-
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
1344+
WVSPLIT_TILE(sYT, 4)
13461345
break;
13471346
default:
13481347
throw std::runtime_error(

0 commit comments

Comments
 (0)