-
Notifications
You must be signed in to change notification settings - Fork 254
compiler: Enhance detect_accesses and patch symbolic padding #2886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -921,12 +921,17 @@ def __padding_setup_smart__(self, **kwargs): | |
| return nopadding | ||
|
|
||
| mmts = configuration['platform'].max_mem_trans_size(self.__padding_dtype__) | ||
| remainder = self._size_nopad[d] % mmts | ||
|
|
||
| snp = self._size_nopad[d] | ||
| remainder = snp % mmts | ||
| if remainder == 0: | ||
| # Already a multiple of `mmts`, no need to pad | ||
| return nopadding | ||
| else: | ||
| from devito.symbolics import RoundUp # noqa | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe implies that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it's a long standing issue, documented somewhere |
||
| v = RoundUp(snp, mmts) - snp | ||
|
|
||
| dpadding = (0, (mmts - remainder)) | ||
| dpadding = (0, v) | ||
| padding = [(0, 0)]*self.ndim | ||
| padding[self.dimensions.index(d)] = dpadding | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,16 +8,16 @@ | |
| from devito import ( # noqa | ||
| Abs, Conj, Constant, Dimension, Eq, Function, Ge, Grid, Gt, Imag, Le, Lt, Max, Min, | ||
| Operator, Real, SubDimension, SubDomain, TimeFunction, configuration, cos, norm, sin, | ||
| solve | ||
| solve, switchconfig | ||
| ) | ||
| from devito.finite_differences.differentiable import Mul, SafeInv, Weights | ||
| from devito.ir import Expression, FindNodes, ccode | ||
| from devito.ir.support.guards import GuardExpr, pairwise_or, simplify_and | ||
| from devito.mpi.halo_scheme import HaloTouch | ||
| from devito.symbolics import ( # noqa | ||
| INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite, | ||
| FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, Rvalue, SizeOf, | ||
| VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, | ||
| FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue, | ||
| SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, | ||
| retrieve_indexed, uxreplace | ||
| ) | ||
| from devito.tools import CustomDtype, as_tuple | ||
|
|
@@ -390,6 +390,20 @@ def test_safeinv(): | |
| assert str(v) == 'u[x, y]' | ||
|
|
||
|
|
||
| def test_roundup(): | ||
| grid = Grid(shape=(11, 11)) | ||
| u = Function(name='u', grid=grid) | ||
| a = dSymbol('a', dtype=np.int32) | ||
|
|
||
| expr = RoundUp(a, 16) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can the round up factor also be symbolic?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure, I cooked up something simple for my needs after days of frustration
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see comment above -- now ensuring it's an integer number |
||
| with switchconfig(platform='bdw', language='openmp'): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'bdw' is oddly specific |
||
| op = Operator(Eq(u, u + expr)) | ||
|
|
||
| assert ccode(expr) == 'ROUND_UP(a, 16)' | ||
| assert '#define ROUND_UP(a,b)' in str(op) | ||
| assert 'ROUND_UP(a, 16)' in str(op) | ||
|
|
||
|
|
||
| def test_def_function(): | ||
| foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int']) | ||
| foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int']) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This look very specific and seems like that's something retrieve_index should catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the idea is that other objects can end up there, in the future, maybe...
as for retrieve_indexed catching it: disagree, it's not an implicit Indexed, it's rarther a logical representation of the base address of the TensorMove -- as an Indexed, for homogeneity