Skip to content

Commit ca2d413

Browse files
committed
šŸ› fix branching parameter update semantics and finalize 2.0.1 release notes
1 parent 3f105ce commit ca2d413

5 files changed

Lines changed: 157 additions & 5 deletions

File tree

ā€ŽCHANGELOG.mdā€Ž

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) to build the
88

99
<!-- towncrier release notes start -->
1010

11+
## [2.0.1](https://github.com/Infineon/StreamGen/tree/2.0.1) - 2026-03-31
12+
13+
### šŸ› Fixed
14+
15+
- fixed `SamplingTree.update` and `SamplingTree.set_update_step` so shared parameters in branching trees are updated exactly once per step.
16+
- added regression tests for branching update behavior to ensure schedule progression stays consistent.
17+
- stabilized plotting tests by forcing a non-interactive matplotlib backend during test runs.
18+
19+
1120
## [2.0.0](https://github.com/Infineon/StreamGen/tree/2.0.0) - 2026-03-05
1221

1322
### āž– Removed

ā€Žpyproject.tomlā€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "streamgen"
3-
version = "2.0.0"
3+
version = "2.0.1"
44
description = "🌌 a framework for generating streams of labeled data."
55
authors = [
66
{name = "Laurenz Farthofer", email = "laurenz@hey.com"}

ā€Žstreamgen/samplers/tree.pyā€Ž

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,33 @@ def collect(self, num_samples: int, strategy: SamplingStrategy | SamplingStrateg
325325

326326
return tuple(map(self.collate_func, zip(*samples, strict=True))) if self.collate_func else samples
327327

328+
def _get_unique_parameters(self) -> list[Parameter]:
329+
"""Collect every unique parameter object referenced by the tree nodes."""
330+
unique_parameters = []
331+
seen_parameter_ids = set()
332+
333+
for node in anytree.PreOrderIter(self.root):
334+
if isinstance(node, BranchingNode) and node.probs is not None:
335+
parameter_id = id(node.probs)
336+
if parameter_id not in seen_parameter_ids:
337+
seen_parameter_ids.add(parameter_id)
338+
unique_parameters.append(node.probs)
339+
340+
if isinstance(node, TransformNode) and node.params is not None:
341+
for parameter_name in node.params.parameter_names:
342+
parameter = node.params[parameter_name]
343+
parameter_id = id(parameter)
344+
if parameter_id in seen_parameter_ids:
345+
continue
346+
seen_parameter_ids.add(parameter_id)
347+
unique_parameters.append(parameter)
348+
349+
return unique_parameters
350+
328351
def update(self) -> None:
329352
"""šŸ†™ updates every parameter."""
330-
for node in self.nodes:
331-
node.update()
353+
for parameter in self._get_unique_parameters():
354+
parameter.update()
332355

333356
def set_update_step(self, idx: int) -> None:
334357
"""šŸ• updates every parameter to a certain update step using `param[idx]`.
@@ -339,8 +362,8 @@ def set_update_step(self, idx: int) -> None:
339362
Returns:
340363
None: this function mutates `self`
341364
"""
342-
for node in self.nodes:
343-
node.set_update_step(idx)
365+
for parameter in self._get_unique_parameters():
366+
parameter[idx]
344367

345368
def get_params(self) -> ParameterStore | None:
346369
"""āš™ļø collects parameters from every node.

ā€Žtests/conftest.pyā€Ž

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
"""šŸ—ƒļø fixtures available in all tests."""
2+
3+
import matplotlib
4+
5+
# Use a non-interactive backend in tests to avoid GUI/Tk dependencies.
6+
matplotlib.use("Agg", force=True)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""🧪 regression tests for branching update behavior."""
2+
# ruff: noqa: S101, D103, ANN001, ANN201
3+
4+
from collections import Counter
5+
6+
from streamgen.nodes import TransformNode
7+
from streamgen.parameter import Parameter
8+
from streamgen.samplers.tree import SamplingTree
9+
from streamgen.transforms import noop
10+
11+
12+
# ---------------------------------------------------------------------------- #
13+
# * helper functions #
14+
# ---------------------------------------------------------------------------- #
15+
16+
17+
def _build_minimal_branch_tree() -> SamplingTree:
18+
"""Builds a small tree where one scoped parameter is attached multiple times.
19+
20+
Layout:
21+
root -> BranchingNode(left|right) -> shared_scope
22+
23+
The node after the branching node is deep-copied into each branch by
24+
streamgen's tree construction shorthand. All copies fetch params from the
25+
same scope name ("shared_scope").
26+
"""
27+
nodes = [
28+
TransformNode(noop, name="root"),
29+
{
30+
"name": "selector",
31+
"left": [TransformNode(noop, name="left_leaf")],
32+
"right": [TransformNode(noop, name="right_leaf")],
33+
},
34+
TransformNode(noop, name="shared_scope"),
35+
]
36+
37+
params = {
38+
"shared_scope": {
39+
"sweep": {
40+
"schedule": [10, 20, 30],
41+
"strategy": "hold",
42+
},
43+
},
44+
}
45+
46+
return SamplingTree(nodes=nodes, params=params, rng=0)
47+
48+
49+
def _collect_schedule_values(mode: str) -> list[int]:
50+
"""Collects two update steps from either tree.update or tree.params.update."""
51+
tree = _build_minimal_branch_tree()
52+
parameter = tree.get_params()["shared_scope.sweep"]
53+
values = [int(parameter.value)]
54+
55+
for _ in range(2):
56+
if mode == "tree":
57+
tree.update()
58+
elif mode == "params":
59+
tree.params.update()
60+
else:
61+
raise ValueError(f"unknown mode: {mode}")
62+
values.append(int(parameter.value))
63+
64+
return values
65+
66+
67+
# ---------------------------------------------------------------------------- #
68+
# * tests #
69+
# ---------------------------------------------------------------------------- #
70+
71+
72+
def test_tree_update_calls_each_parameter_once_in_branching_tree(monkeypatch) -> None:
73+
"""Tests that tree.update touches each parameter object once."""
74+
tree = _build_minimal_branch_tree()
75+
target = tree.get_params()["shared_scope.sweep"]
76+
77+
call_counts: Counter[int] = Counter()
78+
original_update = Parameter.update
79+
80+
def wrapped_update(self, *args, **kwargs):
81+
call_counts[id(self)] += 1
82+
return original_update(self, *args, **kwargs)
83+
84+
monkeypatch.setattr(Parameter, "update", wrapped_update)
85+
tree.update()
86+
87+
assert call_counts[id(target)] == 1
88+
89+
90+
def test_parameter_store_update_calls_each_parameter_once(monkeypatch) -> None:
91+
"""Tests that ParameterStore.update touches each parameter object once."""
92+
tree = _build_minimal_branch_tree()
93+
target = tree.get_params()["shared_scope.sweep"]
94+
95+
call_counts: Counter[int] = Counter()
96+
original_update = Parameter.update
97+
98+
def wrapped_update(self, *args, **kwargs):
99+
call_counts[id(self)] += 1
100+
return original_update(self, *args, **kwargs)
101+
102+
monkeypatch.setattr(Parameter, "update", wrapped_update)
103+
tree.params.update()
104+
105+
assert call_counts[id(target)] == 1
106+
107+
108+
def test_tree_update_matches_parameter_store_update_progression() -> None:
109+
"""Tests that tree.update progresses one step and matches params.update."""
110+
tree_values = _collect_schedule_values("tree")
111+
params_values = _collect_schedule_values("params")
112+
113+
assert tree_values == [10, 20, 30]
114+
assert params_values == [10, 20, 30]
115+
assert tree_values == params_values

0 commit comments

Comments
Ā (0)