66.. currentmodule:: arraycontext
77
88.. autofunction:: with_container_arithmetic
9+
10+ .. autoclass:: BcastUntilActxArray
911"""
1012
1113
3436"""
3537
3638import enum
39+ import operator
3740from collections .abc import Callable
41+ from dataclasses import dataclass , field
42+ from functools import partialmethod
43+ from numbers import Number
3844from typing import Any , TypeVar
3945from warnings import warn
4046
4147import numpy as np
4248
49+ from arraycontext .container import (
50+ NotAnArrayContainerError ,
51+ deserialize_container ,
52+ serialize_container ,
53+ )
54+ from arraycontext .context import ArrayContext , ArrayOrContainer
55+
4356
4457# {{{ with_container_arithmetic
4558
@@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142155 warn (
143156 "Broadcasting container against non-object numpy array. "
144157 "This was never documented to work and will now stop working in "
145- "2025. Convert the array to an object array to preserve the "
146- "current semantics." , DeprecationWarning , stacklevel = 3 )
158+ "2025. Convert the array to an object array or use "
159+ "arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
160+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
147161 return True
148162 else :
149163 return False
@@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207221
208222 Each operator class also includes the "reverse" operators if applicable.
209223
224+ .. note::
225+
226+ For the generated binary arithmetic operators, if certain types
227+ should be broadcast over the container (with the container as the
228+ 'outer' structure) but are not handled in this way by their types,
229+ you may wrap them in :class:`BcastUntilActxArray` to achieve
230+ the desired semantics.
231+
210232 .. note::
211233
212234 To generate the code implementing the operators, this function relies on
@@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any:
402424 warn (
403425 f"Broadcasting array context array types across { cls } "
404426 "has been explicitly "
405- "enabled. As of 2025, this will stop working. "
406- "There is no replacement as of right now. "
427+ "enabled. As of 2026, this will stop working. "
428+ "Use arraycontext.Bcast* object wrappers for "
429+ "roughly equivalent functionality. "
407430 "See the discussion in "
408431 "https://github.com/inducer/arraycontext/pull/190. "
409432 "To opt out now (and avoid this warning), "
@@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any:
413436 warn (
414437 f"Broadcasting array context array types across { cls } "
415438 "has been implicitly "
416- "enabled. As of 2025, this will no longer work. "
417- "There is no replacement as of right now. "
439+ "enabled. As of 2026, this will no longer work. "
440+ "Use arraycontext.Bcast* object wrappers for "
441+ "roughly equivalent functionality. "
418442 "See the discussion in "
419443 "https://github.com/inducer/arraycontext/pull/190. "
420444 "To opt out now (and avoid this warning), "
@@ -603,8 +627,9 @@ def {fname}(arg1):
603627 if isinstance(arg2, { tup_str (bcast_actx_ary_types )} ):
604628 warn("Broadcasting { cls } over array "
605629 f"context array type {{type(arg2)}} is deprecated "
606- "and will no longer work in 2025. "
607- "There is no replacement as of right now. "
630+ "and will no longer work in 2026. "
631+ "Use arraycontext.Bcast* object wrappers for "
632+ "roughly equivalent functionality. "
608633 "See the discussion in "
609634 "https://github.com/inducer/arraycontext/"
610635 "pull/190. ",
@@ -654,8 +679,10 @@ def {fname}(arg2, arg1):
654679 warn("Broadcasting { cls } over array "
655680 f"context array type {{type(arg1)}} "
656681 "is deprecated "
657- "and will no longer work in 2025."
658- "There is no replacement as of right now. "
682+ "and will no longer work in 2026."
683+ "Use arraycontext.Bcast* object "
684+ "wrappers for roughly equivalent "
685+ "functionality. "
659686 "See the discussion in "
660687 "https://github.com/inducer/arraycontext/"
661688 "pull/190. ",
@@ -687,4 +714,110 @@ def {fname}(arg2, arg1):
687714# }}}
688715
689716
717+ # {{{ Bcast object-ified broadcast rules
718+
719+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
720+ #
721+ # - If one rule does not fit the user's need, they can straightforwardly use
722+ # another.
723+ #
724+ # - It's straightforward to find where certain broadcast rules are used.
725+ #
726+ # - The broadcast rule can contain more state. For example, it's now easy
727+ # for the rule to know what array context should be used to determine
728+ # actx array types.
729+ #
730+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
731+ #
732+ # - User code is a bit more wordy.
733+
734+ @dataclass (frozen = True )
735+ class BcastUntilActxArray :
736+ """
737+ A broadcast rule that broadcasts *broadcastee* across array containers until
738+ the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types`
739+ of *actx* or a :class:`~numbers.Number`.
740+
741+ Suggested usage pattern::
742+
743+ bcast = functools.partial(BcastUntilActxArray, actx)
744+
745+ container + bcast(actx_array)
746+
747+ .. automethod:: __init__
748+ """
749+
750+ array_context : ArrayContext
751+ broadcastee : ArrayOrContainer
752+
753+ _stop_types : tuple [type , ...] = field (init = False )
754+
755+ def __post_init__ (self ) -> None :
756+ object .__setattr__ (
757+ self , "_stop_types" , (* self .array_context .array_types , Number ))
758+
759+ def _binary_op (self ,
760+ op : Callable [
761+ [ArrayOrContainer , ArrayOrContainer ],
762+ ArrayOrContainer
763+ ],
764+ right : ArrayOrContainer
765+ ) -> ArrayOrContainer :
766+ try :
767+ serialized = serialize_container (right )
768+ except NotAnArrayContainerError :
769+ return op (self .broadcastee , right )
770+
771+ return deserialize_container (right , [
772+ (k , op (self .broadcastee , right_v )
773+ if isinstance (right_v , self ._stop_types ) else
774+ self ._binary_op (op , right_v )
775+ )
776+ for k , right_v in serialized ])
777+
778+ def _rev_binary_op (self ,
779+ op : Callable [
780+ [ArrayOrContainer , ArrayOrContainer ],
781+ ArrayOrContainer
782+ ],
783+ left : ArrayOrContainer
784+ ) -> ArrayOrContainer :
785+ try :
786+ serialized = serialize_container (left )
787+ except NotAnArrayContainerError :
788+ return op (left , self .broadcastee )
789+
790+ return deserialize_container (left , [
791+ (k , op (left_v , self .broadcastee )
792+ if isinstance (left_v , self ._stop_types ) else
793+ self ._rev_binary_op (op , left_v )
794+ )
795+ for k , left_v in serialized ])
796+
797+ __add__ = partialmethod (_binary_op , operator .add )
798+ __radd__ = partialmethod (_rev_binary_op , operator .add )
799+ __sub__ = partialmethod (_binary_op , operator .sub )
800+ __rsub__ = partialmethod (_rev_binary_op , operator .sub )
801+ __mul__ = partialmethod (_binary_op , operator .mul )
802+ __rmul__ = partialmethod (_rev_binary_op , operator .mul )
803+ __truediv__ = partialmethod (_binary_op , operator .truediv )
804+ __rtruediv__ = partialmethod (_rev_binary_op , operator .truediv )
805+ __floordiv__ = partialmethod (_binary_op , operator .floordiv )
806+ __rfloordiv__ = partialmethod (_rev_binary_op , operator .floordiv )
807+ __mod__ = partialmethod (_binary_op , operator .mod )
808+ __rmod__ = partialmethod (_rev_binary_op , operator .mod )
809+ __pow__ = partialmethod (_binary_op , operator .pow )
810+ __rpow__ = partialmethod (_rev_binary_op , operator .pow )
811+
812+ __lshift__ = partialmethod (_binary_op , operator .lshift )
813+ __rlshift__ = partialmethod (_rev_binary_op , operator .lshift )
814+ __rshift__ = partialmethod (_binary_op , operator .rshift )
815+ __rrshift__ = partialmethod (_rev_binary_op , operator .rshift )
816+ __and__ = partialmethod (_binary_op , operator .and_ )
817+ __rand__ = partialmethod (_rev_binary_op , operator .and_ )
818+ __or__ = partialmethod (_binary_op , operator .or_ )
819+ __ror__ = partialmethod (_rev_binary_op , operator .or_ )
820+
821+ # }}}
822+
690823# vim: foldmethod=marker
0 commit comments