From f4ef75628e9605bcc44d7f47123f58aeef4724e0 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 17 Dec 2025 18:04:06 +0000 Subject: [PATCH 1/6] compiler: Start adding machinery to specialise operators with hardcoded values --- devito/ir/iet/visitors.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 8f6ae8f02f..7db3b1ae30 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1499,6 +1499,41 @@ def visit_KernelLaunch(self, o): arguments=arguments) +class Specializer(Uxreplace): + """ + A Transformer to "specialize" a pre-built Operator - that is to replace a given + set of (scalar) symbols with hard-coded values to free up registers. This will + yield a "specialized" version of the Operator, specific to a particular setup. + """ + + def __init__(self, mapper, nested=False): + super().__init__(mapper, nested=nested) + + # Sanity check + for k in self.mapper.keys(): + if not isinstance(k, AbstractSymbol): + raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + + def visit_Operator(self, o, **kwargs): + # Entirely fine to apply this to an Operator + body = self._visit(o.body) + parameters = tuple(i for i in o.parameters if i not in self.mapper) + + # Note: the following is not dissimilar to unpickling an Operator + state = o.__getstate__() + state['parameters'] = parameters + state['body'] = body + state.pop('ccode') + + # FIXME: These names aren't great + newargs, newkwargs = o.__getnewargs_ex__() + newop = o.__class__(*newargs, **newkwargs) + + newop.__setstate__(state) + + return newop + + # Utils blankline = c.Line("") From df6cf2fd085652f07a5168f457c7549841bc6610 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 18 Dec 2025 16:39:36 +0000 Subject: [PATCH 2/6] tests: Start adding tests for operator specialization --- devito/ir/iet/visitors.py | 25 +++++- devito/types/dimension.py | 5 ++ tests/test_specialization.py | 146 +++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 tests/test_specialization.py diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 7db3b1ae30..f1aba31aab 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1504,6 +1504,11 @@ class Specializer(Uxreplace): A Transformer to "specialize" a pre-built Operator - that is to replace a given set of (scalar) symbols with hard-coded values to free up registers. This will yield a "specialized" version of the Operator, specific to a particular setup. + + Note that the Operator is not re-optimized in response to this replacement - this + transformation could nominally result in expressions of the form `f + 0` in the + generated code. If one wants to construct an Operator where such expressions are + considered, then use of `subs=...` is a better choice. """ def __init__(self, mapper, nested=False): @@ -1515,15 +1520,31 @@ def __init__(self, mapper, nested=False): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") def visit_Operator(self, o, **kwargs): - # Entirely fine to apply this to an Operator + # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this + # is the intended use case body = self._visit(o.body) + + not_params = tuple(i for i in self.mapper if i not in o.parameters) + if not_params: + raise ValueError(f"Attempted to specialize symbols {not_params} which are not" + " found in the Operator parameters") + + # FIXME: Should also type-check the values supplied against the symbols they are + # replacing (and cast them if needed?) -> use a try-except on the cast in + # python-land + parameters = tuple(i for i in o.parameters if i not in self.mapper) # Note: the following is not dissimilar to unpickling an Operator state = o.__getstate__() state['parameters'] = parameters state['body'] = body - state.pop('ccode') + + try: + state.pop('ccode') + except KeyError: + # C code has not previously been generated for this Operator + pass # FIXME: These names aren't great newargs, newkwargs = o.__getnewargs_ex__() diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 6e000349e9..8e74926ac9 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -175,6 +175,11 @@ def symbolic_max(self): """Symbol defining the maximum point of the Dimension.""" return Scalar(name=self.max_name, dtype=np.int32, is_const=True) + @property + def symbolic_extrema(self): + """Symbols for the minimum and maximum points of the Dimension""" + return (self.symbolic_min, self.symbolic_max) + @property def symbolic_incr(self): """The increment value while iterating over the Dimension.""" diff --git a/tests/test_specialization.py b/tests/test_specialization.py new file mode 100644 index 0000000000..0962ce33bb --- /dev/null +++ b/tests/test_specialization.py @@ -0,0 +1,146 @@ +import sympy +import pytest + +from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, + ConditionalDimension) +from devito.ir.iet.visitors import Specializer + +# Test that specializer replaces symbols as expected + +# Create a couple of arbitrary operators +# Reference bounds, subdomains, spacings, constants, conditionaldimensions with symbolic +# factor +# Create a couple of different substitution sets + +# Check that all the instances in the kernel are replaced +# Check that all the instances in the parameters are removed + +# Check that sanity check catches attempts to specialize non-scalar types +# Check that trying to specialize symbols not in the Operator parameters results +# in an error being thrown + +# Check that sizes and strides get specialized when using `linearize=True` + + +class TestSpecializer: + """Tests for the Specializer transformer""" + + @pytest.mark.parametrize('pre_gen', [True, False]) + @pytest.mark.parametrize('expand', [True, False]) + def test_bounds(self, pre_gen, expand): + """Test specialization of dimension bounds""" + grid = Grid(shape=(11, 11)) + + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + time_m = grid.time_dim.symbolic_min + minima = (x_m, y_m, time_m) + maxima = (x_M, y_M) + + def check_op(mapper, operator): + for k, v in mapper.items(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + # Check that the loop bounds are modified correctly + if k in minima: + assert f"{k.name.split('_')[0]} = {v}" in str(operator.ccode) + elif k in maxima: + assert f"{k.name.split('_')[0]} <= {v}" in str(operator.ccode) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = TimeFunction(name='h', grid=grid) + + eq0 = Eq(f, f + 1) + eq1 = Eq(g, f.dx) + eq2 = Eq(h.forward, (g + x_m).dy) + eq3 = Eq(f, x_M) + + # Check behaviour with expansion since we have a replaced symbol inside a + # derivative + if expand: + kwargs = {'opt': ('advanced', {'expand': True})} + else: + kwargs = {'opt': ('advanced', {'expand': False})} + + op = Operator([eq0, eq1, eq2, eq3], **kwargs) + + if pre_gen: + # Generate C code for the unspecialized Operator - the result should be + # the same regardless, but it ensures that the old generated code is + # purged and replaced in the specialized Operator + _ = op.ccode + + mapper0 = {x_m: sympy.S.Zero} + mapper1 = {x_M: sympy.Integer(20), y_m: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + mapper3 = {y_M: sympy.Integer(10), time_m: sympy.Integer(5)} + + mappers = (mapper0, mapper1, mapper2, mapper3) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + def test_subdomain(self): + """Test that SubDomain thicknesses can be specialized""" + + def check_op(mapper, operator): + for k in mapper.keys(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + + class SD(SubDomain): + name = 'sd' + + def define(self, dimensions): + x, y = dimensions + return {x: ('middle', 1, 1), y: ('right', 2)} + + grid = Grid(shape=(11, 11)) + sd = SD(grid=grid) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=sd) + + eqs = [Eq(f, f+1, subdomain=sd), + Eq(g, g+1, subdomain=sd)] + + op = Operator(eqs) + + subdims = [d for d in op.dimensions if d.is_Sub] + ((xltkn, xrtkn), (_, yrtkn)) = [d.thickness for d in subdims] + + mapper0 = {xltkn: sympy.S.Zero} + mapper1 = {xrtkn: sympy.Integer(2), yrtkn: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + + mappers = (mapper0, mapper1, mapper2) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + # FIXME: Currently throws an error + # def test_factor(self): + # """Test that ConditionalDimensions can have their symbolic factors specialized""" + # size = 16 + # factor = 4 + # i = Dimension(name='i') + # ci = ConditionalDimension(name='ci', parent=i, factor=factor) + + # g = Function(name='g', shape=(size,), dimensions=(i,)) + # f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + + # op0 = Operator([Eq(f, g)]) + + # mapper = {ci.symbolic_factor: sympy.Integer(factor)} + + # op1 = Specializer(mapper).visit(op0) + + # assert ci.symbolic_factor not in op1.parameters + # assert ci.symbolic_factor.name not in str(op1.ccode) + # assert "if ((i)%(4) == 0)" in str(op1.ccode) + + # Spacings + + # Strides/sizes From cc1133264c713c80e5b80aa610a353a197e697b7 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 19 Dec 2025 11:36:38 +0000 Subject: [PATCH 3/6] tests: Introduce further tests --- devito/ir/iet/visitors.py | 11 +++++++++-- tests/test_specialization.py | 27 +++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index f1aba31aab..0b4936e14d 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -11,7 +11,7 @@ import ctypes import cgen as c -from sympy import IndexedBase +from sympy import IndexedBase, Number from sympy.core.function import Application from devito.exceptions import CompilationError @@ -1515,10 +1515,17 @@ def __init__(self, mapper, nested=False): super().__init__(mapper, nested=nested) # Sanity check - for k in self.mapper.keys(): + for k, v in self.mapper.items(): + # FIXME: Erronously blocks f_vec->size[1] + # Apparently this is an IndexedPointer if not isinstance(k, AbstractSymbol): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + if not isinstance(v, Number): + raise ValueError("Only SymPy Numbers can used to replace values during " + f"specialization. Value {v} was supplied for symbol " + f"{k}, but is of type {type(v)}.") + def visit_Operator(self, o, **kwargs): # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this # is the intended use case diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 0962ce33bb..c8e8f54c59 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -120,7 +120,8 @@ def define(self, dimensions): for m, o in zip(mappers, ops): check_op(m, o) - # FIXME: Currently throws an error + # FIXME: Currently throws an error - probably a missing handler for GuardFactor + # in Uxreplace # def test_factor(self): # """Test that ConditionalDimensions can have their symbolic factors specialized""" # size = 16 @@ -141,6 +142,28 @@ def define(self, dimensions): # assert ci.symbolic_factor.name not in str(op1.ccode) # assert "if ((i)%(4) == 0)" in str(op1.ccode) - # Spacings + def test_spacing(self): + """Test that grid spacings can be specialized""" + grid = Grid(shape=(11,)) + f = Function(name='f', grid=grid) + + op0 = Operator(Eq(f, f.dx)) + + mapper = {grid.dimensions[0].spacing: sympy.Float(grid.spacing[0])} + op1 = Specializer(mapper).visit(op0) + + assert grid.dimensions[0].spacing not in op1.parameters + assert grid.dimensions[0].spacing.name not in str(op1.ccode) + assert "/1.0e-1F;" in str(op1.ccode) # Strides/sizes + def test_strides(self): + """Test that strides and sizes generated for linearization can be specialized""" + grid = Grid(shape=(11, 11)) + + f = TimeFunction(name='f', grid=grid, space_order=2) + + op0 = Operator(Eq(f.forward, f.dx2), + opt=('advanced', {'expand': True, 'linearize': True})) + + from IPython import embed; embed() \ No newline at end of file From b21cff2bb12897eae96552d856875f2ad17f3e03 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 12:28:47 +0000 Subject: [PATCH 4/6] tests: Add tests for specialising ConditionalDimension factors --- devito/ir/equations/equation.py | 7 +++-- devito/ir/iet/visitors.py | 10 ++++--- devito/ir/support/guards.py | 24 +++++++--------- devito/symbolics/extended_sympy.py | 3 +- devito/symbolics/manipulation.py | 4 +-- tests/test_pickle.py | 5 ++-- tests/test_specialization.py | 45 +++++++++++++++++------------- 7 files changed, 52 insertions(+), 46 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f83dc39c94..69c4aad4d3 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -5,8 +5,9 @@ from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.finite_differences.differentiable import diff2sympy -from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace, +from devito.ir.support import (Interval, IntervalGroup, IterationSpace, Stencil, detect_io, detect_accesses) +from devito.ir.support.guards import GuardFactorEq from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min @@ -210,11 +211,11 @@ def __new__(cls, *args, **kwargs): if not d.is_Conditional: continue if d.condition is None: - conditionals[d] = GuardFactor(d) + conditionals[d] = GuardFactorEq.new_from_dim(d) else: cond = diff2sympy(lower_exprs(d.condition)) if d._factor is not None: - cond = sympy.And(cond, GuardFactor(d)) + cond = sympy.And(cond, GuardFactorEq.new_from_dim(d)) conditionals[d] = cond # Replace dimension with index index = d.index diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 0b4936e14d..446c6ee0d0 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -15,6 +15,7 @@ from sympy.core.function import Application from devito.exceptions import CompilationError +from devito.symbolics import IndexedPointer from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward @@ -1516,9 +1517,7 @@ def __init__(self, mapper, nested=False): # Sanity check for k, v in self.mapper.items(): - # FIXME: Erronously blocks f_vec->size[1] - # Apparently this is an IndexedPointer - if not isinstance(k, AbstractSymbol): + if not isinstance(k, (AbstractSymbol, IndexedPointer)): raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") if not isinstance(v, Number): @@ -1531,7 +1530,10 @@ def visit_Operator(self, o, **kwargs): # is the intended use case body = self._visit(o.body) - not_params = tuple(i for i in self.mapper if i not in o.parameters) + # NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in + # the Operator parameters. Perhaps this check wants relaxing. + not_params = tuple(i for i in self.mapper + if i not in o.parameters and isinstance(i, AbstractSymbol)) if not_params: raise ValueError(f"Attempted to specialize symbols {not_params} which are not" " found in the Operator parameters") diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index a014db8abb..443d5d5f19 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -38,35 +38,34 @@ def canonical(self): @property def negated(self): - return negations[self.__class__](*self._args_rebuild, evaluate=False) + try: + return negations[self.__class__](*self._args_rebuild, evaluate=False) + except KeyError: + raise ValueError(f"Class {self.__class__.__name__} does not have a negation") # *** GuardFactor -class GuardFactor(Guard, CondEq, Pickable): +class GuardFactor(Guard, Pickable): """ A guard for factor-based ConditionalDimensions. - Given the ConditionalDimension `d` with factor `k`, create the - symbolic relational `d.parent % k == 0`. + Introduces a constructor where, given the ConditionalDimension `d` with factor `k`, + the symbolic relational `d.parent % k == 0` is created. """ - __rargs__ = ('d',) + __rargs__ = ('lhs', 'rhs') - def __new__(cls, d, **kwargs): + @classmethod + def new_from_dim(cls, d, **kwargs): assert d.is_Conditional obj = super().__new__(cls, d.parent % d.symbolic_factor, 0) - obj.d = d return obj - @property - def _args_rebuild(self): - return (self.d,) - class GuardFactorEq(GuardFactor, CondEq): pass @@ -76,9 +75,6 @@ class GuardFactorNe(GuardFactor, CondNe): pass -GuardFactor = GuardFactorEq - - # *** GuardBound diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 3a6e61742c..b0d820e3bc 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -35,7 +35,8 @@ class CondEq(sympy.Eq): """ def __new__(cls, *args, **kwargs): - return sympy.Eq.__new__(cls, *args, evaluate=False) + kwargs['evaluate'] = False + return sympy.Eq.__new__(cls, *args, **kwargs) @property def canonical(self): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index e43494974d..4b0631fab0 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -136,9 +136,9 @@ def _(mapper, rule): @singledispatch def _uxreplace_handle(expr, args, kwargs): try: - return expr.func(*args, evaluate=False) + return expr.func(*args, evaluate=False, **kwargs) except TypeError: - return expr.func(*args) + return expr.func(*args, **kwargs) @_uxreplace_handle.register(Min) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 40e4c14ffe..70e549ca00 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -10,7 +10,8 @@ Dimension, SubDimension, ConditionalDimension, IncrDimension, TimeDimension, SteppingDimension, Operator, MPI, Min, solve, PrecomputedSparseTimeFunction, SubDomain) -from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext +from devito.ir import Backward, Forward, GuardBound, GuardBoundNext +from devito.ir.support.guards import GuardFactorEq from devito.data import LEFT, OWNED from devito.finite_differences.tools import direct, transpose, left, right, centered from devito.mpi.halo_scheme import Halo @@ -500,7 +501,7 @@ def test_guard_factor(self, pickle): d = Dimension(name='d') cd = ConditionalDimension(name='cd', parent=d, factor=4) - gf = GuardFactor(cd) + gf = GuardFactorEq.new_from_dim(cd) pkl_gf = pickle.dumps(gf) new_gf = pickle.loads(pkl_gf) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index c8e8f54c59..5350d7675f 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -120,27 +120,25 @@ def define(self, dimensions): for m, o in zip(mappers, ops): check_op(m, o) - # FIXME: Currently throws an error - probably a missing handler for GuardFactor - # in Uxreplace - # def test_factor(self): - # """Test that ConditionalDimensions can have their symbolic factors specialized""" - # size = 16 - # factor = 4 - # i = Dimension(name='i') - # ci = ConditionalDimension(name='ci', parent=i, factor=factor) + def test_factor(self): + """Test that ConditionalDimensions can have their symbolic factors specialized""" + size = 16 + factor = 4 + i = Dimension(name='i') + ci = ConditionalDimension(name='ci', parent=i, factor=factor) - # g = Function(name='g', shape=(size,), dimensions=(i,)) - # f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + g = Function(name='g', shape=(size,), dimensions=(i,)) + f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) - # op0 = Operator([Eq(f, g)]) + op0 = Operator([Eq(f, g)]) - # mapper = {ci.symbolic_factor: sympy.Integer(factor)} + mapper = {ci.symbolic_factor: sympy.Integer(factor)} - # op1 = Specializer(mapper).visit(op0) + op1 = Specializer(mapper).visit(op0) - # assert ci.symbolic_factor not in op1.parameters - # assert ci.symbolic_factor.name not in str(op1.ccode) - # assert "if ((i)%(4) == 0)" in str(op1.ccode) + assert ci.symbolic_factor not in op1.parameters + assert ci.symbolic_factor.name not in str(op1.ccode) + assert "if ((i)%(4) == 0)" in str(op1.ccode) def test_spacing(self): """Test that grid spacings can be specialized""" @@ -156,9 +154,8 @@ def test_spacing(self): assert grid.dimensions[0].spacing.name not in str(op1.ccode) assert "/1.0e-1F;" in str(op1.ccode) - # Strides/sizes - def test_strides(self): - """Test that strides and sizes generated for linearization can be specialized""" + def test_sizes(self): + """Test that strides generated for linearization can be specialized""" grid = Grid(shape=(11, 11)) f = TimeFunction(name='f', grid=grid, space_order=2) @@ -166,4 +163,12 @@ def test_strides(self): op0 = Operator(Eq(f.forward, f.dx2), opt=('advanced', {'expand': True, 'linearize': True})) - from IPython import embed; embed() \ No newline at end of file + mapper = {f.symbolic_shape[1]: sympy.Integer(11), + f.symbolic_shape[2]: sympy.Integer(11)} + + op1 = Specializer(mapper).visit(op0) + + assert "const int x_fsz0 = 11;" in str(op1.ccode) + assert "const int y_fsz0 = 11;" in str(op1.ccode) + + # TODO: Should strides get linearized? If so, how? From f47f2b24e1df89bb1702bc39793003f3d9cba97b Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 16:57:03 +0000 Subject: [PATCH 5/6] tests: Added test applying a specialized operator --- tests/test_specialization.py | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 5350d7675f..40d3248cea 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -1,6 +1,8 @@ import sympy import pytest +import numpy as np + from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, ConditionalDimension) from devito.ir.iet.visitors import Specializer @@ -171,4 +173,44 @@ def test_sizes(self): assert "const int x_fsz0 = 11;" in str(op1.ccode) assert "const int y_fsz0 = 11;" in str(op1.ccode) - # TODO: Should strides get linearized? If so, how? + # TODO: Should strides get specialized? If so, how? + + def test_apply_basic(self): + """ + Test that a trivial operator runs and returns the same results when + specialized. + """ + grid = Grid(shape=(11, 11)) + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + f = Function(name='f', grid=grid, dtype=np.int32) + + op0 = Operator(Eq(f, f+1)) + + mapper = {x_m: sympy.Integer(2), x_M: sympy.Integer(7), + y_m: sympy.Integer(3), y_M: sympy.Integer(8)} + + op1 = Specializer(mapper).visit(op0) + + assert op1.cfunction is not op0.cfunction + + op1.apply() + + check = np.array(f.data[:]) + f.data[:] = 0 + + op0.apply(x_m=2, x_M=7, y_m=3, y_M=8) + + assert np.all(check == f.data) + + +# class TestApply: +# """Tests for specialization of operators at apply time""" + +# def test_basic(self): +# grid = Grid(shape=(11, 11)) + +# f = TimeFunction(name='f', grid=grid, space_order=2) + +# op = Operator(Eq(f.forward, f + 1)) + +# op.apply(time_M=10, specialize=('x_m', 'x_M')) \ No newline at end of file From e70b6d59216eff0dad8ef5a138702ffc6c794ab9 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 23 Dec 2025 16:59:20 +0000 Subject: [PATCH 6/6] misc: flake8 --- tests/test_specialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_specialization.py b/tests/test_specialization.py index 40d3248cea..759b6cc2db 100644 --- a/tests/test_specialization.py +++ b/tests/test_specialization.py @@ -213,4 +213,4 @@ def test_apply_basic(self): # op = Operator(Eq(f.forward, f + 1)) -# op.apply(time_M=10, specialize=('x_m', 'x_M')) \ No newline at end of file +# op.apply(time_M=10, specialize=('x_m', 'x_M'))