From 6ce1492b35d2c5042752e953019dfe84beadd7ea Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 12 Mar 2026 16:42:04 +0800 Subject: [PATCH 1/9] deep-fin-pre-commit-patch --- .pre-commit-config.yaml | 1 + .../example_deep_finance/deep_finance_judge.py | 15 ++++++++------- .../example_deep_finance/judge/cgcv/json_utils.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b97f95..15cebb4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml + exclude: ^tutorial/example_deep_finance/ - id: check-added-large-files - id: check-ast - id: check-json diff --git a/tutorial/example_deep_finance/deep_finance_judge.py b/tutorial/example_deep_finance/deep_finance_judge.py index 8a8e354..071eccf 100644 --- a/tutorial/example_deep_finance/deep_finance_judge.py +++ b/tutorial/example_deep_finance/deep_finance_judge.py @@ -200,7 +200,7 @@ def _load(path, key): _load(train_ref_ans_path, "train") _load(val_ref_ans_path, "val") - def _get_reference_data(self, task_id: str) -> Tuple[str, str]: + def _get_reference_data(self, task_id: str) -> Tuple[str, str | None]: """获取任务的参考答案和领域""" cache_key = "val" if task_id.startswith("val_") else "train" ans = DeepFinanceJudgeByOpenJudge._ref_answers_cache.get(cache_key, {}).get(task_id, "") @@ -301,8 +301,8 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO # 1. 提取输入数据 history = metadata.get("conversation_history", []) - query = metadata.get("query") or getattr(workflow_task.task, "main_query", "") - task_id = metadata.get("task_id") or getattr(workflow_task.task, "task_id", "") + query: str = metadata.get("query") or getattr(workflow_task.task, "main_query", "") + task_id: str = metadata.get("task_id") or getattr(workflow_task.task, "task_id", "") rubrics = metadata.get("rubrics") # 可能是 None 或 list of dicts step_reward = metadata.get("reward_stats", {}).get("step_reward", 0.0) chat_date = metadata.get("chat_date") if metadata else datetime.now().strftime("%Y-%m-%d") @@ -318,7 +318,7 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO # RM Gallery 耗时记录 rm_start_time = time.time() if self._rm_enabled and self.rm_evaluator: - rm_raw = self._evaluate_with_rm_gallery(query, assistants[-1] if assistants else "", ref_ans, task_id, domain) + rm_raw = self._evaluate_with_rm_gallery(query, assistants[-1] if assistants else "", ref_ans, task_id, domain or "") else: rm_raw = 0.0 rm_time = time.time() - rm_start_time @@ -788,19 +788,20 @@ def _save_evaluation_log(self, task_id: str, grader_results: Dict[str, List[Any] 保存 OpenJudge 评估日志(可选) """ try: + grader_results_log: Dict[str, List[Dict[str, Any]]] = {} log = { "task_id": task_id, "query": query, "timestamp": datetime.now().isoformat(), - "grader_results": {} + "grader_results": grader_results_log } # 简化 grader_results 以便序列化 for grader_name, score_list in grader_results.items(): - log["grader_results"][grader_name] = [] + grader_results_log[grader_name] = [] for score in score_list: if hasattr(score, "score"): - log["grader_results"][grader_name].append({ + grader_results_log[grader_name].append({ "score": score.score, "reason": score.reason[:200] if hasattr(score, "reason") else "", }) diff --git a/tutorial/example_deep_finance/judge/cgcv/json_utils.py b/tutorial/example_deep_finance/judge/cgcv/json_utils.py index 7301401..fe6c810 100644 --- a/tutorial/example_deep_finance/judge/cgcv/json_utils.py +++ b/tutorial/example_deep_finance/judge/cgcv/json_utils.py @@ -299,7 +299,7 @@ def validate_cgcv_schema(obj: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], # 验证 status if normalized["status"] not in VALID_STATUSES: # 尝试模糊匹配 - status_lower = normalized["status"] + status_lower: str = normalized["status"] matched = False for valid_status in VALID_STATUSES: if valid_status in status_lower or status_lower in valid_status: From 9707fafcd362ea68c28cab45781045558a7edc54 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 13 Mar 2026 16:04:29 +0800 Subject: [PATCH 2/9] revise openclaw training --- .gitignore | 1 + .../opencode_build_openclaw_agent/README.md | 46 +++++- .../on_compute_relative_reward.py | 138 ++++++++++++++---- .../test_reward.py | 93 ++++++++++++ 4 files changed, 249 insertions(+), 29 deletions(-) create mode 100644 tutorial/opencode_build_openclaw_agent/test_reward.py diff --git a/.gitignore b/.gitignore index db79fdf..00da513 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,4 @@ werewolves_swarm .claude tensorboard_log tutorial/**/*.json +node_modules diff --git a/tutorial/opencode_build_openclaw_agent/README.md b/tutorial/opencode_build_openclaw_agent/README.md index 5c69a53..eefb51d 100644 --- a/tutorial/opencode_build_openclaw_agent/README.md +++ b/tutorial/opencode_build_openclaw_agent/README.md @@ -75,8 +75,19 @@ In a new terminal: ```bash cd tutorial/opencode_build_openclaw_agent + +# Option 1: Use OpenJudge pointwise grading (default) +export AJET_SWARM_URL="http://localhost:10086" +export NUM_REPEAT=4 +export REWARD_MODE=pointwise +export DASHSCOPE_API_KEY=your_api_key_here +python fake_vllm_endpoint.py + +# Option 2: Use OpenJudge listwise ranking export AJET_SWARM_URL="http://localhost:10086" export NUM_REPEAT=4 +export REWARD_MODE=listwise +export DASHSCOPE_API_KEY=your_api_key_here python fake_vllm_endpoint.py ``` @@ -113,13 +124,40 @@ Key parameters in `fake_vllm_endpoint.py`: - `num_repeat=4` - GRPO N parameter (responses per query) - `model` - Base model path +Environment variables for reward computation: + +- `REWARD_MODE` - Reward computation mode: `pointwise` (default) or `listwise` +- `DASHSCOPE_API_KEY` - API key for OpenJudge LLM grader +- `JUDGE_BASE_URL` - Base URL for judge model API (default: DashScope) +- `JUDGE_MODEL` - Judge model name (default: `qwen-plus`) + ## Reward Function -The `ExtraversionGrader` evaluates responses on a 1-10 scale: -- 1 = Highly introverted (reserved, quiet) -- 10 = Highly extraverted (energetic, enthusiastic) +Two OpenJudge-based reward modes are available: + +### 1. Pointwise Mode (Default) -Scores are normalized to [-1, 1] for GRPO training. +Uses OpenJudge LLM grader to evaluate each response independently: +- Evaluates extraversion traits on 1-10 scale +- Provides detailed reasoning for each score +- Scores normalized to [-1, 1] for GRPO training + +```bash +export REWARD_MODE=pointwise +export DASHSCOPE_API_KEY=your_api_key_here +``` + +### 2. Listwise Mode + +Uses OpenJudge to rank all responses together: +- Compares responses directly against each other +- Produces relative rankings +- Best for capturing subtle differences + +```bash +export REWARD_MODE=listwise +export DASHSCOPE_API_KEY=your_api_key_here +``` ## Monitoring diff --git a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py index ea7c164..5bafd2f 100644 --- a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py +++ b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py @@ -1,41 +1,129 @@ # -*- coding: utf-8 -*- -"""Compute relative rewards based on extraversion personality alignment.""" +"""Compute relative rewards based on extraversion personality alignment using OpenJudge.""" +import os from typing import List, Dict from beast_logger import print_listofdict +from openjudge.graders.base_grader import GraderMode, GraderScore, GraderRank +from openjudge.graders.llm_grader import LLMGrader +from openjudge.models import OpenAIChatModel -def score_extraversion(response_text: str) -> float: - """Score response for extraversion traits (1-10 scale).""" - extraversion_keywords = [ - 'excited', 'love', 'amazing', 'awesome', 'fantastic', 'great', - 'wonderful', 'thrilled', 'energetic', 'enthusiastic', 'fun', - 'social', 'outgoing', 'active', 'lively', 'vibrant', 'happy', - 'enjoy', 'delighted', 'cheerful', 'positive' - ] +# Configuration +REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # Options: pointwise, listwise +API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") +BASE_URL = os.getenv("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") +JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen-plus") - text_lower = response_text.lower() - score = 5.0 +# OpenJudge grader setup +judge_model = OpenAIChatModel( + model=JUDGE_MODEL, + api_key=API_KEY, + base_url=BASE_URL, +) - for keyword in extraversion_keywords: - if keyword in text_lower: - score += 0.5 +EXTRAVERSION_PROMPT = """You are evaluating responses for extraversion personality traits. - score += min(response_text.count('!') * 0.3, 2.0) +Extraversion characteristics include: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions +- Use of exclamation marks and emotional words - if len(response_text) < 50: - score -= 1.0 +Rate the response on a scale of 0.0-1.0: +0.0 = Highly introverted (reserved, quiet, minimal emotion) +1.0 = Highly extraverted (energetic, enthusiastic, very expressive) - return max(1.0, min(10.0, score)) +Question: {question} +Response: {response} -async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: - """Compute relative rewards for extraversion alignment.""" +Return a json object with exactly two fields: +- "score": float between 0.0 and 1.0 +- "reason": brief explanation""" + +def build_listwise_template(n: int) -> str: + """Build a listwise prompt template for n responses.""" + answers_block = "\n".join([f"{i+1}. {{answer_{i+1}}}" for i in range(n)]) + return f"""You are ranking multiple responses based on extraversion personality traits. + +Extraversion characteristics include: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions + +Question: {{question}} + +Responses to rank: +{answers_block} + +Rank these responses from most extraverted to least extraverted. +Return a json object with exactly two fields: +- "rank": list of integers (1-indexed) ordered from most to least extraverted, e.g. [2, 1, 3] +- "reason": brief explanation of the ranking""" + +pointwise_grader = LLMGrader( + name="extraversion_pointwise", + mode=GraderMode.POINTWISE, + description="Evaluate extraversion traits", + model=judge_model, + template=EXTRAVERSION_PROMPT, +) + + +async def compute_pointwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: + """Compute rewards using OpenJudge pointwise grading.""" scores = [] for answer in all_answers: content = answer.get("content", "") - raw_score = score_extraversion(content) - normalized = (raw_score - 5.5) / 4.5 - scores.append(normalized) - answer["reward"] = normalized + result = await pointwise_grader.aevaluate(question=question, response=content) + if isinstance(result, GraderScore): + # score is already normalized 0-1 by OpenJudge + score = result.score + else: + score = 0.0 + scores.append(score) + answer["reward"] = score + return scores + + +async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: + """Compute rewards using OpenJudge listwise ranking.""" + n = len(all_answers) + template = build_listwise_template(n) + grader = LLMGrader( + name="extraversion_listwise", + mode=GraderMode.LISTWISE, + description="Rank responses by extraversion", + model=judge_model, + template=template, + ) + kwargs = {"question": question} + for i, ans in enumerate(all_answers): + kwargs[f"answer_{i+1}"] = ans.get("content", "") + + result = await grader.aevaluate(**kwargs) + + scores = [0.0] * n + if isinstance(result, GraderRank): + # rank is a list of 1-indexed positions ordered best to worst + # convert to reward: rank 1 (best) -> 1.0, rank n (worst) -> 0.0 + for position, idx in enumerate(result.rank): + scores[idx - 1] = 1.0 - (position / (n - 1)) if n > 1 else 0.5 + + for answer, score in zip(all_answers, scores): + answer["reward"] = score + return scores + + +async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: + """Compute relative rewards for extraversion alignment.""" + question = valid_results[0].get("question", "") if valid_results else "" + + if REWARD_MODE == "listwise": + scores = await compute_listwise_rewards(question, all_answers) + else: # pointwise (default) + scores = await compute_pointwise_rewards(question, all_answers) - print_listofdict(all_answers, header="on_compute_relative_reward") + print_listofdict(all_answers, header=f"on_compute_relative_reward (mode={REWARD_MODE})") return scores diff --git a/tutorial/opencode_build_openclaw_agent/test_reward.py b/tutorial/opencode_build_openclaw_agent/test_reward.py new file mode 100644 index 0000000..a731b25 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/test_reward.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +"""Test script for on_compute_relative_reward.py using real OpenJudge API.""" + +import asyncio +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) +os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") + + +async def test_pointwise(): + """Test pointwise reward mode with real API.""" + print("\n=== Testing Pointwise Mode (real API) ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + + valid_results = [{"question": "What are your thoughts on Paris?"}] + all_answers = [ + {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "Paris is a city in France."}, + {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + ] + + try: + scores = await mod.on_compute_relative_reward(valid_results, all_answers) + print(f"Scores: {scores}") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" + assert all(isinstance(s, float) for s in scores), "All scores should be floats" + # extraverted responses should score higher than neutral + assert scores[0] > scores[1], f"Extraverted response should score higher than neutral: {scores}" + assert scores[2] > scores[1], f"Extraverted response should score higher than neutral: {scores}" + print("✓ Pointwise mode test passed") + return True + except Exception as e: + print(f"✗ Pointwise mode test failed: {e}") + import traceback + traceback.print_exc() + return False + + +async def test_listwise(): + """Test listwise reward mode with real API.""" + print("\n=== Testing Listwise Mode (real API) ===") + os.environ["REWARD_MODE"] = "listwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + + valid_results = [{"question": "What are your thoughts on Paris?"}] + all_answers = [ + {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "Paris is a city in France."}, + {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + ] + + try: + scores = await mod.on_compute_relative_reward(valid_results, all_answers) + print(f"Scores: {scores}") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" + assert all(isinstance(s, float) for s in scores), "All scores should be floats" + # neutral response should score lowest + assert scores[1] < scores[0] or scores[1] < scores[2], \ + f"Neutral response should score lower than at least one extraverted response: {scores}" + print("✓ Listwise mode test passed") + return True + except Exception as e: + print(f"✗ Listwise mode test failed: {e}") + import traceback + traceback.print_exc() + return False + + +async def main(): + print("Testing on_compute_relative_reward.py (real API)") + print("=" * 50) + + results = [] + results.append(await test_pointwise()) + results.append(await test_listwise()) + + print("\n" + "=" * 50) + print(f"Tests passed: {sum(results)}/{len(results)}") + return all(results) + + +if __name__ == "__main__": + success = asyncio.run(main()) + sys.exit(0 if success else 1) From b6da77fe48b4431b04ac24a2439b6bb7b9b231d6 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 13 Mar 2026 16:16:16 +0800 Subject: [PATCH 3/9] add illustration --- docs/en/example_train_multi_model.md | 3 +++ docs/en/example_train_multi_model.zh.md | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/docs/en/example_train_multi_model.md b/docs/en/example_train_multi_model.md index a062ac0..e55d49d 100644 --- a/docs/en/example_train_multi_model.md +++ b/docs/en/example_train_multi_model.md @@ -90,6 +90,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **Architecture Explanation**: - **Swarm Server 1 (Port 10086)**: Hosts the 7B model, responsible for Agent 1 and Agent 3's inference and training diff --git a/docs/en/example_train_multi_model.zh.md b/docs/en/example_train_multi_model.zh.md index 772a84f..8e74c6b 100644 --- a/docs/en/example_train_multi_model.zh.md +++ b/docs/en/example_train_multi_model.zh.md @@ -88,6 +88,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **架构说明**: - **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练 @@ -176,6 +179,8 @@ sequenceDiagram 4. 将各自的奖励汇报给对应的 Swarm Server 5. 两个 Server 独立执行策略梯度更新 + + ## 训练曲线 ![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png) From f091efcf8243d1ff58f4db918b32c4eb3de34b00 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Thu, 19 Mar 2026 18:05:56 +0800 Subject: [PATCH 4/9] add better reward for openclaw agent build --- .../cheatsheet.md | 47 +++ .../fake_vllm_endpoint.py | 14 +- .../on_compute_relative_reward.py | 298 ++++++++++++++++-- .../on_user_submit_new_requests.py | 42 ++- .../test_reward.py | 283 ++++++++++++++--- 5 files changed, 609 insertions(+), 75 deletions(-) create mode 100644 tutorial/opencode_build_openclaw_agent/cheatsheet.md diff --git a/tutorial/opencode_build_openclaw_agent/cheatsheet.md b/tutorial/opencode_build_openclaw_agent/cheatsheet.md new file mode 100644 index 0000000..0d79b05 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/cheatsheet.md @@ -0,0 +1,47 @@ +# OpenClaw Reward Cheatsheet + +## Run the test + +```bash +cd agentjet/tutorial/opencode_build_openclaw_agent + +# pointwise (default) +DASHSCOPE_API_KEY=your_key python test_reward.py + +# listwise +REWARD_MODE=listwise DASHSCOPE_API_KEY=your_key python test_reward.py +``` + +## Run the training endpoint + +```bash +# pointwise (default) +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=pointwise \ +python fake_vllm_endpoint.py + +# listwise +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=listwise \ +python fake_vllm_endpoint.py +``` + +## Reward modes + +| Mode | Description | +|------|-------------| +| `pointwise` | Each response scored independently (0.0–1.0) | +| `listwise` | All responses ranked together (best=1.0, worst=0.0) | + +## Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `REWARD_MODE` | `pointwise` | `pointwise` or `listwise` | +| `DASHSCOPE_API_KEY` | — | DashScope API key (required) | +| `JUDGE_MODEL` | `qwen-plus` | Judge model name | +| `JUDGE_BASE_URL` | DashScope endpoint | Judge model base URL | +| `AJET_SWARM_URL` | `http://localhost:10086` | Swarm server URL | +| `NUM_REPEAT` | `4` | GRPO N (responses per query) | diff --git a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py index 0831cd2..e73cc80 100644 --- a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py +++ b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py @@ -25,7 +25,7 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from on_user_submit_new_requests import on_user_submit_new_requests +from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history from on_compute_relative_reward import on_compute_relative_reward # Configuration @@ -91,6 +91,14 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i json_data = await request.json() json_data["stream"] = is_stream + # Remove fields not supported by vLLM to avoid warnings + UNSUPPORTED_FIELDS = {"strict", "store"} + for field in UNSUPPORTED_FIELDS: + json_data.pop(field, None) + # Also remove 'strict' from response_format if present + if "response_format" in json_data and isinstance(json_data["response_format"], dict): + json_data["response_format"].pop("strict", None) + async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers) resp.raise_for_status() @@ -200,7 +208,7 @@ async def handle_one2many_request(request: Request, request_id: str) -> Dict | L valid_results = await run_all_episodes(request, is_stream) all_answers = [extract_assistant_message(r.response) for r in valid_results] - rewards = await on_compute_relative_reward(valid_results, all_answers) + rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query) await finalize_episodes(task, valid_results, rewards) @@ -259,7 +267,7 @@ async def health_check(): @app.get("/requests") async def get_requests(): """Get all recorded user requests.""" - return {"requests": USER_REQUEST_RECORD} + return {"requests": get_query_history()} if __name__ == "__main__": diff --git a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py index 5bafd2f..53894a9 100644 --- a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py +++ b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py @@ -1,26 +1,55 @@ # -*- coding: utf-8 -*- -"""Compute relative rewards based on extraversion personality alignment using OpenJudge.""" +"""Compute relative rewards based on extraversion, relevance, diversity, and repetition quality.""" import os +import collections from typing import List, Dict + +from loguru import logger from beast_logger import print_listofdict from openjudge.graders.base_grader import GraderMode, GraderScore, GraderRank from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.format.ngram_repetition_penalty import NgramRepetitionPenaltyGrader from openjudge.models import OpenAIChatModel +try: + from ajet.utils.compute_madness import has_repeat +except ImportError: + # Fallback: when running outside the full ajet package (e.g. tests), + # resolve relative to the repo root. + import sys as _sys + from pathlib import Path as _Path + _repo_root = str(_Path(__file__).resolve().parents[2]) + if _repo_root not in _sys.path: + _sys.path.insert(0, _repo_root) + from ajet.utils.compute_madness import has_repeat +# --------------------------------------------------------------------------- # Configuration -REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # Options: pointwise, listwise +# --------------------------------------------------------------------------- +REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # pointwise | listwise API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") BASE_URL = os.getenv("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen-plus") -# OpenJudge grader setup +# Reward weights (must sum to 1.0) +W_EXTRAVERSION = float(os.getenv("W_EXTRAVERSION", "0.5")) +W_RELEVANCE = float(os.getenv("W_RELEVANCE", "0.3")) +W_DIVERSITY = float(os.getenv("W_DIVERSITY", "0.2")) + +# Cross-request history buffer size +HISTORY_MAX_SIZE = int(os.getenv("DIVERSITY_HISTORY_SIZE", "25")) + +# --------------------------------------------------------------------------- +# Shared model & graders +# --------------------------------------------------------------------------- judge_model = OpenAIChatModel( model=JUDGE_MODEL, api_key=API_KEY, base_url=BASE_URL, ) +# --- Extraversion grader (custom LLM prompt) --- EXTRAVERSION_PROMPT = """You are evaluating responses for extraversion personality traits. Extraversion characteristics include: @@ -41,6 +70,153 @@ - "score": float between 0.0 and 1.0 - "reason": brief explanation""" +pointwise_grader = LLMGrader( + name="extraversion_pointwise", + mode=GraderMode.POINTWISE, + description="Evaluate extraversion traits", + model=judge_model, + template=EXTRAVERSION_PROMPT, +) + +# --- Relevance grader (built-in OpenJudge) --- +relevance_grader = RelevanceGrader(model=judge_model) + +# --- Repetition penalty grader (deterministic, no LLM) --- +# Detects n-gram repetition within a single response. +# Returns score in [0, 1] where 1 = no repetition, 0 = heavily repetitive. +repetition_grader = NgramRepetitionPenaltyGrader( + n=4, # 4-gram detection + penalty_threshold=0.15, # trigger penalty when >15% of n-grams are repeated + use_soft_penalty=True, # gradual penalty rather than cliff + max_penalty=-1.0, # worst case: score becomes 0 + min_scaling=0.0, # at max penalty, multiplier goes to 0 +) + +# --------------------------------------------------------------------------- +# In-process history of recent responses (for cross-request diversity) +# --------------------------------------------------------------------------- +_response_history: List[str] = [] + + +def record_responses_to_history(contents: List[str]) -> None: + """Append new responses to the rolling history buffer.""" + _response_history.extend(contents) + # Trim to keep only the most recent entries + while len(_response_history) > HISTORY_MAX_SIZE: + _response_history.pop(0) + + +# --------------------------------------------------------------------------- +# Diversity: n-gram overlap (fast, deterministic, no LLM needed) +# --------------------------------------------------------------------------- +def _get_ngrams(text: str, n: int = 3) -> collections.Counter: + """Extract character-level n-grams from text.""" + tokens = text.lower().split() + if len(tokens) < n: + return collections.Counter(tokens) + return collections.Counter( + tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1) + ) + + +def _ngram_overlap(text_a: str, text_b: str, n: int = 3) -> float: + """Compute Jaccard overlap of n-grams between two texts. Returns 0-1.""" + ngrams_a = _get_ngrams(text_a, n) + ngrams_b = _get_ngrams(text_b, n) + if not ngrams_a or not ngrams_b: + return 0.0 + intersection = sum((ngrams_a & ngrams_b).values()) + union = sum((ngrams_a | ngrams_b).values()) + return intersection / union if union > 0 else 0.0 + + +def compute_diversity_scores(contents: List[str], history: List[str]) -> List[float]: + """ + Compute a diversity score for each response (0 = duplicate, 1 = fully unique). + + Two components: + 1. Within-batch: average pairwise n-gram overlap with other responses in the batch + 2. Cross-request: max n-gram overlap with any response in the history buffer + + Final diversity_score = 1 - max(within_batch_overlap, cross_request_overlap) + """ + n = len(contents) + scores = [] + for i, content_i in enumerate(contents): + # Within-batch overlap: average overlap with other responses in this batch + if n > 1: + batch_overlaps = [ + _ngram_overlap(content_i, contents[j]) + for j in range(n) + if j != i + ] + within_batch = max(batch_overlaps) # worst-case overlap within batch + else: + within_batch = 0.0 + + # Cross-request overlap: max overlap with any historical response + if history: + cross_request = max(_ngram_overlap(content_i, h) for h in history) + else: + cross_request = 0.0 + + overlap = max(within_batch, cross_request) + scores.append(1.0 - overlap) + + return scores + + +# --------------------------------------------------------------------------- +# Quality gate: repetition & degeneration detection (deterministic) +# --------------------------------------------------------------------------- +async def compute_quality_scores(contents: List[str]) -> List[float]: + """ + Compute a quality multiplier for each response (0 = degenerate, 1 = clean). + + Combines two signals: + 1. NgramRepetitionPenaltyGrader — detects looping/repeated n-gram blocks + 2. compute_string_madness — catches nonsense chars, special token leaks, + character-level repetition + + Returns a score in [0, 1] that will be used as a *multiplier* on the + composite reward, so degenerate outputs get crushed to near-zero. + """ + scores = [] + for content in contents: + # --- Signal 1: n-gram repetition (OpenJudge) --- + try: + rep_result = await repetition_grader.aevaluate(response=content) + # NgramRepetitionPenaltyGrader returns penalty in [-1, 0]: + # 0 = no repetition, -1 = max repetition + # Convert to quality: add 1 → [0, 1] + ngram_penalty = rep_result.score if isinstance(rep_result, GraderScore) else 0.0 + ngram_score = 1.0 + ngram_penalty + except Exception as e: + logger.warning(f"NgramRepetitionPenaltyGrader failed: {e}") + ngram_score = 1.0 + + # --- Signal 2: string madness (char-level degeneration) --- + # Only check for word/char repetition and special token leaks. + # We pass checklist=[] to skip the non-ASCII check (accented + # characters like é are legitimate), and check repetition manually. + madness_score = 1.0 # assume clean + if "<|im_start|>" in content: + madness_score = 0.0 + elif has_repeat(content.split(), remember_n_words=5, patience_max=10): + madness_score = 0.0 + elif has_repeat(content, remember_n_words=4, patience_max=200): + madness_score = 0.0 + + # Combined quality: take the minimum (strictest gate wins) + quality = max(0.0, min(1.0, min(ngram_score, madness_score))) + scores.append(quality) + + return scores + + +# --------------------------------------------------------------------------- +# Extraversion scoring (pointwise / listwise) +# --------------------------------------------------------------------------- def build_listwise_template(n: int) -> str: """Build a listwise prompt template for n responses.""" answers_block = "\n".join([f"{i+1}. {{answer_{i+1}}}" for i in range(n)]) @@ -62,33 +238,20 @@ def build_listwise_template(n: int) -> str: - "rank": list of integers (1-indexed) ordered from most to least extraverted, e.g. [2, 1, 3] - "reason": brief explanation of the ranking""" -pointwise_grader = LLMGrader( - name="extraversion_pointwise", - mode=GraderMode.POINTWISE, - description="Evaluate extraversion traits", - model=judge_model, - template=EXTRAVERSION_PROMPT, -) - -async def compute_pointwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge pointwise grading.""" +async def compute_pointwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using pointwise grading.""" scores = [] for answer in all_answers: content = answer.get("content", "") result = await pointwise_grader.aevaluate(question=question, response=content) - if isinstance(result, GraderScore): - # score is already normalized 0-1 by OpenJudge - score = result.score - else: - score = 0.0 + score = result.score if isinstance(result, GraderScore) else 0.0 scores.append(score) - answer["reward"] = score return scores -async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge listwise ranking.""" +async def compute_listwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using listwise ranking.""" n = len(all_answers) template = build_listwise_template(n) grader = LLMGrader( @@ -106,24 +269,93 @@ async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> Li scores = [0.0] * n if isinstance(result, GraderRank): - # rank is a list of 1-indexed positions ordered best to worst - # convert to reward: rank 1 (best) -> 1.0, rank n (worst) -> 0.0 for position, idx in enumerate(result.rank): scores[idx - 1] = 1.0 - (position / (n - 1)) if n > 1 else 0.5 + return scores + - for answer, score in zip(all_answers, scores): - answer["reward"] = score +# --------------------------------------------------------------------------- +# Relevance scoring (built-in OpenJudge RelevanceGrader, score 1-5 → 0-1) +# --------------------------------------------------------------------------- +async def compute_relevance_scores(question: str, all_answers: List[Dict]) -> List[float]: + """Score how relevant each response is to the question. Returns 0-1.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + result = await relevance_grader.aevaluate(query=question, response=content) + if isinstance(result, GraderScore): + # RelevanceGrader returns 1-5; normalise to 0-1 + score = (result.score - 1.0) / 4.0 + else: + score = 0.0 + scores.append(max(0.0, min(1.0, score))) return scores -async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: - """Compute relative rewards for extraversion alignment.""" - question = valid_results[0].get("question", "") if valid_results else "" +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- +async def on_compute_relative_reward( + valid_results: List, + all_answers: List[Dict], + question: str = "", +) -> List[float]: + """ + Compute composite rewards combining extraversion, relevance, diversity, + and a quality gate for repetition/degeneration. + + Final reward = quality * (W_EXTRAVERSION * extraversion + + W_RELEVANCE * relevance + + W_DIVERSITY * diversity) + The quality multiplier (0-1) acts as a hard gate: degenerate responses + (looping, repeated paragraphs, nonsense characters) get their reward + crushed toward zero regardless of other signal scores. + """ + contents = [a.get("content", "") for a in all_answers] + + # 0. Quality gate (deterministic — fast, runs first) + quality_scores = await compute_quality_scores(contents) + + # 1. Extraversion score (LLM-based) if REWARD_MODE == "listwise": - scores = await compute_listwise_rewards(question, all_answers) - else: # pointwise (default) - scores = await compute_pointwise_rewards(question, all_answers) + extraversion_scores = await compute_listwise_extraversion(question, all_answers) + else: + extraversion_scores = await compute_pointwise_extraversion(question, all_answers) - print_listofdict(all_answers, header=f"on_compute_relative_reward (mode={REWARD_MODE})") - return scores + # 2. Relevance score (LLM-based) + relevance_scores = await compute_relevance_scores(question, all_answers) + + # 3. Diversity score (deterministic, n-gram overlap) + diversity_scores = compute_diversity_scores(contents, _response_history) + + # Composite reward = quality * weighted_sum + final_scores = [] + for i in range(len(all_answers)): + weighted_sum = ( + W_EXTRAVERSION * extraversion_scores[i] + + W_RELEVANCE * relevance_scores[i] + + W_DIVERSITY * diversity_scores[i] + ) + composite = quality_scores[i] * weighted_sum + final_scores.append(round(composite, 4)) + + # Annotate the answer dict for logging + all_answers[i]["reward"] = final_scores[i] + all_answers[i]["quality"] = round(quality_scores[i], 4) + all_answers[i]["extraversion"] = round(extraversion_scores[i], 4) + all_answers[i]["relevance"] = round(relevance_scores[i], 4) + all_answers[i]["diversity"] = round(diversity_scores[i], 4) + + # Update history buffer with this batch's responses + record_responses_to_history(contents) + + print_listofdict( + all_answers, + header=( + f"on_compute_relative_reward (mode={REWARD_MODE}, " + f"w_ext={W_EXTRAVERSION}, w_rel={W_RELEVANCE}, w_div={W_DIVERSITY}, " + f"quality_gate=multiplicative)" + ), + ) + return final_scores diff --git a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py index 07f32a5..11b7932 100644 --- a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py +++ b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py @@ -1,8 +1,44 @@ # -*- coding: utf-8 -*- -"""Handle new user requests.""" +"""Handle new user requests and track query history for diversity awareness.""" +from typing import List, Dict +from loguru import logger from ajet.schema.task import Task +# Rolling buffer of recent queries — used to detect repeated / near-duplicate +# questions so the system can log warnings. The response-level diversity +# signal lives in on_compute_relative_reward._response_history. +_query_history: List[Dict] = [] +QUERY_HISTORY_MAX = 100 + + +def get_query_history() -> List[Dict]: + """Return the current query history (read-only copy).""" + return list(_query_history) + + async def on_user_submit_new_requests(request_id: str, task: Task) -> None: - """Store user request when submitted.""" - pass # No special processing needed for this use case + """ + Store user request metadata when submitted. + + This populates a lightweight in-process history so that: + 1. The /requests endpoint can expose recent queries for debugging. + 2. We can detect if the same question keeps appearing, which signals + a data distribution issue upstream rather than a model problem. + """ + entry = { + "request_id": request_id, + "task_id": task.task_id, + "query": task.main_query, + } + _query_history.append(entry) + + # Trim oldest entries + while len(_query_history) > QUERY_HISTORY_MAX: + _query_history.pop(0) + + logger.info( + f"[on_user_submit] request_id={request_id} " + f"query_len={len(task.main_query)} " + f"history_size={len(_query_history)}" + ) diff --git a/tutorial/opencode_build_openclaw_agent/test_reward.py b/tutorial/opencode_build_openclaw_agent/test_reward.py index a731b25..8b65922 100644 --- a/tutorial/opencode_build_openclaw_agent/test_reward.py +++ b/tutorial/opencode_build_openclaw_agent/test_reward.py @@ -1,90 +1,301 @@ #!/usr/bin/env python3 -"""Test script for on_compute_relative_reward.py using real OpenJudge API.""" +"""Test script for on_compute_relative_reward.py using real OpenJudge API. + +Tests four reward dimensions: + 1. Extraversion — enthusiastic responses score higher + 2. Relevance — on-topic responses score higher than off-topic + 3. Diversity — unique responses score higher than near-duplicates + 4. Quality gate — repetitive/degenerate responses get crushed +""" import asyncio import sys import os sys.path.insert(0, os.path.dirname(__file__)) -os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") +os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-311cfac3a0f94ff4b5ddf401f70fa338") -async def test_pointwise(): - """Test pointwise reward mode with real API.""" - print("\n=== Testing Pointwise Mode (real API) ===") +async def test_pointwise_composite(): + """Test pointwise composite reward (extraversion + relevance + diversity).""" + print("\n=== Testing Pointwise Composite Reward ===") os.environ["REWARD_MODE"] = "pointwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() # fresh history for test isolation - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking and the cafes are amazing!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The energy on the Champs-Élysées is fantastic and so vibrant!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) - print(f"Scores: {scores}") + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Composite scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # extraverted responses should score higher than neutral - assert scores[0] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - assert scores[2] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - print("✓ Pointwise mode test passed") + # Extraverted + relevant responses should beat the flat neutral one + assert scores[0] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + assert scores[2] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_relevance_penalty(): + """Off-topic answers should get lower composite scores than on-topic ones.""" + print("\n=== Testing Relevance Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "What is your favorite food?" + all_answers = [ + # On-topic, extraverted + {"content": "Oh my gosh, I absolutely LOVE sushi! The flavors are incredible and I get so excited every time!"}, + # Off-topic, extraverted (talks about space, not food) + {"content": "WOW space exploration is SO exciting! Rockets launching into the sky fills me with energy!!!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # Both are extraverted, but on-topic should win because of relevance + assert scores[0] > scores[1], \ + f"On-topic extraverted should beat off-topic extraverted: {scores}" + print("PASSED") return True except Exception as e: - print(f"✗ Pointwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False -async def test_listwise(): - """Test listwise reward mode with real API.""" - print("\n=== Testing Listwise Mode (real API) ===") +async def test_diversity_penalty(): + """Near-duplicate answers should get lower diversity scores.""" + print("\n=== Testing Diversity Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about your hobbies." + all_answers = [ + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Near-duplicate of answer 0 + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Unique answer + {"content": "Dancing is my absolute passion! Nothing beats the energy of moving to great music with friends!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # The duplicate pair should have lower diversity than the unique one + div_duplicate = all_answers[0].get("diversity", 1.0) + div_unique = all_answers[2].get("diversity", 0.0) + assert div_unique > div_duplicate, \ + f"Unique response should have higher diversity ({div_unique}) than duplicate ({div_duplicate})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_cross_request_diversity(): + """Answers that repeat historical responses should be penalized.""" + print("\n=== Testing Cross-Request Diversity ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + # Simulate a prior request that produced a response + mod.record_responses_to_history([ + "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!" + ]) + + question = "What do you enjoy doing on weekends?" + all_answers = [ + # Repeats the historical response almost verbatim + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!"}, + # Fresh, unique response + {"content": "Weekends are for exploring new restaurants and trying exotic cuisines! I get so thrilled by new flavors!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + div_stale = all_answers[0].get("diversity", 1.0) + div_fresh = all_answers[1].get("diversity", 0.0) + assert div_fresh > div_stale, \ + f"Fresh response should have higher diversity ({div_fresh}) than stale ({div_stale})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_repetition_penalty(): + """Degenerate looping responses should get near-zero reward.""" + print("\n=== Testing Repetition / Degeneration Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about Dunfermline." + + # Build a degenerate looping response (similar to the real failure case) + good_intro = "Hello! Dunfermline is a charming town in Fife, Scotland, with a rich history." + loop_block = ( + "\n\n---\n\n" + "If you have any specific questions or need more information, just " + "let me know! I'm here to assist you in making your visit to " + "Dunfermline a delightful experience.\n\n---\n\n" + "Looking forward to your wonderful Dunfermline adventures!\n\n---\n\n" + "Thank you for the opportunity to share my thoughts on Dunfermline. " + "If you have any more questions or need assistance, feel free to " + "reach out!" + ) + degenerate_response = good_intro + (loop_block * 15) # repeat the block many times + + all_answers = [ + # Degenerate looping response + {"content": degenerate_response}, + # Clean, concise, extraverted response + {"content": "Dunfermline is absolutely wonderful! The abbey ruins are breathtaking and the town has such vibrant energy. I love the mix of history and modern community spirit there!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" quality={a.get('quality')}, ext={a.get('extraversion')}, " + f"rel={a.get('relevance')}, div={a.get('diversity')}, " + f"reward={a.get('reward')} " + f"content={a['content'][:60]}...") + + quality_degenerate = all_answers[0].get("quality", 1.0) + quality_clean = all_answers[1].get("quality", 0.0) + print(f" Quality scores: degenerate={quality_degenerate}, clean={quality_clean}") + + # The degenerate response should have much lower quality + assert quality_clean > quality_degenerate, \ + f"Clean response quality ({quality_clean}) should exceed degenerate ({quality_degenerate})" + # The clean response should win overall + assert scores[1] > scores[0], \ + f"Clean response ({scores[1]}) should beat degenerate ({scores[0]})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_listwise_composite(): + """Listwise mode should also produce composite rewards.""" + print("\n=== Testing Listwise Composite Reward ===") os.environ["REWARD_MODE"] = "listwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The Champs-Élysées energy is fantastic!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) + scores = await mod.on_compute_relative_reward([], all_answers, question=question) print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" - assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # neutral response should score lowest + # Neutral response should score lowest assert scores[1] < scores[0] or scores[1] < scores[2], \ f"Neutral response should score lower than at least one extraverted response: {scores}" - print("✓ Listwise mode test passed") + print("PASSED") return True except Exception as e: - print(f"✗ Listwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False async def main(): - print("Testing on_compute_relative_reward.py (real API)") - print("=" * 50) + print("Testing on_compute_relative_reward.py — Composite Reward") + print("(extraversion + relevance + diversity + quality gate)") + print("=" * 60) results = [] - results.append(await test_pointwise()) - results.append(await test_listwise()) + results.append(await test_pointwise_composite()) + results.append(await test_relevance_penalty()) + results.append(await test_diversity_penalty()) + results.append(await test_cross_request_diversity()) + results.append(await test_repetition_penalty()) + results.append(await test_listwise_composite()) - print("\n" + "=" * 50) - print(f"Tests passed: {sum(results)}/{len(results)}") + print("\n" + "=" * 60) + passed = sum(results) + total = len(results) + print(f"Tests passed: {passed}/{total}") + if not all(results): + names = [ + "pointwise_composite", "relevance_penalty", "diversity_penalty", + "cross_request_diversity", "repetition_penalty", "listwise_composite", + ] + for name, ok in zip(names, results): + if not ok: + print(f" FAILED: {name}") return all(results) From 78182cd886d94fdf2e5a30e21fb6b761e7964e97 Mon Sep 17 00:00:00 2001 From: binary-husky <96192199+binary-husky@users.noreply.github.com> Date: Thu, 19 Mar 2026 18:10:59 +0800 Subject: [PATCH 5/9] Openclaw exp (#17) * deep-fin-pre-commit-patch * revise openclaw training * add illustration * add better reward for openclaw agent build --- docs/en/example_train_multi_model.md | 3 + docs/en/example_train_multi_model.zh.md | 5 + .../cheatsheet.md | 47 +++ .../fake_vllm_endpoint.py | 14 +- .../on_compute_relative_reward.py | 298 ++++++++++++++++-- .../on_user_submit_new_requests.py | 42 ++- .../test_reward.py | 283 ++++++++++++++--- 7 files changed, 617 insertions(+), 75 deletions(-) create mode 100644 tutorial/opencode_build_openclaw_agent/cheatsheet.md diff --git a/docs/en/example_train_multi_model.md b/docs/en/example_train_multi_model.md index a062ac0..e55d49d 100644 --- a/docs/en/example_train_multi_model.md +++ b/docs/en/example_train_multi_model.md @@ -90,6 +90,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **Architecture Explanation**: - **Swarm Server 1 (Port 10086)**: Hosts the 7B model, responsible for Agent 1 and Agent 3's inference and training diff --git a/docs/en/example_train_multi_model.zh.md b/docs/en/example_train_multi_model.zh.md index 772a84f..8e74c6b 100644 --- a/docs/en/example_train_multi_model.zh.md +++ b/docs/en/example_train_multi_model.zh.md @@ -88,6 +88,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **架构说明**: - **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练 @@ -176,6 +179,8 @@ sequenceDiagram 4. 将各自的奖励汇报给对应的 Swarm Server 5. 两个 Server 独立执行策略梯度更新 + + ## 训练曲线 ![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png) diff --git a/tutorial/opencode_build_openclaw_agent/cheatsheet.md b/tutorial/opencode_build_openclaw_agent/cheatsheet.md new file mode 100644 index 0000000..0d79b05 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/cheatsheet.md @@ -0,0 +1,47 @@ +# OpenClaw Reward Cheatsheet + +## Run the test + +```bash +cd agentjet/tutorial/opencode_build_openclaw_agent + +# pointwise (default) +DASHSCOPE_API_KEY=your_key python test_reward.py + +# listwise +REWARD_MODE=listwise DASHSCOPE_API_KEY=your_key python test_reward.py +``` + +## Run the training endpoint + +```bash +# pointwise (default) +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=pointwise \ +python fake_vllm_endpoint.py + +# listwise +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=listwise \ +python fake_vllm_endpoint.py +``` + +## Reward modes + +| Mode | Description | +|------|-------------| +| `pointwise` | Each response scored independently (0.0–1.0) | +| `listwise` | All responses ranked together (best=1.0, worst=0.0) | + +## Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `REWARD_MODE` | `pointwise` | `pointwise` or `listwise` | +| `DASHSCOPE_API_KEY` | — | DashScope API key (required) | +| `JUDGE_MODEL` | `qwen-plus` | Judge model name | +| `JUDGE_BASE_URL` | DashScope endpoint | Judge model base URL | +| `AJET_SWARM_URL` | `http://localhost:10086` | Swarm server URL | +| `NUM_REPEAT` | `4` | GRPO N (responses per query) | diff --git a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py index 0831cd2..e73cc80 100644 --- a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py +++ b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py @@ -25,7 +25,7 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from on_user_submit_new_requests import on_user_submit_new_requests +from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history from on_compute_relative_reward import on_compute_relative_reward # Configuration @@ -91,6 +91,14 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i json_data = await request.json() json_data["stream"] = is_stream + # Remove fields not supported by vLLM to avoid warnings + UNSUPPORTED_FIELDS = {"strict", "store"} + for field in UNSUPPORTED_FIELDS: + json_data.pop(field, None) + # Also remove 'strict' from response_format if present + if "response_format" in json_data and isinstance(json_data["response_format"], dict): + json_data["response_format"].pop("strict", None) + async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers) resp.raise_for_status() @@ -200,7 +208,7 @@ async def handle_one2many_request(request: Request, request_id: str) -> Dict | L valid_results = await run_all_episodes(request, is_stream) all_answers = [extract_assistant_message(r.response) for r in valid_results] - rewards = await on_compute_relative_reward(valid_results, all_answers) + rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query) await finalize_episodes(task, valid_results, rewards) @@ -259,7 +267,7 @@ async def health_check(): @app.get("/requests") async def get_requests(): """Get all recorded user requests.""" - return {"requests": USER_REQUEST_RECORD} + return {"requests": get_query_history()} if __name__ == "__main__": diff --git a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py index 5bafd2f..53894a9 100644 --- a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py +++ b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py @@ -1,26 +1,55 @@ # -*- coding: utf-8 -*- -"""Compute relative rewards based on extraversion personality alignment using OpenJudge.""" +"""Compute relative rewards based on extraversion, relevance, diversity, and repetition quality.""" import os +import collections from typing import List, Dict + +from loguru import logger from beast_logger import print_listofdict from openjudge.graders.base_grader import GraderMode, GraderScore, GraderRank from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.format.ngram_repetition_penalty import NgramRepetitionPenaltyGrader from openjudge.models import OpenAIChatModel +try: + from ajet.utils.compute_madness import has_repeat +except ImportError: + # Fallback: when running outside the full ajet package (e.g. tests), + # resolve relative to the repo root. + import sys as _sys + from pathlib import Path as _Path + _repo_root = str(_Path(__file__).resolve().parents[2]) + if _repo_root not in _sys.path: + _sys.path.insert(0, _repo_root) + from ajet.utils.compute_madness import has_repeat +# --------------------------------------------------------------------------- # Configuration -REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # Options: pointwise, listwise +# --------------------------------------------------------------------------- +REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # pointwise | listwise API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") BASE_URL = os.getenv("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen-plus") -# OpenJudge grader setup +# Reward weights (must sum to 1.0) +W_EXTRAVERSION = float(os.getenv("W_EXTRAVERSION", "0.5")) +W_RELEVANCE = float(os.getenv("W_RELEVANCE", "0.3")) +W_DIVERSITY = float(os.getenv("W_DIVERSITY", "0.2")) + +# Cross-request history buffer size +HISTORY_MAX_SIZE = int(os.getenv("DIVERSITY_HISTORY_SIZE", "25")) + +# --------------------------------------------------------------------------- +# Shared model & graders +# --------------------------------------------------------------------------- judge_model = OpenAIChatModel( model=JUDGE_MODEL, api_key=API_KEY, base_url=BASE_URL, ) +# --- Extraversion grader (custom LLM prompt) --- EXTRAVERSION_PROMPT = """You are evaluating responses for extraversion personality traits. Extraversion characteristics include: @@ -41,6 +70,153 @@ - "score": float between 0.0 and 1.0 - "reason": brief explanation""" +pointwise_grader = LLMGrader( + name="extraversion_pointwise", + mode=GraderMode.POINTWISE, + description="Evaluate extraversion traits", + model=judge_model, + template=EXTRAVERSION_PROMPT, +) + +# --- Relevance grader (built-in OpenJudge) --- +relevance_grader = RelevanceGrader(model=judge_model) + +# --- Repetition penalty grader (deterministic, no LLM) --- +# Detects n-gram repetition within a single response. +# Returns score in [0, 1] where 1 = no repetition, 0 = heavily repetitive. +repetition_grader = NgramRepetitionPenaltyGrader( + n=4, # 4-gram detection + penalty_threshold=0.15, # trigger penalty when >15% of n-grams are repeated + use_soft_penalty=True, # gradual penalty rather than cliff + max_penalty=-1.0, # worst case: score becomes 0 + min_scaling=0.0, # at max penalty, multiplier goes to 0 +) + +# --------------------------------------------------------------------------- +# In-process history of recent responses (for cross-request diversity) +# --------------------------------------------------------------------------- +_response_history: List[str] = [] + + +def record_responses_to_history(contents: List[str]) -> None: + """Append new responses to the rolling history buffer.""" + _response_history.extend(contents) + # Trim to keep only the most recent entries + while len(_response_history) > HISTORY_MAX_SIZE: + _response_history.pop(0) + + +# --------------------------------------------------------------------------- +# Diversity: n-gram overlap (fast, deterministic, no LLM needed) +# --------------------------------------------------------------------------- +def _get_ngrams(text: str, n: int = 3) -> collections.Counter: + """Extract character-level n-grams from text.""" + tokens = text.lower().split() + if len(tokens) < n: + return collections.Counter(tokens) + return collections.Counter( + tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1) + ) + + +def _ngram_overlap(text_a: str, text_b: str, n: int = 3) -> float: + """Compute Jaccard overlap of n-grams between two texts. Returns 0-1.""" + ngrams_a = _get_ngrams(text_a, n) + ngrams_b = _get_ngrams(text_b, n) + if not ngrams_a or not ngrams_b: + return 0.0 + intersection = sum((ngrams_a & ngrams_b).values()) + union = sum((ngrams_a | ngrams_b).values()) + return intersection / union if union > 0 else 0.0 + + +def compute_diversity_scores(contents: List[str], history: List[str]) -> List[float]: + """ + Compute a diversity score for each response (0 = duplicate, 1 = fully unique). + + Two components: + 1. Within-batch: average pairwise n-gram overlap with other responses in the batch + 2. Cross-request: max n-gram overlap with any response in the history buffer + + Final diversity_score = 1 - max(within_batch_overlap, cross_request_overlap) + """ + n = len(contents) + scores = [] + for i, content_i in enumerate(contents): + # Within-batch overlap: average overlap with other responses in this batch + if n > 1: + batch_overlaps = [ + _ngram_overlap(content_i, contents[j]) + for j in range(n) + if j != i + ] + within_batch = max(batch_overlaps) # worst-case overlap within batch + else: + within_batch = 0.0 + + # Cross-request overlap: max overlap with any historical response + if history: + cross_request = max(_ngram_overlap(content_i, h) for h in history) + else: + cross_request = 0.0 + + overlap = max(within_batch, cross_request) + scores.append(1.0 - overlap) + + return scores + + +# --------------------------------------------------------------------------- +# Quality gate: repetition & degeneration detection (deterministic) +# --------------------------------------------------------------------------- +async def compute_quality_scores(contents: List[str]) -> List[float]: + """ + Compute a quality multiplier for each response (0 = degenerate, 1 = clean). + + Combines two signals: + 1. NgramRepetitionPenaltyGrader — detects looping/repeated n-gram blocks + 2. compute_string_madness — catches nonsense chars, special token leaks, + character-level repetition + + Returns a score in [0, 1] that will be used as a *multiplier* on the + composite reward, so degenerate outputs get crushed to near-zero. + """ + scores = [] + for content in contents: + # --- Signal 1: n-gram repetition (OpenJudge) --- + try: + rep_result = await repetition_grader.aevaluate(response=content) + # NgramRepetitionPenaltyGrader returns penalty in [-1, 0]: + # 0 = no repetition, -1 = max repetition + # Convert to quality: add 1 → [0, 1] + ngram_penalty = rep_result.score if isinstance(rep_result, GraderScore) else 0.0 + ngram_score = 1.0 + ngram_penalty + except Exception as e: + logger.warning(f"NgramRepetitionPenaltyGrader failed: {e}") + ngram_score = 1.0 + + # --- Signal 2: string madness (char-level degeneration) --- + # Only check for word/char repetition and special token leaks. + # We pass checklist=[] to skip the non-ASCII check (accented + # characters like é are legitimate), and check repetition manually. + madness_score = 1.0 # assume clean + if "<|im_start|>" in content: + madness_score = 0.0 + elif has_repeat(content.split(), remember_n_words=5, patience_max=10): + madness_score = 0.0 + elif has_repeat(content, remember_n_words=4, patience_max=200): + madness_score = 0.0 + + # Combined quality: take the minimum (strictest gate wins) + quality = max(0.0, min(1.0, min(ngram_score, madness_score))) + scores.append(quality) + + return scores + + +# --------------------------------------------------------------------------- +# Extraversion scoring (pointwise / listwise) +# --------------------------------------------------------------------------- def build_listwise_template(n: int) -> str: """Build a listwise prompt template for n responses.""" answers_block = "\n".join([f"{i+1}. {{answer_{i+1}}}" for i in range(n)]) @@ -62,33 +238,20 @@ def build_listwise_template(n: int) -> str: - "rank": list of integers (1-indexed) ordered from most to least extraverted, e.g. [2, 1, 3] - "reason": brief explanation of the ranking""" -pointwise_grader = LLMGrader( - name="extraversion_pointwise", - mode=GraderMode.POINTWISE, - description="Evaluate extraversion traits", - model=judge_model, - template=EXTRAVERSION_PROMPT, -) - -async def compute_pointwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge pointwise grading.""" +async def compute_pointwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using pointwise grading.""" scores = [] for answer in all_answers: content = answer.get("content", "") result = await pointwise_grader.aevaluate(question=question, response=content) - if isinstance(result, GraderScore): - # score is already normalized 0-1 by OpenJudge - score = result.score - else: - score = 0.0 + score = result.score if isinstance(result, GraderScore) else 0.0 scores.append(score) - answer["reward"] = score return scores -async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge listwise ranking.""" +async def compute_listwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using listwise ranking.""" n = len(all_answers) template = build_listwise_template(n) grader = LLMGrader( @@ -106,24 +269,93 @@ async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> Li scores = [0.0] * n if isinstance(result, GraderRank): - # rank is a list of 1-indexed positions ordered best to worst - # convert to reward: rank 1 (best) -> 1.0, rank n (worst) -> 0.0 for position, idx in enumerate(result.rank): scores[idx - 1] = 1.0 - (position / (n - 1)) if n > 1 else 0.5 + return scores + - for answer, score in zip(all_answers, scores): - answer["reward"] = score +# --------------------------------------------------------------------------- +# Relevance scoring (built-in OpenJudge RelevanceGrader, score 1-5 → 0-1) +# --------------------------------------------------------------------------- +async def compute_relevance_scores(question: str, all_answers: List[Dict]) -> List[float]: + """Score how relevant each response is to the question. Returns 0-1.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + result = await relevance_grader.aevaluate(query=question, response=content) + if isinstance(result, GraderScore): + # RelevanceGrader returns 1-5; normalise to 0-1 + score = (result.score - 1.0) / 4.0 + else: + score = 0.0 + scores.append(max(0.0, min(1.0, score))) return scores -async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: - """Compute relative rewards for extraversion alignment.""" - question = valid_results[0].get("question", "") if valid_results else "" +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- +async def on_compute_relative_reward( + valid_results: List, + all_answers: List[Dict], + question: str = "", +) -> List[float]: + """ + Compute composite rewards combining extraversion, relevance, diversity, + and a quality gate for repetition/degeneration. + + Final reward = quality * (W_EXTRAVERSION * extraversion + + W_RELEVANCE * relevance + + W_DIVERSITY * diversity) + The quality multiplier (0-1) acts as a hard gate: degenerate responses + (looping, repeated paragraphs, nonsense characters) get their reward + crushed toward zero regardless of other signal scores. + """ + contents = [a.get("content", "") for a in all_answers] + + # 0. Quality gate (deterministic — fast, runs first) + quality_scores = await compute_quality_scores(contents) + + # 1. Extraversion score (LLM-based) if REWARD_MODE == "listwise": - scores = await compute_listwise_rewards(question, all_answers) - else: # pointwise (default) - scores = await compute_pointwise_rewards(question, all_answers) + extraversion_scores = await compute_listwise_extraversion(question, all_answers) + else: + extraversion_scores = await compute_pointwise_extraversion(question, all_answers) - print_listofdict(all_answers, header=f"on_compute_relative_reward (mode={REWARD_MODE})") - return scores + # 2. Relevance score (LLM-based) + relevance_scores = await compute_relevance_scores(question, all_answers) + + # 3. Diversity score (deterministic, n-gram overlap) + diversity_scores = compute_diversity_scores(contents, _response_history) + + # Composite reward = quality * weighted_sum + final_scores = [] + for i in range(len(all_answers)): + weighted_sum = ( + W_EXTRAVERSION * extraversion_scores[i] + + W_RELEVANCE * relevance_scores[i] + + W_DIVERSITY * diversity_scores[i] + ) + composite = quality_scores[i] * weighted_sum + final_scores.append(round(composite, 4)) + + # Annotate the answer dict for logging + all_answers[i]["reward"] = final_scores[i] + all_answers[i]["quality"] = round(quality_scores[i], 4) + all_answers[i]["extraversion"] = round(extraversion_scores[i], 4) + all_answers[i]["relevance"] = round(relevance_scores[i], 4) + all_answers[i]["diversity"] = round(diversity_scores[i], 4) + + # Update history buffer with this batch's responses + record_responses_to_history(contents) + + print_listofdict( + all_answers, + header=( + f"on_compute_relative_reward (mode={REWARD_MODE}, " + f"w_ext={W_EXTRAVERSION}, w_rel={W_RELEVANCE}, w_div={W_DIVERSITY}, " + f"quality_gate=multiplicative)" + ), + ) + return final_scores diff --git a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py index 07f32a5..11b7932 100644 --- a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py +++ b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py @@ -1,8 +1,44 @@ # -*- coding: utf-8 -*- -"""Handle new user requests.""" +"""Handle new user requests and track query history for diversity awareness.""" +from typing import List, Dict +from loguru import logger from ajet.schema.task import Task +# Rolling buffer of recent queries — used to detect repeated / near-duplicate +# questions so the system can log warnings. The response-level diversity +# signal lives in on_compute_relative_reward._response_history. +_query_history: List[Dict] = [] +QUERY_HISTORY_MAX = 100 + + +def get_query_history() -> List[Dict]: + """Return the current query history (read-only copy).""" + return list(_query_history) + + async def on_user_submit_new_requests(request_id: str, task: Task) -> None: - """Store user request when submitted.""" - pass # No special processing needed for this use case + """ + Store user request metadata when submitted. + + This populates a lightweight in-process history so that: + 1. The /requests endpoint can expose recent queries for debugging. + 2. We can detect if the same question keeps appearing, which signals + a data distribution issue upstream rather than a model problem. + """ + entry = { + "request_id": request_id, + "task_id": task.task_id, + "query": task.main_query, + } + _query_history.append(entry) + + # Trim oldest entries + while len(_query_history) > QUERY_HISTORY_MAX: + _query_history.pop(0) + + logger.info( + f"[on_user_submit] request_id={request_id} " + f"query_len={len(task.main_query)} " + f"history_size={len(_query_history)}" + ) diff --git a/tutorial/opencode_build_openclaw_agent/test_reward.py b/tutorial/opencode_build_openclaw_agent/test_reward.py index a731b25..8b65922 100644 --- a/tutorial/opencode_build_openclaw_agent/test_reward.py +++ b/tutorial/opencode_build_openclaw_agent/test_reward.py @@ -1,90 +1,301 @@ #!/usr/bin/env python3 -"""Test script for on_compute_relative_reward.py using real OpenJudge API.""" +"""Test script for on_compute_relative_reward.py using real OpenJudge API. + +Tests four reward dimensions: + 1. Extraversion — enthusiastic responses score higher + 2. Relevance — on-topic responses score higher than off-topic + 3. Diversity — unique responses score higher than near-duplicates + 4. Quality gate — repetitive/degenerate responses get crushed +""" import asyncio import sys import os sys.path.insert(0, os.path.dirname(__file__)) -os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") +os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-311cfac3a0f94ff4b5ddf401f70fa338") -async def test_pointwise(): - """Test pointwise reward mode with real API.""" - print("\n=== Testing Pointwise Mode (real API) ===") +async def test_pointwise_composite(): + """Test pointwise composite reward (extraversion + relevance + diversity).""" + print("\n=== Testing Pointwise Composite Reward ===") os.environ["REWARD_MODE"] = "pointwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() # fresh history for test isolation - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking and the cafes are amazing!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The energy on the Champs-Élysées is fantastic and so vibrant!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) - print(f"Scores: {scores}") + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Composite scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # extraverted responses should score higher than neutral - assert scores[0] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - assert scores[2] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - print("✓ Pointwise mode test passed") + # Extraverted + relevant responses should beat the flat neutral one + assert scores[0] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + assert scores[2] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_relevance_penalty(): + """Off-topic answers should get lower composite scores than on-topic ones.""" + print("\n=== Testing Relevance Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "What is your favorite food?" + all_answers = [ + # On-topic, extraverted + {"content": "Oh my gosh, I absolutely LOVE sushi! The flavors are incredible and I get so excited every time!"}, + # Off-topic, extraverted (talks about space, not food) + {"content": "WOW space exploration is SO exciting! Rockets launching into the sky fills me with energy!!!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # Both are extraverted, but on-topic should win because of relevance + assert scores[0] > scores[1], \ + f"On-topic extraverted should beat off-topic extraverted: {scores}" + print("PASSED") return True except Exception as e: - print(f"✗ Pointwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False -async def test_listwise(): - """Test listwise reward mode with real API.""" - print("\n=== Testing Listwise Mode (real API) ===") +async def test_diversity_penalty(): + """Near-duplicate answers should get lower diversity scores.""" + print("\n=== Testing Diversity Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about your hobbies." + all_answers = [ + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Near-duplicate of answer 0 + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Unique answer + {"content": "Dancing is my absolute passion! Nothing beats the energy of moving to great music with friends!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # The duplicate pair should have lower diversity than the unique one + div_duplicate = all_answers[0].get("diversity", 1.0) + div_unique = all_answers[2].get("diversity", 0.0) + assert div_unique > div_duplicate, \ + f"Unique response should have higher diversity ({div_unique}) than duplicate ({div_duplicate})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_cross_request_diversity(): + """Answers that repeat historical responses should be penalized.""" + print("\n=== Testing Cross-Request Diversity ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + # Simulate a prior request that produced a response + mod.record_responses_to_history([ + "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!" + ]) + + question = "What do you enjoy doing on weekends?" + all_answers = [ + # Repeats the historical response almost verbatim + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!"}, + # Fresh, unique response + {"content": "Weekends are for exploring new restaurants and trying exotic cuisines! I get so thrilled by new flavors!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + div_stale = all_answers[0].get("diversity", 1.0) + div_fresh = all_answers[1].get("diversity", 0.0) + assert div_fresh > div_stale, \ + f"Fresh response should have higher diversity ({div_fresh}) than stale ({div_stale})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_repetition_penalty(): + """Degenerate looping responses should get near-zero reward.""" + print("\n=== Testing Repetition / Degeneration Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about Dunfermline." + + # Build a degenerate looping response (similar to the real failure case) + good_intro = "Hello! Dunfermline is a charming town in Fife, Scotland, with a rich history." + loop_block = ( + "\n\n---\n\n" + "If you have any specific questions or need more information, just " + "let me know! I'm here to assist you in making your visit to " + "Dunfermline a delightful experience.\n\n---\n\n" + "Looking forward to your wonderful Dunfermline adventures!\n\n---\n\n" + "Thank you for the opportunity to share my thoughts on Dunfermline. " + "If you have any more questions or need assistance, feel free to " + "reach out!" + ) + degenerate_response = good_intro + (loop_block * 15) # repeat the block many times + + all_answers = [ + # Degenerate looping response + {"content": degenerate_response}, + # Clean, concise, extraverted response + {"content": "Dunfermline is absolutely wonderful! The abbey ruins are breathtaking and the town has such vibrant energy. I love the mix of history and modern community spirit there!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" quality={a.get('quality')}, ext={a.get('extraversion')}, " + f"rel={a.get('relevance')}, div={a.get('diversity')}, " + f"reward={a.get('reward')} " + f"content={a['content'][:60]}...") + + quality_degenerate = all_answers[0].get("quality", 1.0) + quality_clean = all_answers[1].get("quality", 0.0) + print(f" Quality scores: degenerate={quality_degenerate}, clean={quality_clean}") + + # The degenerate response should have much lower quality + assert quality_clean > quality_degenerate, \ + f"Clean response quality ({quality_clean}) should exceed degenerate ({quality_degenerate})" + # The clean response should win overall + assert scores[1] > scores[0], \ + f"Clean response ({scores[1]}) should beat degenerate ({scores[0]})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_listwise_composite(): + """Listwise mode should also produce composite rewards.""" + print("\n=== Testing Listwise Composite Reward ===") os.environ["REWARD_MODE"] = "listwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The Champs-Élysées energy is fantastic!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) + scores = await mod.on_compute_relative_reward([], all_answers, question=question) print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" - assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # neutral response should score lowest + # Neutral response should score lowest assert scores[1] < scores[0] or scores[1] < scores[2], \ f"Neutral response should score lower than at least one extraverted response: {scores}" - print("✓ Listwise mode test passed") + print("PASSED") return True except Exception as e: - print(f"✗ Listwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False async def main(): - print("Testing on_compute_relative_reward.py (real API)") - print("=" * 50) + print("Testing on_compute_relative_reward.py — Composite Reward") + print("(extraversion + relevance + diversity + quality gate)") + print("=" * 60) results = [] - results.append(await test_pointwise()) - results.append(await test_listwise()) + results.append(await test_pointwise_composite()) + results.append(await test_relevance_penalty()) + results.append(await test_diversity_penalty()) + results.append(await test_cross_request_diversity()) + results.append(await test_repetition_penalty()) + results.append(await test_listwise_composite()) - print("\n" + "=" * 50) - print(f"Tests passed: {sum(results)}/{len(results)}") + print("\n" + "=" * 60) + passed = sum(results) + total = len(results) + print(f"Tests passed: {passed}/{total}") + if not all(results): + names = [ + "pointwise_composite", "relevance_penalty", "diversity_penalty", + "cross_request_diversity", "repetition_penalty", "listwise_composite", + ] + for name, ok in zip(names, results): + if not ok: + print(f" FAILED: {name}") return all(results) From 69b19cb7b9e05fffbc49d216b3bcd886083212c8 Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Fri, 20 Mar 2026 15:02:25 +0800 Subject: [PATCH 6/9] add who-is-spy vibe rl (prompt+result+blog) --- ajet/copilot/write-swarm-client/SKILL.md | 4 + ajet/tuner_lib/experimental/swarm_client.py | 1 + docs/en/example_vibe_rl_who_is_spy.md | 139 +++++++ docs/en/installation.md | 5 +- tutorial/opencode_build_spy_game/__init__.py | 1 + .../opencode_build_spy_game/agent_roll.py | 175 +++++++++ .../opencode_build_spy_game/agent_roll_adv.py | 199 ++++++++++ tutorial/opencode_build_spy_game/agent_run.py | 179 +++++++++ .../opencode_build_spy_game/agent_run_adv.py | 218 +++++++++++ .../opencode_build_spy_game/game_engine.py | 365 ++++++++++++++++++ .../opencode_build_spy_game/mock_dataset.py | 102 +++++ tutorial/opencode_build_spy_game/readme.md | 237 ++++++++++++ .../spy_game_config.yaml | 35 ++ .../test_single_game.py | 43 +++ tutorial/opencode_build_who_is_spy.prompt.md | 48 +++ 15 files changed, 1748 insertions(+), 3 deletions(-) create mode 100644 docs/en/example_vibe_rl_who_is_spy.md create mode 100644 tutorial/opencode_build_spy_game/__init__.py create mode 100644 tutorial/opencode_build_spy_game/agent_roll.py create mode 100644 tutorial/opencode_build_spy_game/agent_roll_adv.py create mode 100644 tutorial/opencode_build_spy_game/agent_run.py create mode 100644 tutorial/opencode_build_spy_game/agent_run_adv.py create mode 100644 tutorial/opencode_build_spy_game/game_engine.py create mode 100644 tutorial/opencode_build_spy_game/mock_dataset.py create mode 100644 tutorial/opencode_build_spy_game/readme.md create mode 100644 tutorial/opencode_build_spy_game/spy_game_config.yaml create mode 100644 tutorial/opencode_build_spy_game/test_single_game.py create mode 100644 tutorial/opencode_build_who_is_spy.prompt.md diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index 3fd6fc4..5112a61 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -194,6 +194,10 @@ Below are some reference materials. Please run `ajet-swarm overwatch` during training, this panel displays everything about the weight update timing, transparently. When opening this panel, you can see 3 modes which you can select from: "rollout_until_finish_enough_episodes"(only count episodes), "rollout_until_finish_enough_tasks" (+consider task group), "rollout_until_finish_enough_non_dummy_tasks" (+consider group reward) + Another important thing to notice: each task must have a valid task_id (str), which is used to: + - Group up epsiodes that belong to same task inside swarm server (you do not have to worry about that). + - Used as a random seed if the task is a game requires random initialization. (e.g. werewolves game's player identity) + ### 2-3. Intergrate with your agent loop. diff --git a/ajet/tuner_lib/experimental/swarm_client.py b/ajet/tuner_lib/experimental/swarm_client.py index 340980d..ba24517 100644 --- a/ajet/tuner_lib/experimental/swarm_client.py +++ b/ajet/tuner_lib/experimental/swarm_client.py @@ -356,6 +356,7 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut return task_id = task.task_id + assert task_id, "task.task_id must be valid!" workflow_output.metadata["task_id"] = task_id req_obj = EndEpisodeRequest( client_uuid=self.client_uuid, diff --git a/docs/en/example_vibe_rl_who_is_spy.md b/docs/en/example_vibe_rl_who_is_spy.md new file mode 100644 index 0000000..ff84044 --- /dev/null +++ b/docs/en/example_vibe_rl_who_is_spy.md @@ -0,0 +1,139 @@ +# Vibe RL 实例:不写一行代码,从零构建一个会玩“谁是卧底”的 Agent 训练器 + + +摘要:强化学习研究中,从灵感迸发,到编写代码,再到第一条成功的训练曲线产生,这个过程是漫长、乏味的。 +幸运的是,如今在 AgentJet 框架中,从想法到训练成功,你只需要动动嘴,花几分钟写一点点提示词, +然后只需要等待片刻,然后你就可以看到**完整、简洁、人类易读易改的训练程序** + **初次训练的训练曲线** 展现在你面前了。 +接下来,我们以经典的“谁是摸底”桌游游戏为例,从零展示不写代码训练Agent的全过程。 + + +## 安装 AgentJet 环境 + +您可以选择[手动安装](https://doc.agentjet.top/en/installation/),或者使用skills安装。运行以下指令将skills复制到claude code或者 opencode中。 +```bash +npx skills add modelscope/agentjet +npx skills add binary-husky/Vibe-RL +``` +在skill添加完成之后,你可以指挥claude code或者opencode使用uv(或者conda / docker)安装 AgentJet。 + +## 撰写提示词 + +在安装完成 AgentJet 之后,就可以直接开始工作了,打开OpenCode(尽管ClaudeCode比OpenCode更加强大,但笔者还是喜欢完全开源的东西;再者,在AgentJet中Vibe RL的难度很低,我们也不需要非常强的agent), +然后选择 claude-4.5-sonnet 模型 (这个模型在推理速度比opus更快,对于不太困难的问题已经足够了),开始执行任务: + +```txt +你的任务: +- 编写一个学习"谁是卧底"任务的智能体,通过强化学习和监督学习相结合的方式训练,游戏规则如下: + - 游戏共有 N 名玩家,其中大多数人是**平民**,少数人是**卧底** + - 游戏开始时,每位平民会收到同一个**平民词**,每位卧底会收到一个与平民词相近但不同的**卧底词**(例如平民词为"苹果",卧底词为"梨") + - 每轮游戏中,所有玩家依次对自己拿到的词进行**口头描述**,描述必须真实反映自己的词,但不能直接说出词语本身,也不能过于明显地暴露自己的身份 + - 全部玩家描述完毕后,进入**投票环节**,所有玩家投票选出自己认为最可疑的卧底,得票最多的玩家被淘汰出局 + - 游戏持续多轮,直到满足以下任一结束条件: + - **平民获胜**:所有卧底均被淘汰 + - **卧底获胜**:卧底人数 ≥ 平民人数(卧底在数量上取得优势) + - 智能体需要通过大量对局训练掌握两种核心能力: + - **描述策略学习**:学会根据自己的词语和当前局势,生成既不暴露身份、又能让同阵营玩家认同的最优描述 + - **推理决策学习**:学会根据历史对话、其他玩家的描述模式和行为特征,准确识别卧底并做出最优投票决策 + - 训练目标:最大化智能体在不同角色(平民/卧底)下的游戏胜率,通过自对弈和奖励机制不断优化策略 +- 我希望使用基础模型 `/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct` +- 使用 8 GPU 训练 +- Batch Size 16 +- 我目前没有数据集,你需要帮助我 mock 少量游戏对局数据以供测试和初始训练 +- 使用OpenAI SDK,灵活使用Tools +- 代码中不得出现中文 + +你的 skill(首先读取该 SKILL 文件,获取必要知识): +./ajet/copilot/write-swarm-client/SKILL.md + +- 追加要求: + - optional 0. (agent_roll) team A 平民 共享一个7B模型, team B卧底使用qwen-max (DASHSCOPE_API_KEY已经在环境变量中), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + - optional 1. (agent_roll_adv) 对抗式训练,team A 平民 共享一个7B模型(swarm server 1), team B卧底共享另一个7B模型(swarm server 2), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + +- 追加要求: + agent_roll: 使用4个显卡 + agent_roll_adv:swarm server 1 和 swarm server 2 分别使用4个显卡(一共8个显卡) + +- 追加要求:使用 tmux + uv 的 .venv 调试,直到所有Bug都已经排除 & 训练正常开始。你可以使用 `spy-swarm-server`, `spy-swarm-server-2`, `spy-swarm-client` 三个 tmux session + + - 当前调试阶段: + 调试 agent_roll 【执行调试】 + 调试 agent_roll_adv 【跳过调试】 +``` + + +## 检查结果 + +### 生成的训练代码 + +在agentjet skill的指导下,OpenCode会在 tutorial/opencode_build_*** 生成训练的全部代码: + +```bash +(base) ➜ agentjet git:(main) ✗ tree tutorial/opencode_build_spy_game +tutorial/opencode_build_spy_game/ +├── mock_dataset.py # Generate mock game configurations +├── mock_game_dataset.json # 200 game scenarios with word pairs +├── game_engine.py # Core game mechanics and player logic +├── agent_run.py # Agent executor for agent_roll mode +├── agent_roll.py # Training script for agent_roll mode +├── agent_run_adv.py # Agent executor for adversarial mode +├── agent_roll_adv.py # Training script for adversarial mode +└── readme.md # This file +``` + +### 检查训练蜂群,发现并引导智能体修复训练的Bug + + +等了一会,运行 `ajet-swarm overwatch` 命令,看一下现在训练运行到第几步了,结果发现 claude-sonnet 搞出了一个令人难绷错误: + +```bash + Completed Episode Pool Summary (Progress to Hit Next Weight Update) +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Metric ┃ Current ┃ Target ┃ Progress ┃ Bar ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ Completed Episodes │ 140 │ 16 │ 875.0% │ █████████████████████████████████████████████████████████████████████ │ +│ │ │ │ │ █████████████████████████████████████████████████████████████████████ │ +│ │ │ │ │ █████████████████████████████████████ │ +│ -> *Completed Tasks (chosen)* │ 1 │ 4 │ 25.0% │ █████░░░░░░░░░░░░░░░ │ +│ Completed Non-Dummy Tasks │ 1 │ 4 │ 25.0% │ █████░░░░░░░░░░░░░░░ │ +│ Average Episode Per Task │ 140.00 │ 4 │ - │ - │ +└────────────────────────────────────────┴─────────────┴─────────────┴──────────────┴───────────────────────────────────────────────────────────────────────┘ + + Task Completion Details +┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Task ID ┃ Episodes ┃ Reward ┃ Episode UUIDs (first 3) ┃ +┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ │ 140 │ 0.779 ± 0.448 │ b47d7b96..., 8caec2d7..., b48bd9fb... (+137 more) │ +└──────────────┴───────────────┴───────────────────────┴───────────────────────────────────────────────────────────────────────────┘ +``` + +从蜂群监视表格可以看出,现在样本池已经累计了 875.0%(140个)的回合样本,但AgentJet并没有开始训练。 +仔细一看,CompletedTasks 进度只有 1个,说明140个回合都被识别成一个task了。这些样本的task id,哎,怎么是空字符串? +毫无疑问,claude mock的数据集出了很搞笑的问题,直接给OpenCode下达新指令: + +```txt +task.task_id 有严重的问题,task_id应该是每个episode的随机数种子,不能为空! +``` + +顺便修改了一下参数,batchsize从4改成32,grpo_n从4改成6,然后喝杯茶,再回来看看。不错,这次正常了。 + +![alt text](https://img.alicdn.com/imgextra/i4/O1CN01cQny931D4FI93OwyB_!!6000000000162-2-tps-2445-1227.png) + + +为了保证agent运行逻辑是准确无误的,我们再打开 beast_logger (和agentjet配套的日志监视组件) 看一眼: + +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01w7QLeg26hS3yIma36_!!6000000007693-2-tps-3782-1963.png) + +看了一眼,果然还是有问题(有点后悔没用opus了)。我们的要求是team A平民共享大脑用一个7B模型, team B卧底使用qwen-max。但平民队伍里面怎么混进来一个间谍? +这回得让claude-sonnet好好反省一下了: + +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01ECZFjI286viB25hk1_!!6000000007884-2-tps-1079-498.png) + +等一会,再看了一下,问题都已经修复了 + +### 检查训练曲线 + +去SwanLab看看,不错,奖励平稳上升。 + +![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png) \ No newline at end of file diff --git a/docs/en/installation.md b/docs/en/installation.md index 2909ca1..1c01533 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -87,9 +87,8 @@ AgentJet supports multiple backbones, you can choose any of them depending on yo !!! warning "flash-attn Installation" - `flash-attn` must be installed **after** other dependencies. - - Ensure a healthy connection to GitHub to install pre-compiled wheels. - - If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub. - - To build faster, export `MAX_JOBS=${N_CPU}`. + - Please ensure a **healthy connection to GitHub** to install pre-compiled wheels to install faster. + - If the connection to github is unstable, it automatically falls back to building from source, export `MAX_JOBS=${N_CPU}` to install faster. === "Trinity" diff --git a/tutorial/opencode_build_spy_game/__init__.py b/tutorial/opencode_build_spy_game/__init__.py new file mode 100644 index 0000000..78a7861 --- /dev/null +++ b/tutorial/opencode_build_spy_game/__init__.py @@ -0,0 +1 @@ +# Spy Game RL Agent diff --git a/tutorial/opencode_build_spy_game/agent_roll.py b/tutorial/opencode_build_spy_game/agent_roll.py new file mode 100644 index 0000000..7eb344c --- /dev/null +++ b/tutorial/opencode_build_spy_game/agent_roll.py @@ -0,0 +1,175 @@ +""" +Swarm client for training spy game agent - agent_roll mode. +Civilians (7B model) vs Spies (qwen-max) +""" + +import os +import json +import uuid +from pathlib import Path +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.default_config.ajet_default import AjetTaskReader +from ajet.task_reader import RouterTaskReader +from ajet.schema.task import Task +from tutorial.opencode_build_spy_game.agent_run import run_agent_and_compute_reward + + +# Local configurations (client-side) +LOCAL_GRPO_N = 6 # GRPO group size (rollout.n) +LOCAL_NUM_EPOCH = 100 +LOCAL_MAX_PARALLEL = 32 +LOCAL_DATASET_PATH = str(Path(__file__).parent / "mock_game_dataset.json") + +# Remote configurations (swarm server) +REMOTE_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +REMOTE_BATCH_SIZE = 32 # Small batch size to fit in memory +REMOTE_ALLOCATE_GPU = 8 # Use only 2 GPUs to avoid OOM +REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + + +class SpyGameDatasetReader: + """Custom dataset reader for spy game configurations.""" + + def __init__(self, dataset_path: str): + self.dataset_path = dataset_path + with open(dataset_path, 'r', encoding='utf-8') as f: + self.data = json.load(f) + + def generate_training_tasks(self): + """Generate training tasks from dataset.""" + for idx, item in enumerate(self.data): + # Each task needs a unique task_id - use a deterministic ID based on index + task_id = f"spy_game_task_{idx:04d}" + yield Task( + task_id=task_id, # Required: explicit task_id + main_query=f"Play spy game episode {idx}", + metadata={ + "civilian_word": item["civilian_word"], + "spy_word": item["spy_word"], + "num_players": item["num_players"], + "num_spies": item["num_spies"], + "episode_id": idx + } + ) + + +def main(): + """Main training loop.""" + + # Load dataset + print(f"Loading dataset from: {LOCAL_DATASET_PATH}") + dataset_reader = SpyGameDatasetReader(LOCAL_DATASET_PATH) + + # Connect to swarm server + print(f"Connecting to swarm server: {REMOTE_SWARM_URL}") + swarm_worker = SwarmClient(REMOTE_SWARM_URL) + + # Configure and start training + ajet_job = AgentJetJob( + algorithm="grpo", + project_name="spy-game-rl", + logging="swanlab", + experiment_name="agent_roll_7b_vs_qwen_max", + n_gpu=REMOTE_ALLOCATE_GPU, + model=REMOTE_TRAIN_MODEL, + batch_size=REMOTE_BATCH_SIZE, + num_repeat=LOCAL_GRPO_N, + ) + + print("Starting swarm engine...") + swarm_worker.auto_sync_train_config_and_start_engine(ajet_job) + + def rollout(task: Task): + """Execute one episode rollout.""" + try: + # Begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=300) + + # Execute agent workflow + workflow_output = run_agent_and_compute_reward( + task=task, + base_url=api_baseurl_key.base_url, + api_key=api_baseurl_key.api_key + ) + + # Report result back to swarm server + swarm_worker.end_episode(task, episode_uuid, workflow_output) + + # Print status + print(f"Episode {task.metadata.get('episode_id', '?')}: " + f"Winner={workflow_output.metadata.get('winner', '?')}, " + f"Reward={workflow_output.reward:.2f}") + + swarm_worker.print_rollout_stat() + + return workflow_output.reward + + except Exception as e: + print(f"Error in rollout: {e}") + return None + + # Training loop + print(f"\nStarting training for {LOCAL_NUM_EPOCH} epochs...") + + for epoch in range(LOCAL_NUM_EPOCH): + print(f"\n{'='*60}") + print(f"EPOCH {epoch + 1}/{LOCAL_NUM_EPOCH}") + print(f"{'='*60}") + + next_batch = [] + task_count = 0 + for task in dataset_reader.generate_training_tasks(): + task_count += 1 + # For each task, add it LOCAL_GRPO_N times to the batch + # These are multiple rollouts of the SAME task for GRPO + for _ in range(LOCAL_GRPO_N): + next_batch.append(task) + + # Debug logging + if task_count <= 5: + print(f"[DEBUG] Added task {task_count} (episode_id={task.metadata.get('episode_id')}), batch size now: {len(next_batch)}") + + # When we have enough tasks in batch, execute them + if len(next_batch) >= (REMOTE_BATCH_SIZE * LOCAL_GRPO_N): + # Execute batch with retry logic + episode_results = run_episodes_until_all_complete( + next_batch, + func=rollout, + auto_retry=True + ) + + # Print batch statistics + valid_results = [r for r in episode_results if r is not None] + if valid_results: + avg_reward = sum(valid_results) / len(valid_results) + num_tasks = len(next_batch) // LOCAL_GRPO_N + print(f"\nBatch completed: {len(valid_results)}/{len(next_batch)} episodes " + f"({num_tasks} tasks x {LOCAL_GRPO_N} episodes), Avg reward: {avg_reward:.3f}") + + next_batch.clear() + + # Process any remaining tasks in the batch at end of epoch + if len(next_batch) > 0: + episode_results = run_episodes_until_all_complete( + next_batch, + func=rollout, + auto_retry=True + ) + valid_results = [r for r in episode_results if r is not None] + if valid_results: + avg_reward = sum(valid_results) / len(valid_results) + num_tasks = len(next_batch) // LOCAL_GRPO_N + print(f"\nFinal batch completed: {len(valid_results)}/{len(next_batch)} episodes " + f"({num_tasks} tasks x {LOCAL_GRPO_N} episodes), Avg reward: {avg_reward:.3f}") + + print("\n" + "="*60) + print("Training completed!") + print("="*60) + + # Optionally stop the engine (commented out to keep it running) + # swarm_worker.stop_engine() + + +if __name__ == "__main__": + main() diff --git a/tutorial/opencode_build_spy_game/agent_roll_adv.py b/tutorial/opencode_build_spy_game/agent_roll_adv.py new file mode 100644 index 0000000..3094ffd --- /dev/null +++ b/tutorial/opencode_build_spy_game/agent_roll_adv.py @@ -0,0 +1,199 @@ +""" +Swarm client for adversarial training - agent_roll_adv mode. +Team A (civilians): 7B model from swarm server 1 +Team B (spies): 7B model from swarm server 2 +Both teams train simultaneously in competitive setting. +""" + +import os +import json +import uuid +from pathlib import Path +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.experimental.swarm_client import SwarmClient, run_episodes_until_all_complete +from ajet.schema.task import Task +from tutorial.opencode_build_spy_game.agent_run_adv import run_agent_and_compute_reward + + +# Local configurations (client-side) +LOCAL_GRPO_N = 4 # GRPO group size (rollout.n) +LOCAL_NUM_EPOCH = 100 +LOCAL_MAX_PARALLEL = 16 +LOCAL_DATASET_PATH = str(Path(__file__).parent / "mock_game_dataset.json") + +# Remote configurations for swarm server 1 (civilian team) +REMOTE_SWARM_URL_1 = os.getenv("AJET_SWARM_URL_1", "http://localhost:10086") +REMOTE_BATCH_SIZE_1 = 16 +REMOTE_ALLOCATE_GPU_1 = 4 +REMOTE_TRAIN_MODEL_1 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + +# Remote configurations for swarm server 2 (spy team) +REMOTE_SWARM_URL_2 = os.getenv("AJET_SWARM_URL_2", "http://localhost:10087") +REMOTE_BATCH_SIZE_2 = 16 +REMOTE_ALLOCATE_GPU_2 = 4 +REMOTE_TRAIN_MODEL_2 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct' + + +class SpyGameDatasetReader: + """Custom dataset reader for spy game configurations.""" + + def __init__(self, dataset_path: str): + self.dataset_path = dataset_path + with open(dataset_path, 'r', encoding='utf-8') as f: + self.data = json.load(f) + + def generate_training_tasks(self): + """Generate training tasks from dataset.""" + for idx, item in enumerate(self.data): + yield Task( + main_query=f"Play adversarial spy game episode {idx}", + metadata={ + "civilian_word": item["civilian_word"], + "spy_word": item["spy_word"], + "num_players": item["num_players"], + "num_spies": item["num_spies"], + "episode_id": idx + } + ) + + +def main(): + """Main adversarial training loop.""" + + # Load dataset + print(f"Loading dataset from: {LOCAL_DATASET_PATH}") + dataset_reader = SpyGameDatasetReader(LOCAL_DATASET_PATH) + + # Connect to swarm server 1 (civilian team) + print(f"Connecting to swarm server 1 (civilians): {REMOTE_SWARM_URL_1}") + swarm_worker_1 = SwarmClient(REMOTE_SWARM_URL_1) + + ajet_job_1 = AgentJetJob( + algorithm="grpo", + project_name="spy-game-rl-adv", + experiment_name="civilians_team_7b", + n_gpu=REMOTE_ALLOCATE_GPU_1, + model=REMOTE_TRAIN_MODEL_1, + batch_size=REMOTE_BATCH_SIZE_1, + num_repeat=LOCAL_GRPO_N, + ) + + print("Starting swarm engine 1 (civilians)...") + swarm_worker_1.auto_sync_train_config_and_start_engine(ajet_job_1) + + # Connect to swarm server 2 (spy team) + print(f"Connecting to swarm server 2 (spies): {REMOTE_SWARM_URL_2}") + swarm_worker_2 = SwarmClient(REMOTE_SWARM_URL_2) + + ajet_job_2 = AgentJetJob( + algorithm="grpo", + project_name="spy-game-rl-adv", + experiment_name="spies_team_7b", + n_gpu=REMOTE_ALLOCATE_GPU_2, + model=REMOTE_TRAIN_MODEL_2, + batch_size=REMOTE_BATCH_SIZE_2, + num_repeat=LOCAL_GRPO_N, + ) + + print("Starting swarm engine 2 (spies)...") + swarm_worker_2.auto_sync_train_config_and_start_engine(ajet_job_2) + + def rollout(task: Task): + """Execute one adversarial episode rollout.""" + try: + # Begin episode for both teams + episode_uuid_1, api_baseurl_key_1 = swarm_worker_1.begin_episode(discard_episode_timeout=300) + episode_uuid_2, api_baseurl_key_2 = swarm_worker_2.begin_episode(discard_episode_timeout=300) + + # Execute adversarial agent workflow + workflow_output_civilians, workflow_output_spies = run_agent_and_compute_reward( + task=task, + base_url_civilians=api_baseurl_key_1.base_url, + api_key_civilians=api_baseurl_key_1.api_key, + base_url_spies=api_baseurl_key_2.base_url, + api_key_spies=api_baseurl_key_2.api_key + ) + + # Report results back to both swarm servers + swarm_worker_1.end_episode(task, episode_uuid_1, workflow_output_civilians) + swarm_worker_2.end_episode(task, episode_uuid_2, workflow_output_spies) + + # Print status + winner = workflow_output_civilians.metadata.get('winner', '?') + print(f"Episode {task.metadata.get('episode_id', '?')}: " + f"Winner={winner}, " + f"Civilian_Reward={workflow_output_civilians.reward:.2f}, " + f"Spy_Reward={workflow_output_spies.reward:.2f}") + + # Print rollout statistics + print("Civilian team stats:") + swarm_worker_1.print_rollout_stat() + print("Spy team stats:") + swarm_worker_2.print_rollout_stat() + + # Return average reward for logging + return (workflow_output_civilians.reward + workflow_output_spies.reward) / 2.0 + + except Exception as e: + print(f"Error in adversarial rollout: {e}") + return None + + # Training loop + print(f"\nStarting adversarial training for {LOCAL_NUM_EPOCH} epochs...") + + for epoch in range(LOCAL_NUM_EPOCH): + print(f"\n{'='*60}") + print(f"EPOCH {epoch + 1}/{LOCAL_NUM_EPOCH}") + print(f"{'='*60}") + + next_batch = [] + for task in dataset_reader.generate_training_tasks(): + # For each task, add it LOCAL_GRPO_N times to the batch + # These are multiple rollouts of the SAME task for GRPO + for _ in range(LOCAL_GRPO_N): + next_batch.append(task) + + # When we have enough tasks in batch, execute them + if len(next_batch) >= (REMOTE_BATCH_SIZE_1 * LOCAL_GRPO_N): + # Execute batch with retry logic + episode_results = run_episodes_until_all_complete( + next_batch, + func=rollout, + auto_retry=True + ) + + # Print batch statistics + valid_results = [r for r in episode_results if r is not None] + if valid_results: + avg_reward = sum(valid_results) / len(valid_results) + num_tasks = len(next_batch) // LOCAL_GRPO_N + print(f"\nBatch completed: {len(valid_results)}/{len(next_batch)} episodes " + f"({num_tasks} tasks x {LOCAL_GRPO_N} episodes), Avg combined reward: {avg_reward:.3f}") + + next_batch.clear() + + # Process any remaining tasks in the batch at end of epoch + if len(next_batch) > 0: + episode_results = run_episodes_until_all_complete( + next_batch, + func=rollout, + auto_retry=True + ) + valid_results = [r for r in episode_results if r is not None] + if valid_results: + avg_reward = sum(valid_results) / len(valid_results) + num_tasks = len(next_batch) // LOCAL_GRPO_N + print(f"\nFinal batch completed: {len(valid_results)}/{len(next_batch)} episodes " + f"({num_tasks} tasks x {LOCAL_GRPO_N} episodes), Avg combined reward: {avg_reward:.3f}") + + print("\n" + "="*60) + print("Adversarial training completed!") + print("="*60) + + # Optionally stop the engines (commented out to keep them running) + # swarm_worker_1.stop_engine() + # swarm_worker_2.stop_engine() + + +if __name__ == "__main__": + main() diff --git a/tutorial/opencode_build_spy_game/agent_run.py b/tutorial/opencode_build_spy_game/agent_run.py new file mode 100644 index 0000000..6c4d9b0 --- /dev/null +++ b/tutorial/opencode_build_spy_game/agent_run.py @@ -0,0 +1,179 @@ +""" +Agent runner for spy game - agent_roll mode. +Team A (civilians): shared 7B model +Team B (spies): qwen-max via DashScope API +""" + +import os +import random +from typing import Dict +from ajet.schema.task import Task, WorkflowOutput +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from tutorial.opencode_build_spy_game.game_engine import SpyGame + + +# Pre-generated diverse name pool for players +PLAYER_NAMES = [ + "Alexander", "Benjamin", "Christopher", "Daniel", "Elizabeth", + "Fitzgerald", "Gabriella", "Harrison", "Isabella", "Jonathan", + "Katherine", "Leonardo", "Margaret", "Nathaniel", "Ophelia", + "Penelope", "Quentin", "Rosalind", "Sebastian", "Theodora", + "Ulysses", "Victoria", "Wellington", "Xander", "Yasmine", + "Zachary", "Adelaide", "Beatrice", "Cornelius", "Desmond", + "Eleanor", "Frederick", "Genevieve", "Humphrey", "Imogen", + "Jasper", "Lillian", "Maximilian", "Nicolette", "Orlando", + "Percival", "Quintessa", "Reginald", "Seraphina", "Tristan", + "Valentina", "Winifred", "Xavier", "Yolanda", "Zephyr" +] + + +def _compute_reward(game_result: Dict) -> float: + """ + Compute reward for the trainable team (civilians using 7B model). + + Args: + game_result: Dictionary containing game outcome + + Returns: + Reward value: 1.0 if civilians win, 0.0 if spies win + """ + return game_result["civilian_reward"] + + +def _execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey) -> Dict: + """ + Execute one episode of the spy game. + + Args: + task: Task containing game configuration + api_baseurl_key: API credentials for the trainable 7B model + + Returns: + Game result dictionary + """ + # Extract game configuration from task + civilian_word = task.metadata["civilian_word"] + spy_word = task.metadata["spy_word"] + num_players = task.metadata["num_players"] + num_spies = task.metadata["num_spies"] + + # Get DashScope API key for opponent + dashscope_api_key = os.environ.get("DASHSCOPE_API_KEY", "") + if not dashscope_api_key: + raise ValueError("DASHSCOPE_API_KEY environment variable is not set") + + # Randomly sample player names for this episode + selected_names = random.sample(PLAYER_NAMES, num_players) + + # Randomly assign player IDs and roles + player_indices = list(range(num_players)) + random.shuffle(player_indices) + + # First num_spies indices become spies (using qwen-max) + # Remaining indices become civilians (using 7B model) + player_configs = [] + spy_count = 0 + civilian_count = 0 + + for i in range(num_players): + # Use shuffled index to determine role + is_spy = player_indices[i] < num_spies + + if is_spy: + # Spy uses qwen-max + config = { + "name": selected_names[i], + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "api_key": dashscope_api_key, + "model": "qwen-max" + } + spy_count += 1 + else: + # Civilian uses trainable 7B model + config = { + "name": selected_names[i], + "base_url": api_baseurl_key.base_url, + "api_key": api_baseurl_key.api_key, + "model": "agentjet-model" + } + civilian_count += 1 + + player_configs.append(config) + + # Assert correct role distribution + assert spy_count == num_spies, f"Expected {num_spies} spies, got {spy_count}" + assert civilian_count == num_players - num_spies, f"Expected {num_players - num_spies} civilians, got {civilian_count}" + + # Assert all trainable model users have the same role (all civilians in this mode) + trainable_base_url = api_baseurl_key.base_url + trainable_roles = [] + for i, config in enumerate(player_configs): + if config["base_url"] == trainable_base_url: + # Determine role by checking which model is used + is_civilian = config["model"] == "agentjet-model" and config["base_url"] == trainable_base_url + trainable_roles.append(is_civilian) + + # All trainable model users must have the same role + if trainable_roles: + assert all(trainable_roles) or not any(trainable_roles), \ + f"All trainable model users must have the same role (all civilians or all spies), but got mixed roles" + # In agent_run mode, all trainable model users should be civilians + assert all(trainable_roles), \ + f"In agent_run mode, all trainable model users should be civilians" + + # Create and run game + game = SpyGame( + civilian_word=civilian_word, + spy_word=spy_word, + num_players=num_players, + num_spies=num_spies, + player_configs=player_configs + ) + + game_result = game.play_game() + return game_result + + +def run_agent_and_compute_reward(task: Task, base_url: str, api_key: str) -> WorkflowOutput: + """ + Main entry point for running the agent and computing reward. + + Args: + task: Task containing game configuration + base_url: Base URL for the trainable model + api_key: API key for the trainable model + + Returns: + WorkflowOutput with reward and game metadata + """ + api_baseurl_key = OpenaiBaseUrlAndApiKey(base_url=base_url, api_key=api_key) + + try: + # Execute game + game_result = _execute_agent(task, api_baseurl_key) + + # Compute reward (1.0 if civilians win, 0.0 if spies win) + reward = _compute_reward(game_result) + + # Return workflow output + return WorkflowOutput( + reward=reward, + metadata={ + "winner": game_result["winner"], + "total_rounds": game_result["total_rounds"], + "civilian_word": game_result["civilian_word"], + "spy_word": game_result["spy_word"], + "final_alive": game_result["final_alive"] + } + ) + + except Exception as e: + print(f"Error during game execution: {e}") + # Return 0 reward on failure + return WorkflowOutput( + reward=0.0, + metadata={ + "error": str(e), + "winner": "error" + } + ) diff --git a/tutorial/opencode_build_spy_game/agent_run_adv.py b/tutorial/opencode_build_spy_game/agent_run_adv.py new file mode 100644 index 0000000..5a581bb --- /dev/null +++ b/tutorial/opencode_build_spy_game/agent_run_adv.py @@ -0,0 +1,218 @@ +""" +Agent runner for spy game - agent_roll_adv mode (adversarial training). +Team A (civilians): shared 7B model from swarm server 1 +Team B (spies): shared 7B model from swarm server 2 +""" + +import random +from typing import Dict, Tuple +from ajet.schema.task import Task, WorkflowOutput +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from tutorial.opencode_build_spy_game.game_engine import SpyGame + + +# Pre-generated diverse name pool for players +PLAYER_NAMES = [ + "Alexander", "Benjamin", "Christopher", "Daniel", "Elizabeth", + "Fitzgerald", "Gabriella", "Harrison", "Isabella", "Jonathan", + "Katherine", "Leonardo", "Margaret", "Nathaniel", "Ophelia", + "Penelope", "Quentin", "Rosalind", "Sebastian", "Theodora", + "Ulysses", "Victoria", "Wellington", "Xander", "Yasmine", + "Zachary", "Adelaide", "Beatrice", "Cornelius", "Desmond", + "Eleanor", "Frederick", "Genevieve", "Humphrey", "Imogen", + "Jasper", "Lillian", "Maximilian", "Nicolette", "Orlando", + "Percival", "Quintessa", "Reginald", "Seraphina", "Tristan", + "Valentina", "Winifred", "Xavier", "Yolanda", "Zephyr" +] + + +def _compute_rewards(game_result: Dict) -> Tuple[float, float]: + """ + Compute rewards for both teams in adversarial mode. + + Args: + game_result: Dictionary containing game outcome + + Returns: + (civilian_team_reward, spy_team_reward) + """ + civilian_reward = game_result["civilian_reward"] + spy_reward = game_result["spy_reward"] + return civilian_reward, spy_reward + + +def _execute_agent(task: Task, + api_baseurl_key_civilians: OpenaiBaseUrlAndApiKey, + api_baseurl_key_spies: OpenaiBaseUrlAndApiKey) -> Dict: + """ + Execute one episode of the adversarial spy game. + + Args: + task: Task containing game configuration + api_baseurl_key_civilians: API credentials for civilian team (swarm server 1) + api_baseurl_key_spies: API credentials for spy team (swarm server 2) + + Returns: + Game result dictionary + """ + # Extract game configuration from task + civilian_word = task.metadata["civilian_word"] + spy_word = task.metadata["spy_word"] + num_players = task.metadata["num_players"] + num_spies = task.metadata["num_spies"] + + # Randomly sample player names for this episode + selected_names = random.sample(PLAYER_NAMES, num_players) + + # Randomly assign player IDs and roles + player_indices = list(range(num_players)) + random.shuffle(player_indices) + + # First num_spies indices become spies (using swarm server 2) + # Remaining indices become civilians (using swarm server 1) + player_configs = [] + spy_count = 0 + civilian_count = 0 + + for i in range(num_players): + # Use shuffled index to determine role + is_spy = player_indices[i] < num_spies + + if is_spy: + # Spy uses swarm server 2 + config = { + "name": selected_names[i], + "base_url": api_baseurl_key_spies.base_url, + "api_key": api_baseurl_key_spies.api_key, + "model": "agentjet-model" + } + spy_count += 1 + else: + # Civilian uses swarm server 1 + config = { + "name": selected_names[i], + "base_url": api_baseurl_key_civilians.base_url, + "api_key": api_baseurl_key_civilians.api_key, + "model": "agentjet-model" + } + civilian_count += 1 + + player_configs.append(config) + + # Assert correct role distribution + assert spy_count == num_spies, f"Expected {num_spies} spies, got {spy_count}" + assert civilian_count == num_players - num_spies, f"Expected {num_players - num_spies} civilians, got {civilian_count}" + + # Assert all trainable model users from each server have the same role + # Swarm server 1 (civilians) users should all be civilians + # Swarm server 2 (spies) users should all be spies + civilians_base_url = api_baseurl_key_civilians.base_url + spies_base_url = api_baseurl_key_spies.base_url + + server1_users = [i for i, cfg in enumerate(player_configs) if cfg["base_url"] == civilians_base_url] + server2_users = [i for i, cfg in enumerate(player_configs) if cfg["base_url"] == spies_base_url] + + # Check all server 1 users are civilians (indices >= num_spies in shuffled assignment) + for idx in server1_users: + assert player_configs[idx]["base_url"] == civilians_base_url, \ + f"Player {idx} should use civilian server but uses {player_configs[idx]['base_url']}" + + # Check all server 2 users are spies (indices < num_spies in shuffled assignment) + for idx in server2_users: + assert player_configs[idx]["base_url"] == spies_base_url, \ + f"Player {idx} should use spy server but uses {player_configs[idx]['base_url']}" + + # Verify role consistency: all server1 users should be civilians, all server2 users should be spies + assert len(server1_users) == civilian_count, \ + f"Server 1 (civilians) should have {civilian_count} users, but has {len(server1_users)}" + assert len(server2_users) == spy_count, \ + f"Server 2 (spies) should have {spy_count} users, but has {len(server2_users)}" + + # Create and run game + game = SpyGame( + civilian_word=civilian_word, + spy_word=spy_word, + num_players=num_players, + num_spies=num_spies, + player_configs=player_configs + ) + + game_result = game.play_game() + return game_result + + +def run_agent_and_compute_reward( + task: Task, + base_url_civilians: str, + api_key_civilians: str, + base_url_spies: str, + api_key_spies: str +) -> Tuple[WorkflowOutput, WorkflowOutput]: + """ + Main entry point for adversarial mode - returns two WorkflowOutputs. + + Args: + task: Task containing game configuration + base_url_civilians: Base URL for civilian team model (swarm server 1) + api_key_civilians: API key for civilian team + base_url_spies: Base URL for spy team model (swarm server 2) + api_key_spies: API key for spy team + + Returns: + (workflow_output_civilians, workflow_output_spies) + """ + api_baseurl_key_civilians = OpenaiBaseUrlAndApiKey( + base_url=base_url_civilians, + api_key=api_key_civilians + ) + api_baseurl_key_spies = OpenaiBaseUrlAndApiKey( + base_url=base_url_spies, + api_key=api_key_spies + ) + + try: + # Execute game + game_result = _execute_agent(task, api_baseurl_key_civilians, api_baseurl_key_spies) + + # Compute rewards for both teams + civilian_reward, spy_reward = _compute_rewards(game_result) + + # Create separate workflow outputs for each team + workflow_output_civilians = WorkflowOutput( + reward=civilian_reward, + metadata={ + "team": "civilians", + "winner": game_result["winner"], + "total_rounds": game_result["total_rounds"], + "civilian_word": game_result["civilian_word"], + "spy_word": game_result["spy_word"], + "final_alive": game_result["final_alive"] + } + ) + + workflow_output_spies = WorkflowOutput( + reward=spy_reward, + metadata={ + "team": "spies", + "winner": game_result["winner"], + "total_rounds": game_result["total_rounds"], + "civilian_word": game_result["civilian_word"], + "spy_word": game_result["spy_word"], + "final_alive": game_result["final_alive"] + } + ) + + return workflow_output_civilians, workflow_output_spies + + except Exception as e: + print(f"Error during adversarial game execution: {e}") + # Return neutral rewards on failure + error_output_civilians = WorkflowOutput( + reward=0.5, + metadata={"error": str(e), "winner": "error", "team": "civilians"} + ) + error_output_spies = WorkflowOutput( + reward=0.5, + metadata={"error": str(e), "winner": "error", "team": "spies"} + ) + return error_output_civilians, error_output_spies diff --git a/tutorial/opencode_build_spy_game/game_engine.py b/tutorial/opencode_build_spy_game/game_engine.py new file mode 100644 index 0000000..1d77db1 --- /dev/null +++ b/tutorial/opencode_build_spy_game/game_engine.py @@ -0,0 +1,365 @@ +import random +from typing import List, Dict, Tuple +from openai import OpenAI + + +class SpyGamePlayer: + """Represents a single player in the spy game.""" + + def __init__(self, player_id: str, name: str, role: str, word: str, + base_url: str, api_key: str, model: str = "agentjet-model"): + self.player_id = player_id + self.name = name + self.role = role # "civilian" or "spy" + self.word = word + self.base_url = base_url + self.api_key = api_key + self.model = model + self.is_alive = True + self.descriptions: List[str] = [] + self.votes_received = 0 + + def get_client(self) -> OpenAI: + """Get OpenAI client for this player.""" + return OpenAI(base_url=self.base_url, api_key=self.api_key) + + def generate_description(self, round_num: int, game_history: List[Dict]) -> str: + """ + Generate a description of the word without saying it directly. + + Args: + round_num: Current round number + game_history: List of previous descriptions and events + + Returns: + Generated description string + """ + client = self.get_client() + + # Build context from game history + history_context = "" + if game_history: + history_context = "\n\nPrevious descriptions from other players:\n" + for entry in game_history: + if entry.get("type") == "description": + history_context += f"- {entry['player_name']}: \"{entry['description']}\"\n" + + prompt = f"""You are playing a social deduction game called "Who is the Spy". + +Your role: {self.role.upper()} +Your word: {self.word} + +Game rules: +- Most players are CIVILIANS with the same word +- A few players are SPIES with a similar but different word +- Each round, players describe their word WITHOUT saying it directly +- After descriptions, players vote to eliminate someone suspicious +- CIVILIANS win if all spies are eliminated +- SPIES win if they equal or outnumber civilians + +Current round: {round_num} +{history_context} + +Your task: Generate a brief description (1-2 sentences) of your word that: +1. Accurately reflects your word +2. Helps your teammates identify you +3. Does NOT reveal your word directly +4. Does NOT make your role too obvious if you're a spy + +Output only the description, nothing else.""" + + try: + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a strategic player in a social deduction game."}, + {"role": "user", "content": prompt} + ], + temperature=0.8, + max_tokens=150, + timeout=60 + ) + description = response.choices[0].message.content.strip() + self.descriptions.append(description) + return description + except Exception as e: + print(f"Error generating description for {self.name}: {e}") + # Fallback description + fallback = f"It's something related to {self.word[0]}... things." + self.descriptions.append(fallback) + return fallback + + def vote(self, alive_players: List['SpyGamePlayer'], game_history: List[Dict]) -> str: + """ + Vote for the most suspicious player. + + Args: + alive_players: List of players still in the game + game_history: Full game history including descriptions + + Returns: + Name of the player to vote for + """ + client = self.get_client() + + # Build player list and their descriptions + players_info = "\n\nPlayers and their descriptions:\n" + for entry in game_history: + if entry.get("type") == "description": + players_info += f"- {entry['player_name']}: \"{entry['description']}\"\n" + + # Available players to vote for + available_players = [p.name for p in alive_players if p.name != self.name] + players_list = ", ".join(available_players) + + prompt = f"""You are playing "Who is the Spy" game. + +Your role: {self.role.upper()} +Your word: {self.word} + +{players_info} + +Available players to vote for: {players_list} + +Your goal: +- If you're a CIVILIAN: Vote for the player who seems to have a DIFFERENT word (the spy) +- If you're a SPY: Vote strategically to survive and avoid suspicion + +Analyze all descriptions carefully. Look for: +- Descriptions that don't quite match the majority +- Vague or contradictory descriptions +- Players who seem to be hiding something + +Output ONLY the name of the player you want to vote for (choose from: {players_list}) +Do not include any explanation, just the name.""" + + try: + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a strategic player making voting decisions."}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=50, + timeout=60 + ) + vote_text = response.choices[0].message.content.strip() + + # Extract valid player name from response + for player_name in available_players: + if player_name.lower() in vote_text.lower(): + return player_name + + # Fallback: random vote + return random.choice(available_players) + + except Exception as e: + print(f"Error generating vote for {self.name}: {e}") + return random.choice(available_players) + + +class SpyGame: + """Main game engine for "Who is the Spy" game.""" + + def __init__(self, civilian_word: str, spy_word: str, + num_players: int, num_spies: int, + player_configs: List[Dict]): + """ + Initialize a spy game. + + Args: + civilian_word: Word for civilians + spy_word: Word for spies + num_players: Total number of players + num_spies: Number of spies + player_configs: List of player configurations with base_url, api_key, model, name + """ + self.civilian_word = civilian_word + self.spy_word = spy_word + self.num_players = num_players + self.num_spies = num_spies + self.players: List[SpyGamePlayer] = [] + self.game_history: List[Dict] = [] + self.current_round = 0 + self.max_rounds = 10 + + # Assign roles randomly + roles = ["spy"] * num_spies + ["civilian"] * (num_players - num_spies) + random.shuffle(roles) + + # Create players + for i, (role, config) in enumerate(zip(roles, player_configs)): + word = spy_word if role == "spy" else civilian_word + player = SpyGamePlayer( + player_id=f"player_{i}", + name=config["name"], + role=role, + word=word, + base_url=config["base_url"], + api_key=config["api_key"], + model=config.get("model", "agentjet-model") + ) + self.players.append(player) + + def get_alive_players(self) -> List[SpyGamePlayer]: + """Get list of players still in the game.""" + return [p for p in self.players if p.is_alive] + + def check_game_end(self) -> Tuple[bool, str, float]: + """ + Check if game has ended and determine winner. + + Returns: + (is_ended, winner, civilian_team_reward) + winner: "civilians", "spies", or "draw" + civilian_team_reward: 1.0 if civilians win, 0.0 if spies win, 0.5 for draw + """ + alive = self.get_alive_players() + alive_spies = [p for p in alive if p.role == "spy"] + alive_civilians = [p for p in alive if p.role == "civilian"] + + # Spies win if they equal or outnumber civilians + if len(alive_spies) >= len(alive_civilians): + return True, "spies", 0.0 + + # Civilians win if all spies are eliminated + if len(alive_spies) == 0: + return True, "civilians", 1.0 + + # Draw if max rounds reached + if self.current_round >= self.max_rounds: + return True, "draw", 0.5 + + return False, "", 0.5 + + def play_round(self) -> bool: + """ + Play one round of the game. + + Returns: + True if game should continue, False if game ended + """ + self.current_round += 1 + print(f"\n{'='*60}") + print(f"ROUND {self.current_round}") + print(f"{'='*60}") + + alive_players = self.get_alive_players() + + # Phase 1: Description phase + print("\n--- Description Phase ---") + round_descriptions = [] + for player in alive_players: + description = player.generate_description(self.current_round, self.game_history) + print(f"{player.name} ({player.role}): \"{description}\"") + + entry = { + "type": "description", + "round": self.current_round, + "player_id": player.player_id, + "player_name": player.name, + "role": player.role, + "description": description + } + self.game_history.append(entry) + round_descriptions.append(entry) + + # Check if game should end before voting + is_ended, winner, _ = self.check_game_end() + if is_ended: + return False + + # Phase 2: Voting phase + print("\n--- Voting Phase ---") + votes: Dict[str, List[str]] = {p.name: [] for p in alive_players} + + for player in alive_players: + voted_name = player.vote(alive_players, self.game_history) + votes[voted_name].append(player.name) + print(f"{player.name} votes for: {voted_name}") + + self.game_history.append({ + "type": "vote", + "round": self.current_round, + "voter_name": player.name, + "voted_name": voted_name + }) + + # Determine who gets eliminated + max_votes = max(len(v) for v in votes.values()) + candidates = [name for name, voters in votes.items() if len(voters) == max_votes] + + if len(candidates) > 1: + # Tie - randomly eliminate one + eliminated_name = random.choice(candidates) + else: + eliminated_name = candidates[0] + + eliminated_player = next(p for p in alive_players if p.name == eliminated_name) + eliminated_player.is_alive = False + + print(f"\n{eliminated_name} ({eliminated_player.role}) has been eliminated!") + print(f"Their word was: {eliminated_player.word}") + + self.game_history.append({ + "type": "elimination", + "round": self.current_round, + "eliminated_name": eliminated_name, + "eliminated_role": eliminated_player.role, + "votes_received": len(votes[eliminated_name]) + }) + + # Check game end condition + is_ended, winner, _ = self.check_game_end() + return not is_ended + + def play_game(self) -> Dict: + """ + Play the full game until completion. + + Returns: + Game result dictionary with winner, rewards, and history + """ + print(f"\n{'#'*60}") + print(f"GAME START") + print(f"Civilian word: {self.civilian_word}") + print(f"Spy word: {self.spy_word}") + print(f"Players: {self.num_players}, Spies: {self.num_spies}") + print(f"{'#'*60}") + + # Print initial player assignments + print("\nPlayers:") + for player in self.players: + print(f" {player.name}: {player.role} (word: {player.word})") + + # Play rounds until game ends + while self.play_round(): + pass + + # Determine final result + is_ended, winner, civilian_reward = self.check_game_end() + + print(f"\n{'#'*60}") + print(f"GAME END - {winner.upper()} WIN!") + print(f"{'#'*60}") + + # Calculate individual rewards + player_rewards = {} + for player in self.players: + if player.role == "civilian": + player_rewards[player.name] = civilian_reward + else: # spy + player_rewards[player.name] = 1.0 - civilian_reward + + return { + "winner": winner, + "civilian_reward": civilian_reward, + "spy_reward": 1.0 - civilian_reward, + "player_rewards": player_rewards, + "total_rounds": self.current_round, + "game_history": self.game_history, + "final_alive": [p.name for p in self.get_alive_players()], + "civilian_word": self.civilian_word, + "spy_word": self.spy_word + } diff --git a/tutorial/opencode_build_spy_game/mock_dataset.py b/tutorial/opencode_build_spy_game/mock_dataset.py new file mode 100644 index 0000000..68a2606 --- /dev/null +++ b/tutorial/opencode_build_spy_game/mock_dataset.py @@ -0,0 +1,102 @@ +import json +import random +from pathlib import Path + + +CIVILIAN_SPY_PAIRS = [ + ("apple", "pear"), + ("coffee", "tea"), + ("basketball", "football"), + ("piano", "guitar"), + ("rose", "tulip"), + ("dog", "cat"), + ("bicycle", "motorcycle"), + ("ocean", "sea"), + ("winter", "autumn"), + ("sunrise", "sunset"), + ("rice", "noodle"), + ("book", "magazine"), + ("violin", "cello"), + ("lion", "tiger"), + ("river", "lake"), + ("mountain", "hill"), + ("sun", "moon"), + ("chair", "stool"), + ("milk", "yogurt"), + ("bread", "cake"), + ("airplane", "helicopter"), + ("train", "subway"), + ("doctor", "nurse"), + ("teacher", "professor"), + ("pen", "pencil"), + ("email", "letter"), + ("computer", "laptop"), + ("phone", "tablet"), + ("shoes", "slippers"), + ("hat", "cap"), + ("sword", "knife"), + ("bow", "crossbow"), + ("king", "emperor"), + ("princess", "queen"), + ("spring", "summer"), + ("rain", "snow"), + ("thunder", "lightning"), + ("diamond", "crystal"), + ("gold", "silver"), + ("red", "orange"), + ("square", "rectangle"), + ("circle", "oval"), + ("triangle", "pyramid"), + ("watermelon", "melon"), + ("strawberry", "raspberry"), + ("carrot", "radish"), + ("potato", "sweet potato"), + ("chicken", "duck"), + ("beef", "pork"), + ("shark", "whale"), +] + + +def generate_mock_dataset(num_samples: int = 100, output_path: str = None) -> list[dict]: + """ + Generate mock game configuration dataset. + Each sample contains: + - civilian_word: word for civilians + - spy_word: word for spies + - num_players: total number of players (6-9) + - num_spies: number of spies (1-2) + """ + dataset = [] + + for _ in range(num_samples): + civilian_word, spy_word = random.choice(CIVILIAN_SPY_PAIRS) + + # Randomly swap to increase diversity + if random.random() > 0.5: + civilian_word, spy_word = spy_word, civilian_word + + num_players = random.randint(6, 9) + num_spies = 1 if num_players <= 7 else random.choice([1, 2]) + + dataset.append({ + "civilian_word": civilian_word, + "spy_word": spy_word, + "num_players": num_players, + "num_spies": num_spies, + }) + + if output_path: + output_file = Path(output_path) + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + print(f"Dataset saved to {output_path}") + + return dataset + + +if __name__ == "__main__": + output_path = Path(__file__).parent / "mock_game_dataset.json" + dataset = generate_mock_dataset(num_samples=200, output_path=str(output_path)) + print(f"Generated {len(dataset)} game configurations") + print(f"Sample: {dataset[0]}") diff --git a/tutorial/opencode_build_spy_game/readme.md b/tutorial/opencode_build_spy_game/readme.md new file mode 100644 index 0000000..6e6a3fc --- /dev/null +++ b/tutorial/opencode_build_spy_game/readme.md @@ -0,0 +1,237 @@ +# Spy Game Reinforcement Learning Agent + +A trainable multi-agent system for the social deduction game "Who is the Spy" using reinforcement learning. + +## Game Overview + +"Who is the Spy" is a social deduction game where: +- **N players** participate (typically 6-9) +- Most are **civilians** with the same word +- A few are **spies** with a similar but different word +- Each round, players describe their word without saying it directly +- After descriptions, players vote to eliminate suspects +- **Civilians win** if all spies are eliminated +- **Spies win** if they equal or outnumber civilians + +The agent learns to: +1. Generate strategic descriptions that help teammates while avoiding detection +2. Analyze other players' descriptions to identify spies +3. Make optimal voting decisions + +## Project Structure + +``` +tutorial/opencode_build_spy_game/ +├── mock_dataset.py # Generate mock game configurations +├── mock_game_dataset.json # 200 game scenarios with word pairs +├── game_engine.py # Core game mechanics and player logic +├── agent_run.py # Agent executor for agent_roll mode +├── agent_roll.py # Training script for agent_roll mode +├── agent_run_adv.py # Agent executor for adversarial mode +├── agent_roll_adv.py # Training script for adversarial mode +└── readme.md # This file +``` + +## Training Modes + +### Mode 1: agent_roll (Civilians vs Fixed Opponent) + +Train a 7B model as the civilian team against qwen-max (via DashScope API) as spies. + +**Hardware:** 4 GPUs +**Reward:** 1.0 if civilians win, 0.0 if spies win + +#### Setup: + +1. Ensure DASHSCOPE_API_KEY is set in environment: + ```bash + export DASHSCOPE_API_KEY="your_api_key_here" + ``` + +2. Start swarm server in one terminal: + ```bash + cd /root/agentjet + source .venv/bin/activate + ajet-swarm start --swarm-port=10086 + ``` + +3. Run training in another terminal: + ```bash + cd /root/agentjet + source .venv/bin/activate + python -m tutorial.opencode_build_spy_game.agent_roll + ``` + +### Mode 2: agent_roll_adv (Adversarial Training) + +Train two 7B models competitively - one as civilians, one as spies. + +**Hardware:** 8 GPUs total (4 per swarm server) +**Reward:** Team-based (1.0 for winners, 0.0 for losers) + +#### Setup: + +1. Start swarm server 1 (civilians) in terminal 1: + ```bash + cd /root/agentjet + source .venv/bin/activate + ajet-swarm start --swarm-port=10086 + ``` + +2. Start swarm server 2 (spies) in terminal 2: + ```bash + cd /root/agentjet + source .venv/bin/activate + ajet-swarm start --swarm-port=10087 + ``` + +3. Run adversarial training in terminal 3: + ```bash + cd /root/agentjet + source .venv/bin/activate + export AJET_SWARM_URL_1="http://localhost:10086" + export AJET_SWARM_URL_2="http://localhost:10087" + python -m tutorial.opencode_build_spy_game.agent_roll_adv + ``` + +## Debugging with tmux + +For easier debugging, use tmux sessions: + +### For agent_roll mode: + +```bash +# Terminal 1: Start swarm server +tmux new -s spy-swarm-server +cd /root/agentjet && source .venv/bin/activate +ajet-swarm start --swarm-port=10086 + +# Detach: Ctrl+B, then D + +# Terminal 2: Start training client +tmux new -s spy-swarm-client +cd /root/agentjet && source .venv/bin/activate +python -m tutorial.opencode_build_spy_game.agent_roll + +# View sessions +tmux ls + +# Attach to server: tmux attach -t spy-swarm-server +# Attach to client: tmux attach -t spy-swarm-client +``` + +### For agent_roll_adv mode: + +```bash +# Terminal 1: Start swarm server 1 +tmux new -s spy-swarm-server +cd /root/agentjet && source .venv/bin/activate +ajet-swarm start --swarm-port=10086 + +# Terminal 2: Start swarm server 2 +tmux new -s spy-swarm-server-2 +cd /root/agentjet && source .venv/bin/activate +ajet-swarm start --swarm-port=10087 + +# Terminal 3: Start training client +tmux new -s spy-swarm-client +cd /root/agentjet && source .venv/bin/activate +export AJET_SWARM_URL_1="http://localhost:10086" +export AJET_SWARM_URL_2="http://localhost:10087" +python -m tutorial.opencode_build_spy_game.agent_roll_adv +``` + +## Configuration + +### Key Parameters (in agent_roll.py / agent_roll_adv.py): + +- `LOCAL_GRPO_N = 4`: Number of rollouts per task (GRPO group size) +- `LOCAL_NUM_EPOCH = 100`: Number of training epochs +- `REMOTE_BATCH_SIZE = 16`: Batch size for policy updates +- `REMOTE_ALLOCATE_GPU = 4`: Number of GPUs per swarm server +- `REMOTE_TRAIN_MODEL`: Path to base model + +### Dataset: + +The `mock_game_dataset.json` contains 200 diverse game scenarios with word pairs like: +- apple vs pear +- coffee vs tea +- basketball vs football +- piano vs guitar +- etc. + +Regenerate dataset: +```bash +python tutorial/opencode_build_spy_game/mock_dataset.py +``` + +## Game Mechanics Details + +1. **Random Player Assignment**: Each episode randomly assigns player names and roles +2. **Description Phase**: Players generate descriptions using LLM without revealing their word +3. **Voting Phase**: Players vote to eliminate the most suspicious player +4. **Win Conditions**: + - Civilians win when all spies eliminated + - Spies win when they equal/outnumber civilians + - Draw if max rounds (10) reached + +## Reward Structure + +### agent_roll mode: +- Civilian team (trainable 7B): 1.0 for win, 0.0 for loss +- Spy team (qwen-max): Not trained + +### agent_roll_adv mode: +- Civilian team (7B server 1): 1.0 for win, 0.0 for loss +- Spy team (7B server 2): 1.0 for win, 0.0 for loss +- Both models train competitively + +## Expected Training Behavior + +The agent should learn to: +1. Generate contextually appropriate descriptions +2. Balance between being informative and protective +3. Recognize inconsistent descriptions from opponents +4. Make strategic voting decisions +5. Adapt strategies based on role (civilian vs spy) + +## Troubleshooting + +### Import errors: +Make sure you're running from the agentjet root directory with proper Python path. + +### Connection errors: +- Check swarm server is running: `ajet-swarm overwatch` +- Verify port availability: `netstat -an | grep 10086` +- Check firewall settings if running on different machines + +### DASHSCOPE_API_KEY errors (agent_roll mode): +```bash +export DASHSCOPE_API_KEY="your_key" +echo $DASHSCOPE_API_KEY # Verify it's set +``` + +### GPU memory issues: +- Reduce batch size in config +- Reduce number of parallel episodes +- Check GPU availability: `nvidia-smi` + +## Monitoring Training + +Use the swarm overwatch tool: +```bash +ajet-swarm overwatch --swarm-url=http://localhost:10086 +``` + +This displays: +- Current training step +- Sample pool status +- Policy gradient updates +- Model loading status + +## Notes + +- Each episode creates random player names from a pool of 50 diverse names +- Game typically completes in 3-10 rounds depending on player strategies +- Training uses GRPO (Group Relative Policy Optimization) algorithm +- Models are trained with temperature=0.7-0.8 for creative descriptions diff --git a/tutorial/opencode_build_spy_game/spy_game_config.yaml b/tutorial/opencode_build_spy_game/spy_game_config.yaml new file mode 100644 index 0000000..f7eff1f --- /dev/null +++ b/tutorial/opencode_build_spy_game/spy_game_config.yaml @@ -0,0 +1,35 @@ +# Custom configuration for spy game training +# Extends default AgentJet configuration with optimized memory settings + +defaults: + - ajet_ts_default + +# Override rollout configuration to use tensor parallelism +actor_rollout_ref: + rollout: + tensor_model_parallel_size: 2 # Use 2-way tensor parallelism for vLLM to reduce memory per GPU + max_num_seqs: 8 # Reduce concurrent sequences + max_model_len: 15000 # Slightly reduce max length + + # Optimize actor configuration for memory + actor: + fsdp_config: + param_offload: true # Offload parameters to CPU when not needed + optimizer_offload: true # Offload optimizer states to CPU + ppo_max_token_len_per_gpu: 15000 # Reduce token length per GPU + ppo_micro_batch_size_per_gpu: 1 # Keep small micro batch size + + # Optimize reference model configuration + ref: + fsdp_config: + param_offload: true + reshard_after_forward: true + log_prob_max_token_len_per_gpu: 15000 + log_prob_micro_batch_size_per_gpu: 3 + +# Environment settings +ajet: + rollout: + tensor_model_parallel_size: 2 + max_num_seqs: 8 + max_model_len: 15000 diff --git a/tutorial/opencode_build_spy_game/test_single_game.py b/tutorial/opencode_build_spy_game/test_single_game.py new file mode 100644 index 0000000..58d1ebd --- /dev/null +++ b/tutorial/opencode_build_spy_game/test_single_game.py @@ -0,0 +1,43 @@ +""" +Test script to verify a single game works correctly. +""" + +import os +from ajet.schema.task import Task +from tutorial.opencode_build_spy_game.agent_run import run_agent_and_compute_reward + +# Test with a simple game configuration +test_task = Task( + main_query="Test spy game episode", + metadata={ + "civilian_word": "apple", + "spy_word": "pear", + "num_players": 6, + "num_spies": 1, + "episode_id": 0 + } +) + +# Use a fake base_url and api_key for testing (will be replaced by swarm server) +fake_base_url = "http://localhost:10086" +fake_api_key = "test_key" + +print("Testing single game execution...") +print(f"Civilian word: {test_task.metadata['civilian_word']}") +print(f"Spy word: {test_task.metadata['spy_word']}") +print(f"Players: {test_task.metadata['num_players']}, Spies: {test_task.metadata['num_spies']}") +print("\nStarting game...\n") + +try: + result = run_agent_and_compute_reward(test_task, fake_base_url, fake_api_key) + print(f"\n{'='*60}") + print(f"Game Result:") + print(f" Winner: {result.metadata.get('winner', 'unknown')}") + print(f" Reward: {result.reward}") + print(f" Rounds: {result.metadata.get('total_rounds', 'unknown')}") + print(f" Survivors: {result.metadata.get('final_alive', [])}") + print(f"{'='*60}") +except Exception as e: + print(f"\nError: {e}") + import traceback + traceback.print_exc() diff --git a/tutorial/opencode_build_who_is_spy.prompt.md b/tutorial/opencode_build_who_is_spy.prompt.md new file mode 100644 index 0000000..95e053a --- /dev/null +++ b/tutorial/opencode_build_who_is_spy.prompt.md @@ -0,0 +1,48 @@ +# Generate an agent / agent loop with AgentJet Swarm and train it with one key + +Use prompt below in opencode or claudecode to generate a one-key-to-tune agent + +============================= + +English prompt to be tranlated ... + +============================= + +你的任务: +- 编写一个学习"谁是卧底"任务的智能体,通过强化学习和监督学习相结合的方式训练,游戏规则如下: + - 游戏共有 N 名玩家,其中大多数人是**平民**,少数人是**卧底** + - 游戏开始时,每位平民会收到同一个**平民词**,每位卧底会收到一个与平民词相近但不同的**卧底词**(例如平民词为"苹果",卧底词为"梨") + - 每轮游戏中,所有玩家依次对自己拿到的词进行**口头描述**,描述必须真实反映自己的词,但不能直接说出词语本身,也不能过于明显地暴露自己的身份 + - 全部玩家描述完毕后,进入**投票环节**,所有玩家投票选出自己认为最可疑的卧底,得票最多的玩家被淘汰出局 + - 游戏持续多轮,直到满足以下任一结束条件: + - **平民获胜**:所有卧底均被淘汰 + - **卧底获胜**:卧底人数 ≥ 平民人数(卧底在数量上取得优势) + - 智能体需要通过大量对局训练掌握两种核心能力: + - **描述策略学习**:学会根据自己的词语和当前局势,生成既不暴露身份、又能让同阵营玩家认同的最优描述 + - **推理决策学习**:学会根据历史对话、其他玩家的描述模式和行为特征,准确识别卧底并做出最优投票决策 + - 训练目标:最大化智能体在不同角色(平民/卧底)下的游戏胜率,通过自对弈和奖励机制不断优化策略 +- 我希望使用基础模型 `/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct` +- 使用 8 GPU 训练 +- Batch Size 16 +- 我目前没有数据集,你需要帮助我 mock 少量游戏对局数据以供测试和初始训练 +- 使用OpenAI SDK,灵活使用Tools +- 代码中不得出现中文 + +你的 skill(首先读取该 SKILL 文件,获取必要知识): +./ajet/copilot/write-swarm-client/SKILL.md + +- 追加要求: + - optional 0. (agent_roll) team A 平民 共享一个7B模型, team B卧底使用qwen-max (DASHSCOPE_API_KEY已经在环境变量中), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + - optional 1. (agent_roll_adv) 对抗式训练,team A 平民 共享一个7B模型(swarm server 1), team B卧底共享另一个7B模型(swarm server 2), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + +- 追加要求: + agent_roll: 使用4个显卡 + agent_roll_adv:swarm server 1 和 swarm server 2 分别使用4个显卡(一共8个显卡) + +- 追加要求:使用 tmux + uv 的 .venv 调试,直到所有Bug都已经排除 & 训练正常开始。你可以使用 `spy-swarm-server`, `spy-swarm-server-2`, `spy-swarm-client` 三个 tmux session + + - 当前调试阶段: + 调试 agent_roll 【执行调试】 + 调试 agent_roll_adv 【跳过调试】 From adb93d841326d9509ce89164a695ae38499e929b Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Fri, 20 Mar 2026 15:08:52 +0800 Subject: [PATCH 7/9] add vibe rl example --- docs/en/example_vibe_rl_who_is_spy.md | 124 ++++++++++---------- docs/en/example_vibe_rl_who_is_spy.zh.md | 139 +++++++++++++++++++++++ mkdocs.yml | 2 + 3 files changed, 200 insertions(+), 65 deletions(-) create mode 100644 docs/en/example_vibe_rl_who_is_spy.zh.md diff --git a/docs/en/example_vibe_rl_who_is_spy.md b/docs/en/example_vibe_rl_who_is_spy.md index ff84044..a398280 100644 --- a/docs/en/example_vibe_rl_who_is_spy.md +++ b/docs/en/example_vibe_rl_who_is_spy.md @@ -1,73 +1,72 @@ -# Vibe RL 实例:不写一行代码,从零构建一个会玩“谁是卧底”的 Agent 训练器 +# Vibe RL Example: Building a "Who is the Spy" Agent Trainer from Scratch Without Writing a Single Line of Code +> This article is a translated version of the [Chinese original](./example_vibe_rl_who_is_spy.zh.md). -摘要:强化学习研究中,从灵感迸发,到编写代码,再到第一条成功的训练曲线产生,这个过程是漫长、乏味的。 -幸运的是,如今在 AgentJet 框架中,从想法到训练成功,你只需要动动嘴,花几分钟写一点点提示词, -然后只需要等待片刻,然后你就可以看到**完整、简洁、人类易读易改的训练程序** + **初次训练的训练曲线** 展现在你面前了。 -接下来,我们以经典的“谁是摸底”桌游游戏为例,从零展示不写代码训练Agent的全过程。 +## Abstract +In reinforcement learning research, the journey from inspiration to writing code to generating the first successful training curve is long and tedious. Fortunately, with the AgentJet framework, going from idea to successful training is now just a matter of speaking up and spending a few minutes writing some prompts. After a short wait, you get to see **complete, concise, human-readable and editable training code** alongside **the first training curve** displayed before you. In this article, we use the classic "Who is the Spy" board game as an example to demonstrate the entire process of training an Agent without writing code. -## 安装 AgentJet 环境 +## Install AgentJet Environment + +You can choose to [install manually](https://doc.agentjet.top/en/installation/) or use skills. Run the following commands to copy skills into Claude Code or OpenCode: -您可以选择[手动安装](https://doc.agentjet.top/en/installation/),或者使用skills安装。运行以下指令将skills复制到claude code或者 opencode中。 ```bash npx skills add modelscope/agentjet npx skills add binary-husky/Vibe-RL ``` -在skill添加完成之后,你可以指挥claude code或者opencode使用uv(或者conda / docker)安装 AgentJet。 -## 撰写提示词 +After the skills are added, you can instruct Claude Code or OpenCode to install AgentJet using uv (or conda / docker). + +## Write the Prompt -在安装完成 AgentJet 之后,就可以直接开始工作了,打开OpenCode(尽管ClaudeCode比OpenCode更加强大,但笔者还是喜欢完全开源的东西;再者,在AgentJet中Vibe RL的难度很低,我们也不需要非常强的agent), -然后选择 claude-4.5-sonnet 模型 (这个模型在推理速度比opus更快,对于不太困难的问题已经足够了),开始执行任务: +Once AgentJet is installed, you can get started right away. Open OpenCode (while ClaudeCode is more powerful, the author prefers fully open-source tools; moreover, Vibe RL difficulty in AgentJet is quite low, so we don't need a very strong agent), then select the claude-4.5-sonnet model (this model is faster than opus for reasoning speed and sufficient for tasks that aren't too difficult), and start executing the task: ```txt -你的任务: -- 编写一个学习"谁是卧底"任务的智能体,通过强化学习和监督学习相结合的方式训练,游戏规则如下: - - 游戏共有 N 名玩家,其中大多数人是**平民**,少数人是**卧底** - - 游戏开始时,每位平民会收到同一个**平民词**,每位卧底会收到一个与平民词相近但不同的**卧底词**(例如平民词为"苹果",卧底词为"梨") - - 每轮游戏中,所有玩家依次对自己拿到的词进行**口头描述**,描述必须真实反映自己的词,但不能直接说出词语本身,也不能过于明显地暴露自己的身份 - - 全部玩家描述完毕后,进入**投票环节**,所有玩家投票选出自己认为最可疑的卧底,得票最多的玩家被淘汰出局 - - 游戏持续多轮,直到满足以下任一结束条件: - - **平民获胜**:所有卧底均被淘汰 - - **卧底获胜**:卧底人数 ≥ 平民人数(卧底在数量上取得优势) - - 智能体需要通过大量对局训练掌握两种核心能力: - - **描述策略学习**:学会根据自己的词语和当前局势,生成既不暴露身份、又能让同阵营玩家认同的最优描述 - - **推理决策学习**:学会根据历史对话、其他玩家的描述模式和行为特征,准确识别卧底并做出最优投票决策 - - 训练目标:最大化智能体在不同角色(平民/卧底)下的游戏胜率,通过自对弈和奖励机制不断优化策略 -- 我希望使用基础模型 `/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct` -- 使用 8 GPU 训练 +Your task: +- Write an agent that learns the "Who is the Spy" task, trained using a combination of reinforcement learning and supervised learning. The game rules are as follows: + - The game has N players, most of whom are **civilians**, with a few being **spies** + - At the start of the game, each civilian receives the same **civilian word**, and each spy receives a **spy word** that is similar to the civilian word but different (e.g., civilian word is "apple", spy word is "pear") + - In each round, all players take turns giving **verbal descriptions** of their word. The description must truthfully reflect the word, but cannot directly say the word itself or expose the player's identity too obviously + - After all players have described, the game enters the **voting phase**, where all players vote for who they think is the most suspicious spy. The player with the most votes is eliminated + - The game continues for multiple rounds until one of the following end conditions is met: + - **Civilians win**: All spies are eliminated + - **Spies win**: The number of spies >= the number of civilians (spies have the numerical advantage) + - The agent needs to master two core abilities through extensive gameplay: + - **Description strategy learning**: Learn to generate optimal descriptions based on the agent's word and current game state that neither expose identity nor alienate teammates + - **Reasoning and decision learning**: Learn to accurately identify spies based on conversation history, other players' description patterns, and behavioral characteristics, and make optimal voting decisions + - Training objective: Maximize the agent's win rate across different roles (civilian/spy), continuously optimizing strategy through self-play and reward mechanisms +- I want to use the base model `/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct` +- Use 8 GPUs for training - Batch Size 16 -- 我目前没有数据集,你需要帮助我 mock 少量游戏对局数据以供测试和初始训练 -- 使用OpenAI SDK,灵活使用Tools -- 代码中不得出现中文 +- I don't have a dataset yet, please help me mock some game data for testing and initial training +- Use OpenAI SDK, flexibly use Tools +- Code must not contain Chinese characters -你的 skill(首先读取该 SKILL 文件,获取必要知识): +Your skill (please read this SKILL file first to get necessary knowledge): ./ajet/copilot/write-swarm-client/SKILL.md -- 追加要求: - - optional 0. (agent_roll) team A 平民 共享一个7B模型, team B卧底使用qwen-max (DASHSCOPE_API_KEY已经在环境变量中), - 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 - - optional 1. (agent_roll_adv) 对抗式训练,team A 平民 共享一个7B模型(swarm server 1), team B卧底共享另一个7B模型(swarm server 2), - 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 +- Additional requirements: + - optional 0. (agent_roll) Team A civilians share one 7B model, Team B spies use qwen-max (DASHSCOPE_API_KEY is already in environment variables), + each episode randomly assigns each player's ID and name (randomly generate a long list of random names), winner gets reward 1, loser gets reward 0 + - optional 1. (agent_roll_adv) Adversarial training, Team A civilians share one 7B model (swarm server 1), Team B spies share another 7B model (swarm server 2), + each episode randomly assigns each player's ID and name (randomly generate a long list of random names), winner gets reward 1, loser gets reward 0 -- 追加要求: - agent_roll: 使用4个显卡 - agent_roll_adv:swarm server 1 和 swarm server 2 分别使用4个显卡(一共8个显卡) +- Additional requirements: + agent_roll: Use 4 GPUs + agent_roll_adv: swarm server 1 and swarm server 2 each use 4 GPUs (total 8 GPUs) -- 追加要求:使用 tmux + uv 的 .venv 调试,直到所有Bug都已经排除 & 训练正常开始。你可以使用 `spy-swarm-server`, `spy-swarm-server-2`, `spy-swarm-client` 三个 tmux session +- Additional requirements: Use tmux + uv's .venv for debugging until all bugs are fixed & training starts normally. You can use `spy-swarm-server`, `spy-swarm-server-2`, `spy-swarm-client` three tmux sessions - - 当前调试阶段: - 调试 agent_roll 【执行调试】 - 调试 agent_roll_adv 【跳过调试】 + - Current debugging stage: + Debugging agent_roll [Execute debugging] + Debugging agent_roll_adv [Skip debugging] ``` +## Check Results -## 检查结果 - -### 生成的训练代码 +### Generated Training Code -在agentjet skill的指导下,OpenCode会在 tutorial/opencode_build_*** 生成训练的全部代码: +Under the guidance of the agentjet skill, OpenCode generates all training code in `tutorial/opencode_build_***`: ```bash (base) ➜ agentjet git:(main) ✗ tree tutorial/opencode_build_spy_game @@ -82,13 +81,12 @@ tutorial/opencode_build_spy_game/ └── readme.md # This file ``` -### 检查训练蜂群,发现并引导智能体修复训练的Bug +### Inspect the Training Swarm, Find and Fix Agent Training Bugs - -等了一会,运行 `ajet-swarm overwatch` 命令,看一下现在训练运行到第几步了,结果发现 claude-sonnet 搞出了一个令人难绷错误: +After waiting a while, running the `ajet-swarm overwatch` command shows the current training progress: ```bash - Completed Episode Pool Summary (Progress to Hit Next Weight Update) + Completed Episode Pool Summary (Progress to Hit Next Weight Update) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Metric ┃ Current ┃ Target ┃ Progress ┃ Bar ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -100,7 +98,7 @@ tutorial/opencode_build_spy_game/ │ Average Episode Per Task │ 140.00 │ 4 │ - │ - │ └────────────────────────────────────────┴─────────────┴─────────────┴──────────────┴───────────────────────────────────────────────────────────────────────┘ - Task Completion Details + Task Completion Details ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Task ID ┃ Episodes ┃ Reward ┃ Episode UUIDs (first 3) ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -108,32 +106,28 @@ tutorial/opencode_build_spy_game/ └──────────────┴───────────────┴───────────────────────┴───────────────────────────────────────────────────────────────────────────┘ ``` -从蜂群监视表格可以看出,现在样本池已经累计了 875.0%(140个)的回合样本,但AgentJet并没有开始训练。 -仔细一看,CompletedTasks 进度只有 1个,说明140个回合都被识别成一个task了。这些样本的task id,哎,怎么是空字符串? -毫无疑问,claude mock的数据集出了很搞笑的问题,直接给OpenCode下达新指令: +From the swarm monitoring table, the sample pool has accumulated 875.0% (140) episode samples, but AgentJet hasn't started training yet. Looking closer, the Completed Tasks progress is only 1, meaning all 140 episodes were identified as one task. The task IDs for these samples? They're empty strings. No doubt, claude-sonnet produced a hilarious bug in the mock dataset. We give OpenCode a new directive: ```txt -task.task_id 有严重的问题,task_id应该是每个episode的随机数种子,不能为空! +task.task_id has a serious problem - task_id should be a random seed for each episode and must not be empty! ``` -顺便修改了一下参数,batchsize从4改成32,grpo_n从4改成6,然后喝杯茶,再回来看看。不错,这次正常了。 - -![alt text](https://img.alicdn.com/imgextra/i4/O1CN01cQny931D4FI93OwyB_!!6000000000162-2-tps-2445-1227.png) +While we're at it, we adjust some parameters: batch size from 4 to 32, grpo_n from 4 to 6. Then we have a cup of tea and come back. This time it works. +![alt text](https://img.alicdn.com/imicdn.com/imgextra/i4/O1CN01cQny931D4FI93OwyB_!!6000000000162-2-tps-2445-1227.png) -为了保证agent运行逻辑是准确无误的,我们再打开 beast_logger (和agentjet配套的日志监视组件) 看一眼: +To ensure the agent logic is correct, we also open beast_logger (the log monitoring component that comes with agentjet): ![alt text](https://img.alicdn.com/imgextra/i3/O1CN01w7QLeg26hS3yIma36_!!6000000007693-2-tps-3782-1963.png) -看了一眼,果然还是有问题(有点后悔没用opus了)。我们的要求是team A平民共享大脑用一个7B模型, team B卧底使用qwen-max。但平民队伍里面怎么混进来一个间谍? -这回得让claude-sonnet好好反省一下了: +One look and sure enough, there are still issues (slightly regretting not using opus). Our requirement was that Team A civilians share one brain with a 7B model, while Team B spies use qwen-max. But why did a spy sneak into the civilian team? This time we need claude-sonnet to reflect carefully: ![alt text](https://img.alicdn.com/imgextra/i3/O1CN01ECZFjI286viB25hk1_!!6000000007884-2-tps-1079-498.png) -等一会,再看了一下,问题都已经修复了 +After a while, we check again and the issues are all fixed. -### 检查训练曲线 +### Check Training Curves -去SwanLab看看,不错,奖励平稳上升。 +Heading over to SwanLab — not bad, the reward is steadily climbing. -![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png) \ No newline at end of file +![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png) diff --git a/docs/en/example_vibe_rl_who_is_spy.zh.md b/docs/en/example_vibe_rl_who_is_spy.zh.md new file mode 100644 index 0000000..ff84044 --- /dev/null +++ b/docs/en/example_vibe_rl_who_is_spy.zh.md @@ -0,0 +1,139 @@ +# Vibe RL 实例:不写一行代码,从零构建一个会玩“谁是卧底”的 Agent 训练器 + + +摘要:强化学习研究中,从灵感迸发,到编写代码,再到第一条成功的训练曲线产生,这个过程是漫长、乏味的。 +幸运的是,如今在 AgentJet 框架中,从想法到训练成功,你只需要动动嘴,花几分钟写一点点提示词, +然后只需要等待片刻,然后你就可以看到**完整、简洁、人类易读易改的训练程序** + **初次训练的训练曲线** 展现在你面前了。 +接下来,我们以经典的“谁是摸底”桌游游戏为例,从零展示不写代码训练Agent的全过程。 + + +## 安装 AgentJet 环境 + +您可以选择[手动安装](https://doc.agentjet.top/en/installation/),或者使用skills安装。运行以下指令将skills复制到claude code或者 opencode中。 +```bash +npx skills add modelscope/agentjet +npx skills add binary-husky/Vibe-RL +``` +在skill添加完成之后,你可以指挥claude code或者opencode使用uv(或者conda / docker)安装 AgentJet。 + +## 撰写提示词 + +在安装完成 AgentJet 之后,就可以直接开始工作了,打开OpenCode(尽管ClaudeCode比OpenCode更加强大,但笔者还是喜欢完全开源的东西;再者,在AgentJet中Vibe RL的难度很低,我们也不需要非常强的agent), +然后选择 claude-4.5-sonnet 模型 (这个模型在推理速度比opus更快,对于不太困难的问题已经足够了),开始执行任务: + +```txt +你的任务: +- 编写一个学习"谁是卧底"任务的智能体,通过强化学习和监督学习相结合的方式训练,游戏规则如下: + - 游戏共有 N 名玩家,其中大多数人是**平民**,少数人是**卧底** + - 游戏开始时,每位平民会收到同一个**平民词**,每位卧底会收到一个与平民词相近但不同的**卧底词**(例如平民词为"苹果",卧底词为"梨") + - 每轮游戏中,所有玩家依次对自己拿到的词进行**口头描述**,描述必须真实反映自己的词,但不能直接说出词语本身,也不能过于明显地暴露自己的身份 + - 全部玩家描述完毕后,进入**投票环节**,所有玩家投票选出自己认为最可疑的卧底,得票最多的玩家被淘汰出局 + - 游戏持续多轮,直到满足以下任一结束条件: + - **平民获胜**:所有卧底均被淘汰 + - **卧底获胜**:卧底人数 ≥ 平民人数(卧底在数量上取得优势) + - 智能体需要通过大量对局训练掌握两种核心能力: + - **描述策略学习**:学会根据自己的词语和当前局势,生成既不暴露身份、又能让同阵营玩家认同的最优描述 + - **推理决策学习**:学会根据历史对话、其他玩家的描述模式和行为特征,准确识别卧底并做出最优投票决策 + - 训练目标:最大化智能体在不同角色(平民/卧底)下的游戏胜率,通过自对弈和奖励机制不断优化策略 +- 我希望使用基础模型 `/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct` +- 使用 8 GPU 训练 +- Batch Size 16 +- 我目前没有数据集,你需要帮助我 mock 少量游戏对局数据以供测试和初始训练 +- 使用OpenAI SDK,灵活使用Tools +- 代码中不得出现中文 + +你的 skill(首先读取该 SKILL 文件,获取必要知识): +./ajet/copilot/write-swarm-client/SKILL.md + +- 追加要求: + - optional 0. (agent_roll) team A 平民 共享一个7B模型, team B卧底使用qwen-max (DASHSCOPE_API_KEY已经在环境变量中), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + - optional 1. (agent_roll_adv) 对抗式训练,team A 平民 共享一个7B模型(swarm server 1), team B卧底共享另一个7B模型(swarm server 2), + 每个episode随机分配每个所有人的ID和名字(随机生成一个长长的随机姓名名字清单),胜者奖励 1,败者奖励 0 + +- 追加要求: + agent_roll: 使用4个显卡 + agent_roll_adv:swarm server 1 和 swarm server 2 分别使用4个显卡(一共8个显卡) + +- 追加要求:使用 tmux + uv 的 .venv 调试,直到所有Bug都已经排除 & 训练正常开始。你可以使用 `spy-swarm-server`, `spy-swarm-server-2`, `spy-swarm-client` 三个 tmux session + + - 当前调试阶段: + 调试 agent_roll 【执行调试】 + 调试 agent_roll_adv 【跳过调试】 +``` + + +## 检查结果 + +### 生成的训练代码 + +在agentjet skill的指导下,OpenCode会在 tutorial/opencode_build_*** 生成训练的全部代码: + +```bash +(base) ➜ agentjet git:(main) ✗ tree tutorial/opencode_build_spy_game +tutorial/opencode_build_spy_game/ +├── mock_dataset.py # Generate mock game configurations +├── mock_game_dataset.json # 200 game scenarios with word pairs +├── game_engine.py # Core game mechanics and player logic +├── agent_run.py # Agent executor for agent_roll mode +├── agent_roll.py # Training script for agent_roll mode +├── agent_run_adv.py # Agent executor for adversarial mode +├── agent_roll_adv.py # Training script for adversarial mode +└── readme.md # This file +``` + +### 检查训练蜂群,发现并引导智能体修复训练的Bug + + +等了一会,运行 `ajet-swarm overwatch` 命令,看一下现在训练运行到第几步了,结果发现 claude-sonnet 搞出了一个令人难绷错误: + +```bash + Completed Episode Pool Summary (Progress to Hit Next Weight Update) +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Metric ┃ Current ┃ Target ┃ Progress ┃ Bar ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ Completed Episodes │ 140 │ 16 │ 875.0% │ █████████████████████████████████████████████████████████████████████ │ +│ │ │ │ │ █████████████████████████████████████████████████████████████████████ │ +│ │ │ │ │ █████████████████████████████████████ │ +│ -> *Completed Tasks (chosen)* │ 1 │ 4 │ 25.0% │ █████░░░░░░░░░░░░░░░ │ +│ Completed Non-Dummy Tasks │ 1 │ 4 │ 25.0% │ █████░░░░░░░░░░░░░░░ │ +│ Average Episode Per Task │ 140.00 │ 4 │ - │ - │ +└────────────────────────────────────────┴─────────────┴─────────────┴──────────────┴───────────────────────────────────────────────────────────────────────┘ + + Task Completion Details +┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Task ID ┃ Episodes ┃ Reward ┃ Episode UUIDs (first 3) ┃ +┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ │ 140 │ 0.779 ± 0.448 │ b47d7b96..., 8caec2d7..., b48bd9fb... (+137 more) │ +└──────────────┴───────────────┴───────────────────────┴───────────────────────────────────────────────────────────────────────────┘ +``` + +从蜂群监视表格可以看出,现在样本池已经累计了 875.0%(140个)的回合样本,但AgentJet并没有开始训练。 +仔细一看,CompletedTasks 进度只有 1个,说明140个回合都被识别成一个task了。这些样本的task id,哎,怎么是空字符串? +毫无疑问,claude mock的数据集出了很搞笑的问题,直接给OpenCode下达新指令: + +```txt +task.task_id 有严重的问题,task_id应该是每个episode的随机数种子,不能为空! +``` + +顺便修改了一下参数,batchsize从4改成32,grpo_n从4改成6,然后喝杯茶,再回来看看。不错,这次正常了。 + +![alt text](https://img.alicdn.com/imgextra/i4/O1CN01cQny931D4FI93OwyB_!!6000000000162-2-tps-2445-1227.png) + + +为了保证agent运行逻辑是准确无误的,我们再打开 beast_logger (和agentjet配套的日志监视组件) 看一眼: + +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01w7QLeg26hS3yIma36_!!6000000007693-2-tps-3782-1963.png) + +看了一眼,果然还是有问题(有点后悔没用opus了)。我们的要求是team A平民共享大脑用一个7B模型, team B卧底使用qwen-max。但平民队伍里面怎么混进来一个间谍? +这回得让claude-sonnet好好反省一下了: + +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01ECZFjI286viB25hk1_!!6000000007884-2-tps-1079-498.png) + +等一会,再看了一下,问题都已经修复了 + +### 检查训练曲线 + +去SwanLab看看,不错,奖励平稳上升。 + +![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png) \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 1d3c8da..d03d4cb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -68,6 +68,8 @@ nav: - Multi Model Training (ZH): en/example_train_multi_model.zh.md - Training OpenClaw (EN): en/example_openclaw.md - Training OpenClaw (ZH): en/example_openclaw.zh.md + - Vibe RL Who is Spy (EN): en/example_vibe_rl_who_is_spy.md + - Vibe RL Who is Spy (ZH): en/example_vibe_rl_who_is_spy.zh.md plugins: - search: From 8f56e4a47a2ea01aa2ba55f3820d4c0de0805ec9 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 20 Mar 2026 15:46:14 +0800 Subject: [PATCH 8/9] stage openclaw examples --- .../REWARD_UPDATE.md | 108 ++++++ .../README.md | 197 ++++++++++ .../REWARD_UPDATE.md | 109 ++++++ .../cheatsheet.md | 47 +++ .../download_dataset.py | 28 ++ .../fake_vllm_endpoint.py | 275 +++++++++++++ .../mock_user_request.py | 119 ++++++ .../on_compute_relative_reward.py | 361 ++++++++++++++++++ .../on_user_submit_new_requests.py | 44 +++ .../openclaw.md | 120 ++++++ .../test_reward.py | 304 +++++++++++++++ 11 files changed, 1712 insertions(+) create mode 100644 tutorial/opencode_build_openclaw_agent/REWARD_UPDATE.md create mode 100644 tutorial/opencode_build_openclaw_interactive_train/README.md create mode 100644 tutorial/opencode_build_openclaw_interactive_train/REWARD_UPDATE.md create mode 100644 tutorial/opencode_build_openclaw_interactive_train/cheatsheet.md create mode 100644 tutorial/opencode_build_openclaw_interactive_train/download_dataset.py create mode 100644 tutorial/opencode_build_openclaw_interactive_train/fake_vllm_endpoint.py create mode 100644 tutorial/opencode_build_openclaw_interactive_train/mock_user_request.py create mode 100644 tutorial/opencode_build_openclaw_interactive_train/on_compute_relative_reward.py create mode 100644 tutorial/opencode_build_openclaw_interactive_train/on_user_submit_new_requests.py create mode 100644 tutorial/opencode_build_openclaw_interactive_train/openclaw.md create mode 100644 tutorial/opencode_build_openclaw_interactive_train/test_reward.py diff --git a/tutorial/opencode_build_openclaw_agent/REWARD_UPDATE.md b/tutorial/opencode_build_openclaw_agent/REWARD_UPDATE.md new file mode 100644 index 0000000..28b3c42 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/REWARD_UPDATE.md @@ -0,0 +1,108 @@ +# OpenClaw奖励模块更新 + +**项目**: OpenClaw Agent 构建 +**时间范围**: 2026 年 3 月 13 日 — 2026 年 3 月 20 日 + +--- + +## 一、之前的奖励模块的问题 + +之前奖励模块处于最小可行状态:仅依赖语言模型评判回复的外向性(Extraversion)人格特征分数,逻辑简洁但功能单一。在实际训练中,这种单一维度的评估暴露出三类问题: + +- **离题回复仍然获得较高奖励**:回复虽然热情洋溢、表达力强,但若与问题无关,仍能得到不错分数,导致模型学会"热情地答非所问"。 +- **批量内回复趋于同质化**:语言模型在生成多个候选回复时,容易产出大量近似重复的内容,这些内容各自获得接近的分数,缺乏多样性信号。 +- **退化输出缺乏惩罚机制**:训练过程中偶尔出现的循环段落、特殊 token 泄露或字符级重复(nonsense generation),因为在 Extraversion 维度上没有明显短板,仍能获得中上奖励,无法被有效压制。 + +本次更新的核心目标,就是将奖励系统从单一维度扩展为多维度复合架构,以更精细的信号引导模型同时兼顾**相关性**、**多样性**和**输出质量**。 + +--- + +## 二、核心升级:从单一维度到四维复合奖励 + +### 2.1 奖励公式 + +新的奖励由四个维度加权融合,并通过一个乘法质量门控进行修正: + +``` +最终奖励 = 质量分数 × (外向性权重 × 外向性 + 相关性权重 × 相关性 + 多样性权重 × 多样性) +``` + +默认权重配置为:外向性 0.5、相关性 0.3、多样性 0.2。三个子维度权重之和为 1.0,质量门控以乘法形式作用于最终得分。 + +### 2.2 各维度说明 + +**外向性(Extraversion)** + +沿用上一版本的 LLM 评判方案,由语言模型评估回复在热情、活力和表达力方面的表现。评估模式保持两种:pointwise 模式对每个回复独立打分(0–1),listwise 模式在同一批回复中做相对排名(最好 1.0,最差 0.0)。 + +**相关性(Relevance)** + +新增的维度。评判回复是否围绕问题展开、是否切中主题。相关性的加入解决了"热情但跑题"这一问题:即使回复在表达力上得分很高,若相关性不足,综合得分也会被拉低。 + +**多样性(Diversity)** + +新增的维度。鼓励模型在生成多个候选回复时保持差异性,避免同质化输出。多样性评估分为两个层面: + +- **批量内多样性**:当前这批候选回复中,各回复之间的相似程度。相似度越高,多样性分越低。 +- **跨请求多样性**:当前回复与近期历史上出现过的回复之间的相似程度。若模型反复产出与历史相似的回复,多样性分也会被压低。 + +多样性评估采用 n-gram 字符级重叠度(Jaccard 相似度)作为量化指标,无需语言模型调用,完全确定性执行。 + +**质量门控(Quality Gate)** + +新增的维度。作为一个乘法修正项(0–1 之间),质量门控以"硬开关"的方式惩罚两类退化输出: + +- **段落级循环**:同一结构化段落(如 `If you have any questions...` 模板段落)被重复多次。 +- **字符级重复与 token 泄露**:连续重复词汇、特殊标记(如 `<|im_start|>`)泄露等。 + +质量门控采用 OpenJudge 的 NgramRepetitionPenaltyGrader 结合字符串退化检测工具联合判定。当检测到上述退化模式时,质量分数直接压至接近零,无论其他三个维度的得分有多高。 + +--- + +## 三、其他变更 + +### 3.1 查询历史记录 + +在请求处理环节新增了一个轻量级的查询历史滚动缓冲区(上限 100 条),记录每次提交的请求元信息。其目的不在训练奖励计算,而在于系统层面的可观测性:若同一问题在短时间内高频出现,说明上游数据分发存在问题,需要及时告警,而非归咎于模型。 + +### 3.2 vLLM 兼容处理 + +服务端点在转发请求时,自动剥离了上游不支持的字段(如 `strict`、`store`),避免不必要的警告输出。同时,`/requests` 接口的返回值从原始请求记录改为查询历史,提供更清晰的调试视图。 + +### 3.3 测试体系 + +原有的两个端到端测试(pointwise 模式、listwise 模式)被扩展为六个专项测试,覆盖复合奖励的各个维度以及质量门控的惩罚效果: + +- 外向性复合奖励测试:验证热情回复优于平淡回复 +- 相关性惩罚测试:验证离题回复得分低于切题回复 +- 多样性惩罚测试:验证近似重复回复得分低于独特回复 +- 跨请求多样性测试:验证重复历史回复的代价 +- 退化惩罚测试:验证循环段落和特殊 token 泄露会被质量门控压制 +- listwise 复合测试:验证 listwise 模式下复合奖励同样生效 + +每个测试在运行时隔离历史状态,确保测试结果不受执行顺序影响。 + +### 3.4 快速参考文档 + +新增了一份速查文档(cheatsheet),包含测试运行命令、服务启动命令、所有奖励模式说明和环境变量速查表,方便日常操作时快速查阅。 + +--- + +## 四、架构升级概览 + +| 特性 | 更新前 | 更新后 | +|------|--------|--------| +| 奖励维度 | 1 个(外向性) | 4 个(外向性 + 相关性 + 多样性 + 质量门控) | +| 质量门控 | 无 | 乘法门控,压制退化输出至 ~0 | +| 批量内多样性 | 无 | n-gram 相似度检测 | +| 跨请求记忆 | 无 | 25 条回复历史滚动缓冲区 | +| 相关性评估 | 无 | LLM 评判 | +| 测试用例数 | 2 个 | 6 个 | +| 快速参考文档 | 无 | 新增 cheatsheet | +| 请求可观测性 | 无 | 查询历史记录接口 | + +--- + +## 五、总结 + +本次更新的本质,是将奖励模块从"外向性评分器"转变为"多维度质量评估系统"。新增的相关性和多样性维度填补了上一版本的盲区,质量门控则为训练稳定性提供了最后一道防线。更新后的系统能够在鼓励热情表达的同时,确保回复切题、不重复、无退化,使模型真正学会在正确方向上发挥外向性人格优势。 diff --git a/tutorial/opencode_build_openclaw_interactive_train/README.md b/tutorial/opencode_build_openclaw_interactive_train/README.md new file mode 100644 index 0000000..eefb51d --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/README.md @@ -0,0 +1,197 @@ +# OpenClaw Agent Training - Extraversion Personality + +Train an LLM agent to exhibit more extraverted personality traits using reinforcement learning. + +## Overview + +This training program uses GRPO (Group Relative Policy Optimization) to train Qwen2.5-7B-Instruct to respond with more extraverted characteristics: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions + +## Architecture + +``` +User Query → fake_vllm_endpoint.py → Swarm Server (8 GPUs) + ↓ + Generate N=4 responses in parallel + ↓ + Evaluate with ExtraversionGrader (OpenJudge) + ↓ + Compute rewards & update model (GRPO) + ↓ + Return best response to user +``` + +## Prerequisites + +```bash +pip install py-openjudge datasets +``` + +## Setup + +### 1. Download Dataset + +```bash +cd tutorial/opencode_build_openclaw_agent +python download_dataset.py +``` + +This downloads the `holistic-ai/personality_manipulation` dataset and extracts extraversion examples. + +### 2. Configure API Key + +Edit `on_compute_relative_reward.py` and set your API key for the judge model: + +```python +model = OpenAIChatModel( + model="qwen-plus", + api_key="YOUR_API_KEY_HERE", # Change this + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +``` + +## Training + +### Step 1: Start Swarm Server + +On your GPU server (with 8 GPUs available): + +```bash +ajet-swarm start +``` + +Or with monitoring: + +```bash +(ajet-swarm start &> ajet-swarm-server.log) & (ajet-swarm overwatch) +``` + +### Step 2: Start Fake vLLM Endpoint + +In a new terminal: + +```bash +cd tutorial/opencode_build_openclaw_agent + +# Option 1: Use OpenJudge pointwise grading (default) +export AJET_SWARM_URL="http://localhost:10086" +export NUM_REPEAT=4 +export REWARD_MODE=pointwise +export DASHSCOPE_API_KEY=your_api_key_here +python fake_vllm_endpoint.py + +# Option 2: Use OpenJudge listwise ranking +export AJET_SWARM_URL="http://localhost:10086" +export NUM_REPEAT=4 +export REWARD_MODE=listwise +export DASHSCOPE_API_KEY=your_api_key_here +python fake_vllm_endpoint.py +``` + +This starts the training proxy on `http://localhost:8090`. + +### Step 3: Configure OpenClaw to Use Training Endpoint + +OpenClaw needs to connect to the fake vLLM endpoint. + +Configure it to use `http://localhost:8090` as the LLM backend. + +### Step 4: Send Training Requests + +Option A - Manual testing via OpenClaw Web / Cli: + +```bash +openclaw agent --message "What are your thoughts on Paris?" --thinking high +``` + +Option B - Automated dataset iteration: + +```bash +python mock_user_request.py +``` + +This will iterate through the personality_manipulation dataset and send each question via OpenClaw CLI. + +## Configuration + +Key parameters in `fake_vllm_endpoint.py`: + +- `n_gpu=8` - Number of GPUs for training +- `batch_size=32` - Training batch size +- `num_repeat=4` - GRPO N parameter (responses per query) +- `model` - Base model path + +Environment variables for reward computation: + +- `REWARD_MODE` - Reward computation mode: `pointwise` (default) or `listwise` +- `DASHSCOPE_API_KEY` - API key for OpenJudge LLM grader +- `JUDGE_BASE_URL` - Base URL for judge model API (default: DashScope) +- `JUDGE_MODEL` - Judge model name (default: `qwen-plus`) + +## Reward Function + +Two OpenJudge-based reward modes are available: + +### 1. Pointwise Mode (Default) + +Uses OpenJudge LLM grader to evaluate each response independently: +- Evaluates extraversion traits on 1-10 scale +- Provides detailed reasoning for each score +- Scores normalized to [-1, 1] for GRPO training + +```bash +export REWARD_MODE=pointwise +export DASHSCOPE_API_KEY=your_api_key_here +``` + +### 2. Listwise Mode + +Uses OpenJudge to rank all responses together: +- Compares responses directly against each other +- Produces relative rankings +- Best for capturing subtle differences + +```bash +export REWARD_MODE=listwise +export DASHSCOPE_API_KEY=your_api_key_here +``` + +## Monitoring + +Check training progress: + +```bash +# View swarm status +ajet-swarm overwatch + +# Check request history +curl http://localhost:8090/requests + +# Health check +curl http://localhost:8090/health +``` + +## Files + +- `fake_vllm_endpoint.py` - Main training server +- `on_compute_relative_reward.py` - Extraversion reward function +- `on_user_submit_new_requests.py` - Request handler +- `download_dataset.py` - Dataset downloader +- `mock_user_request.py` - Automated testing client + +## Troubleshooting + +**Import errors**: LSP warnings about unresolved imports are normal - dependencies will be available at runtime. + +**Connection refused**: Ensure swarm server is running on port 10086. + +**All episodes failed**: Check GPU availability and swarm server logs. + +## Notes + +- Training is passive - the endpoint waits for requests rather than iterating a dataset +- Each request generates N=4 responses, evaluates them, and trains on the best +- The model gradually learns to produce more extraverted responses over time diff --git a/tutorial/opencode_build_openclaw_interactive_train/REWARD_UPDATE.md b/tutorial/opencode_build_openclaw_interactive_train/REWARD_UPDATE.md new file mode 100644 index 0000000..2a1907d --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/REWARD_UPDATE.md @@ -0,0 +1,109 @@ +# 奖励模块一周更新报告 + +**项目**: OpenClaw Agent 构建 +**时间范围**: 2026 年 3 月 13 日 — 2026 年 3 月 20 日 +**对比基线**: commit `9707faf` → commit `f091efc` + +--- + +## 一、更新背景 + +一周前的奖励模块处于最小可行状态:仅依赖语言模型评判回复的外向性(Extraversion)人格特征分数,逻辑简洁但功能单一。在实际训练中,这种单一维度的评估暴露出三类问题: + +- **离题回复仍然获得较高奖励**:回复虽然热情洋溢、表达力强,但若与问题无关,仍能得到不错分数,导致模型学会"热情地答非所问"。 +- **批量内回复趋于同质化**:语言模型在生成多个候选回复时,容易产出大量近似重复的内容,这些内容各自获得接近的分数,缺乏多样性信号。 +- **退化输出缺乏惩罚机制**:训练过程中偶尔出现的循环段落、特殊 token 泄露或字符级重复(nonsense generation),因为在 Extraversion 维度上没有明显短板,仍能获得中上奖励,无法被有效压制。 + +本次更新的核心目标,就是将奖励系统从单一维度扩展为多维度复合架构,以更精细的信号引导模型同时兼顾**相关性**、**多样性**和**输出质量**。 + +--- + +## 二、核心升级:从单一维度到四维复合奖励 + +### 2.1 奖励公式 + +新的奖励由四个维度加权融合,并通过一个乘法质量门控进行修正: + +``` +最终奖励 = 质量分数 × (外向性权重 × 外向性 + 相关性权重 × 相关性 + 多样性权重 × 多样性) +``` + +默认权重配置为:外向性 0.5、相关性 0.3、多样性 0.2。三个子维度权重之和为 1.0,质量门控以乘法形式作用于最终得分。 + +### 2.2 各维度说明 + +**外向性(Extraversion)** + +沿用上一版本的 LLM 评判方案,由语言模型评估回复在热情、活力和表达力方面的表现。评估模式保持两种:pointwise 模式对每个回复独立打分(0–1),listwise 模式在同一批回复中做相对排名(最好 1.0,最差 0.0)。 + +**相关性(Relevance)** + +新增的维度。评判回复是否围绕问题展开、是否切中主题。相关性的加入解决了"热情但跑题"这一问题:即使回复在表达力上得分很高,若相关性不足,综合得分也会被拉低。 + +**多样性(Diversity)** + +新增的维度。鼓励模型在生成多个候选回复时保持差异性,避免同质化输出。多样性评估分为两个层面: + +- **批量内多样性**:当前这批候选回复中,各回复之间的相似程度。相似度越高,多样性分越低。 +- **跨请求多样性**:当前回复与近期历史上出现过的回复之间的相似程度。若模型反复产出与历史相似的回复,多样性分也会被压低。 + +多样性评估采用 n-gram 字符级重叠度(Jaccard 相似度)作为量化指标,无需语言模型调用,完全确定性执行。 + +**质量门控(Quality Gate)** + +新增的维度。作为一个乘法修正项(0–1 之间),质量门控以"硬开关"的方式惩罚两类退化输出: + +- **段落级循环**:同一结构化段落(如 `If you have any questions...` 模板段落)被重复多次。 +- **字符级重复与 token 泄露**:连续重复词汇、特殊标记(如 `<|im_start|>`)泄露等。 + +质量门控采用 OpenJudge 的 NgramRepetitionPenaltyGrader 结合字符串退化检测工具联合判定。当检测到上述退化模式时,质量分数直接压至接近零,无论其他三个维度的得分有多高。 + +--- + +## 三、其他变更 + +### 3.1 查询历史记录 + +在请求处理环节新增了一个轻量级的查询历史滚动缓冲区(上限 100 条),记录每次提交的请求元信息。其目的不在训练奖励计算,而在于系统层面的可观测性:若同一问题在短时间内高频出现,说明上游数据分发存在问题,需要及时告警,而非归咎于模型。 + +### 3.2 vLLM 兼容处理 + +服务端点在转发请求时,自动剥离了上游不支持的字段(如 `strict`、`store`),避免不必要的警告输出。同时,`/requests` 接口的返回值从原始请求记录改为查询历史,提供更清晰的调试视图。 + +### 3.3 测试体系 + +原有的两个端到端测试(pointwise 模式、listwise 模式)被扩展为六个专项测试,覆盖复合奖励的各个维度以及质量门控的惩罚效果: + +- 外向性复合奖励测试:验证热情回复优于平淡回复 +- 相关性惩罚测试:验证离题回复得分低于切题回复 +- 多样性惩罚测试:验证近似重复回复得分低于独特回复 +- 跨请求多样性测试:验证重复历史回复的代价 +- 退化惩罚测试:验证循环段落和特殊 token 泄露会被质量门控压制 +- listwise 复合测试:验证 listwise 模式下复合奖励同样生效 + +每个测试在运行时隔离历史状态,确保测试结果不受执行顺序影响。 + +### 3.4 快速参考文档 + +新增了一份速查文档(cheatsheet),包含测试运行命令、服务启动命令、所有奖励模式说明和环境变量速查表,方便日常操作时快速查阅。 + +--- + +## 四、架构升级概览 + +| 特性 | 更新前 | 更新后 | +|------|--------|--------| +| 奖励维度 | 1 个(外向性) | 4 个(外向性 + 相关性 + 多样性 + 质量门控) | +| 质量门控 | 无 | 乘法门控,压制退化输出至 ~0 | +| 批量内多样性 | 无 | n-gram 相似度检测 | +| 跨请求记忆 | 无 | 25 条回复历史滚动缓冲区 | +| 相关性评估 | 无 | LLM 评判 | +| 测试用例数 | 2 个 | 6 个 | +| 快速参考文档 | 无 | 新增 cheatsheet | +| 请求可观测性 | 无 | 查询历史记录接口 | + +--- + +## 五、总结 + +本次更新的本质,是将奖励模块从"外向性评分器"转变为"多维度质量评估系统"。新增的相关性和多样性维度填补了上一版本的盲区,质量门控则为训练稳定性提供了最后一道防线。更新后的系统能够在鼓励热情表达的同时,确保回复切题、不重复、无退化,使模型真正学会在正确方向上发挥外向性人格优势。 diff --git a/tutorial/opencode_build_openclaw_interactive_train/cheatsheet.md b/tutorial/opencode_build_openclaw_interactive_train/cheatsheet.md new file mode 100644 index 0000000..0d79b05 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/cheatsheet.md @@ -0,0 +1,47 @@ +# OpenClaw Reward Cheatsheet + +## Run the test + +```bash +cd agentjet/tutorial/opencode_build_openclaw_agent + +# pointwise (default) +DASHSCOPE_API_KEY=your_key python test_reward.py + +# listwise +REWARD_MODE=listwise DASHSCOPE_API_KEY=your_key python test_reward.py +``` + +## Run the training endpoint + +```bash +# pointwise (default) +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=pointwise \ +python fake_vllm_endpoint.py + +# listwise +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=listwise \ +python fake_vllm_endpoint.py +``` + +## Reward modes + +| Mode | Description | +|------|-------------| +| `pointwise` | Each response scored independently (0.0–1.0) | +| `listwise` | All responses ranked together (best=1.0, worst=0.0) | + +## Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `REWARD_MODE` | `pointwise` | `pointwise` or `listwise` | +| `DASHSCOPE_API_KEY` | — | DashScope API key (required) | +| `JUDGE_MODEL` | `qwen-plus` | Judge model name | +| `JUDGE_BASE_URL` | DashScope endpoint | Judge model base URL | +| `AJET_SWARM_URL` | `http://localhost:10086` | Swarm server URL | +| `NUM_REPEAT` | `4` | GRPO N (responses per query) | diff --git a/tutorial/opencode_build_openclaw_interactive_train/download_dataset.py b/tutorial/opencode_build_openclaw_interactive_train/download_dataset.py new file mode 100644 index 0000000..69d8007 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/download_dataset.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""Download personality_manipulation dataset from HuggingFace.""" + +from datasets import load_dataset +import json + +def download_and_save_dataset(): + """Download personality_manipulation dataset and save extraversion samples.""" + print("Downloading personality_manipulation dataset...") + dataset = load_dataset("holistic-ai/personality_manipulation") + + # Filter for extraversion personality + extraversion_data = [item for item in dataset['train'] if item['Target Personality'] == 'extraversion'] + + # Save to JSON + with open('extraversion_questions.json', 'w', encoding='utf-8') as f: + json.dump(extraversion_data, f, ensure_ascii=False, indent=2) + + print(f"Saved {len(extraversion_data)} extraversion samples to extraversion_questions.json") + + # Also save all personalities for reference + with open('all_personalities.json', 'w', encoding='utf-8') as f: + json.dump(list(dataset['train']), f, ensure_ascii=False, indent=2) + + print(f"Saved {len(dataset['train'])} total samples to all_personalities.json") + +if __name__ == "__main__": + download_and_save_dataset() diff --git a/tutorial/opencode_build_openclaw_interactive_train/fake_vllm_endpoint.py b/tutorial/opencode_build_openclaw_interactive_train/fake_vllm_endpoint.py new file mode 100644 index 0000000..e73cc80 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/fake_vllm_endpoint.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +""" +Fake vLLM endpoint for OpenClaw agent training. +Based on ajet/tuner_lib/experimental/oai_model_one2many.py +""" + +import os +import uuid +import asyncio +import httpx +import json +import threading +from contextlib import asynccontextmanager +from typing import Dict, List, Optional + +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse +from loguru import logger +from pydantic import BaseModel + +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.tuner_lib.experimental.swarm_client import SwarmClient + +import sys +sys.path.insert(0, os.path.dirname(__file__)) + +from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history +from on_compute_relative_reward import on_compute_relative_reward + +# Configuration +SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +NUM_REPEAT = int(os.getenv("NUM_REPEAT", "4")) +TRAINING_OBJECTIVE = "Train model to be more extraverted" + +# Global State +USER_REQUEST_RECORD: List[Dict] = [] +REQUEST_COUNTER = 0 +swarm_client: Optional[SwarmClient] = None +ajet_job = AgentJetJob( + algorithm="grpo", + project_name="openclaw-extraversion", + experiment_name="extraversion_training", + n_gpu=8, + model='/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct', + batch_size=32, + logging="swanlab", + num_repeat=NUM_REPEAT, + max_prompt_length=16000, # at least 16000 + max_response_length=8000, + max_model_len=24000, # bigger than / equal to `max_prompt_length + max_response_length` + max_response_length_in_one_turn=4000, +) + +class EpisodeResult(BaseModel): + """Result from a single episode execution.""" + episode_uuid: str + response: Dict | List[bytes] + + +def extract_assistant_message(resp: Dict | List[bytes]) -> Dict: + """Extract assistant message from response.""" + if isinstance(resp, list): + content_parts: List[str] = [] + for raw in resp: + line = raw.decode() if isinstance(raw, bytes) else raw + if not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + if delta.get("content"): + content_parts.append(delta["content"]) + except Exception: + pass + return {"role": "assistant", "content": "".join(content_parts)} + else: + return resp.get("choices", [{}])[0].get("message", {}) + + +async def proxy_chat_completion(base_url: str, api_key: str, request: Request, is_stream: bool = False) -> Dict | List[bytes]: + """Proxy a chat completion request.""" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Connection": "close", + } + json_data = await request.json() + json_data["stream"] = is_stream + + # Remove fields not supported by vLLM to avoid warnings + UNSUPPORTED_FIELDS = {"strict", "store"} + for field in UNSUPPORTED_FIELDS: + json_data.pop(field, None) + # Also remove 'strict' from response_format if present + if "response_format" in json_data and isinstance(json_data["response_format"], dict): + json_data["response_format"].pop("strict", None) + + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers) + resp.raise_for_status() + if is_stream: + chunks = [] + async for line in resp.aiter_lines(): + if line.strip(): + chunks.append(line.encode() if isinstance(line, str) else line) + return chunks + else: + return resp.json() + + +def _check_finish_reason_length(response_data: Dict | List[bytes]) -> bool: + """Return True if any choice has finish_reason='length'.""" + if isinstance(response_data, list): + for raw in response_data: + line = raw.decode() if isinstance(raw, bytes) else raw + if not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") + if finish_reason == "length": + return True + except Exception: + pass + return False + else: + choices = response_data.get("choices", []) + return any(c.get("finish_reason") == "length" for c in choices) + + +async def run_single_episode(episode_index: int, request: Request, is_stream: bool) -> EpisodeResult: + """Run a single episode.""" + assert swarm_client is not None + episode_uuid, api_baseurl_key = await asyncio.to_thread(swarm_client.begin_episode) + try: + response_data = await proxy_chat_completion( + base_url=api_baseurl_key.base_url, + api_key=api_baseurl_key.api_key, + request=request, + is_stream=is_stream, + ) + if _check_finish_reason_length(response_data): + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "This model's maximum context length is exceeded. Please reduce the length of the messages.", + "type": "invalid_request_error", + "param": "messages", + "code": "context_length_exceeded", + } + }, + ) + return EpisodeResult(episode_uuid=episode_uuid, response=response_data) + except Exception as e: + logger.error(f"Error in episode {episode_index}: {e}") + swarm_client.abort_episode(episode_uuid) + raise + + +async def run_all_episodes(request: Request, is_stream: bool) -> List[EpisodeResult]: + """Run all episodes in parallel.""" + episode_tasks = [run_single_episode(i, request, is_stream) for i in range(NUM_REPEAT)] + results = await asyncio.gather(*episode_tasks, return_exceptions=True) + valid_results: List[EpisodeResult] = [] + for result in results: + if isinstance(result, HTTPException) and result.status_code == 400: + # Propagate context_length_exceeded directly to client + raise result + elif isinstance(result, Exception): + logger.warning(f"Episode failed: {result}") + elif isinstance(result, EpisodeResult): + valid_results.append(result) + if not valid_results: + raise HTTPException(status_code=500, detail="All episodes failed") + return valid_results + + +async def finalize_episodes(task: Task, valid_results: List[EpisodeResult], rewards: List[float]) -> None: + """Finalize all episodes by sending rewards.""" + assert swarm_client is not None + loop = asyncio.get_event_loop() + for episode_result, reward in zip(valid_results, rewards): + workflow_output = WorkflowOutput(reward=reward, metadata={}) + await loop.run_in_executor( + None, + lambda ep=episode_result, wo=workflow_output: swarm_client.end_episode(task, ep.episode_uuid, wo), + ) + + +async def handle_one2many_request(request: Request, request_id: str) -> Dict | List[bytes]: + """Handle a one-to-many request.""" + json_data = await request.json() + is_stream = json_data.get('stream', False) + messages = json_data.get('messages', []) + message_latest = messages[-1] + user_query = str(message_latest.get("content", "") if isinstance(message_latest, dict) else "") + + task = Task(task_id=str(uuid.uuid4()), main_query=user_query, metadata={"TRAINING_OBJECTIVE": TRAINING_OBJECTIVE}) + await on_user_submit_new_requests(request_id, task) + + valid_results = await run_all_episodes(request, is_stream) + all_answers = [extract_assistant_message(r.response) for r in valid_results] + rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query) + + await finalize_episodes(task, valid_results, rewards) + + best_idx = rewards.index(max(rewards)) + return valid_results[best_idx].response + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global swarm_client + logger.info(f"Initializing swarm client with URL: {SWARM_URL}") + swarm_client = SwarmClient(SWARM_URL) + logger.info(f"Syncing train config and starting engine with num_repeat={NUM_REPEAT}") + + def start_engine_background(): + try: + swarm_client.auto_sync_train_config_and_start_engine(ajet_job, force_restart=False) + logger.info("Swarm engine is ready!") + except Exception as e: + logger.warning(f"Engine auto-sync skipped or failed: {e}") + + engine_thread = threading.Thread(target=start_engine_background, daemon=True) + engine_thread.start() + yield + + +app = FastAPI(title="OpenClaw Extraversion Training", lifespan=lifespan) + + +@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]) +async def one2many_proxy(request: Request, path: str): + """Main proxy endpoint.""" + global REQUEST_COUNTER + if request.method == "POST" and path == "chat/completions": + REQUEST_COUNTER += 1 + request_id = f"req_{REQUEST_COUNTER}_{uuid.uuid4().hex[:8]}" + logger.info(f"Received chat completion request {request_id}") + response_data = await handle_one2many_request(request, request_id) + if isinstance(response_data, list): + async def stream_chunks(chunks: List[bytes]): + for chunk in chunks: + yield chunk + b"\n\n" + return StreamingResponse(stream_chunks(response_data), media_type="text/event-stream") + return response_data + else: + raise HTTPException(status_code=404, detail="Not Found") + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +@app.get("/requests") +async def get_requests(): + """Get all recorded user requests.""" + return {"requests": get_query_history()} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8090) diff --git a/tutorial/opencode_build_openclaw_interactive_train/mock_user_request.py b/tutorial/opencode_build_openclaw_interactive_train/mock_user_request.py new file mode 100644 index 0000000..6ad6a6d --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/mock_user_request.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +"""Mock user requests using OpenClaw CLI interface.""" + +import json +import subprocess +import time +import os +import random +from typing import List, Dict + +GATEWAY_PORT = os.getenv("OPENCLAW_PORT", "18789") + +def load_dataset(filepath: str = "extraversion_questions.json") -> List[Dict]: + """Load personality manipulation dataset.""" + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + + +def generate_agent_name() -> str: + """Generate a random agent name.""" + adjectives = ["happy", "quick", "bright", "clever", "bold", "calm", "eager", "gentle"] + nouns = ["fox", "wolf", "bear", "eagle", "hawk", "lion", "tiger", "owl"] + return f"{random.choice(adjectives)}_{random.choice(nouns)}_{random.randint(1000, 9999)}" + + +def create_agent(agent_name: str) -> bool: + """Create a new agent using OpenClaw CLI.""" + try: + workspace = f"/root/.openclaw/workspace-{agent_name}" + result = subprocess.run( + ["openclaw", "agents", "add", agent_name, "--workspace", workspace, "--non-interactive"], + capture_output=True, + text=True, + timeout=60 + ) + if result.returncode == 0: + print(f"Created agent: {agent_name}") + return True + else: + print(f"Error creating agent {agent_name}: {result.stderr}") + return False + except Exception as e: + print(f"Error creating agent: {str(e)}") + return False + + +def delete_agent(agent_name: str) -> bool: + """Delete an agent using OpenClaw CLI.""" + try: + result = subprocess.run( + ["openclaw", "agents", "delete", agent_name, "--force"], + capture_output=True, + text=True, + timeout=60 + ) + if result.returncode == 0: + print(f"Deleted agent: {agent_name}") + return True + else: + print(f"Error deleting agent {agent_name}: {result.stderr}") + return False + except Exception as e: + print(f"Error deleting agent: {str(e)}") + return False + + +def send_openclaw_message(agent_name: str, message: str) -> str: + """Send message via OpenClaw CLI to specific agent.""" + try: + result = subprocess.run( + ["openclaw", "agent", "--agent", agent_name, "--message", message], + capture_output=True, + text=True, + timeout=300 + ) + return result.stdout if result.returncode == 0 else f"Error: {result.stderr}" + except Exception as e: + return f"Error: {str(e)}" + + +def main(): + """Main loop to send requests from dataset.""" + print("Starting OpenClaw mock user requests") + + # Load dataset + dataset = load_dataset() + random.shuffle(dataset) + print(f"Loaded {len(dataset)} questions from dataset\n") + + # Process dataset in chunks of 5 + for chunk_start in range(0, len(dataset), 5): + chunk = dataset[chunk_start:chunk_start + 5] + + # Generate random agent name + agent_name = generate_agent_name() + print(f"\n=== Creating agent: {agent_name} ===\n") + + # Create agent + if not create_agent(agent_name): + print(f"Failed to create agent, skipping chunk") + continue + + # Send 5 messages + for i, item in enumerate(chunk): + question = item.get("Question", "") + print(f"[{agent_name}/{i+1}/5] Sending: {question[:80]}...") + response = send_openclaw_message(agent_name, question) + print(f"Response: {response[:200]}...\n") + time.sleep(2) + + # Delete agent + delete_agent(agent_name) + print(f"\n=== Deleted agent: {agent_name} ===\n") + + print("\nAll agents processed successfully") + + +if __name__ == "__main__": + main() diff --git a/tutorial/opencode_build_openclaw_interactive_train/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_interactive_train/on_compute_relative_reward.py new file mode 100644 index 0000000..53894a9 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/on_compute_relative_reward.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +"""Compute relative rewards based on extraversion, relevance, diversity, and repetition quality.""" + +import os +import collections +from typing import List, Dict + +from loguru import logger +from beast_logger import print_listofdict +from openjudge.graders.base_grader import GraderMode, GraderScore, GraderRank +from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.format.ngram_repetition_penalty import NgramRepetitionPenaltyGrader +from openjudge.models import OpenAIChatModel +try: + from ajet.utils.compute_madness import has_repeat +except ImportError: + # Fallback: when running outside the full ajet package (e.g. tests), + # resolve relative to the repo root. + import sys as _sys + from pathlib import Path as _Path + _repo_root = str(_Path(__file__).resolve().parents[2]) + if _repo_root not in _sys.path: + _sys.path.insert(0, _repo_root) + from ajet.utils.compute_madness import has_repeat + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- +REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # pointwise | listwise +API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") +BASE_URL = os.getenv("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") +JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen-plus") + +# Reward weights (must sum to 1.0) +W_EXTRAVERSION = float(os.getenv("W_EXTRAVERSION", "0.5")) +W_RELEVANCE = float(os.getenv("W_RELEVANCE", "0.3")) +W_DIVERSITY = float(os.getenv("W_DIVERSITY", "0.2")) + +# Cross-request history buffer size +HISTORY_MAX_SIZE = int(os.getenv("DIVERSITY_HISTORY_SIZE", "25")) + +# --------------------------------------------------------------------------- +# Shared model & graders +# --------------------------------------------------------------------------- +judge_model = OpenAIChatModel( + model=JUDGE_MODEL, + api_key=API_KEY, + base_url=BASE_URL, +) + +# --- Extraversion grader (custom LLM prompt) --- +EXTRAVERSION_PROMPT = """You are evaluating responses for extraversion personality traits. + +Extraversion characteristics include: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions +- Use of exclamation marks and emotional words + +Rate the response on a scale of 0.0-1.0: +0.0 = Highly introverted (reserved, quiet, minimal emotion) +1.0 = Highly extraverted (energetic, enthusiastic, very expressive) + +Question: {question} +Response: {response} + +Return a json object with exactly two fields: +- "score": float between 0.0 and 1.0 +- "reason": brief explanation""" + +pointwise_grader = LLMGrader( + name="extraversion_pointwise", + mode=GraderMode.POINTWISE, + description="Evaluate extraversion traits", + model=judge_model, + template=EXTRAVERSION_PROMPT, +) + +# --- Relevance grader (built-in OpenJudge) --- +relevance_grader = RelevanceGrader(model=judge_model) + +# --- Repetition penalty grader (deterministic, no LLM) --- +# Detects n-gram repetition within a single response. +# Returns score in [0, 1] where 1 = no repetition, 0 = heavily repetitive. +repetition_grader = NgramRepetitionPenaltyGrader( + n=4, # 4-gram detection + penalty_threshold=0.15, # trigger penalty when >15% of n-grams are repeated + use_soft_penalty=True, # gradual penalty rather than cliff + max_penalty=-1.0, # worst case: score becomes 0 + min_scaling=0.0, # at max penalty, multiplier goes to 0 +) + +# --------------------------------------------------------------------------- +# In-process history of recent responses (for cross-request diversity) +# --------------------------------------------------------------------------- +_response_history: List[str] = [] + + +def record_responses_to_history(contents: List[str]) -> None: + """Append new responses to the rolling history buffer.""" + _response_history.extend(contents) + # Trim to keep only the most recent entries + while len(_response_history) > HISTORY_MAX_SIZE: + _response_history.pop(0) + + +# --------------------------------------------------------------------------- +# Diversity: n-gram overlap (fast, deterministic, no LLM needed) +# --------------------------------------------------------------------------- +def _get_ngrams(text: str, n: int = 3) -> collections.Counter: + """Extract character-level n-grams from text.""" + tokens = text.lower().split() + if len(tokens) < n: + return collections.Counter(tokens) + return collections.Counter( + tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1) + ) + + +def _ngram_overlap(text_a: str, text_b: str, n: int = 3) -> float: + """Compute Jaccard overlap of n-grams between two texts. Returns 0-1.""" + ngrams_a = _get_ngrams(text_a, n) + ngrams_b = _get_ngrams(text_b, n) + if not ngrams_a or not ngrams_b: + return 0.0 + intersection = sum((ngrams_a & ngrams_b).values()) + union = sum((ngrams_a | ngrams_b).values()) + return intersection / union if union > 0 else 0.0 + + +def compute_diversity_scores(contents: List[str], history: List[str]) -> List[float]: + """ + Compute a diversity score for each response (0 = duplicate, 1 = fully unique). + + Two components: + 1. Within-batch: average pairwise n-gram overlap with other responses in the batch + 2. Cross-request: max n-gram overlap with any response in the history buffer + + Final diversity_score = 1 - max(within_batch_overlap, cross_request_overlap) + """ + n = len(contents) + scores = [] + for i, content_i in enumerate(contents): + # Within-batch overlap: average overlap with other responses in this batch + if n > 1: + batch_overlaps = [ + _ngram_overlap(content_i, contents[j]) + for j in range(n) + if j != i + ] + within_batch = max(batch_overlaps) # worst-case overlap within batch + else: + within_batch = 0.0 + + # Cross-request overlap: max overlap with any historical response + if history: + cross_request = max(_ngram_overlap(content_i, h) for h in history) + else: + cross_request = 0.0 + + overlap = max(within_batch, cross_request) + scores.append(1.0 - overlap) + + return scores + + +# --------------------------------------------------------------------------- +# Quality gate: repetition & degeneration detection (deterministic) +# --------------------------------------------------------------------------- +async def compute_quality_scores(contents: List[str]) -> List[float]: + """ + Compute a quality multiplier for each response (0 = degenerate, 1 = clean). + + Combines two signals: + 1. NgramRepetitionPenaltyGrader — detects looping/repeated n-gram blocks + 2. compute_string_madness — catches nonsense chars, special token leaks, + character-level repetition + + Returns a score in [0, 1] that will be used as a *multiplier* on the + composite reward, so degenerate outputs get crushed to near-zero. + """ + scores = [] + for content in contents: + # --- Signal 1: n-gram repetition (OpenJudge) --- + try: + rep_result = await repetition_grader.aevaluate(response=content) + # NgramRepetitionPenaltyGrader returns penalty in [-1, 0]: + # 0 = no repetition, -1 = max repetition + # Convert to quality: add 1 → [0, 1] + ngram_penalty = rep_result.score if isinstance(rep_result, GraderScore) else 0.0 + ngram_score = 1.0 + ngram_penalty + except Exception as e: + logger.warning(f"NgramRepetitionPenaltyGrader failed: {e}") + ngram_score = 1.0 + + # --- Signal 2: string madness (char-level degeneration) --- + # Only check for word/char repetition and special token leaks. + # We pass checklist=[] to skip the non-ASCII check (accented + # characters like é are legitimate), and check repetition manually. + madness_score = 1.0 # assume clean + if "<|im_start|>" in content: + madness_score = 0.0 + elif has_repeat(content.split(), remember_n_words=5, patience_max=10): + madness_score = 0.0 + elif has_repeat(content, remember_n_words=4, patience_max=200): + madness_score = 0.0 + + # Combined quality: take the minimum (strictest gate wins) + quality = max(0.0, min(1.0, min(ngram_score, madness_score))) + scores.append(quality) + + return scores + + +# --------------------------------------------------------------------------- +# Extraversion scoring (pointwise / listwise) +# --------------------------------------------------------------------------- +def build_listwise_template(n: int) -> str: + """Build a listwise prompt template for n responses.""" + answers_block = "\n".join([f"{i+1}. {{answer_{i+1}}}" for i in range(n)]) + return f"""You are ranking multiple responses based on extraversion personality traits. + +Extraversion characteristics include: +- Outgoing, energetic, enthusiastic tone +- Social engagement and excitement +- Positive, upbeat language +- Action-oriented expressions + +Question: {{question}} + +Responses to rank: +{answers_block} + +Rank these responses from most extraverted to least extraverted. +Return a json object with exactly two fields: +- "rank": list of integers (1-indexed) ordered from most to least extraverted, e.g. [2, 1, 3] +- "reason": brief explanation of the ranking""" + + +async def compute_pointwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using pointwise grading.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + result = await pointwise_grader.aevaluate(question=question, response=content) + score = result.score if isinstance(result, GraderScore) else 0.0 + scores.append(score) + return scores + + +async def compute_listwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using listwise ranking.""" + n = len(all_answers) + template = build_listwise_template(n) + grader = LLMGrader( + name="extraversion_listwise", + mode=GraderMode.LISTWISE, + description="Rank responses by extraversion", + model=judge_model, + template=template, + ) + kwargs = {"question": question} + for i, ans in enumerate(all_answers): + kwargs[f"answer_{i+1}"] = ans.get("content", "") + + result = await grader.aevaluate(**kwargs) + + scores = [0.0] * n + if isinstance(result, GraderRank): + for position, idx in enumerate(result.rank): + scores[idx - 1] = 1.0 - (position / (n - 1)) if n > 1 else 0.5 + return scores + + +# --------------------------------------------------------------------------- +# Relevance scoring (built-in OpenJudge RelevanceGrader, score 1-5 → 0-1) +# --------------------------------------------------------------------------- +async def compute_relevance_scores(question: str, all_answers: List[Dict]) -> List[float]: + """Score how relevant each response is to the question. Returns 0-1.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + result = await relevance_grader.aevaluate(query=question, response=content) + if isinstance(result, GraderScore): + # RelevanceGrader returns 1-5; normalise to 0-1 + score = (result.score - 1.0) / 4.0 + else: + score = 0.0 + scores.append(max(0.0, min(1.0, score))) + return scores + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- +async def on_compute_relative_reward( + valid_results: List, + all_answers: List[Dict], + question: str = "", +) -> List[float]: + """ + Compute composite rewards combining extraversion, relevance, diversity, + and a quality gate for repetition/degeneration. + + Final reward = quality * (W_EXTRAVERSION * extraversion + + W_RELEVANCE * relevance + + W_DIVERSITY * diversity) + + The quality multiplier (0-1) acts as a hard gate: degenerate responses + (looping, repeated paragraphs, nonsense characters) get their reward + crushed toward zero regardless of other signal scores. + """ + contents = [a.get("content", "") for a in all_answers] + + # 0. Quality gate (deterministic — fast, runs first) + quality_scores = await compute_quality_scores(contents) + + # 1. Extraversion score (LLM-based) + if REWARD_MODE == "listwise": + extraversion_scores = await compute_listwise_extraversion(question, all_answers) + else: + extraversion_scores = await compute_pointwise_extraversion(question, all_answers) + + # 2. Relevance score (LLM-based) + relevance_scores = await compute_relevance_scores(question, all_answers) + + # 3. Diversity score (deterministic, n-gram overlap) + diversity_scores = compute_diversity_scores(contents, _response_history) + + # Composite reward = quality * weighted_sum + final_scores = [] + for i in range(len(all_answers)): + weighted_sum = ( + W_EXTRAVERSION * extraversion_scores[i] + + W_RELEVANCE * relevance_scores[i] + + W_DIVERSITY * diversity_scores[i] + ) + composite = quality_scores[i] * weighted_sum + final_scores.append(round(composite, 4)) + + # Annotate the answer dict for logging + all_answers[i]["reward"] = final_scores[i] + all_answers[i]["quality"] = round(quality_scores[i], 4) + all_answers[i]["extraversion"] = round(extraversion_scores[i], 4) + all_answers[i]["relevance"] = round(relevance_scores[i], 4) + all_answers[i]["diversity"] = round(diversity_scores[i], 4) + + # Update history buffer with this batch's responses + record_responses_to_history(contents) + + print_listofdict( + all_answers, + header=( + f"on_compute_relative_reward (mode={REWARD_MODE}, " + f"w_ext={W_EXTRAVERSION}, w_rel={W_RELEVANCE}, w_div={W_DIVERSITY}, " + f"quality_gate=multiplicative)" + ), + ) + return final_scores diff --git a/tutorial/opencode_build_openclaw_interactive_train/on_user_submit_new_requests.py b/tutorial/opencode_build_openclaw_interactive_train/on_user_submit_new_requests.py new file mode 100644 index 0000000..11b7932 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/on_user_submit_new_requests.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +"""Handle new user requests and track query history for diversity awareness.""" + +from typing import List, Dict +from loguru import logger +from ajet.schema.task import Task + +# Rolling buffer of recent queries — used to detect repeated / near-duplicate +# questions so the system can log warnings. The response-level diversity +# signal lives in on_compute_relative_reward._response_history. +_query_history: List[Dict] = [] +QUERY_HISTORY_MAX = 100 + + +def get_query_history() -> List[Dict]: + """Return the current query history (read-only copy).""" + return list(_query_history) + + +async def on_user_submit_new_requests(request_id: str, task: Task) -> None: + """ + Store user request metadata when submitted. + + This populates a lightweight in-process history so that: + 1. The /requests endpoint can expose recent queries for debugging. + 2. We can detect if the same question keeps appearing, which signals + a data distribution issue upstream rather than a model problem. + """ + entry = { + "request_id": request_id, + "task_id": task.task_id, + "query": task.main_query, + } + _query_history.append(entry) + + # Trim oldest entries + while len(_query_history) > QUERY_HISTORY_MAX: + _query_history.pop(0) + + logger.info( + f"[on_user_submit] request_id={request_id} " + f"query_len={len(task.main_query)} " + f"history_size={len(_query_history)}" + ) diff --git a/tutorial/opencode_build_openclaw_interactive_train/openclaw.md b/tutorial/opencode_build_openclaw_interactive_train/openclaw.md new file mode 100644 index 0000000..ab9ca3d --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/openclaw.md @@ -0,0 +1,120 @@ +# OpenClaw 概况调研报告 + +## 一、总体概述 + +OpenClaw 是一个诞生于 2025 年的开源个人 AI 助手平台,其核心定位是"AI 智能体的操作系统"(OS for AI Agents)[[1]](https://openclaw.ai/blog/introducing-openclaw)。该项目最初以周末黑客项目(weekend hack)的形式起步,先后经历了 Clawdbot、Moltbot 等命名阶段,最终更名为 OpenClaw 以彰显其开源与社区驱动的本质[[1]](https://openclaw.ai/blog/introducing-openclaw)。项目在极短时间内积累了超过 10 万 GitHub Stars 和数百万访问量,展现出极强的社区吸引力[[1]](https://openclaw.ai/blog/introducing-openclaw)。2026 年 2 月,OpenClaw 被 OpenAI 收购,总部位于旧金山,归属商业/生产力软件行业[[2]](https://pitchbook.com/profiles/company/1318645-09)。 + +这一项目的核心理念可以概括为三个关键词:本地优先(local-first)、用户主权(user sovereignty)、多模型兼容(model-agnostic)。它不是一个简单的聊天机器人,而是一个能够自动化处理邮件、日历、浏览器操作并拥有持久记忆和完整系统访问权限的全功能 AI 代理平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 二、技术架构深度分析 + +### 2.1 基础技术栈 + +OpenClaw 构建于 Node.js(v22.12.0+)之上,支持 macOS、Windows(需 WSL2)和 Linux 三大操作系统[[3]](https://ppaolo.substack.com)。选择 Node.js 作为运行时并非偶然——其事件驱动、非阻塞 I/O 的特性天然适合处理多通道消息的并发场景,同时 JavaScript/TypeScript 生态的丰富性也降低了社区贡献的门槛。 + +### 2.2 分层架构设计 + +OpenClaw 的架构呈现出清晰的分层设计思路[[3]](https://ppaolo.substack.com): + +**通道适配层(Channel Adapters)**:这是系统的"感知层",负责对接 WhatsApp、Telegram、Discord、Slack、Teams、Twitch、Google Chat 等主流通讯平台[[1]](https://openclaw.ai/blog/introducing-openclaw)。每个适配器处理该平台特有的认证协议、消息解析、访问控制和出站格式化。这种设计使得新增通道只需实现标准接口,无需改动核心逻辑。 + +**控制接口层(Control Interfaces)**:提供 Web UI、CLI、macOS 原生应用和移动端等多种交互方式[[3]](https://ppaolo.substack.com),确保用户可以在不同场景下管理和监控 AI 代理。 + +**网关控制平面(Gateway Control Plane)**:作为系统的"中枢神经",负责请求路由、负载均衡和全局策略执行[[3]](https://ppaolo.substack.com)。 + +**代理运行时(Agent Runtime)**:这是架构中最核心的部分,包含以下关键组件[[3]](https://ppaolo.substack.com): +- 会话解析(Session Resolution):识别和管理用户会话上下文 +- 上下文组装(Context Assembly):将历史对话、记忆、工具状态等信息组装为模型可理解的上下文 +- 执行循环(Execution Loop):驱动 AI 代理的思考-行动-观察循环 +- 系统提示词架构(System Prompt Architecture):管理和组合系统级指令 + +**数据存储层**:涵盖会话状态压缩(Session State Compaction)、记忆搜索(Memory Search)、存储索引(Storage Indexing)和嵌入向量提供者选择(Embedding Provider Selection)[[3]](https://ppaolo.substack.com)。会话状态压缩机制尤其值得关注——它解决了长对话场景下上下文窗口溢出的问题,通过智能摘要保留关键信息。 + +### 2.3 多代理协作能力 + +OpenClaw 支持多代理路由(Multi-Agent Routing)、代理间通信(Agent-to-Agent Communication)、定时任务(Scheduled Actions)和外部触发器(External Triggers)[[3]](https://ppaolo.substack.com)。这意味着用户可以构建由多个专业化 AI 代理组成的协作系统——例如一个代理负责邮件分类,另一个负责日程安排,第三个负责代码审查,它们之间可以相互通信和协调。 + +--- + +## 三、AI 模型支持生态 + +### 3.1 多供应商模型矩阵 + +OpenClaw 的模型支持策略体现了"不绑定单一供应商"的设计哲学,目前支持的模型包括[[4]](https://docs.openclaw.ai): + +| 供应商 | 模型 | +|--------|------| +| OpenAI | GPT-5.1, Codex | +| Anthropic | Claude Opus 4.6 | +| Google | Gemini 3 Pro | +| Z.AI | GLM 4.7 | +| Moonshot AI | Kimi K2.5 | +| MiniMax | M2.1 | +| 阿里云 | Qwen | +| 本地运行时 | Ollama | +| 其他 | OpenCode Zen, Synthetic 等 | + +### 3.2 战略意义分析 + +值得注意的是,OpenClaw 同时集成了美国和中国的 AI 模型[[4]](https://docs.openclaw.ai)[[5]](https://scmp.com)。这一策略具有多重意义: + +首先是成本优化——不同模型在不同任务上的性价比差异显著,用户可以为简单任务选择低成本模型,为复杂推理选择高端模型。其次是冗余保障——当某一供应商服务中断时,系统可以自动切换到备选模型。最后是能力互补——中美模型在中英文处理、代码生成、多模态理解等方面各有所长。 + +通过 Ollama 支持本地 LLM 运行[[4]](https://docs.openclaw.ai),OpenClaw 还为对数据隐私有极高要求的用户提供了完全离线的选项,这在企业级应用场景中尤为重要。 + +--- + +## 四、部署方案与硬件要求 + +### 4.1 云端部署 + +云端部署提供一键式快速启动,内置安全加固措施包括防火墙规则、非 root 执行和弹性扩缩容[[6]](https://help.apiyi.com)。优势在于零运维负担和快速上线,但需要承担月度费用,且数据不完全在用户控制之下。 + +### 4.2 本地部署 + +本地部署是 OpenClaw 的核心差异化优势所在。它提供完全的数据隐私保障、离线运行能力和深度定制空间,但对用户的技术能力和硬件配置有一定要求[[6]](https://help.apiyi.com): + +- CPU 推荐:AMD Ryzen 9 7950X 或 Intel Core i9-13900K +- GPU 推荐:NVIDIA RTX 4090 或 RTX 4080 + +这一硬件要求主要针对需要本地运行大语言模型的场景。如果仅使用云端 API 调用模型,硬件要求会大幅降低。 + +### 4.3 安全注意事项 + +安全最佳实践明确建议不要在主力工作机上运行 OpenClaw[[7]](https://safeclaw.io)。这一建议源于 OpenClaw 拥有强大的系统执行能力——包括浏览器自动化、文件系统访问和命令执行——一旦出现提示词注入攻击或配置失误,可能对主机系统造成影响。推荐使用独立的 homelab 服务器或 VPS 进行部署[[1]](https://openclaw.ai/blog/introducing-openclaw)。 + +--- + +## 五、安全体系 + +OpenClaw 的安全架构是多层次的[[3]](https://ppaolo.substack.com): + +- 网络安全(Network Security):传输层加密和网络隔离 +- 认证机制(Authentication):多因素身份验证 +- 通道访问控制(Channel Access Control):细粒度的平台级权限管理 +- 工具沙箱(Tool Sandboxing):限制 AI 代理可调用的系统能力 +- 会话边界(Session-based Boundaries):防止跨会话信息泄露 +- 提示词注入防御(Prompt Injection Defenses):抵御恶意输入攻击 +- 机器可检查安全模型(Machine-checkable Security Models):可形式化验证的安全策略[[1]](https://openclaw.ai/blog/introducing-openclaw) + +引入机器可检查安全模型这一点尤其前瞻——它意味着安全策略不仅是文档化的规则,而是可以被自动化工具验证和执行的形式化规范,这在 AI 代理安全领域属于较为领先的实践。 + +--- + +## 六、关键洞察与启示 + +**从周末项目到被 OpenAI 收购的增长路径**:OpenClaw 的发展轨迹揭示了 2025 年 AI 基础设施领域的一个重要趋势——开源 AI 代理框架正在成为大型 AI 公司的战略收购目标。OpenAI 收购 OpenClaw[[2]](https://pitchbook.com/profiles/company/1318645-09),本质上是在补齐其在"AI 代理运行时"层面的能力,从单纯的模型提供商向平台化方向延伸。 + +**本地优先 vs. 云端便利的张力**:OpenClaw 试图在数据主权和使用便利性之间找到平衡点。其双轨部署策略反映了市场的真实需求分化——企业和隐私敏感用户倾向本地部署,而个人开发者和快速原型场景更青睐云端方案。 + +**多模型策略的行业信号**:OpenClaw 广泛集成中美两国 AI 模型的做法[[5]](https://scmp.com),表明在实际应用层面,模型的地缘属性正在让位于实用性考量。这对整个 AI 应用生态的发展方向具有参考意义。 + +**安全作为一等公民**:在 AI 代理拥有系统级执行权限的背景下,OpenClaw 将安全提升到架构设计的核心位置[[7]](https://safeclaw.io),而非事后补丁。这种"安全左移"的理念值得同类项目借鉴。 + +--- + +## 七、结论 + +OpenClaw 代表了 2025 年 AI 代理平台发展的一个典型样本:以开源社区为驱动力,以本地部署和用户数据主权为核心卖点,以多模型兼容和多通道集成为功能支撑,以分层安全架构为信任基础。其从独立项目到被 OpenAI 收购的历程,既验证了 AI 代理基础设施的市场价值,也预示着这一领域正在从碎片化的开源探索走向平台化整合的新阶段。对于关注 AI 代理技术栈演进的开发者和技术决策者而言,OpenClaw 的架构设计和生态策略都具有重要的参考价值。 diff --git a/tutorial/opencode_build_openclaw_interactive_train/test_reward.py b/tutorial/opencode_build_openclaw_interactive_train/test_reward.py new file mode 100644 index 0000000..8b65922 --- /dev/null +++ b/tutorial/opencode_build_openclaw_interactive_train/test_reward.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +"""Test script for on_compute_relative_reward.py using real OpenJudge API. + +Tests four reward dimensions: + 1. Extraversion — enthusiastic responses score higher + 2. Relevance — on-topic responses score higher than off-topic + 3. Diversity — unique responses score higher than near-duplicates + 4. Quality gate — repetitive/degenerate responses get crushed +""" + +import asyncio +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) +os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-311cfac3a0f94ff4b5ddf401f70fa338") + + +async def test_pointwise_composite(): + """Test pointwise composite reward (extraversion + relevance + diversity).""" + print("\n=== Testing Pointwise Composite Reward ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() # fresh history for test isolation + + question = "What are your thoughts on Paris?" + all_answers = [ + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking and the cafes are amazing!"}, + {"content": "Paris is a city in France."}, + {"content": "I absolutely love Paris! The energy on the Champs-Élysées is fantastic and so vibrant!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Composite scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" + assert all(isinstance(s, float) for s in scores), "All scores should be floats" + # Extraverted + relevant responses should beat the flat neutral one + assert scores[0] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + assert scores[2] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_relevance_penalty(): + """Off-topic answers should get lower composite scores than on-topic ones.""" + print("\n=== Testing Relevance Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "What is your favorite food?" + all_answers = [ + # On-topic, extraverted + {"content": "Oh my gosh, I absolutely LOVE sushi! The flavors are incredible and I get so excited every time!"}, + # Off-topic, extraverted (talks about space, not food) + {"content": "WOW space exploration is SO exciting! Rockets launching into the sky fills me with energy!!!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # Both are extraverted, but on-topic should win because of relevance + assert scores[0] > scores[1], \ + f"On-topic extraverted should beat off-topic extraverted: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_diversity_penalty(): + """Near-duplicate answers should get lower diversity scores.""" + print("\n=== Testing Diversity Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about your hobbies." + all_answers = [ + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Near-duplicate of answer 0 + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Unique answer + {"content": "Dancing is my absolute passion! Nothing beats the energy of moving to great music with friends!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # The duplicate pair should have lower diversity than the unique one + div_duplicate = all_answers[0].get("diversity", 1.0) + div_unique = all_answers[2].get("diversity", 0.0) + assert div_unique > div_duplicate, \ + f"Unique response should have higher diversity ({div_unique}) than duplicate ({div_duplicate})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_cross_request_diversity(): + """Answers that repeat historical responses should be penalized.""" + print("\n=== Testing Cross-Request Diversity ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + # Simulate a prior request that produced a response + mod.record_responses_to_history([ + "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!" + ]) + + question = "What do you enjoy doing on weekends?" + all_answers = [ + # Repeats the historical response almost verbatim + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!"}, + # Fresh, unique response + {"content": "Weekends are for exploring new restaurants and trying exotic cuisines! I get so thrilled by new flavors!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + div_stale = all_answers[0].get("diversity", 1.0) + div_fresh = all_answers[1].get("diversity", 0.0) + assert div_fresh > div_stale, \ + f"Fresh response should have higher diversity ({div_fresh}) than stale ({div_stale})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_repetition_penalty(): + """Degenerate looping responses should get near-zero reward.""" + print("\n=== Testing Repetition / Degeneration Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about Dunfermline." + + # Build a degenerate looping response (similar to the real failure case) + good_intro = "Hello! Dunfermline is a charming town in Fife, Scotland, with a rich history." + loop_block = ( + "\n\n---\n\n" + "If you have any specific questions or need more information, just " + "let me know! I'm here to assist you in making your visit to " + "Dunfermline a delightful experience.\n\n---\n\n" + "Looking forward to your wonderful Dunfermline adventures!\n\n---\n\n" + "Thank you for the opportunity to share my thoughts on Dunfermline. " + "If you have any more questions or need assistance, feel free to " + "reach out!" + ) + degenerate_response = good_intro + (loop_block * 15) # repeat the block many times + + all_answers = [ + # Degenerate looping response + {"content": degenerate_response}, + # Clean, concise, extraverted response + {"content": "Dunfermline is absolutely wonderful! The abbey ruins are breathtaking and the town has such vibrant energy. I love the mix of history and modern community spirit there!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" quality={a.get('quality')}, ext={a.get('extraversion')}, " + f"rel={a.get('relevance')}, div={a.get('diversity')}, " + f"reward={a.get('reward')} " + f"content={a['content'][:60]}...") + + quality_degenerate = all_answers[0].get("quality", 1.0) + quality_clean = all_answers[1].get("quality", 0.0) + print(f" Quality scores: degenerate={quality_degenerate}, clean={quality_clean}") + + # The degenerate response should have much lower quality + assert quality_clean > quality_degenerate, \ + f"Clean response quality ({quality_clean}) should exceed degenerate ({quality_degenerate})" + # The clean response should win overall + assert scores[1] > scores[0], \ + f"Clean response ({scores[1]}) should beat degenerate ({scores[0]})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_listwise_composite(): + """Listwise mode should also produce composite rewards.""" + print("\n=== Testing Listwise Composite Reward ===") + os.environ["REWARD_MODE"] = "listwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "What are your thoughts on Paris?" + all_answers = [ + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking!"}, + {"content": "Paris is a city in France."}, + {"content": "I absolutely love Paris! The Champs-Élysées energy is fantastic!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" + # Neutral response should score lowest + assert scores[1] < scores[0] or scores[1] < scores[2], \ + f"Neutral response should score lower than at least one extraverted response: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def main(): + print("Testing on_compute_relative_reward.py — Composite Reward") + print("(extraversion + relevance + diversity + quality gate)") + print("=" * 60) + + results = [] + results.append(await test_pointwise_composite()) + results.append(await test_relevance_penalty()) + results.append(await test_diversity_penalty()) + results.append(await test_cross_request_diversity()) + results.append(await test_repetition_penalty()) + results.append(await test_listwise_composite()) + + print("\n" + "=" * 60) + passed = sum(results) + total = len(results) + print(f"Tests passed: {passed}/{total}") + if not all(results): + names = [ + "pointwise_composite", "relevance_penalty", "diversity_penalty", + "cross_request_diversity", "repetition_penalty", "listwise_composite", + ] + for name, ok in zip(names, results): + if not ok: + print(f" FAILED: {name}") + return all(results) + + +if __name__ == "__main__": + success = asyncio.run(main()) + sys.exit(0 if success else 1) From 4658ce027f7288e945f50a67906bcf99df50da77 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Fri, 20 Mar 2026 18:16:41 +0800 Subject: [PATCH 9/9] Add reward history tracking and visualization for swarm overwatch - Add RewardHistoryEntry and RewardHistoryResponse models for reward data - Implement reward collection and history finalization in swarm_server - Add /get_reward_history API endpoint for visualization - Add ASCII reward curve display in swarm_overwatch UI - Fix typo in config_utils parameter name (convertion_json_fg -> convertion_json_fp) --- .../experimental/swarm_overwatch_utils.py | 13 ++ ajet/tuner_lib/experimental/swarm_server.py | 81 ++++++++- ajet/utils/config_utils.py | 6 +- ajet/utils/swarm_overwatch.py | 161 +++++++++++++++++- 4 files changed, 255 insertions(+), 6 deletions(-) diff --git a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py index 1064289..4439013 100644 --- a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py +++ b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py @@ -2,6 +2,19 @@ from pydantic import BaseModel +class RewardHistoryEntry(BaseModel): + """A single entry in the reward history.""" + global_step: int + mean_reward: float + std_reward: float + timestamp: float # Unix timestamp when this entry was recorded + + +class RewardHistoryResponse(BaseModel): + """Response containing the reward history for visualization.""" + history: List[RewardHistoryEntry] = [] + + class CurrentBatchRolloutPoolInformation(BaseModel): sample_collection_method: str = "" completed_episodes: int = 0 diff --git a/ajet/tuner_lib/experimental/swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py index aa82984..4ba500d 100644 --- a/ajet/tuner_lib/experimental/swarm_server.py +++ b/ajet/tuner_lib/experimental/swarm_server.py @@ -11,7 +11,11 @@ from multiprocessing.managers import DictProxy from typing import Coroutine, Optional, Tuple, List from ajet.utils.process_killer import kill_process_tree -from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation +from ajet.tuner_lib.experimental.swarm_overwatch_utils import ( + CurrentBatchRolloutPoolInformation, + RewardHistoryEntry, + RewardHistoryResponse, +) from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE from ajet.tuner_lib.experimental.interchange_utils import ( SyncTrainConfigRequest, @@ -63,6 +67,14 @@ def register_enable_swarm_mode_routes( if "current_batch_rollout_pool_information" not in shared_mem_dict: shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation() + # Initialize reward history storage for visualization + if "reward_history" not in shared_mem_dict: + shared_mem_dict["reward_history"] = [] # List of RewardHistoryEntry dicts + + # Initialize reward accumulator for collecting rewards of current global step + if "current_rewards" not in shared_mem_dict: + shared_mem_dict["current_rewards"] = [] # [rewards...] + # ------------------------------------------------------------------------------------------------ # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- # --------------------------------- claimed -> unclaimed ---------------------------------------- @@ -166,6 +178,35 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l if episode_uuid in shared_mem_dict["unclaimed_episodes"]: shared_mem_dict["unclaimed_episodes"].remove(episode_uuid) + # -------------------------------------------------------------------------------------- + # -------------------------- reward history management --------------------------------- + # -------------------------------------------------------------------------------------- + + def _finalize_reward_history_for_step(global_step, shared_mem_dict, shared_mem_dict_lock): + """Finalize reward statistics for a given global step and add to reward_history.""" + import numpy as np + + rewards = shared_mem_dict.get("current_rewards", []) + if rewards: + rewards = list(rewards) # Convert proxy to list if needed + mean_reward = float(np.mean(rewards)) + std_reward = float(np.std(rewards)) + + history = shared_mem_dict.get("reward_history", []) + history = list(history) # Convert proxy to list if needed + + entry = RewardHistoryEntry( + global_step=global_step, + mean_reward=mean_reward, + std_reward=std_reward, + timestamp=time.time(), + ) + history.append(entry.model_dump()) + shared_mem_dict["reward_history"] = history + + # Clear current rewards for next step + shared_mem_dict["current_rewards"] = [] + # -------------------------------------------------------------------------------------- # -------------------------- return workflow output ------------------------------------ # -------------------------------------------------------------------------------------- @@ -272,6 +313,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): shared_mem_dict["unclaimed_episodes"] = [] logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + # clear reward tracking + shared_mem_dict["current_rewards"] = [] + shared_mem_dict["reward_history"] = [] + # -------------------------------------------------------------------------------------- # -------------------------- fastapi routes -------------------------------------------- # -------------------------------------------------------------------------------------- @@ -446,7 +491,12 @@ async def update_engine_status(req: UpdateEngineStatusRequest): engine_status_detail = req.engine_status_detail global_step = req.global_step if global_step is not None: + previous_global_step = shared_mem_dict.get("global_step", None) shared_mem_dict["global_step"] = global_step + # When global_step changes, finalize reward statistics for the previous step + if previous_global_step is not None and previous_global_step != global_step: + _finalize_reward_history_for_step(previous_global_step, shared_mem_dict, shared_mem_dict_lock) + if engine_status_detail is not None: shared_mem_dict["engine_status_detail"] = engine_status_detail logger.info(f"[update_engine_status] Engine status set to {req.engine_status}") @@ -636,6 +686,21 @@ async def end_episode(req: EndEpisodeRequest): shared_mem_dict_lock, ) + # Record reward to current_rewards + if workflow_output.reward is not None: + reward_value = workflow_output.reward + # Handle both single reward and list of rewards + if isinstance(reward_value, list): + rewards_to_record = reward_value + else: + rewards_to_record = [reward_value] + + with shared_mem_dict_lock: + current_rewards = shared_mem_dict.get("current_rewards", []) + current_rewards = list(current_rewards) # Convert proxy to list if needed + current_rewards.extend(rewards_to_record) + shared_mem_dict["current_rewards"] = current_rewards + elif episode_type == "eval": if engine_status in ["ENGINE.ROLLING"]: await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) @@ -779,6 +844,20 @@ async def get_current_batch_rollout_pool_information(): logger.error(f"Error getting current batch rollout pool information: {e}") return CurrentBatchRolloutPoolInformation() + # -------------------------------------------------------------------- + # ------------ get reward history for visualization ------------------ + # -------------------------------------------------------------------- + @app.get("/get_reward_history", response_model=RewardHistoryResponse) + async def get_reward_history(): + """Get the reward history for visualization (reward curves).""" + try: + history = shared_mem_dict.get("reward_history", []) + entries = [RewardHistoryEntry(**entry) for entry in history] + return RewardHistoryResponse(history=entries) + except Exception as e: + logger.error(f"Error getting reward history: {e}") + return RewardHistoryResponse(history=[]) + # -------------------------------------------------------------------- # ------------ bring engine back to ENGINE.OFFLINE ------------------- # -------------------------------------------------------------------- diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 9e6e284..c13b6aa 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -98,7 +98,7 @@ def _dive_to_set_value(config, dotted_key, value): sub_config[keys[-1]] = value -def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone): +def align_parameters(from_config_fp, to_config_fp, convertion_json_fp, backbone): """Align configuration values based on a conversion map. Parameters @@ -107,7 +107,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone) Source YAML path to read values from. to_config_fp : str Destination YAML path that is updated in place. - convertion_json_fg : str + convertion_json_fp : str JSON path mapping dotted keys between configs. backbone : str Backbone identifier used for framework-specific alignment. @@ -121,7 +121,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone) # read convertion json import json - with open(convertion_json_fg, "r", encoding="utf-8") as file: + with open(convertion_json_fp, "r", encoding="utf-8") as file: convertion_json = json.load(file) logger.success("----------------------------------------------------") diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index f8003b1..ffbb7d5 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -17,7 +17,10 @@ from rich.text import Text from loguru import logger -from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation +from ajet.tuner_lib.experimental.swarm_overwatch_utils import ( + CurrentBatchRolloutPoolInformation, + RewardHistoryResponse, +) class SwarmOverwatch: @@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]: # logger.error(f"Failed to fetch pool info: {e}") return None + def fetch_reward_history(self) -> Optional[RewardHistoryResponse]: + """Fetch reward history from server for visualization""" + try: + response = self._httpx_client.get( + f"{self.server_url}/get_reward_history", + timeout=5.0, + ) + response.raise_for_status() + data = RewardHistoryResponse.model_validate(response.json()) + return data + except Exception as e: + logger.error(f"Failed to fetch reward history: {e}") + return None + def create_header( self, info: Optional[CurrentBatchRolloutPoolInformation] = None ) -> Panel: @@ -450,6 +467,141 @@ def create_dashboard( return layout + def display_reward_curve(self): + """Display ASCII reward curve in terminal""" + self.console.clear() + + # Fetch reward history + history = self.fetch_reward_history() + if history is None or not history.history: + self.console.print("[bold yellow]No reward history available yet.[/bold yellow]") + self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]") + self.console.print("\n[dim]Press Enter to return to menu...[/dim]") + input() + return + + # Get terminal size + terminal_width = self.console.width or 80 + terminal_height = self.console.height or 24 + + # Reserve space for header, labels, and footer + chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels + chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis + + # Extract data + global_steps = [entry.global_step for entry in history.history] + mean_rewards = [entry.mean_reward for entry in history.history] + + # Calculate y-axis range with padding + y_min = min(mean_rewards) + y_max = max(mean_rewards) + y_range = y_max - y_min + if y_range == 0: + y_range = 1.0 # Avoid division by zero + y_min -= 0.5 + y_max += 0.5 + else: + # Add 10% padding + y_min -= y_range * 0.1 + y_max += y_range * 0.1 + y_range = y_max - y_min + + # Calculate x-axis range + x_min = min(global_steps) + x_max = max(global_steps) + x_range = x_max - x_min + if x_range == 0: + x_range = 1 + + # Create the chart grid + chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)] + + # Plot the data points + for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)): + # Map to chart coordinates + x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + + # Invert y because terminal coordinates go top-down + y = chart_height - 1 - y + + # Clamp to valid range + x = max(0, min(chart_width - 1, x)) + y = max(0, min(chart_height - 1, y)) + + # Draw point + chart[y][x] = '*' + + # Connect points with lines if there are multiple points + if len(global_steps) > 1: + for i in range(len(global_steps) - 1): + step1, reward1 = global_steps[i], mean_rewards[i] + step2, reward2 = global_steps[i + 1], mean_rewards[i + 1] + + x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + + y1 = chart_height - 1 - y1 + y2 = chart_height - 1 - y2 + + # Simple line drawing between points + steps_between = max(abs(x2 - x1), abs(y2 - y1)) + if steps_between > 0: + for s in range(1, steps_between): + t = s / steps_between + x = int(x1 + t * (x2 - x1)) + y = int(y1 + t * (y2 - y1)) + x = max(0, min(chart_width - 1, x)) + y = max(0, min(chart_height - 1, y)) + if chart[y][x] == ' ': + chart[y][x] = '.' + + # Build the output + output = Text() + output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan") + output.append(f" Server: {self.server_url}\n", style="dim") + output.append(f" Data points: {len(global_steps)}\n\n", style="dim") + + # Draw y-axis labels and chart + y_labels = [] + for i in range(chart_height): + y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max + y_labels.append(y_val) + + for i, row in enumerate(chart): + # Y-axis label (only show a few) + if i == 0 or i == chart_height - 1 or i == chart_height // 2: + label = f"{y_labels[i]:8.3f} |" + else: + label = " |" + output.append(label, style="dim") + output.append(''.join(row), style="green") + output.append("\n") + + # X-axis + output.append(" +" + "-" * chart_width + "\n", style="dim") + + # X-axis labels + x_label_line = " " + x_label_line += f"{x_min:<{chart_width // 3}}" + mid_step = x_min + x_range // 2 + x_label_line += f"{mid_step:^{chart_width // 3}}" + x_label_line += f"{x_max:>{chart_width // 3}}" + output.append(x_label_line[:chart_width + 10] + "\n", style="dim") + output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan") + + # Statistics + output.append("\n Statistics:\n", style="bold yellow") + output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green") + output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green") + output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan") + output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan") + + self.console.print(output) + self.console.print("\n[dim]Press Enter to return to menu...[/dim]") + input() def display_latest_llm_call(self): while True: @@ -515,6 +667,7 @@ def choose_run(self) -> str: self.console.print("\n[bold]Choose action:[/bold]") self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch") self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call") + self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve") self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit") choice = input("\n> ").strip().lower() @@ -526,8 +679,12 @@ def choose_run(self) -> str: mode = "replay_latest_llm_call" self.console.clear() continue + elif choice == "c": + self.display_reward_curve() + self.console.clear() + continue else: - self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]") + self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]") def run(self): """Start the monitoring interface"""