From 063665bbe6f629dca053c225caeb3497e0169151 Mon Sep 17 00:00:00 2001 From: Markus Mohrhard Date: Tue, 21 Oct 2025 03:16:59 +0800 Subject: [PATCH] prevent failure in pow_to_mul when using Piecewise in an equation Previously the test would fail with: ``` def test_piecewise(): grid = Grid(shape=(11,)) u = Function(name='u', grid=grid, space_order=2) v = Function(name='u', grid=grid, space_order=2) eq_u = Eq(u, Piecewise((1, v < 10), (2, True))) > op = Operator(eq_u) ^^^^^^^^^^^^^^ tests/test_symbolics.py:1112: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ devito/operator/operator.py:173: in __new__ op = cls._build(expressions, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/operator/operator.py:206: in _build irs, byproduct = cls._lower(expressions, profiler=profiler, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/operator/operator.py:288: in _lower clusters = cls._lower_clusters(expressions, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/tools/timing.py:76: in __call__ retval = self.func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/operator/operator.py:397: in _lower_clusters clusters = cls._specialize_clusters(clusters, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/tools/timing.py:76: in __call__ retval = self.func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/core/cpu.py:178: in _specialize_clusters clusters = optimize_pows(clusters) ^^^^^^^^^^^^^^^^^^^^^^^ devito/ir/clusters/visitors.py:174: in __call__ processed = [maybe_timed(c, *args) if self.cond(c) else c ^^^^^^^^^^^^^^^^^^^^^ devito/ir/clusters/visitors.py:167: in timed_pass(self.func, self.func.__name__)(*_args, **kwargs) devito/tools/timing.py:76: in __call__ retval = self.func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ devito/passes/clusters/misc.py:423: in optimize_pows return cluster.rebuild(exprs=[pow_to_mul(e) for e in cluster.exprs]) ^^^^^^^^^^^^^ devito/symbolics/manipulation.py:360: in pow_to_mul args = [pow_to_mul(i) for i in expr.args] ^^^^^^^^^^^^^ devito/symbolics/manipulation.py:360: in pow_to_mul args = [pow_to_mul(i) for i in expr.args] ^^^^^^^^^^^^^ E TypeError: ExprCondPair.__new__() got an unexpected keyword argument 'evaluate' devito/symbolics/manipulation.py:373: TypeError ``` --- devito/symbolics/manipulation.py | 3 +++ tests/test_symbolics.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index e43494974d..ec241fc967 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -6,6 +6,7 @@ from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort +from sympy.functions.elementary.piecewise import ExprCondPair from devito.finite_differences.differentiable import ( EvalDerivative, IndexDerivative @@ -356,6 +357,8 @@ def pow_to_mul(expr): else: # Default. We should not end up here as all cases are handled return expr + elif expr.func is ExprCondPair: + return expr.func(*[pow_to_mul(i) for i in expr.args]) else: args = [pow_to_mul(i) for i in expr.args] diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 367090d0d6..8daf598249 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -7,7 +7,7 @@ from sympy import And, Expr, Number, Symbol from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, - Min, Max, Real, Imag, Conj, SubDomain, configuration) + Piecewise, Min, Max, Real, Imag, Conj, SubDomain, configuration) from devito.finite_differences.differentiable import SafeInv, Weights, Mul from devito.ir import Expression, FindNodes, ccode from devito.ir.support.guards import GuardExpr, simplify_and @@ -1104,6 +1104,15 @@ def test_print_div(): assert cstr == 'sizeof(int)/sizeof(long)' +def test_piecewise(): + grid = Grid(shape=(11,)) + u = Function(name='u', grid=grid, space_order=2) + v = Function(name='v', grid=grid, space_order=2) + eq_u = Eq(u, Piecewise((1, v < 10), (2, True))) + op = Operator(eq_u) + # check that the code generated a condition + assert "v[x + 2] < 10" in str(op.ccode) + def test_customdtype_complex(): """ Test that `CustomDtype` doesn't brak is_imag