Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions schema/aind_behavior_dynamic_foraging.json
Original file line number Diff line number Diff line change
Expand Up @@ -3137,6 +3137,12 @@
"minimum": 0,
"title": "Evaluation Window",
"type": "integer"
},
"is_baiting": {
"default": true,
"description": "Whether uncollected rewards carry over to the next trial.",
"title": "Is Baiting",
"type": "boolean"
}
},
"title": "WarmupTrialGenerationEndConditions",
Expand Down Expand Up @@ -3241,8 +3247,7 @@
"description": "Parameters defining the reward probability structure."
},
"is_baiting": {
"const": true,
"default": true,
"default": false,
"description": "Whether uncollected rewards carry over to the next trial.",
"title": "Is Baiting",
"type": "boolean"
Expand All @@ -3253,7 +3258,8 @@
"min_trial": 50,
"max_choice_bias": 0.1,
"min_response_rate": 0.8,
"evaluation_window": 20
"evaluation_window": 20,
"is_baiting": true
},
"description": "Conditions to end trial generation."
}
Expand Down
37 changes: 37 additions & 0 deletions scripts/walk_through_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import os

from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset
from aind_behavior_dynamic_foraging.task_logic.trial_generators.warmup_trial_generator import WarmupTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome

logging.basicConfig(
level=logging.DEBUG,
)
logger = logging.getLogger(__name__)


def walk_through_session(data_directory: os.PathLike):
dataset = df_foraging_dataset(data_directory)
software_events = dataset["Behavior"]["SoftwareEvents"]
software_events.load_all()

trial_outcomes = software_events["TrialOutcome"].data["data"].iloc
warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator()
for i, outcome in enumerate(trial_outcomes):
warmup_trial_generator.update(TrialOutcome.model_validate(outcome))
trial = warmup_trial_generator.next()

if not trial:
print(f"Session finished at trial {i}")
return


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Walk through a behavior session.")
parser.add_argument("data_directory", help="Path to the session directory")
args = parser.parse_args()

walk_through_session(args.data_directory)
26 changes: 24 additions & 2 deletions src/Extensions/AindBehaviorDynamicForaging.Generated.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5229,12 +5229,15 @@ public partial class WarmupTrialGenerationEndConditions

private int _evaluationWindow;

private bool _isBaiting;

public WarmupTrialGenerationEndConditions()
{
_minTrial = 50;
_maxChoiceBias = 0.1D;
_minResponseRate = 0.8D;
_evaluationWindow = 20;
_isBaiting = true;
}

protected WarmupTrialGenerationEndConditions(WarmupTrialGenerationEndConditions other)
Expand All @@ -5243,6 +5246,7 @@ protected WarmupTrialGenerationEndConditions(WarmupTrialGenerationEndConditions
_maxChoiceBias = other._maxChoiceBias;
_minResponseRate = other._minResponseRate;
_evaluationWindow = other._evaluationWindow;
_isBaiting = other._isBaiting;
}

/// <summary>
Expand Down Expand Up @@ -5313,6 +5317,23 @@ public int EvaluationWindow
}
}

/// <summary>
/// Whether uncollected rewards carry over to the next trial.
/// </summary>
[Newtonsoft.Json.JsonPropertyAttribute("is_baiting")]
[System.ComponentModel.DescriptionAttribute("Whether uncollected rewards carry over to the next trial.")]
public bool IsBaiting
{
get
{
return _isBaiting;
}
set
{
_isBaiting = value;
}
}

public System.IObservable<WarmupTrialGenerationEndConditions> Generate()
{
return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new WarmupTrialGenerationEndConditions(this)));
Expand All @@ -5328,7 +5349,8 @@ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("MinTrial = " + _minTrial + ", ");
stringBuilder.Append("MaxChoiceBias = " + _maxChoiceBias + ", ");
stringBuilder.Append("MinResponseRate = " + _minResponseRate + ", ");
stringBuilder.Append("EvaluationWindow = " + _evaluationWindow);
stringBuilder.Append("EvaluationWindow = " + _evaluationWindow + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting);
return true;
}

Expand Down Expand Up @@ -5383,7 +5405,7 @@ public WarmupTrialGeneratorSpec()
_minBlockReward = 1;
_kernelSize = 2;
_rewardProbabilityParameters = new RewardProbabilityParameters();
_isBaiting = true;
_isBaiting = false;
_trialGenerationEndParameters = new WarmupTrialGenerationEndConditions();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class WarmupTrialGenerationEndConditions(BaseModel):
default=20, ge=0, description="Number of most recent trials to evaluate the end criteria."
)

is_baiting: bool = Field(default=True, description="Whether uncollected rewards carry over to the next trial.")


class WarmupTrialGeneratorSpec(BlockBasedTrialGeneratorSpec):
type: Literal["WarmupTrialGenerator"] = "WarmupTrialGenerator"
Expand All @@ -52,9 +54,6 @@ class WarmupTrialGeneratorSpec(BlockBasedTrialGeneratorSpec):
default=WarmupTrialGenerationEndConditions(), description="Conditions to end trial generation."
)
min_block_reward: Literal[1] = Field(default=1, title="Minimal rewards in a block to switch")
is_baiting: Literal[True] = Field(
default=True, description="Whether uncollected rewards carry over to the next trial."
)

def create_generator(self) -> "WarmupTrialGenerator":
return WarmupTrialGenerator(self)
Expand All @@ -69,7 +68,8 @@ def _are_end_conditions_met(self) -> bool:
"""

end_conditions = self.spec.trial_generation_end_parameters
choice_history = self.is_right_choice_history
win = end_conditions.evaluation_window
choice_history = self.is_right_choice_history[-win:] if win > 0 else self.is_right_choice_history

choice_len = len(choice_history)
left_choices = choice_history.count(False)
Expand All @@ -80,18 +80,24 @@ def _are_end_conditions_met(self) -> bool:
choice_ratio = 0 if unignored == 0 else right_choices / (unignored)

if (
choice_len >= end_conditions.min_trial
len(self.is_right_choice_history) >= end_conditions.min_trial
and finish_ratio >= end_conditions.min_response_rate
and abs(choice_ratio - 0.5) <= end_conditions.max_choice_bias
):
logger.debug(
"Warmup trial generation end conditions met: "
f"total trials={choice_len}, "
f"total trials={len(self.is_right_choice_history)}, "
f"finish ratio={finish_ratio}, "
f"choice bias={abs(choice_ratio - 0.5)}"
)
return True

logger.debug(
"Warmup trial generation end conditions are not met: "
f"total trials={len(self.is_right_choice_history)}, "
f"finish ratio={finish_ratio}, "
f"choice bias={abs(choice_ratio - 0.5)}"
)
return False

def update(self, outcome: TrialOutcome | str) -> None:
Expand Down
Loading