Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from devito.ir.equations.algorithms import dimension_sort, lower_exprs
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
from devito.ir.support import (Interval, IntervalGroup, IterationSpace,
Stencil, detect_io, detect_accesses)
from devito.ir.support.guards import GuardFactorEq
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
Expand Down Expand Up @@ -210,11 +211,11 @@ def __new__(cls, *args, **kwargs):
if not d.is_Conditional:
continue
if d.condition is None:
conditionals[d] = GuardFactor(d)
conditionals[d] = GuardFactorEq.new_from_dim(d)
else:
cond = diff2sympy(lower_exprs(d.condition))
if d._factor is not None:
cond = sympy.And(cond, GuardFactor(d))
cond = sympy.And(cond, GuardFactorEq.new_from_dim(d))
conditionals[d] = cond
# Replace dimension with index
index = d.index
Expand Down
67 changes: 66 additions & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import ctypes

import cgen as c
from sympy import IndexedBase
from sympy import IndexedBase, Number
from sympy.core.function import Application

from devito.exceptions import CompilationError
from devito.symbolics import IndexedPointer
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
Call, Lambda, BlankLine, Section, ListMajor)
from devito.ir.support.space import Backward
Expand Down Expand Up @@ -1499,6 +1500,70 @@ def visit_KernelLaunch(self, o):
arguments=arguments)


class Specializer(Uxreplace):
"""
A Transformer to "specialize" a pre-built Operator - that is to replace a given
set of (scalar) symbols with hard-coded values to free up registers. This will
yield a "specialized" version of the Operator, specific to a particular setup.

Note that the Operator is not re-optimized in response to this replacement - this
transformation could nominally result in expressions of the form `f + 0` in the
generated code. If one wants to construct an Operator where such expressions are
considered, then use of `subs=...` is a better choice.
"""

def __init__(self, mapper, nested=False):
super().__init__(mapper, nested=nested)

# Sanity check
for k, v in self.mapper.items():
if not isinstance(k, (AbstractSymbol, IndexedPointer)):
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")

if not isinstance(v, Number):
raise ValueError("Only SymPy Numbers can used to replace values during "
f"specialization. Value {v} was supplied for symbol "
f"{k}, but is of type {type(v)}.")

def visit_Operator(self, o, **kwargs):
# Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this
# is the intended use case
body = self._visit(o.body)

# NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in
# the Operator parameters. Perhaps this check wants relaxing.
not_params = tuple(i for i in self.mapper
if i not in o.parameters and isinstance(i, AbstractSymbol))
if not_params:
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
" found in the Operator parameters")

# FIXME: Should also type-check the values supplied against the symbols they are
# replacing (and cast them if needed?) -> use a try-except on the cast in
# python-land

parameters = tuple(i for i in o.parameters if i not in self.mapper)

# Note: the following is not dissimilar to unpickling an Operator
state = o.__getstate__()
state['parameters'] = parameters
state['body'] = body

try:
state.pop('ccode')
except KeyError:
# C code has not previously been generated for this Operator
pass

# FIXME: These names aren't great
newargs, newkwargs = o.__getnewargs_ex__()
newop = o.__class__(*newargs, **newkwargs)

newop.__setstate__(state)

return newop


# Utils

blankline = c.Line("")
Expand Down
24 changes: 10 additions & 14 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,35 +38,34 @@ def canonical(self):

@property
def negated(self):
return negations[self.__class__](*self._args_rebuild, evaluate=False)
try:
return negations[self.__class__](*self._args_rebuild, evaluate=False)
except KeyError:
raise ValueError(f"Class {self.__class__.__name__} does not have a negation")


# *** GuardFactor


class GuardFactor(Guard, CondEq, Pickable):
class GuardFactor(Guard, Pickable):

"""
A guard for factor-based ConditionalDimensions.

Given the ConditionalDimension `d` with factor `k`, create the
symbolic relational `d.parent % k == 0`.
Introduces a constructor where, given the ConditionalDimension `d` with factor `k`,
the symbolic relational `d.parent % k == 0` is created.
"""

__rargs__ = ('d',)
__rargs__ = ('lhs', 'rhs')

def __new__(cls, d, **kwargs):
@classmethod
def new_from_dim(cls, d, **kwargs):
assert d.is_Conditional

obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
obj.d = d

return obj

@property
def _args_rebuild(self):
return (self.d,)


class GuardFactorEq(GuardFactor, CondEq):
pass
Expand All @@ -76,9 +75,6 @@ class GuardFactorNe(GuardFactor, CondNe):
pass


GuardFactor = GuardFactorEq


# *** GuardBound


Expand Down
3 changes: 2 additions & 1 deletion devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class CondEq(sympy.Eq):
"""

def __new__(cls, *args, **kwargs):
return sympy.Eq.__new__(cls, *args, evaluate=False)
kwargs['evaluate'] = False
return sympy.Eq.__new__(cls, *args, **kwargs)

@property
def canonical(self):
Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def _(mapper, rule):
@singledispatch
def _uxreplace_handle(expr, args, kwargs):
try:
return expr.func(*args, evaluate=False)
return expr.func(*args, evaluate=False, **kwargs)
except TypeError:
return expr.func(*args)
return expr.func(*args, **kwargs)


@_uxreplace_handle.register(Min)
Expand Down
5 changes: 5 additions & 0 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def symbolic_max(self):
"""Symbol defining the maximum point of the Dimension."""
return Scalar(name=self.max_name, dtype=np.int32, is_const=True)

@property
def symbolic_extrema(self):
"""Symbols for the minimum and maximum points of the Dimension"""
return (self.symbolic_min, self.symbolic_max)

@property
def symbolic_incr(self):
"""The increment value while iterating over the Dimension."""
Expand Down
5 changes: 3 additions & 2 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
Dimension, SubDimension, ConditionalDimension, IncrDimension,
TimeDimension, SteppingDimension, Operator, MPI, Min, solve,
PrecomputedSparseTimeFunction, SubDomain)
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext
from devito.ir.support.guards import GuardFactorEq
from devito.data import LEFT, OWNED
from devito.finite_differences.tools import direct, transpose, left, right, centered
from devito.mpi.halo_scheme import Halo
Expand Down Expand Up @@ -500,7 +501,7 @@ def test_guard_factor(self, pickle):
d = Dimension(name='d')
cd = ConditionalDimension(name='cd', parent=d, factor=4)

gf = GuardFactor(cd)
gf = GuardFactorEq.new_from_dim(cd)

pkl_gf = pickle.dumps(gf)
new_gf = pickle.loads(pkl_gf)
Expand Down
Loading
Loading