diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 0272304ee3..9dee3dd793 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1038,7 +1038,7 @@ def supports(self, query, language=None): warning(f"Couldn't establish if `query={query}` is supported on this " "system. Assuming it is not.") return False - elif query == 'async-loads' and cc >= 80: + elif query == 'async-pipe' and cc >= 80: # Asynchronous pipeline loads -- introduced in Ampere return True elif query in ('tma', 'thread-block-cluster') and cc >= 90: @@ -1055,7 +1055,7 @@ class Volta(NvidiaDevice): class Ampere(Volta): def supports(self, query, language=None): - if query == 'async-loads': + if query == 'async-pipe': return True else: return super().supports(query, language) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index bd740adab7..5abfab3244 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -273,6 +273,16 @@ def _print_Abs(self, expr): return f"fabs({self._print(arg)})" return self._print_fmath_func('abs', expr) + def _print_BitwiseNot(self, expr): + # Unary function, single argument + arg = expr.args[0] + return f'~{self._print(arg)}' + + def _print_BitwiseXor(self, expr): + # Binary function + arg0, arg1 = expr.args + return f'{self._print(arg0)} ^ {self._print(arg1)}' + def _print_Add(self, expr, order=None): """" Print an addition. diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f83dc39c94..85162b8a5f 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -83,9 +83,11 @@ def __repr__(self): if not self.is_Reduction: return super().__repr__() elif self.operation is OpInc: - return '%s += %s' % (self.lhs, self.rhs) + return f'Inc({self.lhs}, {self.rhs})' else: - return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs) + return f'Eq({self.lhs}, {self.operation}({self.rhs}))' + + __str__ = __repr__ # Pickling support __reduce_ex__ = Pickable.__reduce_ex__ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 3dcabc5602..d590e4ac75 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -305,6 +305,9 @@ def _gen_value(self, obj, mode=1, masked=()): qualifiers = [v for k, v in self._qualifiers_mapper.items() if getattr(obj.function, k, False) and v not in masked] + if obj.is_LocalObject and mode == 2: + qualifiers.extend(as_tuple(obj._C_tag)) + if (obj._mem_stack or obj._mem_constant) and mode == 1: strtype = self.ccode(obj._C_typedata) strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 64329fbbbd..258c5d3e42 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -775,17 +775,19 @@ def __init__(self, intervals, sub_iterators=None, directions=None): super().__init__(intervals) # Normalize sub-iterators - sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v)))) - for k, v in (sub_iterators or {}).items()]) + sub_iterators = sub_iterators or {} + sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) + for d, v in sub_iterators.items() if d in self.intervals} sub_iterators.update({i.dim: () for i in self.intervals if i.dim not in sub_iterators}) self._sub_iterators = frozendict(sub_iterators) # Normalize directions - if directions is None: - self._directions = frozendict([(i.dim, Any) for i in self.intervals]) - else: - self._directions = frozendict(directions) + directions = directions or {} + directions = {d: v for d, v in directions.items() if d in self.intervals} + directions.update({i.dim: Any for i in self.intervals + if i.dim not in directions}) + self._directions = frozendict(directions) def __repr__(self): ret = ', '.join(["%s%s" % (repr(i), repr(self.directions[i.dim])) @@ -807,8 +809,7 @@ def __lt__(self, other): return len(self.itintervals) < len(other.itintervals) def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, - self.directions)) + return hash((super().__hash__(), self.sub_iterators, self.directions)) def __contains__(self, d): try: diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 509490a2e1..8c09a2ee45 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -12,7 +12,7 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass -from devito.symbolics import estimate_cost, q_leaf, q_terminal +from devito.symbolics import Reserved, estimate_cost, q_leaf, q_terminal from devito.symbolics.search import search from devito.symbolics.manipulation import _uxreplace from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype @@ -401,6 +401,7 @@ def _(expr): @_catch.register(Indexed) @_catch.register(Symbol) +@_catch.register(Reserved) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 5af92a3208..f1fad43164 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -3,7 +3,7 @@ from sympy import S import numpy as np -from devito.finite_differences import IndexDerivative +from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace @@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs): @_core.register(Symbol) -@_core.register(Indexed) @_core.register(BasicWrapperMixin) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): return expr, [] +@_core.register(Indexed) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + if not isinstance(expr.function, Weights): + return expr, [] + + # Lower or reuse a previously lowered Weights array + sregistry = kwargs['sregistry'] + subs_user = kwargs['subs'] + + w0 = expr.function + k = tuple(w0.weights) + try: + w = weights[k] + except KeyError: + name = sregistry.make_name(prefix='w') + dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 + initvalue = tuple(i.subs(subs_user) for i in k) + w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + + rebuilt = expr._subs(w0.indexed, w.indexed) + + return rebuilt, [] + + @_core.register(IndexDerivative) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): sregistry = kwargs['sregistry'] options = kwargs['options'] - subs_user = kwargs['subs'] try: cbk0 = deriv_schedule_registry[options['deriv-schedule']] @@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # Create the concrete Weights array, or reuse an already existing one # if possible - name = sregistry.make_name(prefix='w') - w0 = ideriv.weights.function - dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 - k = tuple(w0.weights) - try: - w = weights[k] - except KeyError: - initvalue = tuple(i.subs(subs_user) for i in k) - w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs) # Replace the abstract Weights array with the concrete one - subs = {w0.indexed: w.indexed} + subs = {ideriv.weights.base: w.base} init = uxreplace(init, subs) ideriv = uxreplace(ideriv, subs) @@ -155,13 +169,13 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) # The Symbol that will hold the result of the IndexDerivative computation - # NOTE: created before recurring so that we ultimately get a sound ordering + # NOTE: created before recursing so that we ultimately get a sound ordering try: s = reusables.pop() - assert np.can_cast(s.dtype, dtype) + assert np.can_cast(s.dtype, w.dtype) except KeyError: name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=dtype) + s = Symbol(name=name, dtype=w.dtype) # Go inside `expr` and recursively lower any nested IndexDerivatives expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 978c093eed..7b16663a9b 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -5,7 +5,6 @@ from collections import OrderedDict from ctypes import c_uint64 -from functools import singledispatch from operator import itemgetter import numpy as np @@ -98,17 +97,29 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ decl = Definition(obj) - if obj._C_init: - definition = (decl, obj._C_init) + init = obj._C_init + if not init: + definition = decl + efuncs = () + elif isinstance(init, (list, tuple)): + assert len(init) == 2, "Expected (efunc, call)" + init, definition = init + efuncs = (init,) + elif init.is_Callable: + definition = Call(init.name, init.parameters, + retobj=obj if init.retval else None) + efuncs = (init,) else: - definition = (decl) + definition = (decl, init) + efuncs = () frees = obj._C_free if obj.free_symbols - {obj}: - storage.update(obj, site, objs=definition, frees=frees) + storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees) else: - storage.update(obj, site, standalones=definition, frees=frees) + storage.update(obj, site, standalones=definition, efuncs=efuncs, + frees=frees) def _alloc_array_on_low_lat_mem(self, site, obj, storage): """ @@ -555,7 +566,7 @@ class DeviceAwareDataManager(DataManager): def __init__(self, options=None, **kwargs): self.gpu_fit = options['gpu-fit'] self.gpu_create = options['gpu-create'] - self.pmode = options.get('place-transfers') + self.gpu_place_transfers = options.get('place-transfers') super().__init__(**kwargs) @@ -588,7 +599,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage): storage.update(obj, site, maps=mmap, unmaps=unmap) - def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False): + def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, + read_only=False, **kwargs): """ Map a Function already defined in the host memory in to the device high bandwidth memory. @@ -621,42 +633,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs) @iet_pass - def place_transfers(self, iet, data_movs=None, **kwargs): + def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs): """ Create a new IET with host-device data transfers. This requires mapping symbols to the suitable memory spaces. """ - if not self.pmode: + if not self.gpu_place_transfers: return iet, {} - @singledispatch - def _place_transfers(iet, data_movs): + if not isinstance(iet, EntryFunction): return iet, {} - @_place_transfers.register(EntryFunction) - def _(iet, data_movs): - reads, writes = data_movs + reads, writes = data_movs - # Special symbol which gives user code control over data deallocations - devicerm = DeviceRM() + # Special symbol which gives user code control over data deallocations + devicerm = DeviceRM() - storage = Storage() - for i in filter_sorted(writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm) - for i in filter_sorted(reads - writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True) - - iet, efuncs = self._inject_definitions(iet, storage) + storage = Storage() + for i in filter_sorted(writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, ctx=ctx + ) + for i in filter_sorted(reads - writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, read_only=True, ctx=ctx + ) - return iet, {'efuncs': efuncs} + iet, efuncs = self._inject_definitions(iet, storage) - return _place_transfers(iet, data_movs=data_movs) + return iet, {'efuncs': efuncs} @iet_pass def place_devptr(self, iet, **kwargs): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 3f002b2c85..0022fc64d6 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -17,7 +17,7 @@ search) from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass from devito.types import ( - Array, Bundle, ComponentAccess, CompositeObject, Lock, IncrDimension, + Array, Bundle, ComponentAccess, CompositeObject, IncrDimension, FunctionMap, ModuloDimension, Indirection, Pointer, SharedData, ThreadArray, Symbol, Temp, NPThreads, NThreadsBase, Wildcard ) @@ -555,12 +555,19 @@ def _(i, mapper, sregistry): @abstract_object.register(Array) def _(i, mapper, sregistry): - if isinstance(i, Lock): - name = sregistry.make_name(prefix='lock') + name = sregistry.make_name(prefix=i._symbol_prefix) + + if i.initvalue is not None: + initvalue = [] + for v in i.initvalue: + try: + initvalue.append(v.xreplace(mapper)) + except AttributeError: + initvalue.append(v) else: - name = sregistry.make_name(prefix='a') + initvalue = None - v = i._rebuild(name=name, alias=True) + v = i._rebuild(name=name, initvalue=initvalue, alias=True) mapper.update({ i: v, @@ -667,6 +674,16 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ptr')) +@abstract_object.register(FunctionMap) +def _(i, mapper, sregistry): + name = sregistry.make_name(prefix=i._symbol_prefix) + tensor = mapper.get(i.tensor, i.tensor) + + v = i._rebuild(name, tensor) + + mapper[i] = v + + @abstract_object.register(NPThreads) def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='npthreads')) diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index 9cefc4786b..7f83659fad 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -419,7 +419,7 @@ def _make_parallel(self, iet, sync_mapper=None): return iet, {'includes': [self.langbb['header']]} - def make_parallel(self, graph): + def make_parallel(self, graph, **kwargs): return self._make_parallel(graph, sync_mapper=graph.sync_mapper) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index eda71a0b74..23f5c33bf0 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,8 +7,8 @@ from devito.tools.dtypes_lowering import dtype_mapper __all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'LONG'] + 'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex', + 'c_double_complex'] limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 4a8d2df206..c92098015b 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ import sympy from sympy import Expr, Function, Number, Tuple, cacheit, sympify from sympy.core.decorators import call_highest_priority +from sympy.logic.boolalg import BooleanFunction from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa @@ -16,13 +17,13 @@ from devito.types import Symbol from devito.types.basic import Basic -__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa - 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', - 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', - 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', - 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', - 'VectorAccess'] +__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'IntDiv', # noqa + 'Terminal', 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer', + 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', + 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'Reserved', + 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', + 'Deref', 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', + 'BasicWrapperMixin', 'ValueLimit', 'VectorAccess'] class CondEq(sympy.Eq): @@ -63,6 +64,17 @@ def negated(self): return CondEq(*self.args, evaluate=False) +class BitwiseNot(BooleanFunction): + pass + + +class BitwiseXor(BooleanFunction): + + # Enforce two args + def __new__(cls, arg0, arg1, **kwargs): + return super().__new__(cls, arg0, arg1, **kwargs) + + class IntDiv(sympy.Expr): """ @@ -117,6 +129,17 @@ def __mul__(self, other): return super().__mul__(other) +class Terminal: + + """ + Abstract base class for all terminal objects, that is, those objects + collected by `retrieve_terminals` in addition to all other SymPy atoms + such as `Symbol`, `Number`, etc. + """ + + pass + + class BasicWrapperMixin: """ @@ -158,7 +181,7 @@ def _sympystr(self, printer): return str(self) -class CallFromPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class CallFromPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``pointer->call(params)``. @@ -226,7 +249,7 @@ def free_symbols(self): __reduce_ex__ = Pickable.__reduce_ex__ -class CallFromComposite(CallFromPointer, Pickable): +class CallFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.call(params)``. @@ -239,7 +262,7 @@ def __str__(self): __repr__ = __str__ -class FieldFromPointer(CallFromPointer, Pickable): +class FieldFromPointer(CallFromPointer): """ Symbolic representation of the C notation ``pointer->field``. @@ -260,7 +283,7 @@ def field(self): __repr__ = __str__ -class FieldFromComposite(CallFromPointer, Pickable): +class FieldFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.field``, @@ -322,7 +345,7 @@ def is_numeric(self): __reduce_ex__ = Pickable.__reduce_ex__ -class UnaryOp(sympy.Expr, Pickable, BasicWrapperMixin): +class UnaryOp(Expr, Pickable, BasicWrapperMixin): """ Symbolic representation of a unary C operator. @@ -460,7 +483,7 @@ def __str__(self): return f"{self._op}{self.base}" -class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class IndexedPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``symbol[...]`` @@ -507,7 +530,21 @@ def __str__(self): __reduce_ex__ = Pickable.__reduce_ex__ -class ReservedWord(sympy.Atom, Pickable): +class Reserved(Pickable): + + """ + A base class for all reserved words used throughout the lowering process, + including the final stage of code generation itself. + + Reserved objects have the following properties: + + * `estimate_cost(o) = 0`, where `o` is an instance of Reserved + """ + + pass + + +class ReservedWord(sympy.Atom, Reserved): """ A `ReservedWord` carries a value that has special meaning in the diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 0ff7fcf6ba..24fffcac64 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -9,8 +9,9 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.logger import warning from devito.symbolics.extended_dtypes import INT -from devito.symbolics.extended_sympy import (CallFromPointer, Cast, - DefFunction, ReservedWord) +from devito.symbolics.extended_sympy import ( + CallFromPointer, Cast, DefFunction, Reserved +) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, prod from devito.tools.dtypes_lowering import infer_dtype @@ -174,7 +175,7 @@ def _(expr, estimate, seen): @_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) -@_estimate_cost.register(ReservedWord) +@_estimate_cost.register(Reserved) def _(expr, estimate, seen): return 0, False diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index 2496a0aeb9..a52fa16aaf 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -1,7 +1,6 @@ from sympy import Eq, IndexedBase, Mod, S, diff, nan -from devito.symbolics.extended_sympy import (FieldFromComposite, FieldFromPointer, - IndexedPointer, IntDiv) +from devito.symbolics.extended_sympy import IntDiv, Terminal from devito.tools import as_tuple, is_integer from devito.types.basic import AbstractFunction from devito.types.constant import Constant @@ -16,13 +15,9 @@ 'q_dimension', 'q_positive', 'q_negative'] -# The following SymPy objects are considered tree leaves: -# -# * Number -# * Symbol -# * Indexed -extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject, - IndexedPointer) +# The following SymPy objects are considered tree leaves in addition to the classic +# SymPy atoms such as Number, Symbol, Indexed, etc +extra_leaves = (IndexedBase, AbstractObject, Terminal) def q_symbol(expr): diff --git a/devito/types/array.py b/devito/types/array.py index c2755a4c42..43f450fc88 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -127,6 +127,8 @@ class Array(ArrayBasic): is_Array = True + _symbol_prefix = 'a' + __rkwargs__ = (ArrayBasic.__rkwargs__ + ('dimensions', 'scope', 'initvalue')) diff --git a/devito/types/basic.py b/devito/types/basic.py index b30da1a872..1836411d16 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1891,8 +1891,17 @@ def _mem_internal_lazy(self): return self._liveness == 'lazy' """ - A modifier added to the subclass C declaration when it appears - in a function signature. For example, a subclass might define `_C_modifier = '&'` + A modifier added to the declaration of the LocalType when it appears in a + function signature. For example, a subclass might define `_C_modifier = '&'` to impose pass-by-reference semantics. """ _C_modifier = None + + """ + One or more optional keywords added to the declaration of the LocalType + in between the type and the variable name when it appears in a function + signature. For example, some languages support these to modify the way + the compiler generates code for passing the parameter and how the + runtime accesses it. + """ + _C_tag = None diff --git a/devito/types/misc.py b/devito/types/misc.py index d9293de2c7..5a37ffca69 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -14,8 +14,8 @@ __all__ = ['Timer', 'Pointer', 'VolatileInt', 'FIndexed', 'Wildcard', 'Fence', 'Global', 'Hyperplane', 'Indirection', 'Temp', 'TempArray', 'Jump', - 'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'auto', - 'size_t'] + 'nop', 'WeakFence', 'CriticalRegion', 'Auto', 'AutoRef', 'FunctionMap', + 'auto', 'size_t'] class Timer(CompositeObject): @@ -345,6 +345,30 @@ def closing(self): """ +class FunctionMap(LocalObject): + + """ + Wrap a Function in a LocalObject. + """ + + __rargs__ = ('name', 'tensor') + + def __init__(self, name, tensor, **kwargs): + super().__init__(name, **kwargs) + self.tensor = tensor + + def _hashable_content(self): + return super()._hashable_content() + (self.tensor,) + + @property + def free_symbols(self): + """ + The free symbols of a FunctionMap are the free symbols of the + underlying Function. + """ + return super().free_symbols | {self.tensor} + + # *** C/CXX support types size_t = CustomDtype('size_t') diff --git a/devito/types/object.py b/devito/types/object.py index 637e19dea0..a883fd8d51 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -176,10 +176,10 @@ class LocalObject(AbstractObject, LocalType): """ __rargs__ = ('name',) - __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'is_global') + __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'scope') def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', - is_global=False, **kwargs): + scope='stack', **kwargs): self.name = name self.cargs = as_tuple(cargs) @@ -191,16 +191,17 @@ def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', assert liveness in ['eager', 'lazy'] self._liveness = liveness - self._is_global = is_global + assert scope in ['stack', 'shared', 'global'] + self._scope = scope def _hashable_content(self): return (super()._hashable_content() + self.cargs + - (self.initvalue, self.liveness, self.is_global)) + (self.initvalue, self.liveness, self.scope)) @property - def is_global(self): - return self._is_global + def scope(self): + return self._scope @property def free_symbols(self): @@ -236,6 +237,10 @@ def _C_free(self): """ return None + @property + def _mem_shared(self): + return self._scope == 'shared' + @property def _mem_global(self): - return self._is_global + return self._scope == 'global' diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 9f127e1263..407443760d 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -226,6 +226,8 @@ class Lock(Array): is_volatile = True + _symbol_prefix = 'lock' + # Not a performance-sensitive object _data_alignment = False diff --git a/tests/test_iet.py b/tests/test_iet.py index 4b3961d9b4..b0e4594f42 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -16,9 +16,10 @@ from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class, - String, FLOAT) + String, ListInitializer, SizeOf, FLOAT) from devito.tools import CustomDtype, as_tuple, dtype_to_ctype from devito.types import CustomDimension, Array, LocalObject, Symbol +from devito.types.misc import FunctionMap @pytest.fixture @@ -297,6 +298,52 @@ def _C_free(self): }""" +def test_make_cuda_tensor_map(): + + class CUTensorMap(FunctionMap): + + dtype = CustomDtype('CUtensorMap') + + @property + def _C_init(self): + symsizes = list(reversed(self.tensor.symbolic_shape)) + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) + + sizes = ListInitializer(symsizes) + strides = ListInitializer([ + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) + ]) + + arguments = [ + Byref(self), + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), + 4, self.tensor.dmap, sizes, strides, + ] + call = Call('cuTensorMapEncodeTiled', arguments) + + return call + + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid) + + tmap = CUTensorMap('tmap', u) + + iet = Call('foo', tmap) + iet = ElementalFunction('foo', iet, parameters=()) + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert str(iet) == """\ +static void foo() +{ + CUtensorMap tmap; + cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); + + foo(tmap); +}""" # noqa + + def test_cpp_local_object(): """ Test C++ support for LocalObjects. @@ -309,7 +356,7 @@ class MyObject(LocalObject): lo0 = MyObject('obj0') # Globally-scoped objects must not be declared in the function body - lo1 = MyObject('obj1', is_global=True) + lo1 = MyObject('obj1', scope='global') # A LocalObject using both a template and a modifier class SpecialObject(LocalObject):