diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f83dc39c94..69c4aad4d3 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -5,8 +5,9 @@ from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.finite_differences.differentiable import diff2sympy -from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace, +from devito.ir.support import (Interval, IntervalGroup, IterationSpace, Stencil, detect_io, detect_accesses) +from devito.ir.support.guards import GuardFactorEq from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min @@ -210,11 +211,11 @@ def __new__(cls, *args, **kwargs): if not d.is_Conditional: continue if d.condition is None: - conditionals[d] = GuardFactor(d) + conditionals[d] = GuardFactorEq.new_from_dim(d) else: cond = diff2sympy(lower_exprs(d.condition)) if d._factor is not None: - cond = sympy.And(cond, GuardFactor(d)) + cond = sympy.And(cond, GuardFactorEq.new_from_dim(d)) conditionals[d] = cond # Replace dimension with index index = d.index diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 8f6ae8f02f..446c6ee0d0 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -11,10 +11,11 @@ import ctypes import cgen as c -from sympy import IndexedBase +from sympy import IndexedBase, Number from sympy.core.function import Application from devito.exceptions import CompilationError +from devito.symbolics import IndexedPointer from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle, Call, Lambda, BlankLine, Section, ListMajor) from devito.ir.support.space import Backward @@ -1499,6 +1500,70 @@ def visit_KernelLaunch(self, o): arguments=arguments) +class Specializer(Uxreplace): + """ + A Transformer to "specialize" a pre-built Operator - that is to replace a given + set of (scalar) symbols with hard-coded values to free up registers. This will + yield a "specialized" version of the Operator, specific to a particular setup. + + Note that the Operator is not re-optimized in response to this replacement - this + transformation could nominally result in expressions of the form `f + 0` in the + generated code. If one wants to construct an Operator where such expressions are + considered, then use of `subs=...` is a better choice. + """ + + def __init__(self, mapper, nested=False): + super().__init__(mapper, nested=nested) + + # Sanity check + for k, v in self.mapper.items(): + if not isinstance(k, (AbstractSymbol, IndexedPointer)): + raise ValueError(f"Attempted to specialize non-scalar symbol: {k}") + + if not isinstance(v, Number): + raise ValueError("Only SymPy Numbers can used to replace values during " + f"specialization. Value {v} was supplied for symbol " + f"{k}, but is of type {type(v)}.") + + def visit_Operator(self, o, **kwargs): + # Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this + # is the intended use case + body = self._visit(o.body) + + # NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in + # the Operator parameters. Perhaps this check wants relaxing. + not_params = tuple(i for i in self.mapper + if i not in o.parameters and isinstance(i, AbstractSymbol)) + if not_params: + raise ValueError(f"Attempted to specialize symbols {not_params} which are not" + " found in the Operator parameters") + + # FIXME: Should also type-check the values supplied against the symbols they are + # replacing (and cast them if needed?) -> use a try-except on the cast in + # python-land + + parameters = tuple(i for i in o.parameters if i not in self.mapper) + + # Note: the following is not dissimilar to unpickling an Operator + state = o.__getstate__() + state['parameters'] = parameters + state['body'] = body + + try: + state.pop('ccode') + except KeyError: + # C code has not previously been generated for this Operator + pass + + # FIXME: These names aren't great + newargs, newkwargs = o.__getnewargs_ex__() + newop = o.__class__(*newargs, **newkwargs) + + newop.__setstate__(state) + + return newop + + # Utils blankline = c.Line("") diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index a014db8abb..443d5d5f19 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -38,35 +38,34 @@ def canonical(self): @property def negated(self): - return negations[self.__class__](*self._args_rebuild, evaluate=False) + try: + return negations[self.__class__](*self._args_rebuild, evaluate=False) + except KeyError: + raise ValueError(f"Class {self.__class__.__name__} does not have a negation") # *** GuardFactor -class GuardFactor(Guard, CondEq, Pickable): +class GuardFactor(Guard, Pickable): """ A guard for factor-based ConditionalDimensions. - Given the ConditionalDimension `d` with factor `k`, create the - symbolic relational `d.parent % k == 0`. + Introduces a constructor where, given the ConditionalDimension `d` with factor `k`, + the symbolic relational `d.parent % k == 0` is created. """ - __rargs__ = ('d',) + __rargs__ = ('lhs', 'rhs') - def __new__(cls, d, **kwargs): + @classmethod + def new_from_dim(cls, d, **kwargs): assert d.is_Conditional obj = super().__new__(cls, d.parent % d.symbolic_factor, 0) - obj.d = d return obj - @property - def _args_rebuild(self): - return (self.d,) - class GuardFactorEq(GuardFactor, CondEq): pass @@ -76,9 +75,6 @@ class GuardFactorNe(GuardFactor, CondNe): pass -GuardFactor = GuardFactorEq - - # *** GuardBound diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 3a6e61742c..b0d820e3bc 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -35,7 +35,8 @@ class CondEq(sympy.Eq): """ def __new__(cls, *args, **kwargs): - return sympy.Eq.__new__(cls, *args, evaluate=False) + kwargs['evaluate'] = False + return sympy.Eq.__new__(cls, *args, **kwargs) @property def canonical(self): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index e43494974d..4b0631fab0 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -136,9 +136,9 @@ def _(mapper, rule): @singledispatch def _uxreplace_handle(expr, args, kwargs): try: - return expr.func(*args, evaluate=False) + return expr.func(*args, evaluate=False, **kwargs) except TypeError: - return expr.func(*args) + return expr.func(*args, **kwargs) @_uxreplace_handle.register(Min) diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 6e000349e9..8e74926ac9 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -175,6 +175,11 @@ def symbolic_max(self): """Symbol defining the maximum point of the Dimension.""" return Scalar(name=self.max_name, dtype=np.int32, is_const=True) + @property + def symbolic_extrema(self): + """Symbols for the minimum and maximum points of the Dimension""" + return (self.symbolic_min, self.symbolic_max) + @property def symbolic_incr(self): """The increment value while iterating over the Dimension.""" diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 40e4c14ffe..70e549ca00 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -10,7 +10,8 @@ Dimension, SubDimension, ConditionalDimension, IncrDimension, TimeDimension, SteppingDimension, Operator, MPI, Min, solve, PrecomputedSparseTimeFunction, SubDomain) -from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext +from devito.ir import Backward, Forward, GuardBound, GuardBoundNext +from devito.ir.support.guards import GuardFactorEq from devito.data import LEFT, OWNED from devito.finite_differences.tools import direct, transpose, left, right, centered from devito.mpi.halo_scheme import Halo @@ -500,7 +501,7 @@ def test_guard_factor(self, pickle): d = Dimension(name='d') cd = ConditionalDimension(name='cd', parent=d, factor=4) - gf = GuardFactor(cd) + gf = GuardFactorEq.new_from_dim(cd) pkl_gf = pickle.dumps(gf) new_gf = pickle.loads(pkl_gf) diff --git a/tests/test_specialization.py b/tests/test_specialization.py new file mode 100644 index 0000000000..759b6cc2db --- /dev/null +++ b/tests/test_specialization.py @@ -0,0 +1,216 @@ +import sympy +import pytest + +import numpy as np + +from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension, + ConditionalDimension) +from devito.ir.iet.visitors import Specializer + +# Test that specializer replaces symbols as expected + +# Create a couple of arbitrary operators +# Reference bounds, subdomains, spacings, constants, conditionaldimensions with symbolic +# factor +# Create a couple of different substitution sets + +# Check that all the instances in the kernel are replaced +# Check that all the instances in the parameters are removed + +# Check that sanity check catches attempts to specialize non-scalar types +# Check that trying to specialize symbols not in the Operator parameters results +# in an error being thrown + +# Check that sizes and strides get specialized when using `linearize=True` + + +class TestSpecializer: + """Tests for the Specializer transformer""" + + @pytest.mark.parametrize('pre_gen', [True, False]) + @pytest.mark.parametrize('expand', [True, False]) + def test_bounds(self, pre_gen, expand): + """Test specialization of dimension bounds""" + grid = Grid(shape=(11, 11)) + + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + time_m = grid.time_dim.symbolic_min + minima = (x_m, y_m, time_m) + maxima = (x_M, y_M) + + def check_op(mapper, operator): + for k, v in mapper.items(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + # Check that the loop bounds are modified correctly + if k in minima: + assert f"{k.name.split('_')[0]} = {v}" in str(operator.ccode) + elif k in maxima: + assert f"{k.name.split('_')[0]} <= {v}" in str(operator.ccode) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + h = TimeFunction(name='h', grid=grid) + + eq0 = Eq(f, f + 1) + eq1 = Eq(g, f.dx) + eq2 = Eq(h.forward, (g + x_m).dy) + eq3 = Eq(f, x_M) + + # Check behaviour with expansion since we have a replaced symbol inside a + # derivative + if expand: + kwargs = {'opt': ('advanced', {'expand': True})} + else: + kwargs = {'opt': ('advanced', {'expand': False})} + + op = Operator([eq0, eq1, eq2, eq3], **kwargs) + + if pre_gen: + # Generate C code for the unspecialized Operator - the result should be + # the same regardless, but it ensures that the old generated code is + # purged and replaced in the specialized Operator + _ = op.ccode + + mapper0 = {x_m: sympy.S.Zero} + mapper1 = {x_M: sympy.Integer(20), y_m: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + mapper3 = {y_M: sympy.Integer(10), time_m: sympy.Integer(5)} + + mappers = (mapper0, mapper1, mapper2, mapper3) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + def test_subdomain(self): + """Test that SubDomain thicknesses can be specialized""" + + def check_op(mapper, operator): + for k in mapper.keys(): + assert k not in operator.parameters + assert k.name not in str(operator.ccode) + + class SD(SubDomain): + name = 'sd' + + def define(self, dimensions): + x, y = dimensions + return {x: ('middle', 1, 1), y: ('right', 2)} + + grid = Grid(shape=(11, 11)) + sd = SD(grid=grid) + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=sd) + + eqs = [Eq(f, f+1, subdomain=sd), + Eq(g, g+1, subdomain=sd)] + + op = Operator(eqs) + + subdims = [d for d in op.dimensions if d.is_Sub] + ((xltkn, xrtkn), (_, yrtkn)) = [d.thickness for d in subdims] + + mapper0 = {xltkn: sympy.S.Zero} + mapper1 = {xrtkn: sympy.Integer(2), yrtkn: sympy.S.Zero} + mapper2 = {**mapper0, **mapper1} + + mappers = (mapper0, mapper1, mapper2) + ops = tuple(Specializer(m).visit(op) for m in mappers) + + for m, o in zip(mappers, ops): + check_op(m, o) + + def test_factor(self): + """Test that ConditionalDimensions can have their symbolic factors specialized""" + size = 16 + factor = 4 + i = Dimension(name='i') + ci = ConditionalDimension(name='ci', parent=i, factor=factor) + + g = Function(name='g', shape=(size,), dimensions=(i,)) + f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,)) + + op0 = Operator([Eq(f, g)]) + + mapper = {ci.symbolic_factor: sympy.Integer(factor)} + + op1 = Specializer(mapper).visit(op0) + + assert ci.symbolic_factor not in op1.parameters + assert ci.symbolic_factor.name not in str(op1.ccode) + assert "if ((i)%(4) == 0)" in str(op1.ccode) + + def test_spacing(self): + """Test that grid spacings can be specialized""" + grid = Grid(shape=(11,)) + f = Function(name='f', grid=grid) + + op0 = Operator(Eq(f, f.dx)) + + mapper = {grid.dimensions[0].spacing: sympy.Float(grid.spacing[0])} + op1 = Specializer(mapper).visit(op0) + + assert grid.dimensions[0].spacing not in op1.parameters + assert grid.dimensions[0].spacing.name not in str(op1.ccode) + assert "/1.0e-1F;" in str(op1.ccode) + + def test_sizes(self): + """Test that strides generated for linearization can be specialized""" + grid = Grid(shape=(11, 11)) + + f = TimeFunction(name='f', grid=grid, space_order=2) + + op0 = Operator(Eq(f.forward, f.dx2), + opt=('advanced', {'expand': True, 'linearize': True})) + + mapper = {f.symbolic_shape[1]: sympy.Integer(11), + f.symbolic_shape[2]: sympy.Integer(11)} + + op1 = Specializer(mapper).visit(op0) + + assert "const int x_fsz0 = 11;" in str(op1.ccode) + assert "const int y_fsz0 = 11;" in str(op1.ccode) + + # TODO: Should strides get specialized? If so, how? + + def test_apply_basic(self): + """ + Test that a trivial operator runs and returns the same results when + specialized. + """ + grid = Grid(shape=(11, 11)) + ((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions] + f = Function(name='f', grid=grid, dtype=np.int32) + + op0 = Operator(Eq(f, f+1)) + + mapper = {x_m: sympy.Integer(2), x_M: sympy.Integer(7), + y_m: sympy.Integer(3), y_M: sympy.Integer(8)} + + op1 = Specializer(mapper).visit(op0) + + assert op1.cfunction is not op0.cfunction + + op1.apply() + + check = np.array(f.data[:]) + f.data[:] = 0 + + op0.apply(x_m=2, x_M=7, y_m=3, y_M=8) + + assert np.all(check == f.data) + + +# class TestApply: +# """Tests for specialization of operators at apply time""" + +# def test_basic(self): +# grid = Grid(shape=(11, 11)) + +# f = TimeFunction(name='f', grid=grid, space_order=2) + +# op = Operator(Eq(f.forward, f + 1)) + +# op.apply(time_M=10, specialize=('x_m', 'x_M'))