Skip to content

Commit eca6b4d

Browse files
committed
Numba linalg: Handle empty inputs
1 parent 26e703e commit eca6b4d

File tree

6 files changed

+97
-7
lines changed

6 files changed

+97
-7
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from numba.core.extending import overload
6+
from numba.core.types import Float
67
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
78
from scipy import linalg
89

@@ -35,7 +36,7 @@ def getrf_impl(
3536
A: np.ndarray, overwrite_a: bool = False
3637
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
3738
ensure_lapack()
38-
_check_linalg_matrix(A, ndim=2, func_name="getrf")
39+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
3940
dtype = A.dtype
4041
numba_getrf = _LAPACK().numba_xgetrf(dtype)
4142

@@ -75,7 +76,7 @@ def lu_factor_impl(
7576
A: np.ndarray, overwrite_a: bool = False
7677
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
7778
ensure_lapack()
78-
_check_linalg_matrix(A, ndim=2, func_name="lu_factor")
79+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")
7980

8081
def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
8182
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)

pytensor/link/numba/dispatch/linalg/solve/lu_solve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def impl(
8787
NRHS,
8888
LU.ctypes,
8989
LDA,
90+
# TODO: Does this work with any int dtype?
9091
IPIV.ctypes,
9192
B_copy.ctypes,
9293
LDB,

pytensor/link/numba/dispatch/linalg/solve/norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from numba.core.extending import overload
5+
from numba.core.types import Float
56
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
67

78
from pytensor.link.numba.dispatch.linalg._LAPACK import (
@@ -27,7 +28,7 @@ def xlange_impl(
2728
largest absolute value of a matrix A.
2829
"""
2930
ensure_lapack()
30-
_check_linalg_matrix(A, ndim=2, func_name="norm")
31+
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="norm")
3132
dtype = A.dtype
3233
numba_lange = _LAPACK().numba_xlange(dtype)
3334

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,15 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
355355

356356
@numba_basic.numba_njit(cache=False)
357357
def lu_factor_tridiagonal(dl, d, du):
358+
if d.size == 0:
359+
return (
360+
np.zeros(dl.shape, dtype=out_dtype),
361+
np.zeros(d.shape, dtype=out_dtype),
362+
np.zeros(du.shape, dtype=out_dtype),
363+
np.zeros(d.shape, dtype=out_dtype),
364+
np.zeros(d.shape, dtype="int32"),
365+
)
366+
358367
if d.dtype != out_dtype:
359368
d.dtype = out_dtype
360369
if dl.dtype != out_dtype:
@@ -382,11 +391,18 @@ def numba_funcify_SolveLUFactorTridiagonal(
382391
return generate_fallback_impl(op, node=node)
383392
out_dtype = node.outputs[0].type.numpy_dtype
384393

394+
b_ndim = op.b_ndim
385395
overwrite_b = op.overwrite_b
386396
transposed = op.transposed
387397

388398
@numba_basic.numba_njit(cache=False)
389399
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
400+
if d.size == 0:
401+
if b_ndim == 1:
402+
return np.zeros(d.shape, dtype=out_dtype)
403+
else:
404+
return np.zeros((d.shape[0], b.shape[1]), dtype=out_dtype)
405+
390406
if dl.dtype != out_dtype:
391407
dl = dl.astype(out_dtype)
392408
if d.dtype != out_dtype:

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def numba_funcify_Cholesky(op, node, **kwargs):
7474

7575
@numba_basic.numba_njit
7676
def cholesky(a):
77+
if a.size == 0:
78+
return np.zeros(a.shape, dtype=out_dtype)
79+
7780
if discrete_inp:
7881
a = a.astype(out_dtype)
7982
elif check_finite:
@@ -114,7 +117,8 @@ def numba_pivot_to_permutation(piv):
114117

115118
return np.argsort(p_inv)
116119

117-
return numba_pivot_to_permutation
120+
cache_key = 1
121+
return numba_pivot_to_permutation, cache_key
118122

119123

120124
@numba_funcify.register(LU)
@@ -134,6 +138,18 @@ def numba_funcify_LU(op, node, **kwargs):
134138

135139
@numba_basic.numba_njit
136140
def lu(a):
141+
if a.size == 0:
142+
L = np.zeros(a.shape, dtype=a.dtype)
143+
U = np.zeros(a.shape, dtype=a.dtype)
144+
if permute_l:
145+
return L, U
146+
elif p_indices:
147+
P = np.zeros(a.shape[0], dtype="int32")
148+
return P, L, U
149+
else:
150+
P = np.zeros(a.shape, dtype=a.dtype)
151+
return P, L, U
152+
137153
if discrete_inp:
138154
a = a.astype(out_dtype)
139155
elif check_finite:
@@ -187,6 +203,12 @@ def numba_funcify_LUFactor(op, node, **kwargs):
187203

188204
@numba_basic.numba_njit
189205
def lu_factor(a):
206+
if a.size == 0:
207+
return (
208+
np.zeros(a.shape, dtype=out_dtype),
209+
np.zeros(a.shape[0], dtype="int32"),
210+
)
211+
190212
if discrete_inp:
191213
a = a.astype(out_dtype)
192214
elif check_finite:
@@ -226,7 +248,7 @@ def block_diag(*arrs):
226248

227249
@numba_funcify.register(Solve)
228250
def numba_funcify_Solve(op, node, **kwargs):
229-
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
251+
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
230252
out_dtype = node.outputs[0].type.numpy_dtype
231253

232254
if A_dtype.kind == "c" or b_dtype.kind == "c":
@@ -269,6 +291,9 @@ def numba_funcify_Solve(op, node, **kwargs):
269291

270292
@numba_basic.numba_njit
271293
def solve(a, b):
294+
if b.size == 0:
295+
return np.zeros(b.shape, dtype=out_dtype)
296+
272297
if must_cast_A:
273298
a = a.astype(out_dtype)
274299
if must_cast_B:
@@ -297,7 +322,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
297322
overwrite_b = op.overwrite_b
298323
b_ndim = op.b_ndim
299324

300-
A_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
325+
A_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
301326
out_dtype = node.outputs[0].type.numpy_dtype
302327

303328
if A_dtype.kind == "c" or b_dtype.kind == "c":
@@ -311,6 +336,8 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
311336

312337
@numba_basic.numba_njit
313338
def solve_triangular(a, b):
339+
if b.size == 0:
340+
return np.zeros(b.shape, dtype=out_dtype)
314341
if must_cast_A:
315342
a = a.astype(out_dtype)
316343
if must_cast_B:
@@ -346,7 +373,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
346373
overwrite_b = op.overwrite_b
347374
check_finite = op.check_finite
348375

349-
c_dtype, b_dtype = (i.numpy_dtype for i in node.inputs)
376+
c_dtype, b_dtype = (i.type.numpy_dtype for i in node.inputs)
350377
out_dtype = node.outputs[0].type.numpy_dtype
351378

352379
if c_dtype.kind == "c" or b_dtype.kind == "c":
@@ -360,6 +387,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
360387

361388
@numba_basic.numba_njit
362389
def cho_solve(c, b):
390+
if b.size == 0:
391+
return np.zeros(b.shape, dtype=out_dtype)
363392
if must_cast_c:
364393
c = c.astype(out_dtype)
365394
if check_finite:

tests/link/numba/test_slinalg.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
LUFactor,
1717
Solve,
1818
SolveTriangular,
19+
cho_solve,
20+
cholesky,
21+
lu,
22+
lu_factor,
23+
lu_solve,
24+
solve,
25+
solve_triangular,
1926
)
2027
from tests.link.numba.test_basic import compare_numba_and_py, numba_inplace_mode
2128

@@ -483,6 +490,27 @@ def test_lu_solve(
483490
# Can never destroy non-contiguous inputs
484491
np.testing.assert_allclose(b_val_not_contig, b_val)
485492

493+
@pytest.mark.parametrize(
494+
"solve_op",
495+
[solve, solve_triangular, cho_solve, lu_solve],
496+
ids=lambda x: x.__name__,
497+
)
498+
def test_empty(self, solve_op):
499+
a = pt.matrix("x")
500+
b = pt.vector("b")
501+
if solve_op is cho_solve:
502+
out = solve_op((a, True), b)
503+
elif solve_op is lu_solve:
504+
out = solve_op((a, b.astype("int32")), b)
505+
else:
506+
out = solve_op(a, b)
507+
compare_numba_and_py(
508+
[a, b],
509+
[out],
510+
[np.zeros((0, 0)), np.zeros(0)],
511+
eval_obj_mode=False, # pivot_to_permutation seems to still be jitted despite the monkey patching
512+
)
513+
486514

487515
class TestDecompositions:
488516
@pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}")
@@ -750,6 +778,20 @@ def test_qr(self, mode, pivoting, overwrite_a):
750778
# Cannot destroy non-contiguous input
751779
np.testing.assert_allclose(val_not_contig, A_val)
752780

781+
@pytest.mark.parametrize(
782+
"decomp_op", (cholesky, lu, lu_factor), ids=lambda x: x.__name__
783+
)
784+
def test_empty(self, decomp_op):
785+
x = pt.matrix("x")
786+
outs = decomp_op(x)
787+
if not isinstance(outs, tuple | list):
788+
outs = [outs]
789+
compare_numba_and_py(
790+
[x],
791+
outs,
792+
[np.zeros((0, 0))],
793+
)
794+
753795

754796
def test_block_diag():
755797
A = pt.matrix("A")

0 commit comments

Comments
 (0)