diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index aebd8b7333..e3878c2a19 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -220,6 +220,11 @@ def _print_SafeInv(self, expr): val = self._print(expr.val) return f'SAFEINV({val}, {base})' + def _print_RoundUp(self, expr): + value = self._print(expr.value) + step = self._print(expr.step) + return f'ROUND_UP({value}, {step})' + def _print_Mod(self, expr): """Print a Mod as a C-like %-based operation.""" args = [f'({self._print(a)})' for a in expr.args] diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index 0d8639a5be..540da6bb52 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -6,7 +6,7 @@ from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split from devito.types import ( - Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension + Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove ) __all__ = [ @@ -137,7 +137,14 @@ def detect_accesses(exprs): """ # Compute M : F -> S mapper = defaultdict(Stencil) - for e in retrieve_indexed(exprs, deep=True): + + # Search among the Indexeds (Most accesses typically stem from Indexeds) + plain_indexeds = retrieve_indexed(exprs, deep=True) + + # Search among higher order objects, which still represent meaningful accesses + high_order_indexeds = [i.indexed for i in search(exprs, TensorMove)] + + for e in (*plain_indexeds, *high_order_indexeds): f = e.function for a, d0 in zip(e.indices, f.dimensions, strict=False): @@ -164,13 +171,16 @@ def detect_accesses(exprs): d, others = split(dims, lambda i: d0 in i._defines) # noqa: B023 if any(i.is_Indexed for i in a.args) or len(d) != 1: - # Case 1) -- with indirect accesses there's not much we can infer + # Case 1) -- with indirect accesses there's not much we + # can infer continue else: # Case 2) d, = d _, o = split(others, lambda i: i.is_Custom) - off = sum(i for i in a.args if i.is_integer or i.free_symbols & o) + off = sum( + i for i in a.args if i.is_integer or i.free_symbols & o + ) else: d, = dims diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1dd2a7ad52..0a631bf3a2 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -14,7 +14,9 @@ from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.passes.iet.engine import iet_pass from devito.passes.iet.languages.C import CPrinter -from devito.symbolics import Cast, ValueLimit, evalrel, has_integer_args, limits_mapper +from devito.symbolics import ( + Cast, RoundUp, ValueLimit, evalrel, has_integer_args, limits_mapper +) from devito.tools import Bunch, as_mapper, as_tuple, filter_ordered, split from devito.types import FIndexed @@ -255,6 +257,12 @@ def _(expr, langbb, printer): f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {} +@_lower_macro_math.register(RoundUp) +def _(expr, langbb, printer): + return (('ROUND_UP(a,b)', + '((((a)%(b)) == 0) ? (a) : ((a) + (b) - ((a)%(b))))'),), {} + + @iet_pass def minimize_symbols(iet): """ diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 7df8f430fc..be79fba3a4 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -23,9 +23,9 @@ 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword', - 'String', 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', - 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', - 'VectorAccess'] + 'String', 'Macro', 'Class', 'MacroArgument', 'RoundUp', 'Deref', + 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', + 'ValueLimit', 'VectorAccess'] class CondEq(sympy.Eq): @@ -623,6 +623,49 @@ def __str__(self): __repr__ = __str__ +class RoundUp(Function): + + """ + Symbolic representation of rounding a value up to the next multiple of a + given step. + """ + + def __new__(cls, value, step, **kwargs): + value = sympify(value) + step = sympify(step) + + if step < 1: + raise ValueError("Cannot round up with negative `step`") + if not is_integer(step): + raise ValueError("`step` must be an integer") + + if value.is_number and step.is_number: + remainder = value % step + if remainder == 0: + return value + else: + return value + step - remainder + + return super().__new__(cls, value, step, **kwargs) + + @property + def value(self): + return self.args[0] + + @property + def step(self): + return self.args[1] + + @property + def is_commutative(self): + return self.value.is_commutative and self.step.is_commutative + + def __str__(self): + return f"ROUND_UP({self.value}, {self.step})" + + __repr__ = __str__ + + class ValueLimit(ReservedWord, sympy.Expr): """ diff --git a/devito/types/basic.py b/devito/types/basic.py index e6dea0e8e0..babd3eaf76 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -921,12 +921,17 @@ def __padding_setup_smart__(self, **kwargs): return nopadding mmts = configuration['platform'].max_mem_trans_size(self.__padding_dtype__) - remainder = self._size_nopad[d] % mmts + + snp = self._size_nopad[d] + remainder = snp % mmts if remainder == 0: # Already a multiple of `mmts`, no need to pad return nopadding + else: + from devito.symbolics import RoundUp # noqa + v = RoundUp(snp, mmts) - snp - dpadding = (0, (mmts - remainder)) + dpadding = (0, v) padding = [(0, 0)]*self.ndim padding[self.dimensions.index(d)] = dpadding diff --git a/devito/types/parallel.py b/devito/types/parallel.py index c6aceb42e5..9670c766a8 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -407,9 +407,47 @@ class TensorMove(Expr, Reserved, Terminal): """ Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher - level of the memory hierarchy + level of the memory hierarchy. + + Parameters + ---------- + base : IndexedBase + The base of the AbstractFunction subject of the TensorMove. + tid0 : Dimension + A representation of thread(s) issuing the TensorMove. + coords : tuple + The base address of the TensorMove (one point per Dimension). """ + __rargs__ = ('base', 'tid0', 'coords') + + def __new__(cls, base, tid0, coords, **kwargs): + return super().__new__(cls, base, tid0, coords) + + @property + def base(self): + return self.args[0] + + @property + def tid0(self): + return self.args[1] + + @property + def coords(self): + return self.args[2] + + @property + def function(self): + return self.base.function + + @cached_property + def indexed(self): + return self.function[self.coords] + + @property + def ndim(self): + return self.function.ndim + func = Reserved._rebuild def _ccode(self, printer): diff --git a/tests/test_data.py b/tests/test_data.py index 73a76fcc79..f5d0e8d177 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -8,8 +8,10 @@ ) from devito.data import LEFT, RIGHT, Decomposition, convert_index, loc_data_idx from devito.data.allocators import DataReference +from devito.ir import ccode from devito.tools import as_tuple from devito.types import Scalar +from devito.types.misc import TempArray class TestDataBasic: @@ -336,6 +338,34 @@ def test_w_halo_w_autopadding(self): assert u1._size_nodomain == ((3, 3), (3, 3), (3, 9)) assert u1.shape_allocated == (10, 10, 16) + @switchconfig(autopadding=True, platform='bdw') + def test_temp_array_smart_padding_no_overshoot(self): + mmts = configuration['platform'].max_mem_trans_size(np.float32) + halo = 4 + z_size = 2*mmts - 2*halo + + grid = Grid(shape=(4, 4, z_size)) + u = Function(name='u', grid=grid, space_order=halo) + r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype) + + z = grid.dimensions[-1] + mapper = {z.symbolic_size: z_size} + + assert r.padding[z][1].subs(mapper) == 0 + assert r.shape_allocated[-1].subs(mapper) == u.shape_allocated[-1] + + @switchconfig(autopadding=True, platform='bdw') + def test_temp_array_smart_padding_codegen_avoids_negative_mod(self): + grid = Grid(shape=(4, 4, 592)) + u = Function(name='u', grid=grid, space_order=0) + r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype) + + code = ccode(r.shape_allocated[-1]) + + assert 'ROUND_UP(' in code + assert '(-z_size)' not in code + assert 'z_size' in code + def test_w_halo_custom(self): grid = Grid(shape=(4, 4)) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 6f1675977e..91e1b3e572 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -8,7 +8,7 @@ from devito import ( # noqa Abs, Conj, Constant, Dimension, Eq, Function, Ge, Grid, Gt, Imag, Le, Lt, Max, Min, Operator, Real, SubDimension, SubDomain, TimeFunction, configuration, cos, norm, sin, - solve + solve, switchconfig ) from devito.finite_differences.differentiable import Mul, SafeInv, Weights from devito.ir import Expression, FindNodes, ccode @@ -16,8 +16,8 @@ from devito.mpi.halo_scheme import HaloTouch from devito.symbolics import ( # noqa INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite, - FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, Rvalue, SizeOf, - VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, + FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue, + SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, retrieve_indexed, uxreplace ) from devito.tools import CustomDtype, as_tuple @@ -390,6 +390,20 @@ def test_safeinv(): assert str(v) == 'u[x, y]' +def test_roundup(): + grid = Grid(shape=(11, 11)) + u = Function(name='u', grid=grid) + a = dSymbol('a', dtype=np.int32) + + expr = RoundUp(a, 16) + with switchconfig(platform='bdw', language='openmp'): + op = Operator(Eq(u, u + expr)) + + assert ccode(expr) == 'ROUND_UP(a, 16)' + assert '#define ROUND_UP(a,b)' in str(op) + assert 'ROUND_UP(a, 16)' in str(op) + + def test_def_function(): foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int']) foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])