-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstructo_eval.py
More file actions
executable file
·295 lines (234 loc) · 10.9 KB
/
structo_eval.py
File metadata and controls
executable file
·295 lines (234 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#!/usr/bin/env python3
"""
structo_eval.py - Evaluation script for structured output research.
Evaluate language models on structured output datasets using lm-evaluation-harness
with LiteLLM for model calls and Hugging Face datasets for data loading.
Usage:
python structo_eval.py <lm_eval_task_name> <model_spec>
Where <model_spec> is of the form "<wrapper>+<base-model>", for example:
- natlang+gpt-4o
- structo+gpt-4o
Examples:
python structo_eval.py bbh_zeroshot_tracking_shuffled_objects_five_objects natlang+gpt-4o
python structo_eval.py bbh_zeroshot_tracking_shuffled_objects_five_objects structo+gpt-4o
python structo_eval.py gsm8k_main natlang+gpt-4o --limit 5
python structo_eval.py unscramble_random_insertion structo+gpt-4o
"""
import argparse
import sys
import time
from pathlib import Path
import dotenv
from src.structo.dataset_loader import load_evaluation_dataset, get_dataset_info, get_preferred_match_filter
from src.structo.results import display_and_save_results
from src.structo.lmeval import evaluate_with_filters
from src.structo.integration_factory import parse_model_identifier, get_supported_wrappers, get_all_wrappers
dotenv.load_dotenv()
def parse_arguments():
"""
Parse command-line arguments.
Returns:
argparse.Namespace with dataset_name and model_identifier
"""
parser = argparse.ArgumentParser(
description="Evaluate language models on structured output datasets.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s bbh_zeroshot_tracking_shuffled_objects_five_objects natlang+gpt-4o
%(prog)s gsm8k_main structo+gpt-4o
%(prog)s unscramble_random_insertion natlang+gpt-4o --limit 5
%(prog)s bbh_zeroshot_tracking_shuffled_objects_five_objects --all --base-model gpt-4o
For more information, see README.md
"""
)
parser.add_argument(
'lm_eval_task_name',
help='lm-eval task name (e.g., bbh_zeroshot_tracking_shuffled_objects_five_objects, gsm8k_main, unscramble_random_insertion)'
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
'model_identifier',
nargs='?',
help='Model specification in the form "<wrapper>+<base-model>" '
'(e.g., natlang+gpt-4o). Required unless --all is used.'
)
group.add_argument(
'--all',
action='store_true',
help='Run evaluation for all integrations (excludes optional integrations by default)'
)
parser.add_argument(
'--base-model',
help='Base model to use when --all flag is specified (e.g., gpt-4o)'
)
parser.add_argument(
'--include-optional',
action='store_true',
help='Include optional integrations when using --all flag (structo_precog, structo_small_steps)'
)
parser.add_argument(
'--limit',
type=int,
default=None,
help='Limit the number of examples to evaluate (default: None, evaluates all examples)'
)
return parser.parse_args()
def extract_accuracy(results):
"""
Extract accuracy from evaluation results.
Looks for 'exact_match,flexible-extract' metric in the results.
Args:
results: Dictionary with evaluation results from lm-evaluation-harness
Returns:
float | None: Accuracy value if found, None otherwise
"""
if 'results' not in results:
return None
for task, metrics in results['results'].items():
if isinstance(metrics, dict):
accuracy_key = "exact_match,flexible-extract"
if accuracy_key in metrics:
return metrics[accuracy_key]
return None
def main():
"""Main entry point for the evaluation script."""
args = parse_arguments()
lm_eval_task_name = args.lm_eval_task_name
limit = args.limit
start_time = time.time()
try:
# Handle --all flag
if args.all:
if not args.base_model:
print("Error: --base-model is required when using --all flag", file=sys.stderr)
return 1
base_model = args.base_model.strip()
wrappers = get_all_wrappers(include_optional=args.include_optional)
print(f"Running evaluation for {len(wrappers)} integrations with base model: {base_model}")
if args.include_optional:
print("Including optional integrations")
else:
print("Excluding optional integrations (use --include-optional to include them)")
# Load dataset for info display
print(f"\nLoading dataset for task: {lm_eval_task_name}")
dataset = load_evaluation_dataset(lm_eval_task_name)
dataset_info = get_dataset_info(dataset)
print(f"Dataset loaded: {dataset_info['num_examples']} examples")
if limit:
print(f"Limiting evaluation to {limit} examples")
# Get the preferred match filter
preferred_filter = get_preferred_match_filter(lm_eval_task_name)
if preferred_filter:
print(f"Using preferred match filter: {preferred_filter}")
# Run evaluation for each wrapper
all_results = {}
for wrapper in wrappers:
model_identifier = f"{wrapper}+{base_model}"
print(f"\n{'='*70}")
print(f"Evaluating: {model_identifier}")
print('='*70)
try:
results = evaluate_with_filters(
lm_eval_task_name=lm_eval_task_name,
base_model=base_model,
wrapper=wrapper,
filter_names=[preferred_filter] if preferred_filter else None,
limit=limit,
)
duration = time.time() - start_time
print(f"Evaluation complete for {model_identifier}!")
# Display and save results for this integration
display_and_save_results(results, lm_eval_task_name, model_identifier, duration)
# Store results for summary
all_results[wrapper] = {
'results': results,
'model_identifier': model_identifier,
'duration': duration,
}
except Exception as e:
print(f"Error evaluating {model_identifier}: {e}", file=sys.stderr)
all_results[wrapper] = {
'error': str(e),
'model_identifier': model_identifier,
}
# Print summary of all integrations
total_duration = time.time() - start_time
print(f"\n{'='*70}")
print("SUMMARY - All Integrations".center(70))
print('='*70)
print(f"\nDataset: {lm_eval_task_name}")
print(f"Base Model: {base_model}")
print(f"Total Duration: {total_duration:.1f}s ({total_duration/60:.1f} minutes)")
print("-"*70)
print(f"{'Integration':<30} {'Accuracy':>15} {'Status':>15}")
print("-"*70)
for wrapper in wrappers:
if wrapper in all_results:
result_data = all_results[wrapper]
if 'error' in result_data:
print(f"{wrapper:<30} {'ERROR':>15} {'Failed':>15}")
else:
accuracy = extract_accuracy(result_data['results'])
if accuracy is not None:
accuracy_str = f"{accuracy:.2%}"
print(f"{wrapper:<30} {accuracy_str:>15} {'Success':>15}")
else:
print(f"{wrapper:<30} {'N/A':>15} {'Success':>15}")
print("-"*70)
print(f"\n{'='*70}")
return 0
# Single integration evaluation (original behavior)
model_identifier = args.model_identifier
if not model_identifier:
print("Error: model_identifier is required when not using --all flag", file=sys.stderr)
return 1
# Validate and parse model identifier using the central factory helper.
print("Validating model identifier...")
wrapper, base_model = parse_model_identifier(model_identifier)
# Load dataset for info display (optional, but useful for user feedback)
print(f"Loading dataset for task: {lm_eval_task_name}")
dataset = load_evaluation_dataset(lm_eval_task_name)
dataset_info = get_dataset_info(dataset)
print(f"Dataset loaded: {dataset_info['num_examples']} examples")
if limit:
print(f"Limiting evaluation to {limit} examples")
# Get the preferred match filter from TASK_PREFIX_DATASET_SOURCE
preferred_filter = get_preferred_match_filter(lm_eval_task_name)
if preferred_filter:
print(f"Using preferred match filter: {preferred_filter}")
# Delegate to Structo's lm-eval helper which builds a task_dict,
# applies the preferred filter, and calls the low-level evaluate() API.
results = evaluate_with_filters(
lm_eval_task_name=lm_eval_task_name,
base_model=base_model,
wrapper=wrapper,
filter_names=[preferred_filter] if preferred_filter else None,
limit=limit,
)
duration = time.time() - start_time
print("Evaluation complete!")
# Display and save results
display_and_save_results(results, lm_eval_task_name, model_identifier, duration)
return 0
except KeyboardInterrupt:
print("\n\nEvaluation interrupted by user.")
return 130
except Exception as e:
# Handle specific error types
error_type = type(e).__name__
if "DatasetNotFoundError" in error_type or "dataset" in str(e).lower():
print(f"\nError: Dataset not found for task '{lm_eval_task_name}'.", file=sys.stderr)
print(f"Please verify the task name matches an entry in TASK_PREFIX_DATASET_SOURCE.", file=sys.stderr)
elif "APIError" in error_type or "api" in str(e).lower():
print(f"\nError: API error with model '{model_identifier}'.", file=sys.stderr)
print(f"Please check your API credentials and model identifier.", file=sys.stderr)
elif "RateLimitError" in error_type or "rate limit" in str(e).lower():
print(f"\nError: Rate limit exceeded for model '{model_identifier}'.", file=sys.stderr)
print(f"Please wait and try again later.", file=sys.stderr)
elif "AuthenticationError" in error_type or "authentication" in str(e).lower():
print(f"\nError: Authentication failed for model '{model_identifier}'.", file=sys.stderr)
print(f"Please check your API credentials (e.g., OPENAI_API_KEY).", file=sys.stderr)
raise e
if __name__ == "__main__":
sys.exit(main())