Skip to content

Commit 8c448dc

Browse files
Merge pull request #2878 from devitocodes/more-TMA-enhance
compiler: Misc enhancements for lowering of parlang backends
2 parents de5a162 + aa388f3 commit 8c448dc

9 files changed

Lines changed: 166 additions & 60 deletions

File tree

devito/ir/clusters/cluster.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from devito.symbolics import estimate_cost
1717
from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype
1818
from devito.types import (
19-
CriticalRegion, Fence, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence
19+
CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit,
20+
ThreadPoolSync, ThreadWait, WeakFence
2021
)
2122

2223
__all__ = ["Cluster", "ClusterGroup"]
@@ -302,6 +303,10 @@ def is_fence(self):
302303
def is_weak_fence(self):
303304
return self._is_type(WeakFence)
304305

306+
@cached_property
307+
def is_phase_marker(self):
308+
return self._is_type(PhaseMarker)
309+
305310
@cached_property
306311
def is_critical_region(self):
307312
return self._is_type(CriticalRegion)
@@ -310,14 +315,47 @@ def is_critical_region(self):
310315
def is_thread_pool_sync(self):
311316
return self._is_type(ThreadPoolSync)
312317

318+
@cached_property
319+
def is_shm_write(self):
320+
return all(w._mem_shared for w in self.scope.writes)
321+
313322
@cached_property
314323
def is_thread_commit(self):
315324
return self._is_type(ThreadCommit)
316325

326+
@cached_property
327+
def is_thread_arrive(self):
328+
return self._is_type(ThreadArrive)
329+
317330
@cached_property
318331
def is_thread_wait(self):
319332
return self._is_type(ThreadWait)
320333

334+
@cached_property
335+
def is_thread_sync(self):
336+
return self.is_thread_pool_sync or self.is_thread_wait
337+
338+
@cached_property
339+
def is_word_move(self):
340+
return (self._is_type(Indexed) and
341+
all(e.rhs.function._mem_heap for e in self.exprs))
342+
343+
@cached_property
344+
def is_tensor_move(self):
345+
return self._is_type(TensorMove)
346+
347+
@cached_property
348+
def is_word_move_to_mem_shared(self):
349+
return self.is_word_move and self.is_shm_write
350+
351+
@cached_property
352+
def is_tensor_move_to_mem_shared(self):
353+
return self.is_tensor_move and self.is_shm_write
354+
355+
@cached_property
356+
def is_glb_load_to_mem_shared(self):
357+
return self.is_word_move_to_mem_shared or self.is_tensor_move_to_mem_shared
358+
321359
@cached_property
322360
def is_async(self):
323361
"""
@@ -557,6 +595,10 @@ def dspace(self):
557595
def is_halo_touch(self):
558596
return all(i.is_halo_touch for i in self)
559597

598+
@cached_property
599+
def is_glb_load_to_mem_shared(self):
600+
return all(i.is_glb_load_to_mem_shared for i in self)
601+
560602
@cached_property
561603
def dtype(self):
562604
"""

devito/ir/iet/efunc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from dataclasses import dataclass
12
from functools import cached_property
3+
from itertools import chain
24

35
from devito.ir.iet.nodes import Call, Callable
46
from devito.ir.iet.utils import derive_parameters
@@ -11,6 +13,7 @@
1113
'CommCallable',
1214
'DeviceCall',
1315
'DeviceFunction',
16+
'EFuncMeta',
1417
'ElementalCall',
1518
'ElementalFunction',
1619
'EntryFunction',
@@ -21,6 +24,38 @@
2124
]
2225

2326

27+
@dataclass(frozen=True)
28+
class EFuncMeta:
29+
30+
body: object = None
31+
efuncs: tuple = ()
32+
includes: tuple = ()
33+
namespaces: tuple = ()
34+
libs: tuple = ()
35+
36+
@classmethod
37+
def compose(cls, *items):
38+
items = tuple(items)
39+
40+
if not items:
41+
return cls()
42+
43+
return cls(
44+
body=items[-1].body,
45+
efuncs=tuple(chain.from_iterable(i.efuncs for i in items)),
46+
includes=tuple(chain.from_iterable(i.includes for i in items)),
47+
namespaces=tuple(chain.from_iterable(i.namespaces for i in items)),
48+
libs=tuple(chain.from_iterable(i.libs for i in items))
49+
)
50+
51+
def __iter__(self):
52+
yield self.body
53+
yield self.efuncs
54+
yield self.includes
55+
yield self.namespaces
56+
yield self.libs
57+
58+
2459
# ElementalFunction machinery
2560

