Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def alloc(self, shape, dtype, padding=0):
raise RuntimeError(f"Unable to allocate {size} elements in memory")

# Compute the pointer to the user data
padleft_bytes = padleft * ctypes.sizeof(ctype)
padleft_bytes = int(padleft * ctypes.sizeof(ctype))
c_pointer = ctypes.c_void_p(padleft_pointer.value + padleft_bytes)

# Cast to 1D array of the specified `datasize`
Expand Down
16 changes: 14 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import inspect
import warnings
from contextlib import suppress
from ctypes import POINTER, Structure, _Pointer, c_char, c_char_p
from functools import cached_property, reduce
Expand All @@ -9,6 +10,7 @@
import sympy
from sympy.core.assumptions import _assume_rules
from sympy.core.decorators import call_highest_priority
from sympy.utilities.exceptions import SymPyDeprecationWarning

from devito.data import default_allocator
from devito.parameters import configuration
Expand Down Expand Up @@ -1533,10 +1535,20 @@ def _sympify(self, arg):
# This is used internally by sympy to process arguments at rebuilt. And since
# some of our properties are non-sympyfiable we need to have a fallback
try:
return super()._sympify(arg)
except sympy.SympifyError:
# Pure sympy object
return arg._sympy_()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This prevents sympy from forcing some weird type conversion because of

https://github.com/sympy/sympy/blob/16fa855354eb7bcabd3fe10993841e03b1382692/sympy/core/sympify.py#L414

That for example converts EnrichedTuples to sympy's tuples. By using _sympy_ explicitly it makes sure it only convert object with explicity conversion defined (so mostly pure sympy objects or our objects with it defined)

except AttributeError:
return arg

@classmethod
def _eval_from_dok(cls, rows, cols, dok):
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=SymPyDeprecationWarning
)
return super()._eval_from_dok(rows, cols, dok)

@property
def grid(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,9 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
* 0 to non-staggered dimensions;
* 1 to staggered dimensions.
"""
if isinstance(staggered, Staggering):
staggered = staggered._ref

if not staggered:
processed = ()
elif staggered is CELL:
Expand All @@ -1154,6 +1157,8 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
assert len(staggered) == len(dimensions)
processed = staggered
else:
# Staggering is not NODE or CELL or None
# therefore it's a tuple of dimensions
processed = []
for d in dimensions:
if d in as_tuple(staggered):
Expand Down
15 changes: 13 additions & 2 deletions devito/types/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@
__all__ = ['TensorFunction', 'TensorTimeFunction', 'VectorFunction', 'VectorTimeFunction']


def staggering(stagg, i, j, d, dims):
if stagg is None:
# No input
return NODE if i == j else (d, dims[j])
elif isinstance(stagg, (tuple, list)):
# User input as list or tuple
return stagg[i][j]
elif isinstance(stagg, AbstractTensor):
# From rebuild/tensor property. Indexed as a sympy Matrix
return stagg[i, j]


class TensorFunction(AbstractTensor):
"""
Tensor valued Function represented as a Matrix.
Expand Down Expand Up @@ -128,8 +140,7 @@ def __subfunc_setup__(cls, *args, **kwargs):
start = i if (symm or diag) else 0
stop = i + 1 if diag else len(dims)
for j in range(start, stop):
staggj = (stagg[i][j] if stagg is not None
else (NODE if i == j else (d, dims[j])))
staggj = staggering(stagg, i, j, d, dims)
sub_kwargs = cls._component_kwargs((i, j), **kwargs)
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
'staggered': staggj})
Expand Down
9 changes: 9 additions & 0 deletions devito/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class Staggering(DimensionTuple):
def on_node(self):
return not self or all(s == 0 for s in self)

@property
def _ref(self):
if not self:
return None
elif self.on_node:
return NODE
else:
return tuple(d for d, s in zip(self.getters, self, strict=True) if s == 1)


class IgnoreDimSort(tuple):
"""A tuple subclass used to wrap the implicit_dims to indicate
Expand Down
27 changes: 20 additions & 7 deletions tests/test_staggered_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from sympy import simplify

from devito import (
CELL, NODE, Dimension, Eq, Function, Grid, Operator, TimeFunction, VectorTimeFunction,
div
CELL, NODE, Eq, Function, Grid, Operator, TimeFunction, VectorTimeFunction, div
)
from devito.tools import as_tuple, powerset

Expand Down Expand Up @@ -173,16 +172,15 @@ def test_staggered_rebuild(stagg):
f = Function(name='f', grid=grid, space_order=4, staggered=stagg)
assert tuple(f.staggered.getters.keys()) == grid.dimensions

new_dims = (Dimension('x1'), Dimension('y1'), Dimension('z1'))
f2 = f.func(dimensions=new_dims)
f2 = f.func(name="f2")

assert f2.dimensions == new_dims
assert f2.dimensions == f.dimensions
assert tuple(f2.staggered) == tuple(f.staggered)
assert tuple(f2.staggered.getters.keys()) == new_dims
assert tuple(f2.staggered.getters.keys()) == f.dimensions

# Check that rebuild correctly set the staggered indices
# with the new dimensions
for (d, nd) in zip(grid.dimensions, new_dims, strict=True):
for (d, nd) in zip(grid.dimensions, f.dimensions, strict=True):
if d in as_tuple(stagg) or stagg is CELL:
assert f2.indices[nd] == nd + nd.spacing / 2
else:
Expand All @@ -200,3 +198,18 @@ def test_eval_at_different_dim():
eq = Eq(tau.forward, v).evaluate

assert grid.time_dim not in eq.rhs.free_symbols


def test_new_from_staggering():
grid = Grid(shape=(31, 17, 25))
x, _, _ = grid.dimensions

f = TimeFunction(name="f", grid=grid, staggered=x)
# This used to fail since f.staggered as 4 elements (0, 1, 0, 0)
# but it is processed for Dimension only.
# Now properly converts Staggering to the ref (x,) at init
g = TimeFunction(name="g", grid=grid, staggered=f.staggered)

assert g.staggered._ref == (x,)
assert g.staggered == (0, 1, 0, 0)
assert g.staggered == f.staggered
Loading