Skip to content

Commit 21ec963

Browse files
Support decorator wrapper (#60)
- Remove CustomFunctionCompiler and directly modify the pipeline class in the numba.jit decorator implementation, to support decorator wrappers
1 parent a7770d6 commit 21ec963

3 files changed

Lines changed: 15 additions & 25 deletions

File tree

src/numba/openmp/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
omp_get_initial_device,
4040
)
4141

42-
from .compiler import (
43-
CustomCompiler,
44-
CustomFunctionCompiler,
45-
)
46-
4742
from .exceptions import ( # noqa: F401
4843
UnspecifiedVarInDefaultNone,
4944
ParallelForExtraCode,

src/numba/openmp/compiler.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,14 @@ def _finalize_specific(self):
321321
raise RuntimeError("error registering OpenMP offloading descriptor")
322322

323323

324-
class CustomFunctionCompiler(_FunctionCompiler):
325-
def _customize_flags(self, flags):
326-
# We need to disable SSA form for OpenMP analysis to detect variables
327-
# used within regions.
324+
class CustomCompiler(compiler.CompilerBase):
325+
def __init__(self, typingctx, targetctx, library, args, return_type, flags, locals):
326+
# Ensure SSA form is disabled for OpenMP analysis to detect variables used within regions.
328327
flags.enable_ssa = False
329-
return flags
330-
328+
super().__init__(
329+
typingctx, targetctx, library, args, return_type, flags, locals
330+
)
331331

332-
class CustomCompiler(compiler.CompilerBase):
333332
@staticmethod
334333
def custom_untyped_pipeline(state, name="untyped-openmp"):
335334
"""Returns an untyped part of the nopython OpenMP pipeline"""
@@ -366,14 +365,15 @@ def custom_untyped_pipeline(state, name="untyped-openmp"):
366365
pm.add_pass(FindLiterallyCalls, "find literally calls")
367366
pm.add_pass(LiteralUnroll, "handles literal_unroll")
368367

369-
if state.flags.enable_ssa:
370-
assert False, "SSA form is not supported in OpenMP"
368+
assert not state.flags.enable_ssa, (
369+
"SSA form is not supported in OpenMP compilation"
370+
)
371371

372372
pm.add_pass(LiteralPropagationSubPipelinePass, "Literal propagation")
373-
# Run WithLifting late to for make_implicit_explicit to work. TODO: We
374-
# should create a pass that does this instead of replicating and hacking
375-
# the untyped pipeline. This handling may also negatively affect
376-
# optimizations.
373+
# Run WithLifting late to for make_implicit_explicit to work.
374+
# TODO: We should create a pass that does this instead of replicating
375+
# and hacking the untyped pipeline. This handling may also negatively
376+
# affect optimizations.
377377
pm.add_pass(WithLifting, "Handle with contexts")
378378

379379
pm.finalize()

src/numba/openmp/decorators.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import warnings
22
import numba
33

4-
from .compiler import (
5-
CustomCompiler,
6-
CustomFunctionCompiler,
7-
)
4+
from .compiler import CustomCompiler
85

96

107
def jit(*args, **kws):
@@ -16,10 +13,8 @@ def jit(*args, **kws):
1613
if "forceobj" in kws:
1714
warnings.warn("forceobj is set for njit and is ignored", RuntimeWarning)
1815
del kws["forceobj"]
19-
kws.update({"nopython": True, "nogil": True})
16+
kws.update({"nopython": True, "nogil": True, "pipeline_class": CustomCompiler})
2017
dispatcher = numba.jit(*args, **kws)
21-
dispatcher._compiler.__class__ = CustomFunctionCompiler
22-
dispatcher._compiler.pipeline_class = CustomCompiler
2318
return dispatcher
2419

2520

0 commit comments

Comments
 (0)