2661
class ElementalCall(Call):

devito/ir/support/properties.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,6 @@ def __init__(self, name, val=None):
9797
A Dimension along which prefetching is feasible and beneficial.
9898
"""
9999

100-
PREFETCHABLE_SHM = Property('prefetchable-shm')
101-
"""
102-
A Dimension along which shared-memory prefetching is feasible and beneficial.
103-
"""
104-
105100
INIT_CORE_SHM = Property('init-core-shm')
106101
"""
107102
A Dimension along which the shared-memory CORE data region is initialized.
@@ -190,32 +185,6 @@ def update_properties(properties, exprs):
190185
if not exprs:
191186
return properties
192187

193-
# Auto-detect prefetchable Dimensions
194-
dims = set()
195-
flag = False
196-
for e in as_tuple(exprs):
197-
w, r = e.args
198-
199-
# Ensure it's in the form `Indexed = Indexed`
200-
try:
201-
wf, rf = w.function, r.function
202-
except AttributeError:
203-
break
204-
205-
if not rf or not wf._mem_shared:
206-
break
207-
dims.update({d.parent for d in wf.dimensions if d.parent in properties})
208-
209-
if not rf._mem_heap:
210-
break
211-
else:
212-
flag = True
213-
214-
if flag:
215-
properties = properties.prefetchable_shm(dims)
216-
else:
217-
properties = properties.drop(properties=PREFETCHABLE_SHM)
218-
219188
# Remove properties that are trivially incompatible with `exprs`
220189
if not all(e.lhs.function._mem_shared for e in as_tuple(exprs)):
221190
drop = {INIT_CORE_SHM, INIT_HALO_LEFT_SHM, INIT_HALO_RIGHT_SHM}
@@ -284,9 +253,6 @@ def prefetchable(self, dims, v=PREFETCHABLE):
284253
m[d] = self.get(d, set()) | {v}
285254
return Properties(m)
286255

287-
def prefetchable_shm(self, dims):
288-
return self.prefetchable(dims, PREFETCHABLE_SHM)
289-
290256
def block(self, dims, kind='default'):
291257
if kind == 'default':
292258
p = TILABLE
@@ -357,9 +323,6 @@ def _is_property_any(self, dims, v):
357323
def is_prefetchable(self, dims=None, v=PREFETCHABLE):
358324
return self._is_property_any(dims, PREFETCHABLE)
359325

360-
def is_prefetchable_shm(self, dims=None):
361-
return self._is_property_any(dims, PREFETCHABLE_SHM)
362-
363326
def is_core_init(self, dims=None):
364327
return self._is_property_any(dims, INIT_CORE_SHM)
365328

devito/passes/clusters/aliases.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
139139
# [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140140
schedule, exprs = self._select(variants)
141141

142+
# Schedule -> Schedule (optimization)
143+
if self.opt_maxpar:
144+
schedule = optimize_schedule_maxpar(schedule)
145+
142146
# Schedule -> Schedule (optimization)
143147
if self.opt_rotate:
144148
schedule = optimize_schedule_rotations(schedule, self.sregistry)
@@ -664,7 +668,6 @@ def lower_aliases(aliases, meta, maxpar):
664668
"""
665669
Create a Schedule from an AliasList.
666670
"""
667-
stampcache = {}
668671
dmapper = {}
669672
processed = []
670673
for a in aliases:
@@ -704,12 +707,6 @@ def lower_aliases(aliases, meta, maxpar):
704707
# use `<1>` as stamp, which is what appears in `ispace`
705708
interval = interval.lift(i.stamp)
706709

707-
# We further bump the interval stamp if we were requested to trade
708-
# fusion for more collapse-parallelism
709-
if maxpar:
710-
stamp = stampcache.setdefault(interval.dim, Stamp())
711-
interval = interval.lift(stamp)
712-
713710
writeto.append(interval)
714711
intervals.append(interval)
715712

@@ -853,6 +850,30 @@ def optimize_schedule_rotations(schedule, sregistry):
853850
return schedule.rebuild(*processed, rmapper=rmapper)
854851

855852

