Skip to content
4 changes: 4 additions & 0 deletions ajet/copilot/write-swarm-client/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ Below are some reference materials.
Please run `ajet-swarm overwatch` during training, this panel displays everything about the weight update timing, transparently.
When opening this panel, you can see 3 modes which you can select from: "rollout_until_finish_enough_episodes"(only count episodes), "rollout_until_finish_enough_tasks" (+consider task group), "rollout_until_finish_enough_non_dummy_tasks" (+consider group reward)

Another important thing to notice: each task must have a valid task_id (str), which is used to:
- Group up epsiodes that belong to same task inside swarm server (you do not have to worry about that).
- Used as a random seed if the task is a game requires random initialization. (e.g. werewolves game's player identity)


### 2-3. Intergrate with your agent loop.

Expand Down
1 change: 1 addition & 0 deletions ajet/tuner_lib/experimental/swarm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
return

task_id = task.task_id
assert task_id, "task.task_id must be valid!"
workflow_output.metadata["task_id"] = task_id
req_obj = EndEpisodeRequest(
client_uuid=self.client_uuid,
Expand Down
13 changes: 13 additions & 0 deletions ajet/tuner_lib/experimental/swarm_overwatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
from pydantic import BaseModel


class RewardHistoryEntry(BaseModel):
"""A single entry in the reward history."""
global_step: int
mean_reward: float
std_reward: float
timestamp: float # Unix timestamp when this entry was recorded


class RewardHistoryResponse(BaseModel):
"""Response containing the reward history for visualization."""
history: List[RewardHistoryEntry] = []


class CurrentBatchRolloutPoolInformation(BaseModel):
sample_collection_method: str = ""
completed_episodes: int = 0
Expand Down
81 changes: 80 additions & 1 deletion ajet/tuner_lib/experimental/swarm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from multiprocessing.managers import DictProxy
from typing import Coroutine, Optional, Tuple, List
from ajet.utils.process_killer import kill_process_tree
from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
CurrentBatchRolloutPoolInformation,
RewardHistoryEntry,
RewardHistoryResponse,
)
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE
from ajet.tuner_lib.experimental.interchange_utils import (
SyncTrainConfigRequest,
Expand Down Expand Up @@ -63,6 +67,14 @@ def register_enable_swarm_mode_routes(
if "current_batch_rollout_pool_information" not in shared_mem_dict:
shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation()

# Initialize reward history storage for visualization
if "reward_history" not in shared_mem_dict:
shared_mem_dict["reward_history"] = [] # List of RewardHistoryEntry dicts

# Initialize reward accumulator for collecting rewards of current global step
if "current_rewards" not in shared_mem_dict:
shared_mem_dict["current_rewards"] = [] # [rewards...]

# ------------------------------------------------------------------------------------------------
# ------ Recycle claimed episodes that client failed to complete in (promised) time --------------
# --------------------------------- claimed -> unclaimed ----------------------------------------
Expand Down Expand Up @@ -166,6 +178,35 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l
if episode_uuid in shared_mem_dict["unclaimed_episodes"]:
shared_mem_dict["unclaimed_episodes"].remove(episode_uuid)

# --------------------------------------------------------------------------------------
# -------------------------- reward history management ---------------------------------
# --------------------------------------------------------------------------------------

def _finalize_reward_history_for_step(global_step, shared_mem_dict, shared_mem_dict_lock):
"""Finalize reward statistics for a given global step and add to reward_history."""
import numpy as np

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import numpy as np statement is placed inside the _finalize_reward_history_for_step function. It is generally considered a best practice in Python to place all imports at the top of the file, outside of any functions or methods. This improves readability, makes dependencies clear at a glance, and avoids potential performance overhead from repeated imports if the function is called frequently. Please move this import to the top of the file with other imports.


rewards = shared_mem_dict.get("current_rewards", [])
if rewards:
rewards = list(rewards) # Convert proxy to list if needed
mean_reward = float(np.mean(rewards))
std_reward = float(np.std(rewards))

history = shared_mem_dict.get("reward_history", [])
history = list(history) # Convert proxy to list if needed

entry = RewardHistoryEntry(
global_step=global_step,
mean_reward=mean_reward,
std_reward=std_reward,
timestamp=time.time(),
)
history.append(entry.model_dump())
shared_mem_dict["reward_history"] = history

# Clear current rewards for next step
shared_mem_dict["current_rewards"] = []

# --------------------------------------------------------------------------------------
# -------------------------- return workflow output ------------------------------------
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -272,6 +313,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict):
shared_mem_dict["unclaimed_episodes"] = []
logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes")

# clear reward tracking
shared_mem_dict["current_rewards"] = []
shared_mem_dict["reward_history"] = []

# --------------------------------------------------------------------------------------
# -------------------------- fastapi routes --------------------------------------------
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -446,7 +491,12 @@ async def update_engine_status(req: UpdateEngineStatusRequest):
engine_status_detail = req.engine_status_detail
global_step = req.global_step
if global_step is not None:
previous_global_step = shared_mem_dict.get("global_step", None)
shared_mem_dict["global_step"] = global_step
# When global_step changes, finalize reward statistics for the previous step
if previous_global_step is not None and previous_global_step != global_step:
_finalize_reward_history_for_step(previous_global_step, shared_mem_dict, shared_mem_dict_lock)

if engine_status_detail is not None:
shared_mem_dict["engine_status_detail"] = engine_status_detail
logger.info(f"[update_engine_status] Engine status set to {req.engine_status}")
Expand Down Expand Up @@ -636,6 +686,21 @@ async def end_episode(req: EndEpisodeRequest):
shared_mem_dict_lock,
)

# Record reward to current_rewards
if workflow_output.reward is not None:
reward_value = workflow_output.reward
# Handle both single reward and list of rewards
if isinstance(reward_value, list):
rewards_to_record = reward_value
else:
rewards_to_record = [reward_value]

with shared_mem_dict_lock:
current_rewards = shared_mem_dict.get("current_rewards", [])
current_rewards = list(current_rewards) # Convert proxy to list if needed
current_rewards.extend(rewards_to_record)
shared_mem_dict["current_rewards"] = current_rewards

elif episode_type == "eval":
if engine_status in ["ENGINE.ROLLING"]:
await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock)
Expand Down Expand Up @@ -779,6 +844,20 @@ async def get_current_batch_rollout_pool_information():
logger.error(f"Error getting current batch rollout pool information: {e}")
return CurrentBatchRolloutPoolInformation()

