diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 8f7d35e155..eb5c9e4de2 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -11,7 +11,7 @@ ) from devito.symbolics import IntDiv, limits_mapper, uxreplace from devito.tools import Pickable, Tag, frozendict -from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min +from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min, relational_shift __all__ = [ 'ClusterizedEq', @@ -213,7 +213,7 @@ def __new__(cls, *args, **kwargs): relations=ordering.relations, mode='partial') ispace = IterationSpace(intervals, iterators) - # Construct the conditionals and replace the ConditionalDimensions in `expr` + # Construct the conditionals conditionals = {} for d in ordering: if not d.is_Conditional: @@ -225,14 +225,30 @@ def __new__(cls, *args, **kwargs): if d._factor is not None: cond = d.relation(cond, GuardFactor(d)) conditionals[d] = cond - # Replace dimension with index - index = d.index - if d.condition is not None and d in expr.free_symbols: - index = index - relational_min(d.condition, d.parent) - expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)}) + + # Merge conditionals when possible. E.g if we have an implicit_dim + # and there is a dimension with the same parent, we ca merged + # its condition + for d in input_expr.implicit_dims: + if d not in conditionals: + continue + for cd in dict(conditionals): + if cd.parent == d.parent and cd != d: + cond = conditionals.pop(d) + mode = cd.relation and d.relation + conditionals[cd] = mode(cond, conditionals[cd]) + break conditionals = frozendict(conditionals) + # Replace the ConditionalDimensions in `expr` + for d, cond in conditionals.items(): + # Replace dimension with index + index = d.index + index = index - relational_min(cond, d.parent) + shift = relational_shift(cond, d.parent) + expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift}) + # Lower all Differentiable operations into SymPy operations rhs = diff2sympy(expr.rhs) diff --git a/devito/ir/support/vector.py b/devito/ir/support/vector.py index 02e26e2a02..79d84605ba 100644 --- a/devito/ir/support/vector.py +++ b/devito/ir/support/vector.py @@ -128,6 +128,7 @@ def __lt__(self, other): return True elif q_positive(i): return False + raise TypeError("Non-comparable index functions") from e return False @@ -164,6 +165,7 @@ def __gt__(self, other): return True elif q_negative(i): return False + raise TypeError("Non-comparable index functions") from e return False @@ -203,6 +205,7 @@ def __le__(self, other): return True elif q_positive(i): return False + raise TypeError("Non-comparable index functions") from e # Note: unlike `__lt__`, if we end up here, then *it is* <=. For example, diff --git a/devito/passes/clusters/buffering.py b/devito/passes/clusters/buffering.py index 2a8c78e7fc..57e0ad591a 100644 --- a/devito/passes/clusters/buffering.py +++ b/devito/passes/clusters/buffering.py @@ -3,7 +3,7 @@ from itertools import chain import numpy as np -from sympy import S +from sympy import S, simplify from devito.exceptions import CompilationError from devito.ir import ( @@ -775,7 +775,7 @@ def infer_buffer_size(f, dim, clusters): slots = [Vector(i) for i in slots] size = int((vmax(*slots) - vmin(*slots) + 1)[0]) - return size + return simplify(size) def offset_from_centre(d, indices): diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index d4d7f0a8b8..ba2c089ee9 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -36,6 +36,18 @@ def retrieve_ctemps(exprs, mode='all'): return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs') +def cse_dtype(exprdtype, cdtype): + """ + Return the dtype of a CSE temporary given the dtype of the expression to be + captured and the cluster's dtype. + """ + if np.issubdtype(cdtype, np.complexfloating): + return np.promote_types(exprdtype, cdtype(0).real.__class__).type + else: + # Real cluster, can safely promote to the largest precision + return np.promote_types(exprdtype, cdtype).type + + @cluster_pass def cse(cluster, sregistry=None, options=None, **kwargs): """ @@ -86,7 +98,7 @@ def cse(cluster, sregistry=None, options=None, **kwargs): if cluster.is_fence: return cluster - make_dtype = lambda e: np.promote_types(e.dtype, dtype).type + make_dtype = lambda e: cse_dtype(e.dtype, dtype) make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e)) exprs = _cse(cluster, make, min_cost=min_cost, mode=mode) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 2c352d50d4..2ea8f45c93 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -17,6 +17,7 @@ ) from devito.types import Symbol from devito.types.basic import Basic +from devito.types.relational import Ge __all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa 'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer', @@ -46,6 +47,11 @@ def canonical(self): def negated(self): return CondNe(*self.args, evaluate=False) + @property + def _as_min(self): + from devito.symbolics.extended_dtypes import INT + return INT(Ge(*self.args)) + class CondNe(sympy.Ne): diff --git a/devito/types/relational.py b/devito/types/relational.py index 731ec29bc7..1fb3e19355 100644 --- a/devito/types/relational.py +++ b/devito/types/relational.py @@ -3,7 +3,8 @@ import sympy -__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min'] +__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min', + 'relational_shift'] class AbstractRel: @@ -291,3 +292,32 @@ def _(expr, s): return expr.gts else: return sympy.S.Infinity + + +def relational_shift(expr, s): + """ + Infer shift incurred by the expression. Generally only + applies when a CondEq is used as it adds a single value. + """ + if not expr.has(s): + return 0 + + return _relational_shift(expr, s) + + +@singledispatch +def _relational_shift(s, expr): + return 0 + + +@_relational_shift.register(sympy.Or) +@_relational_shift.register(sympy.And) +def _(expr, s): + return sum([_relational_shift(e, s) for e in expr.args]) + + +@_relational_shift.register(sympy.Eq) +def _(expr, s): + if isinstance(expr.lhs, sympy.Mod): + return 0 + return expr._as_min diff --git a/tests/test_buffering.py b/tests/test_buffering.py index 0a0037c223..f89f7262d4 100644 --- a/tests/test_buffering.py +++ b/tests/test_buffering.py @@ -754,7 +754,7 @@ def test_buffer_reuse(): assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1)) -def test_multi_cond(): +def test_multi_cond_v0(): grid = Grid((3, 3)) nt = 5 @@ -774,14 +774,47 @@ def test_multi_cond(): T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0) eqs = [Eq(T, grid.time_dim)] - # this to save times from 0 to nt - 2 + # This saves + # - All subsampled times since ct1 is the dimension of f + # - The last time step (ntmod - 2) through ctend (since it's set as ct1 or ctend) + eqs.append(Eq(f, T, implicit_dims=ctend)) + + # run operator with buffering + op = Operator(eqs, opt='buffering') + op.apply(time_m=0, time_M=ntmod-2) + + for i in range(nt-1): + assert np.allclose(f.data[i], i*2) + assert np.allclose(f.data[nt-1], ntmod - 2) + + +def test_multi_cond_v1(): + grid = Grid((3, 3)) + nt = 5 + + x, y = grid.dimensions + + factor = 2 + ntmod = (nt - 1) * factor + 1 + + ct1 = ConditionalDimension(name="ct1", parent=grid.time_dim, + factor=factor, relation=Or, + condition=CondEq(grid.time_dim, ntmod - 2)) + + f = TimeFunction(grid=grid, name='f', time_order=0, + space_order=0, save=nt, time_dim=ct1) + T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0) + + eqs = [Eq(T, grid.time_dim)] + # This saves + # - All subsampled times since ct1 is the dimension of f with factor 2 + # - The last time step (ntmod - 2) since ct1 also has the condition for ntmod - 2 eqs.append(Eq(f, T)) - # this to save the last time sample nt - 1 - eqs.append(Eq(f.forward, T+1, implicit_dims=ctend)) # run operator with buffering op = Operator(eqs, opt='buffering') op.apply(time_m=0, time_M=ntmod-2) - for i in range(nt): + for i in range(nt-1): assert np.allclose(f.data[i], i*2) + assert np.allclose(f.data[nt-1], ntmod - 2) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 1cde2e13e2..8f59ca0076 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -855,6 +855,21 @@ def test_point_symbol_types(dtype, expected): assert point_symbol.dtype is expected +@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) +def test_interp_complex(dtype): + grid = Grid((11, 11, 11)) + + sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype) + sc.coordinates.data[:] = [.5, .5, .5] + + fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) + fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + opC = Operator([sc.interpolate(expr=fc)], name="OpC") + opC() + + assert np.isclose(sc.data[0], fc.data[5, 5, 5]) + + class SD0(SubDomain): name = 'sd0'