Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def _print_SafeInv(self, expr):
val = self._print(expr.val)
return f'SAFEINV({val}, {base})'

def _print_RoundUp(self, expr):
value = self._print(expr.value)
step = self._print(expr.step)
return f'ROUND_UP({value}, {step})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
args = [f'({self._print(a)})' for a in expr.args]
Expand Down
18 changes: 14 additions & 4 deletions devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search
from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split
from devito.types import (
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove
)

__all__ = [
Expand Down Expand Up @@ -137,7 +137,14 @@ def detect_accesses(exprs):
"""
# Compute M : F -> S
mapper = defaultdict(Stencil)
for e in retrieve_indexed(exprs, deep=True):

# Search among the Indexeds (Most accesses typically stem from Indexeds)
plain_indexeds = retrieve_indexed(exprs, deep=True)

# Search among higher order objects, which still represent meaningful accesses
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This look very specific and seems like that's something retrieve_index should catch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea is that other objects can end up there, in the future, maybe...

as for retrieve_indexed catching it: disagree, it's not an implicit Indexed, it's rarther a logical representation of the base address of the TensorMove -- as an Indexed, for homogeneity

high_order_indexeds = [i.indexed for i in search(exprs, TensorMove)]

for e in (*plain_indexeds, *high_order_indexeds):
f = e.function

for a, d0 in zip(e.indices, f.dimensions, strict=False):
Expand All @@ -164,13 +171,16 @@ def detect_accesses(exprs):
d, others = split(dims, lambda i: d0 in i._defines) # noqa: B023

if any(i.is_Indexed for i in a.args) or len(d) != 1:
# Case 1) -- with indirect accesses there's not much we can infer
# Case 1) -- with indirect accesses there's not much we
# can infer
continue
else:
# Case 2)
d, = d
_, o = split(others, lambda i: i.is_Custom)
off = sum(i for i in a.args if i.is_integer or i.free_symbols & o)
off = sum(
i for i in a.args if i.is_integer or i.free_symbols & o
)
else:
d, = dims

Expand Down
10 changes: 9 additions & 1 deletion devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.passes.iet.engine import iet_pass
from devito.passes.iet.languages.C import CPrinter
from devito.symbolics import Cast, ValueLimit, evalrel, has_integer_args, limits_mapper
from devito.symbolics import (
Cast, RoundUp, ValueLimit, evalrel, has_integer_args, limits_mapper
)
from devito.tools import Bunch, as_mapper, as_tuple, filter_ordered, split
from devito.types import FIndexed

Expand Down Expand Up @@ -255,6 +257,12 @@ def _(expr, langbb, printer):
f'(0.0{ext}) : ((1.0{ext}) / (a)))'),), {}


@_lower_macro_math.register(RoundUp)
def _(expr, langbb, printer):
return (('ROUND_UP(a,b)',
'((((a)%(b)) == 0) ? (a) : ((a) + (b) - ((a)%(b))))'),), {}


@iet_pass
def minimize_symbols(iet):
"""
Expand Down
49 changes: 46 additions & 3 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword',
'String', 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace',
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit',
'VectorAccess']
'String', 'Macro', 'Class', 'MacroArgument', 'RoundUp', 'Deref',
'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin',
'ValueLimit', 'VectorAccess']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -623,6 +623,49 @@ def __str__(self):
__repr__ = __str__


class RoundUp(Function):

"""
Symbolic representation of rounding a value up to the next multiple of a
given step.
"""

def __new__(cls, value, step, **kwargs):
value = sympify(value)
step = sympify(step)

if step < 1:
raise ValueError("Cannot round up with negative `step`")
if not is_integer(step):
raise ValueError("`step` must be an integer")

if value.is_number and step.is_number:
remainder = value % step
if remainder == 0:
return value
else:
return value + step - remainder

return super().__new__(cls, value, step, **kwargs)

@property
def value(self):
return self.args[0]

@property
def step(self):
return self.args[1]

@property
def is_commutative(self):
return self.value.is_commutative and self.step.is_commutative

def __str__(self):
return f"ROUND_UP({self.value}, {self.step})"

__repr__ = __str__


class ValueLimit(ReservedWord, sympy.Expr):

"""
Expand Down
9 changes: 7 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,17 @@ def __padding_setup_smart__(self, **kwargs):
return nopadding

mmts = configuration['platform'].max_mem_trans_size(self.__padding_dtype__)
remainder = self._size_nopad[d] % mmts

snp = self._size_nopad[d]
remainder = snp % mmts
if remainder == 0:
# Already a multiple of `mmts`, no need to pad
return nopadding
else:
from devito.symbolics import RoundUp # noqa
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe implies that extended_sympy should be moved somewhere else?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's a long standing issue, documented somewhere

v = RoundUp(snp, mmts) - snp

dpadding = (0, (mmts - remainder))
dpadding = (0, v)
padding = [(0, 0)]*self.ndim
padding[self.dimensions.index(d)] = dpadding

Expand Down
40 changes: 39 additions & 1 deletion devito/types/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,47 @@ class TensorMove(Expr, Reserved, Terminal):

"""
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
level of the memory hierarchy
level of the memory hierarchy.

