Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 66 additions & 12 deletions judgearena/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@
download_arena_hard,
is_arena_hard_dataset,
)
from judgearena.openrouter_reference_pricing import (
OpenRouterReferencePricingTracker,
build_openrouter_reference_pricing_summary,
format_openrouter_reference_pricing_summary,
)
from judgearena.repro import _to_jsonable, write_run_metadata
from judgearena.utils import (
compute_pref_summary,
data_root,
do_inference,
download_hf,
infer_model_spec_from_instance,
read_df,
strip_thinking_tags,
truncate,
)

Expand All @@ -36,6 +43,7 @@ def preference_from_scores(self, score_a: float, score_b: float) -> float:
)

def parse_model_raw(self, judge_completion: str) -> float | None:
judge_completion = strip_thinking_tags(judge_completion)
# lower case to avoid confusion, e.g. when "a" is used instead of "A"
score_a = self.get_regexp_match(
judge_completion.lower(), r'score.*?a[": *\n]*(-?\d+)'
Expand All @@ -44,7 +52,15 @@ def parse_model_raw(self, judge_completion: str) -> float | None:
judge_completion.lower(), r'score.*?b[": *\n]*(-?\d+)'
)
if score_a is None or score_b is None:
return None
verdict_match = re.search(r"\[\[\s*([ABCabc])\s*\]\]", judge_completion)
if verdict_match is None:
return None
bracketed_verdict = verdict_match.group(1).lower()
return {
"a": 0.0,
"b": 1.0,
"c": 0.5,
}[bracketed_verdict]
else:
return float(self.preference_from_scores(score_a, score_b))

Expand Down Expand Up @@ -81,7 +97,6 @@ def load_judge_system_and_user_prompt(
"{explanation_suffix}",
_EXPLANATION_SUFFIX if provide_explanation else _SCORE_FENCE,
)

return system_prompt, user_prompt_template


Expand Down Expand Up @@ -183,6 +198,8 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):
from langchain_together.llms import Together

judge_chat_model = Together(model="meta-llama/Llama-3.3-70B-Instruct-Turbo")
judge_model_spec = infer_model_spec_from_instance(judge_chat_model)
usage_tracker = OpenRouterReferencePricingTracker()

unique_string = dataset + "-" + datetime.now().strftime("%Y%m%d_%H%M%S")
output_folder = data_root / "judge-evals" / unique_string
Expand All @@ -203,6 +220,9 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):
use_tqdm=use_tqdm,
truncate_input_chars=truncate_input_chars,
provide_explanation=provide_explanation,
usage_tracker=usage_tracker,
usage_phase="judge",
usage_model_spec=judge_model_spec,
)

# Pairwise judge results
Expand All @@ -219,6 +239,13 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):
print(f"{method_A} against {method_B}:\n{results}")
with open(output_folder / "results.json", "w") as f:
json.dump(_to_jsonable(results), f, allow_nan=False)
pricing_reference = None
if judge_model_spec is not None:
pricing_reference = build_openrouter_reference_pricing_summary(
tracker=usage_tracker,
phase_model_specs={"judge": judge_model_spec},
)
print(format_openrouter_reference_pricing_summary(pricing_reference))

run_metadata = {
"dataset": dataset,
Expand Down Expand Up @@ -246,6 +273,7 @@ def get_output(df_outputs: pd.DataFrame, dataset: str, method: str):
judge_system_prompt=judge_system_prompt,
judge_user_prompt_template=judge_user_prompt_template,
started_at_utc=run_started_at,
pricing_reference=pricing_reference,
)
except OSError as e:
print(f"Warning: failed to write run metadata: {e}")
Expand All @@ -270,6 +298,9 @@ def annotate_battles(
truncate_input_chars: int | None = 8192,
use_tqdm: bool = False,
provide_explanation: bool = False,
usage_tracker: OpenRouterReferencePricingTracker | None = None,
usage_phase: str | None = None,
usage_model_spec: str | None = None,
) -> list[JudgeAnnotation]:
"""
Directly evaluate from list of instructions and completions
Expand Down Expand Up @@ -311,24 +342,38 @@ def annotate_battles(
prompt_template = ChatPromptTemplate.from_messages(
[("system", system_prompt), ("user", user_prompt_template)]
)

inputs = prompt_template.batch(
[
truncated_completion_count = 0
input_payloads = []
for user_prompt, completion_A, completion_B in zip(
instructions, completions_A, completions_B, strict=True
):
truncated_completion_A = truncate(completion_A, max_len=truncate_input_chars)
truncated_completion_B = truncate(completion_B, max_len=truncate_input_chars)
truncated_completion_count += int(truncated_completion_A != completion_A)
truncated_completion_count += int(truncated_completion_B != completion_B)
input_payloads.append(
{
"user_prompt": user_prompt,
"completion_A": truncate(completion_A, max_len=truncate_input_chars),
"completion_B": truncate(completion_B, max_len=truncate_input_chars),
"completion_A": truncated_completion_A,
"completion_B": truncated_completion_B,
}
for user_prompt, completion_A, completion_B in zip(
instructions, completions_A, completions_B, strict=True
)
]
)
)
if truncated_completion_count:
print(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging for a follow-up PR: the codebase mixes print() for warnings, progress, and debug info, making it hard to filter by severity or redirect output. We should migrate to Python's logging module (or at minimum a thin wrapper like logger = logging.getLogger(__name__)). What do you think @geoalgo

"Warning: truncated "
f"{truncated_completion_count} judge completions to "
f"{truncate_input_chars} characters before evaluation."
)
inputs = prompt_template.batch(input_payloads)

print(f"Start LLM judge annotation ({len(inputs)} annotations).")
judge_completions = do_inference(
chat_model=judge_chat_model,
inputs=inputs,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=usage_model_spec,
)

annotations = []
Expand Down Expand Up @@ -363,6 +408,9 @@ def judge_and_parse_prefs(
user_prompt_template: str | None = None,
truncate_input_chars: int = 8192,
use_tqdm: bool = False,
usage_tracker: OpenRouterReferencePricingTracker | None = None,
usage_phase: str | None = None,
usage_model_spec: str | None = None,
) -> tuple[list[JudgeAnnotation], list[JudgeAnnotation] | None, pd.Series]:
"""Run judge annotation and parse preferences, handling swap_mode='both'.

