Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
b562098
starts on coupled baiting curriculum
micahwoodard Feb 27, 2026
8ad3f02
adds coupled baiting curriculum
micahwoodard Feb 28, 2026
34a649c
adds unit tests
micahwoodard Feb 28, 2026
e53258c
adds tests
micahwoodard Feb 28, 2026
daddc4f
lints
micahwoodard Feb 28, 2026
f2e8b9b
merges with main
micahwoodard Mar 10, 2026
7d3deba
seperate curriculum
micahwoodard Mar 10, 2026
1be3c90
upath comment
micahwoodard Mar 11, 2026
c2b2519
merges with main
micahwoodard Mar 11, 2026
ac0c587
removes whitespace
micahwoodard Mar 20, 2026
2022b71
fixes formatting
micahwoodard Mar 20, 2026
7234afc
fixes pyproject.toml
micahwoodard Mar 20, 2026
fd495f9
updates uv lock
micahwoodard Mar 20, 2026
db5ad04
adds computing metrics from dataset
micahwoodard Mar 20, 2026
ba513f2
adds stage_changes flag
micahwoodard Mar 20, 2026
8ad1ad1
adds basic unit tests
micahwoodard Mar 20, 2026
4bce296
updates coupled stages
micahwoodard Mar 23, 2026
59a7927
updates tests
micahwoodard Mar 23, 2026
fdb5ddb
computes metrics with previous metrics and trainer state contained in…
micahwoodard Mar 23, 2026
7f39253
restuctures project
micahwoodard Mar 23, 2026
6f6fc65
rounds distribution rates
micahwoodard Mar 23, 2026
582ec35
updates readme
micahwoodard Mar 23, 2026
2821f59
runs codespell
micahwoodard Mar 23, 2026
46ca96d
fixes example
micahwoodard Mar 23, 2026
27cfbf2
updates quiescent period
micahwoodard Mar 23, 2026
d45c445
remove test
micahwoodard Mar 23, 2026
3d0030c
replaces dashes with underscores
micahwoodard Mar 23, 2026
2b33e38
regenerates
micahwoodard Mar 23, 2026
6b62c59
adds trainer state model
micahwoodard Mar 24, 2026
851529e
Update aind-behavior-services to generate sgen references
bruno-f-cruz Mar 25, 2026
41e13f0
Update packages and install missing distributions package
bruno-f-cruz Mar 25, 2026
aa198a3
Add missing scripting references
bruno-f-cruz Mar 25, 2026
49a66d4
Update references to manipulator calibration
bruno-f-cruz Mar 25, 2026
e359915
updates metrics from dataset
micahwoodard Mar 27, 2026
db3a34d
Merge branch 'feat-adding-curriculum' of github.com:AllenNeuralDynami…
micahwoodard Mar 27, 2026
8f91f4d
regenerates
micahwoodard Mar 27, 2026
2ad6167
Merge branch 'main' into feat-adding-curriculum
micahwoodard Mar 27, 2026
efddded
regenerates
micahwoodard Mar 27, 2026
9b419e2
adds version
micahwoodard Mar 27, 2026
d82c0f2
fixes warmup bug
micahwoodard Mar 31, 2026
db65298
restructures and gurads update calculate metrics
micahwoodard Apr 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)

A repository for the Dynamic Foraging task.
A repository for the Dynamic Foraging task and its associated curricula.

---

Expand Down Expand Up @@ -34,7 +34,7 @@ from the root of the repository.

## ⚙️ Generating settings files

The Dynamic Foraging tasks is instantiated by a set of three settings files that strictly follow a DSL schema. These files are:
The Dynamic Foraging task is instantiated by a set of three settings files that strictly follow a DSL schema. These files are:

- `task_logic.json`
- `rig.json`
Expand All @@ -52,6 +52,8 @@ However, for a better experiment management user experience, it is recommended t

## [> ] CLI tools

### Task CLI

The platform exposes a few CLI tools to facilitate various tasks. Tools are available via:

```powershell
Expand All @@ -66,6 +68,103 @@ uv run dynamic-foraging -h

You may need to install optional dependencies depending on the sub-commands you run.

### Curriculum CLI

Curricula are available via the `curriculum` CLI entry point. For a full list of commands:

```powershell
uv run curriculum -h
```

#### `list` - List Available Curricula

```bash
uv run curriculum list
```

#### `init` - Initialize a Curriculum

Creates an initial trainer state for enrolling a subject in a curriculum.

```bash
# Start at the first stage
uv run curriculum init --curriculum coupled_baiting --output initial_state.json

