File tree Expand file tree Collapse file tree 2 files changed +43
-1
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +43
-1
lines changed Original file line number Diff line number Diff line change 1818 ScalarFromTensor ,
1919 Split ,
2020 TensorFromScalar ,
21+ Tri ,
2122 get_underlying_scalar_constant_value ,
2223)
2324from pytensor .tensor .exceptions import NotScalarConstantError
2627ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
2728to be constants. The graph that you defined thus cannot be JIT-compiled
2829by JAX. An example of a graph that can be compiled to JAX:
29-
3030>>> import pytensor.tensor basic
3131>>> at.arange(1, 10, 2)
3232"""
@@ -193,3 +193,18 @@ def scalar_from_tensor(x):
193193 return jnp .array (x ).flatten ()[0 ]
194194
195195 return scalar_from_tensor
196+
197+
198+ @jax_funcify .register (Tri )
199+ def jax_funcify_Tri (op , node , ** kwargs ):
200+ # node.inputs is N, M, k
201+ const_args = [getattr (x , "data" , None ) for x in node .inputs ]
202+
203+ def tri (* args ):
204+ # args is N, M, k
205+ args = [
206+ x if const_x is None else const_x for x , const_x in zip (args , const_args )
207+ ]
208+ return jnp .tri (* args , dtype = op .dtype )
209+
210+ return tri
Original file line number Diff line number Diff line change @@ -191,3 +191,30 @@ def test_jax_eye():
191191 out_fg = FunctionGraph ([], [out ])
192192
193193 compare_jax_and_py (out_fg , [])
194+
195+
196+ def test_tri ():
197+ out = at .tri (10 , 10 , 0 )
198+ fgraph = FunctionGraph ([], [out ])
199+ compare_jax_and_py (fgraph , [])
200+
201+
202+ def test_tri_nonconcrete ():
203+ """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
204+
205+ m , n , k = (
206+ scalar ("a" , dtype = "int64" ),
207+ scalar ("n" , dtype = "int64" ),
208+ scalar ("k" , dtype = "int64" ),
209+ )
210+ m .tag .test_value = 10
211+ n .tag .test_value = 10
212+ k .tag .test_value = 0
213+
214+ out = at .tri (m , n , k )
215+
216+ # The actual error the user will see should be jax.errors.ConcretizationTypeError, but
217+ # the error handler raises an Attribute error first, so that's what this test needs to pass
218+ with pytest .raises (AttributeError ):
219+ fgraph = FunctionGraph ([m , n , k ], [out ])
220+ compare_jax_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
You can’t perform that action at this time.
0 commit comments