Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _normalize_kwargs(cls, **kwargs):

# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
Expand Down
1 change: 1 addition & 0 deletions devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _normalize_kwargs(cls, **kwargs):

# Code generation options for derivatives
o['expand'] = oo.pop('expand', cls.EXPAND)
o['eval-mul-first'] = oo.pop('eval-mul-first', cls.MUL_FIRST)
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
Expand Down
6 changes: 6 additions & 0 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class BasicOperator(Operator):
finite-difference derivatives.
"""

MUL_FIRST = False
"""
When evaluating expressions location, prioritize multiplication
operations.
"""

DERIV_COLLECT = True
"""
Factorize finite-difference derivatives exploiting the linearity of the FD
Expand Down
6 changes: 3 additions & 3 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def T(self):

return self._rebuild(transpose=adjoint)

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
"""
Evaluates the derivative at the location of `func`. It is necessary for staggered
setup where one could have Eq(u(x + h_x/2), v(x).dx)) in which case v(x).dx
Expand Down Expand Up @@ -521,7 +521,7 @@ def _eval_at(self, func):
return self._rebuild(self.expr, **rkw)
args = [self.expr.func(*v) for v in mapper.values()]
args.extend([a for a in self.expr.args if a not in self.expr._args_diff])
args = [self._rebuild(a)._eval_at(func) for a in args]
args = [self._rebuild(a)._eval_at(func, **kwargs) for a in args]
return self.expr.func(*args)
elif self.expr.is_Mul:
# For Mul, We treat the basic case `u(x + h_x/2) * v(x) which is what appear
Expand Down Expand Up @@ -594,7 +594,7 @@ def _eval_fd(self, expr, **kwargs):
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
self.deriv_order[0], weights=self.weights,
side=self.side, matvec=self.transpose,
x0=self.x0, expand=expand)
x0=x0_deriv, expand=expand)

# Step 4: Apply substitutions
for e in self._ppsubs:
Expand Down
67 changes: 61 additions & 6 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ def coefficients(self):
key = lambda x: coeff_priority.get(x, -1)
return sorted(coefficients, key=key, reverse=True)[0]

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
if not func.is_Staggered:
# Cartesian grid, do no waste time
return self
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func) for a in self.args])
return self.func(*[getattr(a, '_eval_at', lambda x, **kw: a)(func, **kwargs)
for a in self.args])

def _subs(self, old, new, **hints):
if old == self:
Expand Down Expand Up @@ -454,7 +455,11 @@ def highest_priority(DiffOp):
# set of dimensions is used when multiple ones with the same
# priority appear
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
args = DiffOp._args_diff
if not args:
return DiffOp
else:
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


class DifferentiableOp(Differentiable):
Expand Down Expand Up @@ -520,7 +525,7 @@ class DifferentiableFunction(DifferentiableOp):
def __new__(cls, *args, **kwargs):
return cls.__sympy_class__.__new__(cls, *args, **kwargs)

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
return self


Expand Down Expand Up @@ -629,6 +634,56 @@ def _gather_for_diff(self):

return self.func(*new_args, evaluate=False)

def _eval_at(self, func, mul_first=False, **kwargs):
# Dont evaluate mul first
if not mul_first:
return super()._eval_at(func, mul_first=mul_first)

# Not a basic a*b*c... expression, just defer to superclass
if any(isinstance(f, DifferentiableOp) for f in self.args):
return super()._eval_at(func, mul_first=mul_first)

# Split Derivative and Differentiable args
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))

if derivs:
derivs = Differentiable._eval_at(self.func(*derivs), func,
mul_first=mul_first)
else:
derivs = 1

if not other:
return derivs
elif len(other) > 1:
expr = self.func(*other)._gather_for_diff
else:
expr = other[0]

# Non differentiable expr (e.g., number)
if not isinstance(expr, Differentiable):
return self.func(derivs, expr)

# Build mapper for dimensions that need to be interpolated
mapper = {}
for d in self.dimensions:
try:
if self.indices_ref[d] is not func.indices_ref[d]:
mapper[d] = func.indices_ref[d]
except KeyError:
pass

# Nothing to interpolate
if not mapper:
return super()._eval_at(func, mul_first=mul_first)

# Interpolate expr at the required indices
interp = expr.diff(*mapper.keys(), deriv_order=[0 for _ in mapper],
fd_order=[self.interp_order for _ in mapper],
x0=mapper)

# Return the full expression with Derivatives
return self.func(derivs, interp)


class Pow(DifferentiableOp, sympy.Pow):
_fd_priority = 0
Expand Down Expand Up @@ -971,7 +1026,7 @@ def _subs(self, old, new, **hints):

class DiffDerivative(IndexDerivative, DifferentiableOp):

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
# Like EvalDerivative, a DiffDerivative must have already been evaluated
# at a valid x0 and should not be re-evaluated at a different location
return self
Expand Down Expand Up @@ -1022,7 +1077,7 @@ def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
# An EvalDerivative must have already been evaluated at a valid x0
# and should not be re-evaluated at a different location
return self
Expand Down
4 changes: 3 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def _lower_exprs(cls, expressions, **kwargs):
* Shift indices for domain alignment.
"""
expand = kwargs['options'].get('expand', True)
mul_first = kwargs['options'].get('eval-mul-first', False)