# Start at a specific stage
uv run curriculum init --curriculum coupled_baiting --stage s_stage_1 --output initial_state.json
```

#### `run` - Run a Curriculum

Evaluates a curriculum based on session data and current trainer state.

```bash
uv run curriculum run \
--data-directory /path/to/session/data \
--input-trainer-state current_state.json \
--output-suggestion /path/to/output
```

Force a specific curriculum:

```bash
uv run curriculum run \
--data-directory /path/to/session/data \
--input-trainer-state current_state.json \
--curriculum coupled_baiting \
--output-suggestion /path/to/output
```

#### `version` / `dsl-version` - Show Versions

```bash
uv run curriculum version # Package version
uv run curriculum dsl-version # Underlying DSL library version
```

---

## Typical curriculum workflow

1. **List available curricula:**
```bash
uv run curriculum list
```

2. **Initialize a subject:**
```bash
uv run curriculum init --curriculum coupled_baiting --output trainer_state.json
```

3. **After a session, evaluate progress:**
```bash
uv run curriculum run \
--data-directory /path/to/session/data \
--input-trainer-state trainer_state.json \
--output-suggestion /path/to/output
```

4. **Use the suggestion for the next session:**
The `suggestion.json` output can be passed as `--input-trainer-state` for the next session.

---

## Style guide

To keep things clear, the following naming conventions are recommended:

- **Policies** should start with `p_` (e.g., `p_identity_policy`)
- **Policy transitions** should start with `pt_`
- **Stages** should start with `s_` (e.g., `s_stage1`)
- **Stage transitions** should start with `st_` and be named after the stages they transition between (e.g., `st_s_stage1_s_stage2`)

Define the following modules within a curriculum:

- **metrics**: Defines (or imports) metrics classes and how to calculate them from data
- **stages**: Defines the different stages of the task, including task settings and optionally policies
- **curriculum**: Defines transitions between stages and generates the entry point to the application

---

## 🎮 Experiment launcher (temporarily CLABE)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ dependencies = [
"pydantic-settings",
]

[tool.uv.workspace]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I had in mind was slightly different. I think we could just move the curriculum package to ./src and use that as the workspace root instead. I think it makes the organization a bit more uniform and easier to follow.

members = ["workspace/*"]

[project.urls]
Documentation = "https://allenneuraldynamics.github.io/Aind.Behavior.DynamicForaging/"
Repository = "https://github.com/AllenNeuralDynamics/Aind.Behavior.DynamicForaging/"
Expand Down
17 changes: 6 additions & 11 deletions schema/aind_behavior_dynamic_foraging.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
]
},
"harp_lickometer_right": {
"defulat": null,
"default": null,
"description": "Harp right lickometer",
"oneOf": [
{
Expand Down Expand Up @@ -121,7 +121,6 @@
"data_directory",
"triggered_camera_controller",
"harp_behavior",
"harp_lickometer_right",
"harp_clock_generator",
"harp_sound_card",
"manipulator",
Expand Down Expand Up @@ -1983,7 +1982,7 @@
"type": "object"
},
"RewardProbabilityParameters": {
"description": "Defines the reward probability structure for a dynamic foraging task.\n\nReward probabilities are defined as pairs (p_left, p_right) normalized by\nbase_reward_sum. Pairs are drawn from a family representing a difficulty level:\n\n Family 0: [[8, 1], [6, 1], [3, 1], [1, 1]]\n Family 1: [[8, 1], [1, 1]]\n Family 2: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]\n Family 3: [[6, 1], [3, 1], [1, 1]]",
"description": "Defines the reward probability structure for a dynamic foraging task.\n\nReward probabilities are defined as pairs (p_left, p_right) normalized by\nbase_reward_sum. Pairs are drawn from a family representing a difficulty level:\n\n Family 1: [[8, 1], [6, 1], [3, 1], [1, 1]]\n Family 2: [[8, 1], [1, 1]]\n Family 3: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]\n Family 4: [[6, 1], [3, 1], [1, 1]]",
"properties": {
"base_reward_sum": {
"default": 0.8,
Expand Down Expand Up @@ -3201,16 +3200,12 @@
"block_len": {
"$ref": "#/$defs/Distribution",
"default": {
"family": "Exponential",
"family": "Scalar",
"distribution_parameters": {
"family": "Exponential",
"rate": 1.0
},
"truncation_parameters": {
"max": 2.0,
"min": 1.0,
"truncation_mode": "exclude"
"family": "Scalar",
"value": 0.0
},
"truncation_parameters": null,
"scaling_parameters": null
},
"description": "Distribution describing block length."
Expand Down
37 changes: 37 additions & 0 deletions scripts/walk_through_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import os

from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset
from aind_behavior_dynamic_foraging.task_logic.trial_generators.warmup_trial_generator import WarmupTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome

logging.basicConfig(
level=logging.DEBUG,
)
logger = logging.getLogger(__name__)


def walk_through_session(data_directory: os.PathLike):
dataset = df_foraging_dataset(data_directory)
software_events = dataset["Behavior"]["SoftwareEvents"]
software_events.load_all()

trial_outcomes = software_events["TrialOutcome"].data["data"].iloc
warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator()
for i, outcome in enumerate(trial_outcomes):
warmup_trial_generator.update(TrialOutcome.model_validate(outcome))
trial = warmup_trial_generator.next()

if not trial:
print(f"Session finished at trial {i}")
return


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Walk through a behavior session.")
parser.add_argument("--data-directory", help="Path to the session directory")
args = parser.parse_args()

walk_through_session(args.data_directory)
18 changes: 9 additions & 9 deletions src/Extensions/AindBehaviorDynamicForaging.Generated.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public HarpLicketySplit HarpLickometerLeft
/// Harp right lickometer
/// </summary>
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("harp_lickometer_right", Required=Newtonsoft.Json.Required.AllowNull)]
[Newtonsoft.Json.JsonPropertyAttribute("harp_lickometer_right")]
[System.ComponentModel.DescriptionAttribute("Harp right lickometer")]
public HarpLicketySplit HarpLickometerRight
{
Expand Down Expand Up @@ -2939,21 +2939,21 @@ public override string ToString()
///Reward probabilities are defined as pairs (p_left, p_right) normalized by
///base_reward_sum. Pairs are drawn from a family representing a difficulty level:
///
/// Family 0: [[8, 1], [6, 1], [3, 1], [1, 1]]
/// Family 1: [[8, 1], [1, 1]]
/// Family 2: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
/// Family 3: [[6, 1], [3, 1], [1, 1]]
/// Family 1: [[8, 1], [6, 1], [3, 1], [1, 1]]
/// Family 2: [[8, 1], [1, 1]]
/// Family 3: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
/// Family 4: [[6, 1], [3, 1], [1, 1]]
/// </summary>
[System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
[System.ComponentModel.DescriptionAttribute(@"Defines the reward probability structure for a dynamic foraging task.

Reward probabilities are defined as pairs (p_left, p_right) normalized by
base_reward_sum. Pairs are drawn from a family representing a difficulty level:

Family 0: [[8, 1], [6, 1], [3, 1], [1, 1]]
Family 1: [[8, 1], [1, 1]]
Family 2: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
Family 3: [[6, 1], [3, 1], [1, 1]]")]
Family 1: [[8, 1], [6, 1], [3, 1], [1, 1]]
Family 2: [[8, 1], [1, 1]]
Family 3: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
Family 4: [[6, 1], [3, 1], [1, 1]]")]
[Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
[Bonsai.CombinatorAttribute(MethodName="Generate")]
public partial class RewardProbabilityParameters
Expand Down
16 changes: 15 additions & 1 deletion src/aind_behavior_dynamic_foraging/data_contract/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path

from aind_behavior_curriculum import TrainerState
from aind_behavior_services.session import Session
from contraqctor.contract import Dataset, DataStreamCollection
from contraqctor.contract.camera import Camera
from contraqctor.contract.harp import (
DeviceYmlByFile,
HarpDevice,
)
from contraqctor.contract.json import PydanticModel, SoftwareEvents
from contraqctor.contract.json import Json, PydanticModel, SoftwareEvents
from contraqctor.contract.mux import MapFromPaths

from .. import __semver__
Expand Down Expand Up @@ -58,6 +59,19 @@ def make_dataset(
name="Behavior",
description="Data from the Behavior modality",
data_streams=[
Json(
name="PreviousMetrics",
reader_params=Json.make_params(
path=root_path / "behavior/previous_metrics.json",
),
),
PydanticModel(
name="TrainerState",
reader_params=PydanticModel.make_params(
model=TrainerState,
path=root_path / "behavior/trainer_state.json",
),
),
HarpDevice(
name="HarpBehavior",
reader_params=HarpDevice.make_params(
Expand Down
2 changes: 1 addition & 1 deletion src/aind_behavior_dynamic_foraging/rig.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class AindDynamicForagingRig(rig.Rig):
)
harp_behavior: harp.HarpBehavior = Field(description="Harp behavior")
harp_lickometer_left: Optional[harp.HarpLicketySplit] = Field(default=None, description="Harp left lickometer")
harp_lickometer_right: Optional[harp.HarpLicketySplit] = Field(defulat=None, description="Harp right lickometer")
harp_lickometer_right: Optional[harp.HarpLicketySplit] = Field(default=None, description="Harp right lickometer")
harp_clock_generator: harp.HarpWhiteRabbit = Field(description="Harp clock generator")
harp_sound_card: DynamicForagingSoundCard = Field(description="Harp sound card")
harp_sniff_detector: Optional[harp.HarpSniffDetector] = Field(default=None, description="Harp sniff detector")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class RewardProbabilityParameters(BaseModel):
Reward probabilities are defined as pairs (p_left, p_right) normalized by
base_reward_sum. Pairs are drawn from a family representing a difficulty level:

Family 0: [[8, 1], [6, 1], [3, 1], [1, 1]]
Family 1: [[8, 1], [1, 1]]
Family 2: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
Family 3: [[6, 1], [3, 1], [1, 1]]
Family 1: [[8, 1], [6, 1], [3, 1], [1, 1]]
Family 2: [[8, 1], [1, 1]]
Family 3: [[1.0, 0.0], [0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]
Family 4: [[6, 1], [3, 1], [1, 1]]

"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import logging
from typing import Literal

from aind_behavior_services.task.distributions import (
Distribution,
ExponentialDistribution,
ExponentialDistributionParameters,
TruncationParameters,
)
from aind_behavior_services.task.distributions import Distribution, Scalar
from pydantic import BaseModel, Field

from ..trial_models import TrialOutcome
Expand Down Expand Up @@ -41,10 +36,7 @@ class WarmupTrialGeneratorSpec(BlockBasedTrialGeneratorSpec):
type: Literal["WarmupTrialGenerator"] = "WarmupTrialGenerator"

block_len: Distribution = Field(
default=ExponentialDistribution(
distribution_parameters=ExponentialDistributionParameters(rate=1),
truncation_parameters=TruncationParameters(min=1, max=2),
),
default=Scalar(value=1),
description="Distribution describing block length.",
)

Expand All @@ -69,7 +61,8 @@ def _are_end_conditions_met(self) -> bool:
"""

end_conditions = self.spec.trial_generation_end_parameters
choice_history = self.is_right_choice_history
win = end_conditions.evaluation_window
choice_history = self.is_right_choice_history[-win:] if win > 0 else self.is_right_choice_history

choice_len = len(choice_history)
left_choices = choice_history.count(False)
Expand All @@ -78,20 +71,25 @@ def _are_end_conditions_met(self) -> bool:

finish_ratio = 0 if choice_len == 0 else (unignored) / choice_len
choice_ratio = 0 if unignored == 0 else right_choices / (unignored)

if (
choice_len >= end_conditions.min_trial
len(self.is_right_choice_history) >= end_conditions.min_trial
and finish_ratio >= end_conditions.min_response_rate
and abs(choice_ratio - 0.5) <= end_conditions.max_choice_bias
):
logger.debug(
"Warmup trial generation end conditions met: "
f"total trials={choice_len}, "
f"total trials={len(self.is_right_choice_history)}, "
f"finish ratio={finish_ratio}, "
f"choice bias={abs(choice_ratio - 0.5)}"
)
return True

logger.debug(
"Warmup trial generation end conditions are not met: "
f"total trials={len(self.is_right_choice_history)}, "
f"finish ratio={finish_ratio}, "
f"choice bias={abs(choice_ratio - 0.5)}"
)
return False

def update(self, outcome: TrialOutcome | str) -> None:
Expand Down
Loading
Loading