diff --git a/judgearena/config.py b/judgearena/config.py new file mode 100644 index 0000000..4c018c6 --- /dev/null +++ b/judgearena/config.py @@ -0,0 +1,181 @@ +"""CLI argument configuration for generation and evaluation entrypoints.""" + +import argparse +import json +from dataclasses import dataclass, field + + +@dataclass +class CliArgs: + dataset: str + model_A: str + model_B: str + judge_model: str + + n_instructions: int | None = None + provide_explanation: bool = False + swap_mode: str = "fixed" + ignore_cache: bool = False + use_tqdm: bool = False + truncate_all_input_chars: int = 8192 + max_out_tokens_models: int = 32768 + max_out_tokens_judge: int = 32768 + max_model_len: int | None = None + chat_template: str | None = None + result_folder: str = "results" + engine_kwargs: dict = field(default_factory=dict) + + def __post_init__(self): + supported_modes = ["fixed", "both"] + assert self.swap_mode in supported_modes, ( + f"Only {supported_modes} modes are supported but got {self.swap_mode}." + ) + + @classmethod + def parse_args(cls): + parser = argparse.ArgumentParser( + prog="Generate completion and evaluate with a judge", + ) + parser.add_argument( + "--dataset", + help="The dataset to use. For instance `alpaca-eval`, `arena-hard`, `m-arena-hard-EU` for instruction " + "tuning cases or `french-contexts`, `spanish-contexts` for base models.", + ) + parser.add_argument( + "--model_A", + required=True, + help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", + ) + parser.add_argument( + "--model_B", + required=True, + help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", + ) + parser.add_argument( + "--judge_model", + required=True, + help="Name of the LLM to use, for instance `Together/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, " + "`VLLM/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, `LangChain/LocalPath` etc", + ) + parser.add_argument( + "--n_instructions", + type=int, + required=False, + ) + parser.add_argument( + "--provide_explanation", + action="store_true", + help="If specified, judge will provide explanation before making a judgement. Does not necessarily improve" + "the accuracy of the judge but enables some result interpretation.", + ) + parser.add_argument( + "--swap_mode", + type=str, + choices=["fixed", "both"], + default="fixed", + help="Model comparison order mode. 'fixed': always use model order A-B. 'both': correct for model order " + "bias by evaluating each instruction twice, once as A-B and once as B-A, and average. This helps account " + "for judge position bias. Default is 'fixed'.", + ) + parser.add_argument( + "--ignore_cache", + action="store_true", + help="If specified, ignore cache of previous completions.", + ) + parser.add_argument( + "--use_tqdm", + action="store_true", + help="If specified, use tqdm, does not work with all model providers, vLLM in particular.", + ) + parser.add_argument( + "--result_folder", + type=str, + required=False, + default="results", + help="The folder to save the results. Defaults to `results`. Evaluation results will be saved in" + " `[result_folder]/[evaluation_name]`.", + ) + parser.add_argument( + "--truncate_all_input_chars", + type=int, + required=False, + default=8192, + help="Character-level truncation applied before tokenization: truncates each instruction " + "before model A/B generation and truncates each completion before judge evaluation.", + ) + parser.add_argument( + "--max_out_tokens_models", + type=int, + required=False, + default=32768, + help=( + "Generation token budget for each model A/B response. For VLLM, keep this <= " + "--max_model_len (if provided)." + ), + ) + parser.add_argument( + "--max_out_tokens_judge", + type=int, + required=False, + default=32768, + help=( + "Generation token budget for the judge response (reasoning + scores). For " + "VLLM, keep this <= --max_model_len (if provided)." + ), + ) + parser.add_argument( + "--max_model_len", + type=int, + required=False, + default=None, + help=( + "Optional total context window for VLLM models (prompt + generation). This is " + "independent from --max_out_tokens_models/--max_out_tokens_judge, which only cap " + "generated tokens. This is useful on smaller GPUs to avoid OOM." + ), + ) + parser.add_argument( + "--chat_template", + type=str, + required=False, + default=None, + help="Jinja2 chat template string to use instead of the model's tokenizer template. " + "If not provided, ChatML is used as fallback for models without a chat template.", + ) + parser.add_argument( + "--engine_kwargs", + type=str, + required=False, + default="{}", + help=( + "JSON dict of engine-specific kwargs forwarded to the underlying engine. " + 'Example for vLLM: \'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.' + ), + ) + args = parser.parse_args() + + try: + engine_kwargs = json.loads(args.engine_kwargs) if args.engine_kwargs else {} + if not isinstance(engine_kwargs, dict): + raise ValueError("engine_kwargs must be a JSON object") + except Exception as e: + raise SystemExit(f"Failed to parse --engine_kwargs: {e}") from e + + return cls( + dataset=args.dataset, + model_A=args.model_A, + model_B=args.model_B, + judge_model=args.judge_model, + n_instructions=args.n_instructions, + provide_explanation=args.provide_explanation, + swap_mode=args.swap_mode, + ignore_cache=args.ignore_cache, + use_tqdm=args.use_tqdm, + truncate_all_input_chars=args.truncate_all_input_chars, + max_out_tokens_models=args.max_out_tokens_models, + max_out_tokens_judge=args.max_out_tokens_judge, + max_model_len=args.max_model_len, + chat_template=args.chat_template, + result_folder=args.result_folder, + engine_kwargs=engine_kwargs, + ) diff --git a/judgearena/eval_utils.py b/judgearena/eval_utils.py new file mode 100644 index 0000000..e56e1e9 --- /dev/null +++ b/judgearena/eval_utils.py @@ -0,0 +1,155 @@ +"""Shared evaluation runtime helpers used by entrypoints and benchmark pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pandas as pd + +from judgearena.evaluate import PairScore, annotate_battles +from judgearena.utils import compute_pref_summary + + +def print_results(results): + """Print battle results in a readable format.""" + print("\n" + "=" * 60) + print("🏆 MODEL BATTLE RESULTS 🏆".center(60)) + print(f"📊 Dataset: {results['dataset']}") + print( + f"🤖 Competitors: Model A: {results['model_A']} vs Model B: {results['model_B']}" + ) + print(f"⚖️ Judge: {results['judge_model']}") + print("📈 Results Summary:") + print(f" Total Battles: {results['num_battles']}") + print(f" Win Rate (A): {results['winrate']:.1%}") + print(f" ✅ Wins: {results['num_wins']}") + print(f" ❌ Losses: {results['num_losses']}") + print(f" 🤝 Ties: {results['num_ties']}") + if results.get("num_missing", 0) > 0: + print(f" ❓ Missing: {results['num_missing']}") + + per_category = results.get("per_category") + if per_category: + print("\nPer-Category Breakdown:") + print( + f" {'Category':<14} | {'Win Rate(A)':>11} | {'Wins':>4} | {'Losses':>6} | {'Ties':>4}" + ) + print(f" {'-' * 14}-+-{'-' * 11}-+-{'-' * 4}-+-{'-' * 6}-+-{'-' * 4}") + for cat, stats in sorted(per_category.items()): + print( + f" {cat:<14} | {stats['winrate']:>11.1%} | " + f"{stats['num_wins']:>4} | {stats['num_losses']:>6} | {stats['num_ties']:>4}" + ) + + per_turn = results.get("per_turn") + if per_turn: + print("\nPer-Turn Breakdown:") + for turn, stats in sorted(per_turn.items()): + print( + f" Turn {turn} Win Rate(A): {stats['winrate']:.1%} " + f"(W:{stats['num_wins']} L:{stats['num_losses']} T:{stats['num_ties']})" + ) + print("=" * 60 + "\n") + + +def _compute_grouped_stats( + preferences: pd.Series, + metadata: list[dict[str, object]], + group_by: str, +) -> dict[object, dict[str, float | int]]: + grouped: dict[object, list[float]] = {} + for meta, pref in zip(metadata, preferences, strict=True): + key = meta.get(group_by) + if key is None: + continue + grouped.setdefault(key, []).append(pref) + return {key: compute_pref_summary(pd.Series(vals)) for key, vals in grouped.items()} + + +def _parse_preferences_from_annotations( + annotations: list, + score_parser: PairScore, +) -> pd.Series: + return pd.Series( + [ + score_parser.parse_model_raw(annotation.judge_completion) + for annotation in annotations + ] + ) + + +@dataclass +class JudgeAnnotationResult: + annotations: list + annotations_reversed: list + metadata_for_annotations: list[dict[str, object]] + metadata_for_reversed_annotations: list[dict[str, object]] + preferences: pd.Series + combined_metadata: list[dict[str, object]] + + +def _make_judge_annotation( + *, + judge_chat_model, + instructions: list[str], + completions_A: list[str], + completions_B: list[str], + metadata: list[dict[str, object]], + score_parser: PairScore, + provide_explanation: bool, + swap_mode: str, + truncate_input_chars: int | None, + use_tqdm: bool, + system_prompt: str | None = None, + user_prompt_template: str | None = None, +) -> JudgeAnnotationResult: + if not instructions: + raise ValueError("instructions must be non-empty") + + annotations = annotate_battles( + judge_chat_model=judge_chat_model, + instructions=instructions, + completions_A=completions_A, + completions_B=completions_B, + provide_explanation=provide_explanation, + system_prompt=system_prompt, + user_prompt_template=user_prompt_template, + truncate_input_chars=truncate_input_chars, + use_tqdm=use_tqdm, + ) + preference_parts = [_parse_preferences_from_annotations(annotations, score_parser)] + + annotations_reversed: list = [] + metadata_for_reversed_annotations: list[dict[str, object]] = [] + combined_metadata = list(metadata) + + if swap_mode == "both": + print("Correction for judge bias towards a certain model position is set.") + print("Evaluating completions with models reversed.") + annotations_reversed = annotate_battles( + judge_chat_model=judge_chat_model, + instructions=instructions, + completions_A=completions_B, + completions_B=completions_A, + provide_explanation=provide_explanation, + system_prompt=system_prompt, + user_prompt_template=user_prompt_template, + truncate_input_chars=truncate_input_chars, + use_tqdm=use_tqdm, + ) + prefs_reversed = _parse_preferences_from_annotations( + annotations_reversed, score_parser + ) + preference_parts.append(1 - prefs_reversed) + metadata_for_reversed_annotations = list(metadata) + combined_metadata.extend(metadata) + + preferences = pd.concat(preference_parts).reset_index(drop=True) + return JudgeAnnotationResult( + annotations=annotations, + annotations_reversed=annotations_reversed, + metadata_for_annotations=list(metadata), + metadata_for_reversed_annotations=metadata_for_reversed_annotations, + preferences=preferences, + combined_metadata=combined_metadata, + ) diff --git a/judgearena/evaluate.py b/judgearena/evaluate.py index ee12481..c86d123 100644 --- a/judgearena/evaluate.py +++ b/judgearena/evaluate.py @@ -51,18 +51,31 @@ def get_regexp_match(self, s: str, regex: str, group_index: int = 1): return float(m.group(group_index).strip(" ")) +_COMPLETION_LABEL_SINGLE = "Answer" +_COMPLETION_LABEL_MULTI_TURN = "Conversation with User" +_EXPLANATION_SUFFIX = ", first starts with an explanation of your judgement" +_SCORE_FENCE = "\n```" + + def load_judge_system_and_user_prompt( provide_explanation: bool = True, + multi_turn: bool = False, ) -> tuple[str, str]: - # Prepare judge - with open(Path(__file__).parent / "prompts" / "system-prompt.txt") as f: - system_prompt = str(f.read()) + prompts_dir = Path(__file__).parent / "prompts" + system_prompt = (prompts_dir / "system-prompt.txt").read_text() prompt_filename = ( "prompt-with-explanation.txt" if provide_explanation else "prompt.txt" ) - with open(Path(__file__).parent / "prompts" / prompt_filename) as f: - user_prompt_template = str(f.read()) + user_prompt_template = (prompts_dir / prompt_filename).read_text() + user_prompt_template = user_prompt_template.replace( + "{completion_label}", + _COMPLETION_LABEL_MULTI_TURN if multi_turn else _COMPLETION_LABEL_SINGLE, + ) + user_prompt_template = user_prompt_template.replace( + "{explanation_suffix}", + _EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE, + ) return system_prompt, user_prompt_template @@ -70,11 +83,14 @@ def load_judge_system_and_user_prompt( def resolve_judge_prompts( *, provide_explanation: bool, + multi_turn: bool = False, system_prompt: str | None = None, user_prompt_template: str | None = None, ) -> tuple[str, str]: default_system_prompt, default_user_prompt_template = ( - load_judge_system_and_user_prompt(provide_explanation=provide_explanation) + load_judge_system_and_user_prompt( + provide_explanation=provide_explanation, multi_turn=multi_turn + ) ) return ( system_prompt if system_prompt is not None else default_system_prompt, diff --git a/judgearena/generate.py b/judgearena/generate.py index 88a7d53..f9f0885 100644 --- a/judgearena/generate.py +++ b/judgearena/generate.py @@ -57,6 +57,167 @@ def generate_instructions( return df_outputs +def _set_temperature_on_model(chat_model, temperature: float) -> None: + if hasattr(chat_model, "set_temperature"): + chat_model.set_temperature(temperature) + return + if hasattr(chat_model, "temperature"): + chat_model.temperature = temperature + + +def _infer_grouped_by_temperature( + *, + model_spec: str, + provider: str, + max_tokens: int | None, + model_kwargs: dict, + base_model, + inputs: list, + temperatures: list[float], + use_tqdm: bool, +) -> list[str]: + outputs: list[str] = [""] * len(inputs) + groups: dict[float, list[int]] = {} + for idx, temp in enumerate(temperatures): + groups.setdefault(float(temp), []).append(idx) + + for temp in sorted(groups.keys()): + idxs = groups[temp] + group_inputs = [inputs[i] for i in idxs] + + if provider in {"VLLM", "LlamaCpp"}: + _set_temperature_on_model(base_model, temp) + group_model = base_model + else: + group_model = make_model( + model_spec, max_tokens=max_tokens, temperature=temp, **model_kwargs + ) + + group_outs = do_inference( + chat_model=group_model, + inputs=group_inputs, + use_tqdm=use_tqdm, + ) + for i, out in zip(idxs, group_outs, strict=True): + outputs[i] = out + + return outputs + + +def generate_multiturn( + questions: pd.DataFrame, + model: str, + truncate_input_chars: int | None = 8192, + max_tokens: int | None = 8192, + use_tqdm: bool = True, + temperature_config: dict[str, float] | None = None, + **model_kwargs, +) -> pd.DataFrame: + """Generate two-turn completions for MT-Bench style questions.""" + provider = model.split("/")[0] + use_category_temperatures = temperature_config is not None + local_provider = provider in {"VLLM", "LlamaCpp"} + + if use_category_temperatures and local_provider: + chat_model = make_model( + model, max_tokens=max_tokens, temperature=0.0, **model_kwargs + ) + else: + chat_model = make_model(model, max_tokens=max_tokens, **model_kwargs) + + system_prompt = "You are a helpful assistant." + idxs = questions.index.tolist() + temperatures: list[float] = [] + if use_category_temperatures: + temperatures = [ + temperature_config.get(str(questions.loc[idx].get("category") or ""), 0.7) + for idx in idxs + ] + + turn1_template = ChatPromptTemplate.from_messages( + [("system", system_prompt), ("user", "{user_prompt}")] + ) + turn1_inputs = turn1_template.batch( + [ + {"user_prompt": truncate(row["turn_1"], max_len=truncate_input_chars)} + for _, row in questions.iterrows() + ] + ) + + if use_category_temperatures: + completions_turn_1 = _infer_grouped_by_temperature( + model_spec=model, + provider=provider, + max_tokens=max_tokens, + model_kwargs=model_kwargs, + base_model=chat_model, + inputs=turn1_inputs, + temperatures=temperatures, + use_tqdm=use_tqdm, + ) + else: + completions_turn_1 = do_inference( + chat_model=chat_model, + inputs=turn1_inputs, + use_tqdm=use_tqdm, + ) + + turn2_inputs = [] + for (_, row), t1_answer in zip( + questions.iterrows(), completions_turn_1, strict=True + ): + if row["turn_2"] is None: + turn2_inputs.append( + turn1_template.invoke({"user_prompt": "No follow-up question."}) + ) + else: + multi_turn_template = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ("user", "{turn_1}"), + ("assistant", "{turn_1_answer}"), + ("user", "{turn_2}"), + ] + ) + turn2_inputs.append( + multi_turn_template.invoke( + { + "turn_1": truncate(row["turn_1"], max_len=truncate_input_chars), + "turn_1_answer": truncate( + str(t1_answer), max_len=truncate_input_chars + ), + "turn_2": truncate(row["turn_2"], max_len=truncate_input_chars), + } + ) + ) + + if use_category_temperatures: + completions_turn_2 = _infer_grouped_by_temperature( + model_spec=model, + provider=provider, + max_tokens=max_tokens, + model_kwargs=model_kwargs, + base_model=chat_model, + inputs=turn2_inputs, + temperatures=temperatures, + use_tqdm=use_tqdm, + ) + else: + completions_turn_2 = do_inference( + chat_model=chat_model, + inputs=turn2_inputs, + use_tqdm=use_tqdm, + ) + + return pd.DataFrame( + data={ + "instruction_index": idxs, + "completion_turn_1": completions_turn_1, + "completion_turn_2": completions_turn_2, + }, + ) + + def generate_base( instructions: pd.Series, model: str, diff --git a/judgearena/generate_and_evaluate.py b/judgearena/generate_and_evaluate.py index 1f2242b..8502201 100644 --- a/judgearena/generate_and_evaluate.py +++ b/judgearena/generate_and_evaluate.py @@ -15,6 +15,7 @@ from judgearena.evaluate import judge_and_parse_prefs, resolve_judge_prompts from judgearena.generate import generate_base, generate_instructions from judgearena.instruction_dataset import load_instructions +from judgearena.mt_bench.mt_bench_utils import run_mt_bench from judgearena.repro import _to_jsonable, write_run_metadata from judgearena.utils import ( cache_function_dataframe, @@ -285,6 +286,9 @@ def main(args: CliArgs): # set_langchain_cache() ignore_cache = args.ignore_cache + if args.dataset == "mt-bench": + return run_mt_bench(args, ignore_cache) + # Currrently, we run context evaluation is_fluency_task = "fluency" in args.dataset if is_fluency_task: diff --git a/judgearena/instruction_dataset/__init__.py b/judgearena/instruction_dataset/__init__.py index 2f06848..ec57cac 100644 --- a/judgearena/instruction_dataset/__init__.py +++ b/judgearena/instruction_dataset/__init__.py @@ -5,7 +5,12 @@ def load_instructions(dataset: str, n_instructions: int | None = None) -> pd.DataFrame: - if "m-arena-hard" in dataset: + if dataset == "mt-bench": + from judgearena.instruction_dataset.mt_bench import load_mt_bench + + df_instructions = load_mt_bench() + + elif "m-arena-hard" in dataset: if dataset == "m-arena-hard": language = None else: diff --git a/judgearena/instruction_dataset/mt_bench.py b/judgearena/instruction_dataset/mt_bench.py new file mode 100644 index 0000000..e2a4233 --- /dev/null +++ b/judgearena/instruction_dataset/mt_bench.py @@ -0,0 +1,168 @@ +import warnings +from pathlib import Path +from urllib.request import urlretrieve + +import pandas as pd +from huggingface_hub import snapshot_download + +from judgearena.utils import data_root + +FASTCHAT_GPT4_REFERENCE_URL = ( + "https://raw.githubusercontent.com/lm-sys/FastChat/main/" + "fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl" +) + + +def _download_gpt4_references(local_dir: Path) -> Path | None: + reference_dir = local_dir / "reference_answer" + reference_dir.mkdir(parents=True, exist_ok=True) + gpt4_reference_path = reference_dir / "gpt-4.jsonl" + if gpt4_reference_path.exists(): + return gpt4_reference_path + try: + urlretrieve(FASTCHAT_GPT4_REFERENCE_URL, gpt4_reference_path) + except Exception as e: + warnings.warn( + "Could not download MT-Bench GPT-4 reference answers from FastChat. " + f"Falling back to inline references from question.jsonl: {e}", + RuntimeWarning, + stacklevel=2, + ) + return None + return gpt4_reference_path + + +def download_mt_bench(local_dir: Path | None = None) -> tuple[Path, Path | None]: + """Download MT-Bench questions and GPT-4 references if missing.""" + if local_dir is None: + local_dir = data_root / "mt-bench" + try: + local_dir.mkdir(parents=True, exist_ok=True) + except PermissionError as e: + raise PermissionError( + f"Cannot create MT-Bench cache directory at {local_dir}. " + "Set environment variable OPENJURY_DATA to a writable location." + ) from e + + question_path = local_dir / "data" / "mt_bench" / "question.jsonl" + if not question_path.exists(): + try: + snapshot_download( + repo_id="lmsys/mt-bench", + repo_type="space", + allow_patterns=[ + "data/mt_bench/question.jsonl", + ], + local_dir=local_dir, + force_download=False, + ) + except Exception as e: + raise RuntimeError( + "Failed to download MT-Bench questions from HuggingFace space " + "'lmsys/mt-bench'. If you're in an offline / restricted-network " + "environment, pre-download the space snapshot and place the " + f"questions file at {question_path}, or set OPENJURY_DATA to " + "point to that directory." + ) from e + if not question_path.exists(): + raise FileNotFoundError( + "Could not locate MT-Bench questions after download. " + f"Expected file at {question_path}." + ) + + gpt4_reference_path = _download_gpt4_references(local_dir) + return question_path, gpt4_reference_path + + +def load_mt_bench() -> pd.DataFrame: + """Load MT-Bench questions and reference answers. + + Downloads MT-Bench questions from the HuggingFace LMSYS space and tries to + load GPT-4 references from FastChat GitHub. If GPT-4 references cannot be + downloaded or parsed, falls back to inline references from question.jsonl. + """ + question_path, ref_path = download_mt_bench() + + questions = pd.read_json(question_path, lines=True).to_dict(orient="records") + + ref_by_id: dict[int | str, list[str]] = {} + use_inline_reference_fallback = ref_path is None + if ref_path is not None: + try: + reference_records = pd.read_json(ref_path, lines=True).to_dict( + orient="records" + ) + for rec in reference_records: + qid = rec.get("question_id", rec.get("id")) + if qid is None: + continue + choices = rec.get("choices") + if not (isinstance(choices, list) and choices): + continue + first_choice = choices[0] + if not isinstance(first_choice, dict): + continue + turns = first_choice.get("turns") + if not isinstance(turns, list): + continue + ref_by_id[qid] = turns + try: + ref_by_id[int(qid)] = turns + except Exception: + pass + except Exception as e: + warnings.warn( + "Failed to parse GPT-4 reference answers from FastChat. " + f"Falling back to inline references from question.jsonl: {e}", + RuntimeWarning, + stacklevel=2, + ) + use_inline_reference_fallback = True + + rows = [] + for rec in questions: + qid_raw = rec.get("question_id", rec.get("id")) + if qid_raw is None: + raise ValueError( + f"MT-Bench question record missing question_id/id: keys={list(rec.keys())}" + ) + try: + qid = int(qid_raw) + except Exception: + qid = qid_raw + + category = rec.get("category") + turns = rec.get("turns") + if isinstance(turns, list): + turn_1 = turns[0] if len(turns) > 0 else None + turn_2 = turns[1] if len(turns) > 1 else None + else: + turn_1 = rec.get("turn_1", rec.get("instruction")) + turn_2 = rec.get("turn_2") + + ref_turns = ref_by_id.get(qid_raw) or ref_by_id.get(qid) + if ref_turns is None and use_inline_reference_fallback: + inline_ref = rec.get("reference") + if isinstance(inline_ref, list): + ref_turns = inline_ref + + ref_turn_1 = ( + ref_turns[0] if isinstance(ref_turns, list) and len(ref_turns) > 0 else None + ) + ref_turn_2 = ( + ref_turns[1] if isinstance(ref_turns, list) and len(ref_turns) > 1 else None + ) + + rows.append( + { + "instruction_index": qid, + "category": category, + "turn_1": turn_1, + "turn_2": turn_2, + "reference_turn_1": ref_turn_1, + "reference_turn_2": ref_turn_2, + "instruction": turn_1, + } + ) + + return pd.DataFrame(rows) diff --git a/judgearena/mt_bench/__init__.py b/judgearena/mt_bench/__init__.py new file mode 100644 index 0000000..49d6058 --- /dev/null +++ b/judgearena/mt_bench/__init__.py @@ -0,0 +1,4 @@ +"""MT-Bench-specific helpers. + +This package intentionally contains MT-Bench specific logic. +""" diff --git a/judgearena/mt_bench/common.py b/judgearena/mt_bench/common.py new file mode 100644 index 0000000..d676e05 --- /dev/null +++ b/judgearena/mt_bench/common.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass + +import pandas as pd + +from judgearena.utils import safe_text + + +@dataclass(frozen=True) +class MTBenchPairwiseRow: + question_id: object + category: str | None + turn_1_question: str + turn_2_question: str + answer_a_1: str + answer_a_2: str + answer_b_1: str + answer_b_2: str + ref_1: str + ref_2: str + + +def iter_mt_bench_pairwise_rows( + *, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + truncate_input_chars: int | None, +) -> Iterator[MTBenchPairwiseRow]: + for question_id in questions.index.tolist(): + row = questions.loc[question_id] + comp_a_row = ( + completions_a.loc[question_id] + if question_id in completions_a.index + else completions_a.iloc[0] + ) + comp_b_row = ( + completions_b.loc[question_id] + if question_id in completions_b.index + else completions_b.iloc[0] + ) + yield MTBenchPairwiseRow( + question_id=question_id, + category=row.get("category"), + turn_1_question=safe_text(row.get("turn_1"), truncate_input_chars), + turn_2_question=safe_text(row.get("turn_2"), truncate_input_chars), + answer_a_1=safe_text( + comp_a_row.get("completion_turn_1", ""), + truncate_input_chars, + ), + answer_a_2=safe_text( + comp_a_row.get("completion_turn_2", ""), + truncate_input_chars, + ), + answer_b_1=safe_text( + comp_b_row.get("completion_turn_1", ""), + truncate_input_chars, + ), + answer_b_2=safe_text( + comp_b_row.get("completion_turn_2", ""), + truncate_input_chars, + ), + ref_1=safe_text(row.get("reference_turn_1"), truncate_input_chars), + ref_2=safe_text(row.get("reference_turn_2"), truncate_input_chars), + ) diff --git a/judgearena/mt_bench/fastchat_compat.py b/judgearena/mt_bench/fastchat_compat.py new file mode 100644 index 0000000..3b0e7ec --- /dev/null +++ b/judgearena/mt_bench/fastchat_compat.py @@ -0,0 +1,493 @@ +"""MT-Bench pairwise judging aligned with FastChat ``llm_judge`` (``data/judge_prompts.jsonl``).""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +import pandas as pd +from langchain_core.prompts import ChatPromptTemplate + +from judgearena.mt_bench.common import iter_mt_bench_pairwise_rows +from judgearena.utils import do_inference + +FASTCHAT_TEMPERATURE_CONFIG: dict[str, float] = { + "writing": 0.7, + "roleplay": 0.7, + "extraction": 0.0, + "math": 0.0, + "coding": 0.0, + "reasoning": 0.0, + "stem": 0.1, + "humanities": 0.1, + "arena-hard-200": 0.0, +} + +FASTCHAT_NEED_REF_CATS: set[str] = { + "math", + "reasoning", + "coding", + "arena-hard-200", +} + +FastChatVerdict = Literal["A", "B", "tie", "error"] +PairwiseWinner = Literal["model_A", "model_B", "tie", "error"] + + +@dataclass(frozen=True) +class FastChatPairwisePrompt: + name: str + system_prompt: str + user_prompt_template: str + multi_turn: bool + ref_based: bool + + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / "prompts" / "mt_bench" +_SYSTEM_BASE_FILE = "system-base.txt" +_USER_SINGLE_BASE_FILE = "user-single-base.txt" +_USER_MULTI_BASE_FILE = "user-multi-base.txt" +_USER_SINGLE_REF_BLOCK_FILE = "user-single-reference-block.txt" +_USER_MULTI_REF_BLOCK_FILE = "user-multi-reference-block.txt" + + +def _load_prompt_text(filename: str) -> str: + path = _PROMPTS_DIR / filename + return path.read_text(encoding="utf-8") + + +def _render_prompt_text(filename: str, **kwargs: str) -> str: + return _load_prompt_text(filename).format(**kwargs) + + +def _build_system_prompt( + *, + user_subject: str, + task_description: str, + begin_instruction: str, + focus_line: str = "", +) -> str: + focus_segment = f"{focus_line} " if focus_line else "" + return _render_prompt_text( + _SYSTEM_BASE_FILE, + user_subject=user_subject, + task_description=task_description, + focus_line=focus_segment, + begin_instruction=begin_instruction, + ) + + +def _build_user_prompt_template(*, multi_turn: bool, ref_based: bool) -> str: + base_filename = _USER_MULTI_BASE_FILE if multi_turn else _USER_SINGLE_BASE_FILE + reference_block = "" + if ref_based: + ref_block_filename = ( + _USER_MULTI_REF_BLOCK_FILE if multi_turn else _USER_SINGLE_REF_BLOCK_FILE + ) + reference_block = _load_prompt_text(ref_block_filename).rstrip("\n") + "\n\n" + return _render_prompt_text(base_filename, reference_block=reference_block) + + +def _load_pairwise_prompt( + *, + name: str, + multi_turn: bool, + ref_based: bool, + system_user_subject: str, + system_task_description: str, + system_begin_instruction: str, + system_focus_line: str = "", +) -> FastChatPairwisePrompt: + return FastChatPairwisePrompt( + name=name, + multi_turn=multi_turn, + ref_based=ref_based, + system_prompt=_build_system_prompt( + user_subject=system_user_subject, + task_description=system_task_description, + begin_instruction=system_begin_instruction, + focus_line=system_focus_line, + ), + user_prompt_template=_build_user_prompt_template( + multi_turn=multi_turn, + ref_based=ref_based, + ), + ) + + +_PAIR_V2 = _load_pairwise_prompt( + name="pair-v2", + multi_turn=False, + ref_based=False, + system_user_subject="question displayed below", + system_task_description=( + "You should choose the assistant that follows the user's instructions and answers " + "the user's question better. Your evaluation should consider factors such as the " + "helpfulness, relevance, accuracy, depth, creativity, and level of detail of their " + "responses." + ), + system_begin_instruction="comparing the two responses and provide a short explanation", +) + +_PAIR_V2_MULTI = _load_pairwise_prompt( + name="pair-v2-multi-turn", + multi_turn=True, + ref_based=False, + system_user_subject="questions", + system_task_description=( + "You should choose the assistant that follows the user's instructions and answers " + "the user's questions better. Your evaluation should consider factors such as the " + "helpfulness, relevance, accuracy, depth, creativity, and level of detail of their " + "responses." + ), + system_focus_line="You should focus on who provides a better answer to the second user question.", + system_begin_instruction=( + "comparing the responses of the two assistants and provide a short explanation" + ), +) + +_PAIR_MATH_V1 = _load_pairwise_prompt( + name="pair-math-v1", + multi_turn=False, + ref_based=True, + system_user_subject="question displayed below", + system_task_description=( + "Your evaluation should consider correctness and helpfulness. You will be given a " + "reference answer, assistant A's answer, and assistant B's answer. Your job is to " + "evaluate which assistant's answer is better." + ), + system_begin_instruction=( + "comparing both assistants' answers with the reference answer. Identify and correct any mistakes" + ), +) + +_PAIR_MATH_V1_MULTI = _load_pairwise_prompt( + name="pair-math-v1-multi-turn", + multi_turn=True, + ref_based=True, + system_user_subject="questions", + system_task_description=( + "Your evaluation should consider correctness and helpfulness. You will be given " + "reference answers, the assistant A's answers, the assistant B's answers. Your job is " + "to determine which assistant provides correct and helpful answers to the second user question." + ), + system_begin_instruction=( + "comparing both assistants' answers with the reference answers. Identify and correct any mistakes" + ), +) + + +def _parse_fastchat_verdict(judgment: str) -> FastChatVerdict: + if "[[A]]" in judgment: + return "A" + if "[[B]]" in judgment: + return "B" + if "[[C]]" in judgment: + return "tie" + return "error" + + +def _map_verdict_to_winner(verdict: FastChatVerdict, swapped: bool) -> PairwiseWinner: + if verdict == "tie": + return "tie" + if verdict == "error": + return "error" + if verdict == "A": + return "model_B" if swapped else "model_A" + if verdict == "B": + return "model_A" if swapped else "model_B" + return "error" + + +def _conservative_winner( + g1: PairwiseWinner, g2: PairwiseWinner +) -> tuple[PairwiseWinner, bool]: + """Conservative position-bias handling (FastChat/MT-Bench paper). + + Declare a winner only if the two orderings agree; otherwise treat as tie. + """ + if g1 == "error" or g2 == "error": + return "error", False + if g1 == g2: + return g1, False + return "tie", True + + +def _winner_to_preference(winner: PairwiseWinner) -> float: + if winner == "model_A": + return 0.0 + if winner == "model_B": + return 1.0 + if winner == "tie": + return 0.5 + return math.nan + + +def _select_prompt(category: str | None, multi_turn: bool) -> FastChatPairwisePrompt: + needs_ref = (category or "") in FASTCHAT_NEED_REF_CATS + if needs_ref and multi_turn: + return _PAIR_MATH_V1_MULTI + if needs_ref: + return _PAIR_MATH_V1 + if multi_turn: + return _PAIR_V2_MULTI + return _PAIR_V2 + + +def _group_indices_by_prompt( + items: list[dict[str, Any]], +) -> dict[str, list[int]]: + grouped: dict[str, list[int]] = {} + for idx, item in enumerate(items): + grouped.setdefault(item["prompt_name"], []).append(idx) + return grouped + + +def _swap_prompt_kwargs(kwargs: dict[str, str], *, multi_turn: bool) -> dict[str, str]: + swapped = dict(kwargs) + if multi_turn: + swapped["answer_a_1"], swapped["answer_b_1"] = ( + swapped["answer_b_1"], + swapped["answer_a_1"], + ) + swapped["answer_a_2"], swapped["answer_b_2"] = ( + swapped["answer_b_2"], + swapped["answer_a_2"], + ) + return swapped + swapped["answer_a"], swapped["answer_b"] = swapped["answer_b"], swapped["answer_a"] + return swapped + + +def _infer_by_prompt_groups( + *, + judge_chat_model, + items: list[dict[str, Any]], + use_tqdm: bool, + swap_answers: bool, +) -> list[str]: + """Run judge inference, grouping by prompt variant for batching.""" + grouped_indices = _group_indices_by_prompt(items) + + judgments: list[str] = [""] * len(items) + for _prompt_name, idxs in grouped_indices.items(): + prompt: FastChatPairwisePrompt = items[idxs[0]]["prompt"] + prompt_template = ChatPromptTemplate.from_messages( + [("system", prompt.system_prompt), ("user", prompt.user_prompt_template)] + ) + + batch_kwargs = [] + for i in idxs: + kwargs = items[i]["prompt_kwargs"] + if swap_answers: + kwargs = _swap_prompt_kwargs(kwargs, multi_turn=prompt.multi_turn) + batch_kwargs.append(kwargs) + + prompt_inputs = prompt_template.batch(batch_kwargs) + outs = do_inference( + chat_model=judge_chat_model, + inputs=prompt_inputs, + use_tqdm=use_tqdm, + ) + for i, out in zip(idxs, outs, strict=True): + judgments[i] = str(out) + return judgments + + +def _build_fastchat_judge_items( + *, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + eval_single: bool, + eval_multi: bool, + truncate_input_chars: int | None, +) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + for pair_row in iter_mt_bench_pairwise_rows( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + truncate_input_chars=truncate_input_chars, + ): + category = pair_row.category + if eval_single: + prompt = _select_prompt(category, multi_turn=False) + kwargs: dict[str, str] = { + "question": pair_row.turn_1_question, + "answer_a": pair_row.answer_a_1, + "answer_b": pair_row.answer_b_1, + } + if prompt.ref_based: + kwargs["ref_answer_1"] = pair_row.ref_1 + items.append( + { + "question_id": pair_row.question_id, + "category": category, + "turn": 1, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": kwargs, + } + ) + + if eval_multi and pair_row.turn_2_question: + prompt = _select_prompt(category, multi_turn=True) + kwargs = { + "question_1": pair_row.turn_1_question, + "question_2": pair_row.turn_2_question, + "answer_a_1": pair_row.answer_a_1, + "answer_a_2": pair_row.answer_a_2, + "answer_b_1": pair_row.answer_b_1, + "answer_b_2": pair_row.answer_b_2, + } + if prompt.ref_based: + kwargs["ref_answer_1"] = pair_row.ref_1 + kwargs["ref_answer_2"] = pair_row.ref_2 + items.append( + { + "question_id": pair_row.question_id, + "category": category, + "turn": 2, + "prompt": prompt, + "prompt_name": prompt.name, + "prompt_kwargs": kwargs, + } + ) + return items + + +def _resolve_fastchat_item_result( + *, + item: dict[str, Any], + g1_raw: str, + g2_raw: str | None, + judge_model: str, + model_a: str, + model_b: str, +) -> tuple[dict[str, Any], dict[str, object], float, bool]: + prompt: FastChatPairwisePrompt = item["prompt"] + kwargs = item["prompt_kwargs"] + g1_user_prompt = prompt.user_prompt_template.format(**kwargs) + g1_verdict = _parse_fastchat_verdict(g1_raw) + g1_winner = _map_verdict_to_winner(g1_verdict, swapped=False) + + final_winner = g1_winner + inconsistent = False + annotation_row: dict[str, Any] = { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + "model_A": model_a, + "model_B": model_b, + "judge": judge_model, + "prompt_name": prompt.name, + "system_prompt": prompt.system_prompt, + "g1_user_prompt": g1_user_prompt, + "g1_judgment": g1_raw, + "g1_verdict": g1_verdict, + "g1_winner": g1_winner, + } + + if g2_raw is not None: + g2_verdict = _parse_fastchat_verdict(g2_raw) + g2_winner = _map_verdict_to_winner(g2_verdict, swapped=True) + final_winner, inconsistent = _conservative_winner(g1_winner, g2_winner) + annotation_row.update( + { + "g2_user_prompt": prompt.user_prompt_template.format( + **_swap_prompt_kwargs(kwargs, multi_turn=prompt.multi_turn) + ), + "g2_judgment": g2_raw, + "g2_verdict": g2_verdict, + "g2_winner": g2_winner, + "final_winner": final_winner, + "inconsistent": inconsistent, + } + ) + else: + annotation_row["final_winner"] = final_winner + annotation_row["inconsistent"] = False + + preference = _winner_to_preference(final_winner) + annotation_row["preference"] = preference + metadata = { + "question_id": item["question_id"], + "category": item["category"], + "turn": item["turn"], + } + return annotation_row, metadata, preference, inconsistent + + +def judge_mt_bench_pairwise_fastchat( + *, + judge_chat_model, + judge_model: str, + questions: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + model_a: str, + model_b: str, + turns_mode: str, + swap_mode: str, + truncate_input_chars: int | None, + use_tqdm: bool, +) -> tuple[pd.Series, list[dict[str, Any]], list[dict[str, object]], int]: + """Pairwise MT-Bench judging compatible with FastChat's `[[A]]/[[B]]/[[C]]` format.""" + assert turns_mode in ("both", "single", "multi") + assert swap_mode in ("fixed", "both") + + eval_single = turns_mode in ("both", "single") + eval_multi = turns_mode in ("both", "multi") + + items = _build_fastchat_judge_items( + questions=questions, + completions_a=completions_a, + completions_b=completions_b, + eval_single=eval_single, + eval_multi=eval_multi, + truncate_input_chars=truncate_input_chars, + ) + + g1_judgments = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=False, + ) + + g2_judgments: list[str] | None = None + if swap_mode == "both": + g2_judgments = _infer_by_prompt_groups( + judge_chat_model=judge_chat_model, + items=items, + use_tqdm=use_tqdm, + swap_answers=True, + ) + + annotations: list[dict[str, Any]] = [] + metadata: list[dict[str, object]] = [] + prefs: list[float] = [] + num_inconsistent = 0 + + for idx, item in enumerate(items): + g2_raw = g2_judgments[idx] if g2_judgments is not None else None + annotation_row, item_metadata, preference, inconsistent = ( + _resolve_fastchat_item_result( + item=item, + g1_raw=g1_judgments[idx], + g2_raw=g2_raw, + judge_model=judge_model, + model_a=model_a, + model_b=model_b, + ) + ) + if inconsistent: + num_inconsistent += 1 + annotations.append(annotation_row) + metadata.append(item_metadata) + prefs.append(preference) + + return pd.Series(prefs, dtype=float), annotations, metadata, num_inconsistent diff --git a/judgearena/mt_bench/mt_bench_utils.py b/judgearena/mt_bench/mt_bench_utils.py new file mode 100644 index 0000000..b274f26 --- /dev/null +++ b/judgearena/mt_bench/mt_bench_utils.py @@ -0,0 +1,165 @@ +"""MT-Bench evaluation pipeline. + +Orchestrates multi-turn generation, FastChat-compatible pairwise judging, +and result saving for the MT-Bench benchmark. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING + +import pandas as pd + +from judgearena.eval_utils import _compute_grouped_stats, print_results +from judgearena.generate import generate_multiturn +from judgearena.instruction_dataset import load_instructions +from judgearena.mt_bench.fastchat_compat import ( + FASTCHAT_TEMPERATURE_CONFIG, + judge_mt_bench_pairwise_fastchat, +) +from judgearena.repro import _to_jsonable +from judgearena.utils import cache_function_dataframe, compute_pref_summary, make_model + +if TYPE_CHECKING: + from judgearena.config import CliArgs + + +def _generate_mt_bench_completions( + args: CliArgs, + questions_df: pd.DataFrame, + ignore_cache: bool, +) -> tuple[pd.DataFrame, pd.DataFrame]: + cache_prefix = "mt-bench" + + def _run_generation(model_name: str) -> pd.DataFrame: + return generate_multiturn( + questions=questions_df, + model=model_name, + truncate_input_chars=args.truncate_all_input_chars, + max_tokens=args.max_out_tokens_models, + use_tqdm=args.use_tqdm, + max_model_len=args.max_model_len, + chat_template=args.chat_template, + temperature_config=FASTCHAT_TEMPERATURE_CONFIG, + ) + + completions_a = cache_function_dataframe( + lambda: _run_generation(args.model_A), + ignore_cache=ignore_cache, + cache_name=f"{cache_prefix}_{args.model_A}_{args.n_instructions}", + ).set_index("instruction_index") + + completions_b = cache_function_dataframe( + lambda: _run_generation(args.model_B), + ignore_cache=ignore_cache, + cache_name=f"{cache_prefix}_{args.model_B}_{args.n_instructions}", + ).set_index("instruction_index") + return completions_a, completions_b + + +def _build_mt_bench_result_name(args: CliArgs, suffix: str | None = None) -> str: + name = f"{args.dataset}-{args.model_A}-{args.model_B}-{args.judge_model}" + name += f"-{args.swap_mode}" + if suffix: + name += f"-{suffix}" + return name.replace("/", "_") + + +def _save_mt_bench_results( + *, + args: CliArgs, + results: dict[str, object], + annotations_df: pd.DataFrame, + name_suffix: str | None = None, +) -> None: + name = _build_mt_bench_result_name(args, suffix=name_suffix) + res_folder = Path(args.result_folder) / name + res_folder.mkdir(parents=True, exist_ok=True) + + with open(res_folder / f"args-{name}.json", "w") as f: + json.dump(_to_jsonable(asdict(args)), f, indent=2, allow_nan=False) + + annotations_df.to_csv(res_folder / f"{name}-annotations.csv", index=False) + + with open(res_folder / f"results-{name}.json", "w") as f: + json.dump(_to_jsonable(results), f, indent=2, allow_nan=False) + + +def _run_mt_bench_fastchat( + *, + args: CliArgs, + questions_df: pd.DataFrame, + completions_a: pd.DataFrame, + completions_b: pd.DataFrame, + judge_chat_model, +) -> pd.Series: + prefs, annotations, combined_metadata, num_inconsistent = ( + judge_mt_bench_pairwise_fastchat( + judge_chat_model=judge_chat_model, + judge_model=args.judge_model, + questions=questions_df, + completions_a=completions_a, + completions_b=completions_b, + model_a=args.model_A, + model_b=args.model_B, + turns_mode="both", + swap_mode=args.swap_mode, + truncate_input_chars=args.truncate_all_input_chars, + use_tqdm=args.use_tqdm, + ) + ) + + stats = compute_pref_summary(prefs) + results = { + "dataset": args.dataset, + "model_A": args.model_A, + "model_B": args.model_B, + "judge_model": args.judge_model, + "num_inconsistent": num_inconsistent, + **stats, + "per_category": _compute_grouped_stats(prefs, combined_metadata, "category"), + "per_turn": _compute_grouped_stats(prefs, combined_metadata, "turn"), + "preferences": prefs.tolist(), + "date": str(datetime.now().isoformat()), + "user": os.getenv("USER", ""), + } + print_results(results) + _save_mt_bench_results( + args=args, + results=results, + annotations_df=pd.DataFrame(annotations), + name_suffix="mtbench", + ) + return prefs + + +def run_mt_bench(args: CliArgs, ignore_cache: bool): + """MT-Bench pipeline with FastChat-compatible pairwise judging.""" + questions_df = load_instructions("mt-bench", n_instructions=args.n_instructions) + print( + f"Generating multi-turn completions for MT-Bench with {args.model_A} and {args.model_B}." + ) + completions_a, completions_b = _generate_mt_bench_completions( + args=args, + questions_df=questions_df, + ignore_cache=ignore_cache, + ) + judge_chat_model = make_model( + model=args.judge_model, + max_tokens=args.max_out_tokens_judge, + temperature=0.0, + max_model_len=args.max_model_len, + chat_template=args.chat_template, + ) + return _run_mt_bench_fastchat( + args=args, + questions_df=questions_df, + completions_a=completions_a, + completions_b=completions_b, + judge_chat_model=judge_chat_model, + ) diff --git a/judgearena/prompts/mt_bench/system-base.txt b/judgearena/prompts/mt_bench/system-base.txt new file mode 100644 index 0000000..b4aff2e --- /dev/null +++ b/judgearena/prompts/mt_bench/system-base.txt @@ -0,0 +1 @@ +Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user {user_subject}. {task_description} {focus_line}Begin your evaluation by {begin_instruction}. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: "[[A]]" if assistant A is better, "[[B]]" if assistant B is better, and "[[C]]" for a tie. diff --git a/judgearena/prompts/mt_bench/user-multi-base.txt b/judgearena/prompts/mt_bench/user-multi-base.txt new file mode 100644 index 0000000..33abb79 --- /dev/null +++ b/judgearena/prompts/mt_bench/user-multi-base.txt @@ -0,0 +1,32 @@ +{reference_block}<|The Start of Assistant A's Conversation with User|> + +### User: +{{question_1}} + +### Assistant A: +{{answer_a_1}} + +### User: +{{question_2}} + +### Assistant A: +{{answer_a_2}} + +<|The End of Assistant A's Conversation with User|> + + +<|The Start of Assistant B's Conversation with User|> + +### User: +{{question_1}} + +### Assistant B: +{{answer_b_1}} + +### User: +{{question_2}} + +### Assistant B: +{{answer_b_2}} + +<|The End of Assistant B's Conversation with User|> diff --git a/judgearena/prompts/mt_bench/user-multi-reference-block.txt b/judgearena/prompts/mt_bench/user-multi-reference-block.txt new file mode 100644 index 0000000..04e1e50 --- /dev/null +++ b/judgearena/prompts/mt_bench/user-multi-reference-block.txt @@ -0,0 +1,15 @@ +<|The Start of Reference Answer|> + +### User: +{question_1} + +### Reference answer: +{ref_answer_1} + +### User: +{question_2} + +### Reference answer: +{ref_answer_2} + +<|The End of Reference Answer|> diff --git a/judgearena/prompts/mt_bench/user-single-base.txt b/judgearena/prompts/mt_bench/user-single-base.txt new file mode 100644 index 0000000..ee7701c --- /dev/null +++ b/judgearena/prompts/mt_bench/user-single-base.txt @@ -0,0 +1,10 @@ +[User Question] +{{question}} + +{reference_block}[The Start of Assistant A's Answer] +{{answer_a}} +[The End of Assistant A's Answer] + +[The Start of Assistant B's Answer] +{{answer_b}} +[The End of Assistant B's Answer] diff --git a/judgearena/prompts/mt_bench/user-single-reference-block.txt b/judgearena/prompts/mt_bench/user-single-reference-block.txt new file mode 100644 index 0000000..cf18c90 --- /dev/null +++ b/judgearena/prompts/mt_bench/user-single-reference-block.txt @@ -0,0 +1,3 @@ +[The Start of Reference Answer] +{ref_answer_1} +[The End of Reference Answer] diff --git a/judgearena/prompts/prompt.txt b/judgearena/prompts/prompt.txt index 21d2e48..38021e6 100644 --- a/judgearena/prompts/prompt.txt +++ b/judgearena/prompts/prompt.txt @@ -1,13 +1,13 @@ <|User Prompt|> {user_prompt} -<|The Start of Assistant A's Answer|> +<|The Start of Assistant A's {completion_label}|> {completion_A} -<|The End of Assistant A's Answer|> +<|The End of Assistant A's {completion_label}|> -<|The Start of Assistant B's Answer|> +<|The Start of Assistant B's {completion_label}|> {completion_B} -<|The End of Assistant B's Answer|> +<|The End of Assistant B's {completion_label}|> # Your output @@ -18,5 +18,4 @@ score_A: ``` -## Your output, do not repeat the input above -``` +## Your output, do not repeat the input above{explanation_suffix} diff --git a/judgearena/utils.py b/judgearena/utils.py index dff266d..4ecd801 100644 --- a/judgearena/utils.py +++ b/judgearena/utils.py @@ -49,6 +49,23 @@ def read_df(filename: Path, **pandas_kwargs) -> pd.DataFrame: return pd.read_parquet(filename, **pandas_kwargs) +def truncate(s: str, max_len: int | None = None) -> str: + if not isinstance(s, str): + return "" + if max_len is not None: + return s[:max_len] + return s + + +def safe_text(value: object, truncate_chars: int | None) -> str: + if value is None: + return "" + is_missing = pd.isna(value) + if isinstance(is_missing, bool) and is_missing: + return "" + return truncate(str(value), max_len=truncate_chars) + + def compute_pref_summary(prefs: pd.Series) -> dict[str, float | int]: """Compute win/loss/tie stats for preference series (0=A, 0.5=tie, 1=B).""" prefs = pd.Series(prefs, dtype="float64") @@ -416,6 +433,10 @@ def download_all(): force_download=False, ) + from judgearena.instruction_dataset.mt_bench import download_mt_bench + + download_mt_bench() + class Timeblock: """Timer context manager""" diff --git a/slurmpilot_scripts/launch_mt_bench_smoke.py b/slurmpilot_scripts/launch_mt_bench_smoke.py new file mode 100644 index 0000000..5f833c4 --- /dev/null +++ b/slurmpilot_scripts/launch_mt_bench_smoke.py @@ -0,0 +1,53 @@ +from pathlib import Path + +from slurmpilot import JobCreationInfo, SlurmPilot, unify + +CLUSTER = "kislurm" +REMOTE_PROJECT_ROOT = Path("/work/dlclarge1/lushtake-hiwi/JudgeArena") +LOCAL_PROJECT_ROOT = Path(__file__).resolve().parent.parent +PYTHON_BINARY = REMOTE_PROJECT_ROOT / ".venv" / "bin" / "python" +ENTRYPOINT = "generate_and_evaluate.py" +SRC_DIR = str(LOCAL_PROJECT_ROOT / "judgearena") + +# Use L40S partitions from the all_dlc / ml_dlc families. +# For this cluster/account, mldlc2 and testdlc2 are available. +PARTITION_ALL_DLC_L40S = "testdlc2_gpu-l40s" +PARTITION_ML_DLC_L40S = "mldlc2_gpu-l40s" + + +def submit_smoke_job(partition: str = PARTITION_ALL_DLC_L40S) -> tuple[str, str, int]: + slurm = SlurmPilot(clusters=[CLUSTER]) + dataset = "mt-bench" + jobname = unify("mt-bench-smoke/fastchat-canonical", method="date") + + job_info = JobCreationInfo( + cluster=CLUSTER, + partition=partition, + jobname=jobname, + entrypoint=ENTRYPOINT, + python_binary=str(PYTHON_BINARY), + python_args={ + "dataset": dataset, + "model_A": "Dummy/no_answer", + "model_B": "Dummy/open_answer", + "judge_model": "Dummy/[[A]]", + "n_instructions": 1, + }, + src_dir=SRC_DIR, + n_cpus=1, + max_runtime_minutes=15, + env={ + "HF_HUB_OFFLINE": "0", + "OPENJURY_DATA": "/work/dlclarge1/lushtake-hiwi/judgearena-data", + }, + ) + job_id = slurm.schedule_job(job_info) + print(f"Submitted {dataset}: jobname={job_info.jobname}, job_id={job_id}") + return dataset, job_info.jobname, job_id + + +if __name__ == "__main__": + print(f"Using LOCAL_PROJECT_ROOT={LOCAL_PROJECT_ROOT}") + print(f"Using REMOTE_PROJECT_ROOT={REMOTE_PROJECT_ROOT}") + print(f"Using PYTHON_BINARY={PYTHON_BINARY}") + submit_smoke_job(partition=PARTITION_ALL_DLC_L40S) diff --git a/tests/test_mt_bench_downloads.py b/tests/test_mt_bench_downloads.py new file mode 100644 index 0000000..75851a8 --- /dev/null +++ b/tests/test_mt_bench_downloads.py @@ -0,0 +1,64 @@ +import judgearena.instruction_dataset.mt_bench as mt_bench +import judgearena.utils as utils + + +def test_download_mt_bench_skips_question_download_if_cached(tmp_path, monkeypatch): + question_path = tmp_path / "data" / "mt_bench" / "question.jsonl" + question_path.parent.mkdir(parents=True, exist_ok=True) + question_path.write_text('{"question_id": 1, "turns": ["Q1"]}\n') + + reference_path = tmp_path / "reference_answer" / "gpt-4.jsonl" + reference_path.parent.mkdir(parents=True, exist_ok=True) + reference_path.write_text('{"question_id": 1, "choices": [{"turns": ["A1"]}]}\n') + + calls = {"snapshot_download": 0} + + def _snapshot_download_stub(**_kwargs): + calls["snapshot_download"] += 1 + + monkeypatch.setattr(mt_bench, "snapshot_download", _snapshot_download_stub) + monkeypatch.setattr( + mt_bench, + "_download_gpt4_references", + lambda _local_dir: reference_path, + ) + + downloaded_question_path, downloaded_reference_path = mt_bench.download_mt_bench( + local_dir=tmp_path + ) + + assert downloaded_question_path == question_path + assert downloaded_reference_path == reference_path + assert calls["snapshot_download"] == 0 + + +def test_download_all_includes_mt_bench(tmp_path, monkeypatch): + hf_datasets = [] + calls = {"contexts": 0, "mt_bench": 0} + + monkeypatch.setattr(utils, "data_root", tmp_path) + monkeypatch.setattr( + utils, + "download_hf", + lambda name, local_path: hf_datasets.append((name, local_path)), + ) + + def _contexts_snapshot_stub(**_kwargs): + calls["contexts"] += 1 + + monkeypatch.setattr(utils, "snapshot_download", _contexts_snapshot_stub) + monkeypatch.setattr( + mt_bench, + "download_mt_bench", + lambda: calls.__setitem__("mt_bench", calls["mt_bench"] + 1), + ) + + utils.download_all() + + assert [name for name, _ in hf_datasets] == [ + "alpaca-eval", + "arena-hard", + "m-arena-hard", + ] + assert calls["contexts"] == 1 + assert calls["mt_bench"] == 1