Skip to content

Commit 8a6b131

Browse files
committed
Merge branch 'main' into pipeline
2 parents f8a0430 + 794ae87 commit 8a6b131

56 files changed

Lines changed: 13495 additions & 2349 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ DOCKER_EXE ?= docker
22
DOCKER_NAME ?= accelforge
33
DOCKER_BUILD ?= ${DOCKER_EXE} buildx build --load --pull
44

5-
VERSION := 0.1.5
5+
VERSION := 0.1.6
66

77
USER := timeloopaccelergy
88
REPO := accelforge

accelforge/_accelerated_imports.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if os.environ.get("ACCELFORGE_ACCELERATED_IMPORTS", "0") == "1":
66
import cudf as pd
77
import cupy as np
8+
89
pandas = pd
910
numpy = np
1011
# import cupy as scipy
@@ -13,6 +14,7 @@
1314
else:
1415
import pandas as pd
1516
import numpy as np
17+
1618
pandas = pd
1719
numpy = np
1820
# import scipy

accelforge/frontend/arch/components.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,48 @@ def _eval_tensor2bits(
760760
return {k2: v for k, v in result.items() for k2 in k}
761761

762762

763+
_VALID_DIRECTIONS = {"up", "down", "up_and_down"}
764+
765+
766+
def _eval_direction(toeval, symbol_table: dict[str, Any]) -> dict[str, str]:
767+
"""Evaluate a direction field. If a string, expand to all tensors. If a dict,
768+
resolve tensor expression keys."""
769+
if isinstance(toeval, str):
770+
if toeval not in _VALID_DIRECTIONS:
771+
raise EvaluationError(
772+
f'Invalid direction: "{toeval}". '
773+
f"Must be one of {sorted(_VALID_DIRECTIONS)}."
774+
)
775+
all_tensors = symbol_table["All"].instance
776+
return {t: toeval for t in all_tensors}
777+
778+
result = {}
779+
for key, value in toeval.items():
780+
if value not in _VALID_DIRECTIONS:
781+
raise EvaluationError(
782+
f'Invalid direction for {key}: "{value}". '
783+
f"Must be one of {sorted(_VALID_DIRECTIONS)}."
784+
)
785+
key_evaluated = eval_set_expression(
786+
expression=key,
787+
symbol_table=symbol_table,
788+
expected_space=TensorName,
789+
location=f"direction key {key}",
790+
).instance
791+
result[key_evaluated] = value
792+
793+
all_tensors = symbol_table["All"].instance
794+
for k in result:
795+
all_tensors -= k
796+
797+
if all_tensors:
798+
raise EvaluationError(
799+
f"Missing direction for {all_tensors}. Have {result}."
800+
)
801+
802+
return {t: v for k, v in result.items() for t in k}
803+
804+
763805
class Tensors(EvalableModel):
764806
"""
765807
Fields that control which tensor(s) are kept in a :py:class:`~.TensorHolder` and in
@@ -955,6 +997,12 @@ class Memory(TensorHolder, ConcurrentlyBoundable):
955997
physical units may be flattened into only one logical level.
956998
"""
957999

1000+
skip_initial_output_write: bool = True
1001+
"""
1002+
If False, the initial value of output tensors will be fetched from above and used to
1003+
initalize outputs. If True, this initial fetch and fill is skipped.
1004+
"""
1005+
9581006
def _render_node_shape(self) -> str:
9591007
return "cylinder"
9601008

@@ -986,10 +1034,15 @@ class Toll(TensorHolder):
9861034
zero.
9871035
"""
9881036

989-
direction: Literal["up", "down", "up_and_down"]
1037+
direction: TryEvalTo[dict]
9901038
"""
991-
The direction in which data flows through this `Toll`. If "up", then data
992-
flows from below `TensorHolder`, through this `Toll` (plus paying
1039+
The direction in which data flows through this `Toll`. Can be:
1040+
1041+
- A string: ``"up"``, ``"down"``, or ``"up_and_down"`` — applies to all tensors.
1042+
- A dict mapping tensor expressions to direction strings, e.g.
1043+
``{input: "down", output: "up"}`` — sets direction per tensor.
1044+
1045+
If "up", then data flows from below `TensorHolder`, through this `Toll` (plus paying
9931046
associated costs), and then to the next `TensorHolder` above it. Other data
9941047
movements are assumed to avoid this Toll.
9951048
"""
@@ -1000,6 +1053,32 @@ class Toll(TensorHolder):
10001053
def model_post_init(self, __context__=None) -> None:
10011054
self._update_actions(PROCESSING_STAGE_ACTIONS)
10021055

1056+
def _eval_expressions(self, *args, **kwargs):
1057+
if getattr(self, "_evaluated", False):
1058+
return super()._eval_expressions(*args, **kwargs)
1059+
1060+
# Override TensorHolder's _PostCall to also handle direction
1061+
class MyPostCall(_PostCall):
1062+
def __call__(self_pc, field, value, evaluated, symbol_table):
1063+
if field == "bits_per_value_scale":
1064+
evaluated = _eval_tensor2bits(
1065+
evaluated,
1066+
location="bits_per_value_scale",
1067+
symbol_table=symbol_table,
1068+
)
1069+
if field == "direction":
1070+
evaluated = _eval_direction(
1071+
evaluated,
1072+
symbol_table=symbol_table,
1073+
)
1074+
return evaluated
1075+
1076+
# Skip TensorHolder's _eval_expressions (which adds its own post_calls
1077+
# for bits_per_value_scale) since we handle it here too
1078+
return Component._eval_expressions(
1079+
self, *args, **kwargs, post_calls=(MyPostCall(),)
1080+
)
1081+
10031082
def _render_node_shape(self) -> str:
10041083
return "rarrow"
10051084

@@ -1011,6 +1090,13 @@ class Compute(Component, Leaf, ConcurrentlyBoundable):
10111090
actions: EvalableList[Action] = COMPUTE_ACTIONS
10121091
""" The actions that this `Compute` can perform. """
10131092

1093+
skip_initial_output_write: bool = True
1094+
"""
1095+
If False, the initial value of output tensors will be fetched from above and used to
1096+
initalize outputs. If True, this initial fetch and fill is skipped.
1097+
"""
1098+
1099+
10141100
def model_post_init(self, __context__=None) -> None:
10151101
self._update_actions(COMPUTE_ACTIONS)
10161102

accelforge/frontend/arch/spatialable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class Spatial(EvalableModel):
3636
Note: Loops may be removed if they are constrained to only one iteration.
3737
"""
3838

39-
min_usage: int | float | str = 0.0
39+
min_usage: EvalsTo[int | float] = 0.0
4040
""" The minimum usage of spatial instances, as a value from 0 to 1. A mapping
4141
is invalid if less than this porportion of this dimension's fanout is utilized.
4242
Mappers that support it (e.g., FFM) may, if no mappings satisfy this constraint,
43-
return the highest-usage mappings.
43+
return the highest-usage mappings. These constraints are disabled for copy Einsums.
4444
"""
4545

4646
reuse: TryEvalTo[InvertibleSet[TensorName]] = "Nothing"

0 commit comments

Comments
 (0)