853+
def optimize_schedule_maxpar(schedule):
854+
"""
855+
Bump the IterationSpace' stamp trading fusion for more collapse-parallelism.
856+
"""
857+
key = lambda i: (i.writeto, i.ispace)
858+
859+
processed = []
860+
for (writeto0, ispace0), group in groupby(schedule, key=key):
861+
g = list(group)
862+
863+
stamp = Stamp()
864+
dims = writeto0.itdims
865+
866+
writeto = writeto0.lift(dims, stamp)
867+
ispace = ispace0.lift(dims, stamp)
868+
869+
processed.extend([
870+
ScheduledAlias(pivot, writeto, ispace, aliaseds, indicess)
871+
for pivot, _, _, aliaseds, indicess in g
872+
])
873+
874+
return schedule.rebuild(*processed)
875+
876+
856877
def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
857878
opt_minmem):
858879
"""

devito/passes/clusters/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _key(self, c):
232232
weak.append(c.properties.is_core_init())
233233

234234
# Prefetchable Clusters should get merged, if possible
235-
weak.append(c.properties.is_prefetchable_shm())
235+
weak.append(c.is_glb_load_to_mem_shared)
236236

237237
# Promoting adjacency of IndexDerivatives will maximize their reuse
238238
weak.append(any(search(c.exprs, IndexDerivative)))

devito/types/dense.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,15 +1593,6 @@ def _time_buffering(self):
15931593
def _time_buffering_default(self):
15941594
return self._time_buffering and not isinstance(self.save, Buffer)
15951595

1596-
def _evaluate(self, **kwargs):
1597-
retval = super()._evaluate(**kwargs)
1598-
if not self._time_buffering and not retval.is_Function:
1599-
# Saved TimeFunction might need streaming, expand interpolations
1600-
# for easier processing
1601-
return retval.evaluate
1602-
else:
1603-
return retval
1604-
16051596
def _arg_check(self, args, intervals, **kwargs):
16061597
super()._arg_check(args, intervals, **kwargs)
16071598

devito/types/misc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'Hyperplane',
2525
'Indirection',
2626
'Jump',
27+
'PhaseMarker',
2728
'Pointer',
2829
'Temp',
2930
'TempArray',
@@ -332,6 +333,17 @@ class WeakFence(sympy.Function, Fence):
332333
pass
333334

334335

336+
class PhaseMarker(WeakFence):
337+
338+
"""
339+
A special WeakFence acting as a marker to logically separate different compute phases.
340+
Thus, statements in different phases will not be reordered across the marker upon
341+
topological sorting.
342+
"""
343+
344+
pass
345+
346+
335347
class CriticalRegion(sympy.Function, Fence):
336348

337349
"""

devito/types/parallel.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from functools import cached_property
1212

1313
import numpy as np
14+
from sympy import Expr
1415

1516
from devito.exceptions import InvalidArgument
1617
from devito.parameters import configuration
17-
from devito.symbolics import search
18+
from devito.symbolics import Reserved, Terminal, search
1819
from devito.tools import as_list, as_tuple, is_integer
1920
from devito.types.array import Array, ArrayObject
2021
from devito.types.basic import Scalar, Symbol
@@ -35,7 +36,9 @@
3536
'QueueID',
3637
'SharedData',
3738
'TBArray',
39+
'TensorMove',
3840
'ThreadArray',
41+
'ThreadArrive',
3942
'ThreadCommit',
4043
'ThreadID',
4144
'ThreadPoolSync',
@@ -365,12 +368,24 @@ class ThreadCommit(Fence):
365368
pass
366369

367370

371+
class ThreadArrive(Fence):
372+
373+
"""
374+
A generic arrive operation for a single thread, typically used to signal
375+
the arrival at a certain point through a suitable synchronization object.
376+
"""
377+
378+
pass
379+
380+
368381
class ThreadWait(Fence):
369382

370383
"""
371384
A generic wait operation for a single thread, typically used to synchronize
372-
after a memory operation issued at a specific program point with a
373-
ThreadCommit operation.
385+
with other threads over:
386+
387+
* a memory operation issued by a prior ThreadCommit operation.
388+
* the consumption of a shared resource via a ThreadArrive operation.
374389
"""
375390

376391
pass
@@ -386,3 +401,18 @@ def __init_finalize__(self, *args, **kwargs):
386401
kwargs['liveness'] = 'eager'
387402

388403
super().__init_finalize__(*args, **kwargs)
404+
405+
406+
class TensorMove(Expr, Reserved, Terminal):
407+
408+
"""
409+
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
410+
level of the memory hierarchy
411+
"""
412+
413+
func = Reserved._rebuild
414+
415+
def _ccode(self, printer):
416+
return str(self)
417+
418+
_sympystr = _ccode

0 commit comments

Comments
 (0)