Parameters
----------
base : IndexedBase
The base of the AbstractFunction subject of the TensorMove.
tid0 : Dimension
A representation of thread(s) issuing the TensorMove.
coords : tuple
The base address of the TensorMove (one point per Dimension).
"""

__rargs__ = ('base', 'tid0', 'coords')

def __new__(cls, base, tid0, coords, **kwargs):
return super().__new__(cls, base, tid0, coords)

@property
def base(self):
return self.args[0]

@property
def tid0(self):
return self.args[1]

@property
def coords(self):
return self.args[2]

@property
def function(self):
return self.base.function

@cached_property
def indexed(self):
return self.function[self.coords]

@property
def ndim(self):
return self.function.ndim

func = Reserved._rebuild

def _ccode(self, printer):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
)
from devito.data import LEFT, RIGHT, Decomposition, convert_index, loc_data_idx
from devito.data.allocators import DataReference
from devito.ir import ccode
from devito.tools import as_tuple
from devito.types import Scalar
from devito.types.misc import TempArray


class TestDataBasic:
Expand Down Expand Up @@ -336,6 +338,34 @@ def test_w_halo_w_autopadding(self):
assert u1._size_nodomain == ((3, 3), (3, 3), (3, 9))
assert u1.shape_allocated == (10, 10, 16)

@switchconfig(autopadding=True, platform='bdw')
def test_temp_array_smart_padding_no_overshoot(self):
mmts = configuration['platform'].max_mem_trans_size(np.float32)
halo = 4
z_size = 2*mmts - 2*halo

grid = Grid(shape=(4, 4, z_size))
u = Function(name='u', grid=grid, space_order=halo)
r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype)

z = grid.dimensions[-1]
mapper = {z.symbolic_size: z_size}

assert r.padding[z][1].subs(mapper) == 0
assert r.shape_allocated[-1].subs(mapper) == u.shape_allocated[-1]

@switchconfig(autopadding=True, platform='bdw')
def test_temp_array_smart_padding_codegen_avoids_negative_mod(self):
grid = Grid(shape=(4, 4, 592))
u = Function(name='u', grid=grid, space_order=0)
r = TempArray(name='r', dimensions=grid.dimensions, halo=u.halo, dtype=u.dtype)

code = ccode(r.shape_allocated[-1])

assert 'ROUND_UP(' in code
assert '(-z_size)' not in code
assert 'z_size' in code

def test_w_halo_custom(self):
grid = Grid(shape=(4, 4))

Expand Down
20 changes: 17 additions & 3 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
from devito import ( # noqa
Abs, Conj, Constant, Dimension, Eq, Function, Ge, Grid, Gt, Imag, Le, Lt, Max, Min,
Operator, Real, SubDimension, SubDomain, TimeFunction, configuration, cos, norm, sin,
solve
solve, switchconfig
)
from devito.finite_differences.differentiable import Mul, SafeInv, Weights
from devito.ir import Expression, FindNodes, ccode
from devito.ir.support.guards import GuardExpr, pairwise_or, simplify_and
from devito.mpi.halo_scheme import HaloTouch
from devito.symbolics import ( # noqa
INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite,
FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, Rvalue, SizeOf,
VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions,
FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue,
SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions,
retrieve_indexed, uxreplace
)
from devito.tools import CustomDtype, as_tuple
Expand Down Expand Up @@ -390,6 +390,20 @@ def test_safeinv():
assert str(v) == 'u[x, y]'


def test_roundup():
grid = Grid(shape=(11, 11))
u = Function(name='u', grid=grid)
a = dSymbol('a', dtype=np.int32)

expr = RoundUp(a, 16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the round up factor also be symbolic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, I cooked up something simple for my needs after days of frustration

Copy link
Copy Markdown
Contributor Author

@FabioLuporini FabioLuporini Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment above -- now ensuring it's an integer number

with switchconfig(platform='bdw', language='openmp'):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'bdw' is oddly specific

op = Operator(Eq(u, u + expr))

assert ccode(expr) == 'ROUND_UP(a, 16)'
assert '#define ROUND_UP(a,b)' in str(op)
assert 'ROUND_UP(a, 16)' in str(op)


def test_def_function():
foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
Expand Down
Loading