@@ -760,6 +760,48 @@ def _eval_tensor2bits(
760760 return {k2 : v for k , v in result .items () for k2 in k }
761761
762762
763+ _VALID_DIRECTIONS = {"up" , "down" , "up_and_down" }
764+
765+
766+ def _eval_direction (toeval , symbol_table : dict [str , Any ]) -> dict [str , str ]:
767+ """Evaluate a direction field. If a string, expand to all tensors. If a dict,
768+ resolve tensor expression keys."""
769+ if isinstance (toeval , str ):
770+ if toeval not in _VALID_DIRECTIONS :
771+ raise EvaluationError (
772+ f'Invalid direction: "{ toeval } ". '
773+ f"Must be one of { sorted (_VALID_DIRECTIONS )} ."
774+ )
775+ all_tensors = symbol_table ["All" ].instance
776+ return {t : toeval for t in all_tensors }
777+
778+ result = {}
779+ for key , value in toeval .items ():
780+ if value not in _VALID_DIRECTIONS :
781+ raise EvaluationError (
782+ f'Invalid direction for { key } : "{ value } ". '
783+ f"Must be one of { sorted (_VALID_DIRECTIONS )} ."
784+ )
785+ key_evaluated = eval_set_expression (
786+ expression = key ,
787+ symbol_table = symbol_table ,
788+ expected_space = TensorName ,
789+ location = f"direction key { key } " ,
790+ ).instance
791+ result [key_evaluated ] = value
792+
793+ all_tensors = symbol_table ["All" ].instance
794+ for k in result :
795+ all_tensors -= k
796+
797+ if all_tensors :
798+ raise EvaluationError (
799+ f"Missing direction for { all_tensors } . Have { result } ."
800+ )
801+
802+ return {t : v for k , v in result .items () for t in k }
803+
804+
763805class Tensors (EvalableModel ):
764806 """
765807 Fields that control which tensor(s) are kept in a :py:class:`~.TensorHolder` and in
@@ -955,6 +997,12 @@ class Memory(TensorHolder, ConcurrentlyBoundable):
955997 physical units may be flattened into only one logical level.
956998 """
957999
1000+ skip_initial_output_write : bool = True
1001+ """
1002+ If False, the initial value of output tensors will be fetched from above and used to
1003+ initalize outputs. If True, this initial fetch and fill is skipped.
1004+ """
1005+
9581006 def _render_node_shape (self ) -> str :
9591007 return "cylinder"
9601008
@@ -986,10 +1034,15 @@ class Toll(TensorHolder):
9861034 zero.
9871035 """
9881036
989- direction : Literal [ "up" , "down" , "up_and_down" ]
1037+ direction : TryEvalTo [ dict ]
9901038 """
991- The direction in which data flows through this `Toll`. If "up", then data
992- flows from below `TensorHolder`, through this `Toll` (plus paying
1039+ The direction in which data flows through this `Toll`. Can be:
1040+
1041+ - A string: ``"up"``, ``"down"``, or ``"up_and_down"`` — applies to all tensors.
1042+ - A dict mapping tensor expressions to direction strings, e.g.
1043+ ``{input: "down", output: "up"}`` — sets direction per tensor.
1044+
1045+ If "up", then data flows from below `TensorHolder`, through this `Toll` (plus paying
9931046 associated costs), and then to the next `TensorHolder` above it. Other data
9941047 movements are assumed to avoid this Toll.
9951048 """
@@ -1000,6 +1053,32 @@ class Toll(TensorHolder):
10001053 def model_post_init (self , __context__ = None ) -> None :
10011054 self ._update_actions (PROCESSING_STAGE_ACTIONS )
10021055
1056+ def _eval_expressions (self , * args , ** kwargs ):
1057+ if getattr (self , "_evaluated" , False ):
1058+ return super ()._eval_expressions (* args , ** kwargs )
1059+
1060+ # Override TensorHolder's _PostCall to also handle direction
1061+ class MyPostCall (_PostCall ):
1062+ def __call__ (self_pc , field , value , evaluated , symbol_table ):
1063+ if field == "bits_per_value_scale" :
1064+ evaluated = _eval_tensor2bits (
1065+ evaluated ,
1066+ location = "bits_per_value_scale" ,
1067+ symbol_table = symbol_table ,
1068+ )
1069+ if field == "direction" :
1070+ evaluated = _eval_direction (
1071+ evaluated ,
1072+ symbol_table = symbol_table ,
1073+ )
1074+ return evaluated
1075+
1076+ # Skip TensorHolder's _eval_expressions (which adds its own post_calls
1077+ # for bits_per_value_scale) since we handle it here too
1078+ return Component ._eval_expressions (
1079+ self , * args , ** kwargs , post_calls = (MyPostCall (),)
1080+ )
1081+
10031082 def _render_node_shape (self ) -> str :
10041083 return "rarrow"
10051084
@@ -1011,6 +1090,13 @@ class Compute(Component, Leaf, ConcurrentlyBoundable):
10111090 actions : EvalableList [Action ] = COMPUTE_ACTIONS
10121091 """ The actions that this `Compute` can perform. """
10131092
1093+ skip_initial_output_write : bool = True
1094+ """
1095+ If False, the initial value of output tensors will be fetched from above and used to
1096+ initalize outputs. If True, this initial fetch and fill is skipped.
1097+ """
1098+
1099+
10141100 def model_post_init (self , __context__ = None ) -> None :
10151101 self ._update_actions (COMPUTE_ACTIONS )
10161102
0 commit comments