Expand All @@ -388,6 +436,9 @@ def judge_and_parse_prefs(
user_prompt_template=user_prompt_template,
truncate_input_chars=truncate_input_chars,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=usage_model_spec,
)

annotations_reversed = None
Expand All @@ -402,6 +453,9 @@ def judge_and_parse_prefs(
user_prompt_template=user_prompt_template,
truncate_input_chars=truncate_input_chars,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=usage_model_spec,
)

def _none_to_nan(x):
Expand Down
36 changes: 32 additions & 4 deletions judgearena/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def generate_instructions(
max_tokens: int | None = 32768,
use_tqdm: bool = True,
system_prompt: str | None = None,
usage_tracker=None,
usage_phase: str | None = None,
**engine_kwargs,
) -> pd.DataFrame:
chat_model = make_model(model, max_tokens=max_tokens, **engine_kwargs)
Expand All @@ -41,6 +43,9 @@ def generate_instructions(
chat_model=chat_model,
inputs=inputs,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=model,
)
df_outputs = pd.DataFrame(
data={
Expand Down Expand Up @@ -69,6 +74,8 @@ def _infer_grouped_by_temperature(
inputs: list,
temperatures: list[float],
use_tqdm: bool,
usage_tracker=None,
usage_phase: str | None = None,
) -> list[str]:
outputs: list[str] = [""] * len(inputs)
groups: dict[float, list[int]] = {}
Expand All @@ -91,6 +98,9 @@ def _infer_grouped_by_temperature(
chat_model=group_model,
inputs=group_inputs,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=model_spec,
)
for i, out in zip(idxs, group_outs, strict=True):
outputs[i] = out
Expand All @@ -105,6 +115,8 @@ def generate_multiturn(
max_tokens: int | None = 8192,
use_tqdm: bool = True,
temperature_config: dict[str, float] | None = None,
usage_tracker=None,
usage_phase: str | None = None,
**model_kwargs,
) -> pd.DataFrame:
"""Generate two-turn completions for MT-Bench style questions."""
Expand Down Expand Up @@ -148,12 +160,17 @@ def generate_multiturn(
inputs=turn1_inputs,
temperatures=temperatures,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
)
else:
completions_turn_1 = do_inference(
chat_model=chat_model,
inputs=turn1_inputs,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=model,
)

turn2_inputs = []
Expand Down Expand Up @@ -195,12 +212,17 @@ def generate_multiturn(
inputs=turn2_inputs,
temperatures=temperatures,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
)
else:
completions_turn_2 = do_inference(
chat_model=chat_model,
inputs=turn2_inputs,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=model,
)

return pd.DataFrame(
Expand All @@ -218,20 +240,26 @@ def generate_base(
truncate_input_chars: int | None = 8192,
max_tokens: int | None = 32768,
use_tqdm: bool = False,
usage_tracker=None,
usage_phase: str | None = None,
**engine_kwargs,
) -> pd.DataFrame:
model = make_model(model, max_tokens=max_tokens, **engine_kwargs)
model_spec = model
model = make_model(model_spec, max_tokens=max_tokens, **engine_kwargs)

inputs = [
truncate(instruction, max_len=truncate_input_chars)
for instruction in instructions
]

completions = model.batch(
completions = do_inference(
chat_model=model,
inputs=inputs,
max_tokens=max_tokens,
use_tqdm=use_tqdm,
usage_tracker=usage_tracker,
usage_phase=usage_phase,
usage_model_spec=model_spec,
)
completions = [x.content if hasattr(x, "content") else x for x in completions]

df_outputs = pd.DataFrame(
data={
Expand Down
Loading
Loading