# --------------------------------------------------------------------
# ------------ get reward history for visualization ------------------
# --------------------------------------------------------------------
@app.get("/get_reward_history", response_model=RewardHistoryResponse)
async def get_reward_history():
"""Get the reward history for visualization (reward curves)."""
try:
history = shared_mem_dict.get("reward_history", [])
entries = [RewardHistoryEntry(**entry) for entry in history]
return RewardHistoryResponse(history=entries)
except Exception as e:
logger.error(f"Error getting reward history: {e}")
return RewardHistoryResponse(history=[])

# --------------------------------------------------------------------
# ------------ bring engine back to ENGINE.OFFLINE -------------------
# --------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions ajet/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _dive_to_set_value(config, dotted_key, value):
sub_config[keys[-1]] = value


def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone):
def align_parameters(from_config_fp, to_config_fp, convertion_json_fp, backbone):
"""Align configuration values based on a conversion map.

Parameters
Expand All @@ -107,7 +107,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
Source YAML path to read values from.
to_config_fp : str
Destination YAML path that is updated in place.
convertion_json_fg : str
convertion_json_fp : str
JSON path mapping dotted keys between configs.
backbone : str
Backbone identifier used for framework-specific alignment.
Expand All @@ -121,7 +121,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone)
# read convertion json
import json

with open(convertion_json_fg, "r", encoding="utf-8") as file:
with open(convertion_json_fp, "r", encoding="utf-8") as file:
convertion_json = json.load(file)

logger.success("----------------------------------------------------")
Expand Down
161 changes: 159 additions & 2 deletions ajet/utils/swarm_overwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from rich.text import Text
from loguru import logger

from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation
from ajet.tuner_lib.experimental.swarm_overwatch_utils import (
CurrentBatchRolloutPoolInformation,
RewardHistoryResponse,
)


class SwarmOverwatch:
Expand Down Expand Up @@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]:
# logger.error(f"Failed to fetch pool info: {e}")
return None

def fetch_reward_history(self) -> Optional[RewardHistoryResponse]:
"""Fetch reward history from server for visualization"""
try:
response = self._httpx_client.get(
f"{self.server_url}/get_reward_history",
timeout=5.0,
)
response.raise_for_status()
data = RewardHistoryResponse.model_validate(response.json())
return data
except Exception as e:
logger.error(f"Failed to fetch reward history: {e}")
return None

def create_header(
self, info: Optional[CurrentBatchRolloutPoolInformation] = None
) -> Panel:
Expand Down Expand Up @@ -450,6 +467,141 @@ def create_dashboard(

return layout

def display_reward_curve(self):
"""Display ASCII reward curve in terminal"""
self.console.clear()

# Fetch reward history
history = self.fetch_reward_history()
if history is None or not history.history:
self.console.print("[bold yellow]No reward history available yet.[/bold yellow]")
self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]")
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()
return

# Get terminal size
terminal_width = self.console.width or 80
terminal_height = self.console.height or 24

# Reserve space for header, labels, and footer
chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels
chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis

# Extract data
global_steps = [entry.global_step for entry in history.history]
mean_rewards = [entry.mean_reward for entry in history.history]

# Calculate y-axis range with padding
y_min = min(mean_rewards)
y_max = max(mean_rewards)
y_range = y_max - y_min
if y_range == 0:
y_range = 1.0 # Avoid division by zero
y_min -= 0.5
y_max += 0.5
else:
# Add 10% padding
y_min -= y_range * 0.1
y_max += y_range * 0.1
y_range = y_max - y_min

# Calculate x-axis range
x_min = min(global_steps)
x_max = max(global_steps)
x_range = x_max - x_min
if x_range == 0:
x_range = 1

# Create the chart grid
chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)]

# Plot the data points
for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)):
# Map to chart coordinates
x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0

# Invert y because terminal coordinates go top-down
y = chart_height - 1 - y

# Clamp to valid range
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))

# Draw point
chart[y][x] = '*'

# Connect points with lines if there are multiple points
if len(global_steps) > 1:
for i in range(len(global_steps) - 1):
step1, reward1 = global_steps[i], mean_rewards[i]
step2, reward2 = global_steps[i + 1], mean_rewards[i + 1]

x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0
x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0
y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0

y1 = chart_height - 1 - y1
y2 = chart_height - 1 - y2

# Simple line drawing between points
steps_between = max(abs(x2 - x1), abs(y2 - y1))
if steps_between > 0:
for s in range(1, steps_between):
t = s / steps_between
x = int(x1 + t * (x2 - x1))
y = int(y1 + t * (y2 - y1))
x = max(0, min(chart_width - 1, x))
y = max(0, min(chart_height - 1, y))
if chart[y][x] == ' ':
chart[y][x] = '.'

# Build the output
output = Text()
output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan")
output.append(f" Server: {self.server_url}\n", style="dim")
output.append(f" Data points: {len(global_steps)}\n\n", style="dim")

# Draw y-axis labels and chart
y_labels = []
for i in range(chart_height):
y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max
y_labels.append(y_val)

for i, row in enumerate(chart):
# Y-axis label (only show a few)
if i == 0 or i == chart_height - 1 or i == chart_height // 2:
label = f"{y_labels[i]:8.3f} |"
else:
label = " |"
output.append(label, style="dim")
output.append(''.join(row), style="green")
output.append("\n")

# X-axis
output.append(" +" + "-" * chart_width + "\n", style="dim")

# X-axis labels
x_label_line = " "
x_label_line += f"{x_min:<{chart_width // 3}}"
mid_step = x_min + x_range // 2
x_label_line += f"{mid_step:^{chart_width // 3}}"
x_label_line += f"{x_max:>{chart_width // 3}}"
output.append(x_label_line[:chart_width + 10] + "\n", style="dim")
output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan")

# Statistics
output.append("\n Statistics:\n", style="bold yellow")
output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green")
output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green")
output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan")
output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan")

self.console.print(output)
self.console.print("\n[dim]Press Enter to return to menu...[/dim]")
input()

def display_latest_llm_call(self):
while True:
Expand Down Expand Up @@ -515,6 +667,7 @@ def choose_run(self) -> str:
self.console.print("\n[bold]Choose action:[/bold]")
self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch")
self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call")
self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve")
self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit")
choice = input("\n> ").strip().lower()

Expand All @@ -526,8 +679,12 @@ def choose_run(self) -> str:
mode = "replay_latest_llm_call"
self.console.clear()
continue
elif choice == "c":
self.display_reward_curve()
self.console.clear()
continue
else:
self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]")
self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]")

def run(self):
"""Start the monitoring interface"""
Expand Down
Loading
Loading