# Specialization is performed on unevaluated expressions
expressions = cls._specialize_dsl(expressions, **kwargs)
Expand All @@ -348,7 +349,8 @@ def _lower_exprs(cls, expressions, **kwargs):
# ModuloDimensions
if not expand:
expand = lambda d: d.is_Stepping
expressions = flatten([i._evaluate(expand=expand) for i in expressions])
expressions = flatten([i._evaluate(expand=expand, mul_first=mul_first)
for i in expressions])

# Scalarize the tensor equations, if any
expressions = [j for i in expressions for j in i._flatten]
Expand Down
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def c0(self):
def _eval_deriv(self):
return self

@property
@cached_property
def _grid_map(self):
"""
Mapper of off-grid interpolation points indices for each dimension.
Expand Down
6 changes: 3 additions & 3 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,9 +1070,9 @@ def __fd_setup__(self):

@cached_property
def _fd_priority(self):
return 1 if self.staggered.on_node else 2
return 1.2 if self.staggered.on_node else 1.1

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
if self.staggered == func.staggered:
return self

Expand Down Expand Up @@ -1491,7 +1491,7 @@ def __shape_setup__(cls, **kwargs):

@cached_property
def _fd_priority(self):
return 2.1 if self.staggered.on_node else 2.2
return 2.1 if self.staggered.on_node else 2

@property
def time_order(self):
Expand Down
2 changes: 1 addition & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _evaluate(self, **kwargs):
"""
try:
lhs = self.lhs._evaluate(**kwargs)
rhs = self.rhs._eval_at(self.lhs)._evaluate(**kwargs)
rhs = self.rhs._eval_at(self.lhs, **kwargs)._evaluate(**kwargs)
except AttributeError:
lhs, rhs = self._evaluate_args(**kwargs)
eq = self.func(lhs, rhs, subdomain=self.subdomain,
Expand Down
2 changes: 1 addition & 1 deletion devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def _dist_scatter(self, alias=None, data=None):
mapper.update(self._dist_subfunc_scatter(sf))
return mapper

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
return self

def _halo_exchange(self):
Expand Down
5 changes: 3 additions & 2 deletions devito/types/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ def __getattr__(self, name):
except:
raise AttributeError("%r object has no attribute %r" % (self.__class__, name))

def _eval_at(self, func):
def _eval_at(self, func, **kwargs):
"""
Evaluate tensor at func location
"""
def entries(i, j, func):
return getattr(self[i, j], '_eval_at', lambda x: self[i, j])(func[i, j])
return getattr(self[i, j], '_eval_at',
lambda x: self[i, j])(func[i, j], **kwargs)
entry = lambda i, j: entries(i, j, func)
return self._new(self.rows, self.cols, entry)

Expand Down
Loading