diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 7a5bed9..3800c2e 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -3137,6 +3137,12 @@ "minimum": 0, "title": "Evaluation Window", "type": "integer" + }, + "is_baiting": { + "default": true, + "description": "Whether uncollected rewards carry over to the next trial.", + "title": "Is Baiting", + "type": "boolean" } }, "title": "WarmupTrialGenerationEndConditions", @@ -3241,8 +3247,7 @@ "description": "Parameters defining the reward probability structure." }, "is_baiting": { - "const": true, - "default": true, + "default": false, "description": "Whether uncollected rewards carry over to the next trial.", "title": "Is Baiting", "type": "boolean" @@ -3253,7 +3258,8 @@ "min_trial": 50, "max_choice_bias": 0.1, "min_response_rate": 0.8, - "evaluation_window": 20 + "evaluation_window": 20, + "is_baiting": true }, "description": "Conditions to end trial generation." } diff --git a/scripts/walk_through_session.py b/scripts/walk_through_session.py new file mode 100644 index 0000000..6de550e --- /dev/null +++ b/scripts/walk_through_session.py @@ -0,0 +1,37 @@ +import logging +import os + +from aind_behavior_dynamic_foraging.data_contract import dataset as df_foraging_dataset +from aind_behavior_dynamic_foraging.task_logic.trial_generators.warmup_trial_generator import WarmupTrialGeneratorSpec +from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome + +logging.basicConfig( + level=logging.DEBUG, +) +logger = logging.getLogger(__name__) + + +def walk_through_session(data_directory: os.PathLike): + dataset = df_foraging_dataset(data_directory) + software_events = dataset["Behavior"]["SoftwareEvents"] + software_events.load_all() + + trial_outcomes = software_events["TrialOutcome"].data["data"].iloc + warmup_trial_generator = WarmupTrialGeneratorSpec().create_generator() + for i, outcome in enumerate(trial_outcomes): + warmup_trial_generator.update(TrialOutcome.model_validate(outcome)) + trial = warmup_trial_generator.next() + + if not trial: + print(f"Session finished at trial {i}") + return + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Walk through a behavior session.") + parser.add_argument("data_directory", help="Path to the session directory") + args = parser.parse_args() + + walk_through_session(args.data_directory) diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index ef72cbe..9fe7113 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -5229,12 +5229,15 @@ public partial class WarmupTrialGenerationEndConditions private int _evaluationWindow; + private bool _isBaiting; + public WarmupTrialGenerationEndConditions() { _minTrial = 50; _maxChoiceBias = 0.1D; _minResponseRate = 0.8D; _evaluationWindow = 20; + _isBaiting = true; } protected WarmupTrialGenerationEndConditions(WarmupTrialGenerationEndConditions other) @@ -5243,6 +5246,7 @@ protected WarmupTrialGenerationEndConditions(WarmupTrialGenerationEndConditions _maxChoiceBias = other._maxChoiceBias; _minResponseRate = other._minResponseRate; _evaluationWindow = other._evaluationWindow; + _isBaiting = other._isBaiting; } /// @@ -5313,6 +5317,23 @@ public int EvaluationWindow } } + /// + /// Whether uncollected rewards carry over to the next trial. + /// + [Newtonsoft.Json.JsonPropertyAttribute("is_baiting")] + [System.ComponentModel.DescriptionAttribute("Whether uncollected rewards carry over to the next trial.")] + public bool IsBaiting + { + get + { + return _isBaiting; + } + set + { + _isBaiting = value; + } + } + public System.IObservable Generate() { return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new WarmupTrialGenerationEndConditions(this))); @@ -5328,7 +5349,8 @@ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("MinTrial = " + _minTrial + ", "); stringBuilder.Append("MaxChoiceBias = " + _maxChoiceBias + ", "); stringBuilder.Append("MinResponseRate = " + _minResponseRate + ", "); - stringBuilder.Append("EvaluationWindow = " + _evaluationWindow); + stringBuilder.Append("EvaluationWindow = " + _evaluationWindow + ", "); + stringBuilder.Append("IsBaiting = " + _isBaiting); return true; } @@ -5383,7 +5405,7 @@ public WarmupTrialGeneratorSpec() _minBlockReward = 1; _kernelSize = 2; _rewardProbabilityParameters = new RewardProbabilityParameters(); - _isBaiting = true; + _isBaiting = false; _trialGenerationEndParameters = new WarmupTrialGenerationEndConditions(); } diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/warmup_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/warmup_trial_generator.py index e560b19..26f2c52 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/warmup_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/warmup_trial_generator.py @@ -36,6 +36,8 @@ class WarmupTrialGenerationEndConditions(BaseModel): default=20, ge=0, description="Number of most recent trials to evaluate the end criteria." ) + is_baiting: bool = Field(default=True, description="Whether uncollected rewards carry over to the next trial.") + class WarmupTrialGeneratorSpec(BlockBasedTrialGeneratorSpec): type: Literal["WarmupTrialGenerator"] = "WarmupTrialGenerator" @@ -52,9 +54,6 @@ class WarmupTrialGeneratorSpec(BlockBasedTrialGeneratorSpec): default=WarmupTrialGenerationEndConditions(), description="Conditions to end trial generation." ) min_block_reward: Literal[1] = Field(default=1, title="Minimal rewards in a block to switch") - is_baiting: Literal[True] = Field( - default=True, description="Whether uncollected rewards carry over to the next trial." - ) def create_generator(self) -> "WarmupTrialGenerator": return WarmupTrialGenerator(self) @@ -69,7 +68,8 @@ def _are_end_conditions_met(self) -> bool: """ end_conditions = self.spec.trial_generation_end_parameters - choice_history = self.is_right_choice_history + win = end_conditions.evaluation_window + choice_history = self.is_right_choice_history[-win:] if win > 0 else self.is_right_choice_history choice_len = len(choice_history) left_choices = choice_history.count(False) @@ -80,18 +80,24 @@ def _are_end_conditions_met(self) -> bool: choice_ratio = 0 if unignored == 0 else right_choices / (unignored) if ( - choice_len >= end_conditions.min_trial + len(self.is_right_choice_history) >= end_conditions.min_trial and finish_ratio >= end_conditions.min_response_rate and abs(choice_ratio - 0.5) <= end_conditions.max_choice_bias ): logger.debug( "Warmup trial generation end conditions met: " - f"total trials={choice_len}, " + f"total trials={len(self.is_right_choice_history)}, " f"finish ratio={finish_ratio}, " f"choice bias={abs(choice_ratio - 0.5)}" ) return True + logger.debug( + "Warmup trial generation end conditions are not met: " + f"total trials={len(self.is_right_choice_history)}, " + f"finish ratio={finish_ratio}, " + f"choice bias={abs(choice_ratio - 0.5)}" + ) return False def update(self, outcome: TrialOutcome | str) -> None: