Skip to content

Commit c81ec47

Browse files
authored
[eudsl-python-extras] add support for partially specialized/evaled generics (#260)
This PR adds support for "partially specialized generics". I.e., ```python @func def maxpool2d[ kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1, dtype ]( input: "T.memref(S, S, S, S, dtype)", output: "T.memref(S, S, S, S, dtype)", ): .... kernel_sizes = [2, 3] strides = [4, 5] dilations = [6, 7] maxpool2d_k_dtype = maxpool2d[ kernel_sizes[0], kernel_sizes[1], strides[0], strides[1], dilations[0], dilations[1], # NOTE: missing dtype ] # fill in dtype maxpool2d_k_f32 = maxpool2d_k_dtype[T.f32()].emit() maxpool2d_k_f64 = maxpool2d_k_dtype[T.f64()].emit() ``` produces ```mlir func.func @maxpool2d_int_2_int_3_int_4_int_5_int_6_int_7_type_f32(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>) { ... } func.func @maxpool2d_int_2_int_3_int_4_int_5_int_6_int_7_type_f64(%arg0: memref<?x?x?x?xf64>, %arg1: memref<?x?x?x?xf64>) { ... } ```
1 parent 4536799 commit c81ec47

File tree

5 files changed

+973
-713
lines changed

5 files changed

+973
-713
lines changed

.github/workflows/build_test_release_eudsl_python_extras.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ jobs:
181181
182182
IGNORE=""
183183
if [[ $(python -c "print(__import__('sys').version_info < (3, 13))") == "True" ]]; then
184-
IGNORE="--ignore projects/eudsl-python-extras/tests/dialect/test_generics.py"
184+
IGNORE="--ignore projects/eudsl-python-extras/tests/test_generics.py"
185185
fi
186186
187187
python -m pytest projects/eudsl-python-extras/tests $IGNORE

projects/eudsl-python-extras/examples/cuda_matmul_opt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ def sgemm_shared_mem_1d_block_tiling[
394394
BN,
395395
BK,
396396
TM,
397+
# make the tests happy which expect 9 generic params...
398+
# see test_generics.py/test_wrong_generics_types
399+
_TN,
397400
A_t = T.memref(M, K, dtype),
398401
B_t = T.memref(K, N, dtype),
399402
C_t = T.memref(M, N, dtype),

0 commit comments

Comments
 (0)