diff --git a/docs/user_manual/configure.rst b/docs/user_manual/configure.rst index 4bfb8a67..f1cbf9cd 100644 --- a/docs/user_manual/configure.rst +++ b/docs/user_manual/configure.rst @@ -253,7 +253,7 @@ Underneath you can find the list of all the available datasets. - ``text: str`` * - Image Generation - `LAION256 `_, `OpenImage `_, `COCO `_, `DrawBench `_, `PartiPrompts `_, `GenAIBench `_ - - ``image_generation_collate``, ``prompt_collate`` + - ``image_generation_collate``, ``prompt_with_auxiliaries_collate`` - ``text: str``, ``image: Optional[PIL.Image.Image]`` * - Image Classification - `ImageNet `_, `MNIST `_, `CIFAR10 `_ diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..373e7e78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ dependencies = [ "peft>=0.18.0", "trl<=0.21.0", "termcolor==2.3.0", - "realesrgan" + "realesrgan", ] [project.optional-dependencies] @@ -165,6 +165,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +evaluation = [ + "outlines>1.2.0,<2.0.0", + "litellm>=1.0.0", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", @@ -217,6 +221,7 @@ dev = [ "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", + "pruna[evaluation]", ] cpu = [] lmharness = [ diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index fd14a496..1f0ed5f6 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -34,7 +34,13 @@ setup_hps_dataset, setup_imgedit_dataset, setup_long_text_bench_dataset, + setup_oneig_anime_stylization_dataset, setup_oneig_dataset, + setup_oneig_general_object_dataset, + setup_oneig_knowledge_reasoning_dataset, + setup_oneig_multilingualism_dataset, + setup_oneig_portrait_dataset, + setup_oneig_text_rendering_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -103,19 +109,33 @@ "image_classification_collate", {"img_size": 224}, ), - "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), + "DrawBench": (setup_drawbench_dataset, "prompt_with_auxiliaries_collate", {}), "PartiPrompts": ( setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}, ), - "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "GenAIBench": (setup_genai_bench_dataset, "prompt_with_auxiliaries_collate", {}), "GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}), "HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}), "ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}), "LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}), "GEditBench": (setup_gedit_dataset, "prompt_with_auxiliaries_collate", {}), "OneIG": (setup_oneig_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGAnimeStylization": ( + setup_oneig_anime_stylization_dataset, + "prompt_with_auxiliaries_collate", + {}, + ), + "OneIGGeneralObject": (setup_oneig_general_object_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGKnowledgeReasoning": ( + setup_oneig_knowledge_reasoning_dataset, + "prompt_with_auxiliaries_collate", + {}, + ), + "OneIGMultilingualism": (setup_oneig_multilingualism_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGPortrait": (setup_oneig_portrait_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), "DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 7764d23b..c1118c87 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -123,21 +123,92 @@ DPGCategory = Literal["entity", "attribute", "relation", "global", "other"] -def _to_oneig_record(row: dict, questions_by_key: dict[str, dict]) -> dict: - """Convert OneIG row to unified record format.""" +def _warn_ignored_benchmark_seed(seed: int | None, *, dataset: str) -> None: + if seed is not None: + pruna_logger.warning( + "%s: `seed` is ignored for this test-only benchmark; sampling does not shuffle the test split.", + dataset, + ) + + +def _oneig_alignment_language_zh(row: dict) -> bool: + """Return True when the official Q_D file for this row should use the ``*_zh`` graphs.""" + row_category = row.get("category", "") + if row_category == "Multilingualism": + return True + lang = row.get("language") or row.get("lang") + if isinstance(lang, str) and lang.lower() in {"zh", "zh-cn", "zh_cn", "chinese", "cn"}: + return True + if row.get("prompt_zh"): + return True + prompt = row.get("prompt") + prompt_en = row.get("prompt_en") + return bool(prompt and not (isinstance(prompt_en, str) and prompt_en.strip())) + + +def _oneig_qd_prefix(row: dict) -> str: + """Map dataset ``category`` (+ language) to Q_D JSON stem (e.g. ``object``, ``anime_zh``).""" + row_category = row.get("category", "") + use_zh = _oneig_alignment_language_zh(row) + if row_category == "Multilingualism": + return "multilingualism_zh" + base = _CATEGORY_TO_QD.get(row_category, "") + if not base: + return "" + return f"{base}_zh" if use_zh else base + + +def _to_oneig_record( + row: dict, + questions_by_key: dict[str, dict], + reasoning_gt_en: dict[str, str], + reasoning_gt_zh: dict[str, str], + reasoning_language: str = "EN", +) -> dict: + """Convert OneIG row to unified record format. + + Parameters + ---------- + row : dict + Raw Hugging Face row (``category``, ``id``, ``class``). EN configs use ``prompt_en``; the + ``OneIG-Bench-ZH`` **Multilingualism** split uses ``prompt_cn`` instead of ``prompt_en``. + questions_by_key : dict[str, dict] + Merged Q_D index keyed as ``{qd_stem}_{prompt_id}`` (see ``_fetch_oneig_alignment``). + reasoning_gt_en : dict[str, str] + Official ``gt_answer.json`` keyed by prompt id (e.g. ``"000"``). + reasoning_gt_zh : dict[str, str] + Official ``gt_answer_zh.json`` keyed by prompt id. + reasoning_language : str, optional + Which reasoning GT to use: ``"EN"`` or ``"ZH"``. Default is ``"EN"``. + + Returns + ------- + dict + Unified record including ``questions``, ``dependencies``, and ``reasoning_gt_answer`` when + applicable (Knowledge_Reasoning only). + """ row_category = row.get("category", "") row_class = row.get("class", "None") or "None" - qd_name = _CATEGORY_TO_QD.get(row_category, "") - lookup_key = f"{qd_name}_{row.get('id', '')}" if qd_name else "" + prompt_id = str(row.get("id", "")) + qd_prefix = _oneig_qd_prefix(row) + lookup_key = f"{qd_prefix}_{prompt_id}" if qd_prefix else "" q_info = questions_by_key.get(lookup_key, {}) + text = row.get("prompt") or row.get("prompt_en") or row.get("prompt_cn") or "" + reasoning_gt_answer: str | None = None + if row_category == "Knowledge_Reasoning": + if reasoning_language.upper() == "ZH": + reasoning_gt_answer = reasoning_gt_zh.get(prompt_id) + else: + reasoning_gt_answer = reasoning_gt_en.get(prompt_id) return { - "text": row.get("prompt_en", row.get("prompt", "")), + "text": text, "subset": "Text_Rendering" if row_category in ("Text_Rendering", "Text Rendering") else row_category, "text_content": row_class if row_class != "None" else None, "category": row_category, "class": row_class, "questions": q_info.get("questions", {}), "dependencies": q_info.get("dependencies", {}), + "reasoning_gt_answer": reasoning_gt_answer, } @@ -159,7 +230,7 @@ def setup_drawbench_dataset() -> Tuple[Dataset, Dataset, Dataset]: def setup_parti_prompts_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -172,8 +243,8 @@ def setup_parti_prompts_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -188,6 +259,7 @@ def setup_parti_prompts_dataset( Tuple[Dataset, Dataset, Dataset] The Parti Prompts dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="PartiPrompts") ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] if category is not None: @@ -226,7 +298,7 @@ def _generate_geneval_question(entry: dict) -> list[str]: def setup_geneval_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -239,8 +311,8 @@ def setup_geneval_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -255,6 +327,7 @@ def setup_geneval_dataset( Tuple[Dataset, Dataset, Dataset] The GenEval dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GenEval") import json import requests @@ -286,7 +359,7 @@ def setup_geneval_dataset( def setup_hps_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -299,8 +372,8 @@ def setup_hps_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -315,6 +388,7 @@ def setup_hps_dataset( Tuple[Dataset, Dataset, Dataset] The HPD dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="HPS") import json from huggingface_hub import hf_hub_download @@ -338,7 +412,7 @@ def setup_hps_dataset( def setup_long_text_bench_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -350,8 +424,8 @@ def setup_long_text_bench_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -364,6 +438,7 @@ def setup_long_text_bench_dataset( Tuple[Dataset, Dataset, Dataset] The Long Text Bench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="LongTextBench") ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] ds = ds.rename_column("text", "text_content") ds = ds.rename_column("prompt", "text") @@ -390,7 +465,7 @@ def setup_genai_bench_dataset() -> Tuple[Dataset, Dataset, Dataset]: def setup_imgedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -403,8 +478,8 @@ def setup_imgedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -420,6 +495,7 @@ def setup_imgedit_dataset( Tuple[Dataset, Dataset, Dataset] The ImgEdit dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="ImgEdit") import json import requests @@ -466,18 +542,47 @@ def setup_imgedit_dataset( "General_Object": "object", } -_ONEIG_ALIGNMENT_BASE = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/41b49831e79e6dde5323618c164da1c4cf0f699d/scripts/alignment/Q_D" +_ONEIG_BENCHMARK_REF = "41b49831e79e6dde5323618c164da1c4cf0f699d" +_ONEIG_RAW_BASE = f"https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/{_ONEIG_BENCHMARK_REF}" +_ONEIG_ALIGNMENT_QD_URL = f"{_ONEIG_RAW_BASE}/scripts/alignment/Q_D" +_ONEIG_REASONING_GT_URL_EN = f"{_ONEIG_RAW_BASE}/scripts/reasoning/gt_answer.json" +_ONEIG_REASONING_GT_URL_ZH = f"{_ONEIG_RAW_BASE}/scripts/reasoning/gt_answer_zh.json" + +_ONEIG_QD_JSON_STEMS: tuple[str, ...] = ( + "anime", + "human", + "object", + "anime_zh", + "human_zh", + "object_zh", + "multilingualism_zh", +) def _fetch_oneig_alignment() -> dict[str, dict]: - """Fetch alignment questions from per-category Q_D files (InferBench-style).""" + """Load OneIG question/dependency graphs from the official repo (HTTP, no on-disk cache). + + Fetches every ``scripts/alignment/Q_D/*.json`` file used by upstream ``alignment_score.py`` (EN + ZH), + including ``multilingualism_zh.json``. Keys in the returned map are ``{stem}_{prompt_id}`` matching + upstream file stems (e.g. ``object_012``, ``multilingualism_zh_000``). + + Returns + ------- + dict[str, dict] + ``prompt_id``-level ``questions`` and ``dependencies`` dicts (parsed from JSON strings when needed). + + Raises + ------ + requests.HTTPError + If any asset URL is missing or the response is not successful. + """ import json import requests questions_by_key: dict[str, dict] = {} - for qd_name in ("anime", "human", "object"): - url = f"{_ONEIG_ALIGNMENT_BASE}/{qd_name}.json" + for stem in _ONEIG_QD_JSON_STEMS: + url = f"{_ONEIG_ALIGNMENT_QD_URL}/{stem}.json" resp = requests.get(url, timeout=30) resp.raise_for_status() data = json.loads(resp.text) @@ -488,16 +593,55 @@ def _fetch_oneig_alignment() -> dict[str, dict]: q = json.loads(q) if isinstance(d, str): d = json.loads(d) - questions_by_key[f"{qd_name}_{prompt_id}"] = {"questions": q, "dependencies": d} + questions_by_key[f"{stem}_{prompt_id}"] = {"questions": q, "dependencies": d} return questions_by_key +def _fetch_oneig_reasoning_gt() -> tuple[dict[str, str], dict[str, str]]: + """Load official knowledge-reasoning reference answers (HTTP, no on-disk cache). + + Mirrors ``scripts/reasoning/gt_answer.json`` and ``gt_answer_zh.json`` from the same pinned commit as Q_D. + Keys are prompt ids (``str``), values are answer strings; downstream metrics may slice filenames to the + first three characters like ``reasoning_score.py``. + + Returns + ------- + tuple[dict[str, str], dict[str, str]] + ``(en_by_id, zh_by_id)``. + + Raises + ------ + requests.HTTPError + If any asset URL is missing or the response is not successful. + """ + import json + + import requests + + def _load(url: str) -> dict[str, str]: + resp = requests.get(url, timeout=60) + resp.raise_for_status() + raw = json.loads(resp.text) + return {str(k): str(v) for k, v in raw.items()} + + return _load(_ONEIG_REASONING_GT_URL_EN), _load(_ONEIG_REASONING_GT_URL_ZH) + + +def _oneig_needs_zh_multilingualism_hub(category: OneIGCategory | list[OneIGCategory] | None) -> bool: + """Whether ``OneIG-Bench-ZH`` must be loaded for ``Multilingualism`` rows.""" + if category is None: + return True + categories = [category] if not isinstance(category, list) else category + return "Multilingualism" in categories + + def setup_oneig_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, category: OneIGCategory | list[OneIGCategory] | None = None, + reasoning_language: str = "EN", ) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the OneIG benchmark dataset. @@ -506,8 +650,8 @@ def setup_oneig_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -517,16 +661,43 @@ def setup_oneig_dataset( category : OneIGCategory | list[OneIGCategory] | None Filter by dataset category (Anime_Stylization, Portrait, etc.) or class (fauvism, watercolor, etc.). If None, returns all subsets. + reasoning_language : str, optional + Which reasoning GT to use for Knowledge_Reasoning rows: ``"EN"`` or ``"ZH"``. Default is ``"EN"``. Returns ------- Tuple[Dataset, Dataset, Dataset] - The OneIG dataset (dummy train, dummy val, test). + The OneIG dataset (dummy train, dummy val, test). Rows include ``questions`` and + ``dependencies`` from official Q_D JSON (EN + ZH stems, including ``multilingualism_zh``), + plus ``reasoning_gt_answer`` for ``Knowledge_Reasoning`` (language chosen by ``reasoning_language``). + Rows cover EN categories from ``OneIG-Bench`` plus ``Multilingualism`` from ``OneIG-Bench-ZH``. + Assets are downloaded over HTTP on each call (pinned commit ``_ONEIG_BENCHMARK_REF``); there is + no local disk cache. + + Notes + ----- + Non-multilingual prompts are loaded from the Hub config ``OneIG-Bench``; **Multilingualism** rows + are taken only from ``OneIG-Bench-ZH`` (they use ``prompt_cn``). The ZH config is fetched only when + the requested ``category`` is ``None`` (full suite) or explicitly includes ``Multilingualism``. + Q_D / reasoning JSON URLs are defined next to ``_fetch_oneig_alignment`` and + ``_fetch_oneig_reasoning_gt``. """ + _warn_ignored_benchmark_seed(seed, dataset="OneIG") questions_by_key = _fetch_oneig_alignment() - - ds_raw = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] - records = [_to_oneig_record(dict(row), questions_by_key) for row in ds_raw] + reasoning_gt_en, reasoning_gt_zh = _fetch_oneig_reasoning_gt() + + ds_en = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] + records = [ + _to_oneig_record(dict(row), questions_by_key, reasoning_gt_en, reasoning_gt_zh, reasoning_language) + for row in ds_en + ] + if _oneig_needs_zh_multilingualism_hub(category): + ds_zh = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench-ZH")["train"] # type: ignore[index] + ds_zh_ml = ds_zh.filter(lambda r: r["category"] == "Multilingualism") + records.extend( + _to_oneig_record(dict(row), questions_by_key, reasoning_gt_en, reasoning_gt_zh, reasoning_language) + for row in ds_zh_ml + ) ds = Dataset.from_list(records) if category is not None: @@ -544,8 +715,257 @@ def setup_oneig_dataset( return ds.select([0]), ds.select([0]), ds +def _setup_oneig_subset_with_fixed_category( + category: OneIGCategory, + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category=category, + reasoning_language=reasoning_language, + ) + + +def setup_oneig_anime_stylization_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Anime_Stylization``. + + ``functools.partial`` is not used so ``get_literal_values_from_param`` does not unwrap to + :func:`setup_oneig_dataset` and enumerate every ``OneIGCategory``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "Anime_Stylization", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + +def setup_oneig_general_object_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``General_Object``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "General_Object", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + +def setup_oneig_knowledge_reasoning_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Knowledge_Reasoning``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "Knowledge_Reasoning", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + +def setup_oneig_multilingualism_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Multilingualism``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "Multilingualism", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + +def setup_oneig_portrait_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Portrait``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "Portrait", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + +def setup_oneig_text_rendering_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Text_Rendering``. + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for this subset. + """ + return _setup_oneig_subset_with_fixed_category( + "Text_Rendering", + seed, + fraction, + train_sample_size, + test_sample_size, + reasoning_language, + ) + + def setup_gedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -558,8 +978,8 @@ def setup_gedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -576,6 +996,7 @@ def setup_gedit_dataset( Tuple[Dataset, Dataset, Dataset] The GEditBench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GEditBench") task_type_map = { "subject_add": "subject-add", "subject_remove": "subject-remove", @@ -613,7 +1034,7 @@ def setup_gedit_dataset( def setup_dpg_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -626,8 +1047,8 @@ def setup_dpg_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -642,6 +1063,7 @@ def setup_dpg_dataset( Tuple[Dataset, Dataset, Dataset] The DPG dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="DPG") import csv import io from collections import defaultdict diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 6d1eaadd..03003127 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -135,7 +135,7 @@ def from_string( tokenizer: AutoTokenizer | None = None, collate_fn_args: dict = dict(), dataloader_args: dict = dict(), - seed: int = 42, + seed: int | None = None, category: str | list[str] | None = None, fraction: float = 1.0, train_sample_size: int | None = None, @@ -154,8 +154,10 @@ def from_string( Any additional arguments for the collate function. dataloader_args : dict Any additional arguments for the dataloader. - seed : int - The seed to use. + seed : int | None, optional + Passed to dataset setup when the loader uses shuffled sampling. + If None, setups that require a seed default to 42; test-only benchmarks + omit seed so ordering stays deterministic without warnings. category : str | list[str] | None The category of the dataset. fraction : float @@ -177,7 +179,12 @@ def from_string( collate_fn_args = default_collate_fn_args if "seed" in inspect.signature(setup_fn).parameters: - setup_fn = partial(setup_fn, seed=seed) + seed_param = inspect.signature(setup_fn).parameters["seed"] + has_default = seed_param.default is not inspect.Parameter.empty + if seed is not None: + setup_fn = partial(setup_fn, seed=seed) + elif not has_default: + setup_fn = partial(setup_fn, seed=42) if "category" in inspect.signature(setup_fn).parameters: setup_fn = partial(setup_fn, category=category) diff --git a/src/pruna/data/utils.py b/src/pruna/data/utils.py index 2096f9e6..7cd323d4 100644 --- a/src/pruna/data/utils.py +++ b/src/pruna/data/utils.py @@ -34,20 +34,6 @@ from pruna.logging.logger import pruna_logger -class TokenizerMissingError(Exception): - """ - Custom exception raised when a tokenizer is required but not provided. - - Parameters - ---------- - message : str, optional - The message to display when the exception is raised. - """ - - def __init__(self, message: str = "Tokenizer is missing. Please provide a valid tokenizer.") -> None: - super().__init__(message) - - def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> list[str] | None: """ Extract Literal values from a function parameter's type annotation (handles Union). @@ -78,13 +64,13 @@ def get_literal_values_from_param(func: Callable[..., Any], param_name: str) -> except Exception: return None - def extract(ann: Any) -> list[str] | None: - if ann is None or ann is type(None): + def extract(annotation: Any) -> list[str] | None: + if annotation is None or annotation is type(None): return None - if get_origin(ann) is Literal: - args = get_args(ann) + if get_origin(annotation) is Literal: + args = get_args(annotation) return list(args) if args and all(isinstance(a, str) for a in args) else None - for arg in get_args(ann) or (): + for arg in get_args(annotation) or (): if (r := extract(arg)) is not None: return r return None @@ -92,6 +78,20 @@ def extract(ann: Any) -> list[str] | None: return extract(ann) +class TokenizerMissingError(Exception): + """ + Custom exception raised when a tokenizer is required but not provided. + + Parameters + ---------- + message : str, optional + The message to display when the exception is raised. + """ + + def __init__(self, message: str = "Tokenizer is missing. Please provide a valid tokenizer.") -> None: + super().__init__(message) + + def split_train_into_train_val_test(dataset: Dataset | IterableDataset, seed: int) -> Tuple[Dataset, Dataset, Dataset]: """ Split the training dataset into train, validation, and test. diff --git a/src/pruna/evaluation/benchmark_vlm_integration.py b/src/pruna/evaluation/benchmark_vlm_integration.py new file mode 100644 index 00000000..fc1ca63a --- /dev/null +++ b/src/pruna/evaluation/benchmark_vlm_integration.py @@ -0,0 +1,374 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for VLM benchmark integration runs (scripts and e2e tests).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from pruna.data.pruna_datamodule import PrunaDataModule +from pruna.evaluation.benchmarks import BenchmarkRegistry +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_base import VLM_METRIC_REGISTRY_NAMES, BaseVLM + +DEFAULT_SMOL = "HuggingFaceTB/SmolVLM-256M-Instruct" +DEFAULT_LITELLM = "openai/gpt-4o" + +_CATEGORY_DEFAULTS: dict[str, dict[str, Any]] = { + "GenEval": {"category": "single_object"}, + "ImgEdit": {"category": "replace"}, + "GEditBench": {"category": "background_change"}, +} + + +def discover_vlm_benchmark_jobs(include_oneig_reasoning: bool) -> list[tuple[str, str, str]]: + """ + List ``(lookup_key, benchmark display name, metric_name)`` for VLM-backed paper metrics. + + Parameters + ---------- + include_oneig_reasoning : bool + If True, append ``oneig_reasoning`` for OneIG Knowledge Reasoning (LLM2CLIP, not SmolVLM). + + Returns + ------- + list[tuple[str, str, str]] + Sorted jobs for benchmarks that declare at least one metric in + :data:`VLM_METRIC_REGISTRY_NAMES`, plus optional reasoning jobs. + """ + jobs: list[tuple[str, str, str]] = [] + for key in sorted(BenchmarkRegistry.list()): + b = BenchmarkRegistry.get(key) + for m in b.metrics: + if m in VLM_METRIC_REGISTRY_NAMES: + jobs.append((key, b.name, m)) + if include_oneig_reasoning and "oneig_reasoning" in b.metrics: + tup = (key, b.name, "oneig_reasoning") + if tup not in jobs: + jobs.append(tup) + return jobs + + +def make_random_pred_images(batch_size: int, size: int = 224) -> torch.Tensor: + """ + Return a random RGB batch (placeholder generations for smoke integration). + + Parameters + ---------- + batch_size : int + Number of images in the batch dimension. + size : int, optional + Height and width of each square image (default 224). + + Returns + ------- + torch.Tensor + Tensor of shape ``(batch_size, 3, size, size)`` with values in ``[0, 1)``. + """ + return torch.rand(batch_size, 3, size, size) + + +def build_vlm_benchmark_metric( + metric_name: str, + benchmark_key: str, + *, + vlm_type: str, + model_name: str, + device: str, + vlm: BaseVLM | None = None, +) -> Any: + """ + Instantiate a metric for one benchmark VLM job. + + Parameters + ---------- + metric_name : str + Registry metric name (e.g. ``qa_accuracy``). + benchmark_key : str + Benchmark lookup key matching ``PrunaDataModule`` (e.g. ``GenEval``). + vlm_type : str + ``litellm`` or ``transformers`` when ``vlm`` is None. + model_name : str + Model id when ``vlm`` is None. + device : str + Device for metrics and optional local VLM. + vlm : BaseVLM | None + Pre-built VLM to reuse (e.g. session fixture); skips loading weights again. + + Returns + ------- + Any + A :class:`~pruna.evaluation.metrics.metric_stateful.StatefulMetric` instance. + """ + if metric_name == "oneig_reasoning": + return MetricRegistry.get_metric(metric_name, device=device) + kw: dict[str, Any] = { + "vlm_type": vlm_type, + "model_name": model_name, + "device": device, + "structured_output": True, + } + if vlm is not None: + kw["vlm"] = vlm + if metric_name == "qa_accuracy" and benchmark_key == "GenEval": + kw["aggregation"] = "all_or_nothing" + return MetricRegistry.get_metric(metric_name, **kw) + + +@dataclass(frozen=True) +class BenchmarkVlmBatchOutcome: + """ + Outputs from a single benchmark row plus metric score. + + Parameters + ---------- + result : MetricResult + Aggregated metric output. + prompts : list[Any] + Prompt batch from the dataloader. + auxiliaries : list[Any] + Auxiliary fields per row (e.g. questions). + pred : torch.Tensor + Predicted image batch passed to the metric. + """ + + result: MetricResult + prompts: list[Any] + auxiliaries: list[Any] + pred: torch.Tensor + + +def run_benchmark_vlm_batch_full( + benchmark_key: str, + metric_name: str, + *, + vlm_type: str = "transformers", + model_name: str = DEFAULT_SMOL, + device: str = "cpu", + vlm: BaseVLM | None = None, +) -> BenchmarkVlmBatchOutcome: + """ + Load one test batch, run one VLM metric, return result and batch tensors. + + Parameters + ---------- + benchmark_key : str + Dataset lookup key for :meth:`PrunaDataModule.from_string`. + metric_name : str + Registry metric name. + vlm_type : str, optional + ``litellm`` or ``transformers`` when ``vlm`` is None (default ``transformers``). + model_name : str, optional + Model id when ``vlm`` is None (default HuggingFace SmolVLM). + device : str, optional + Device string (default ``cpu``). + vlm : BaseVLM | None, optional + Pre-built VLM to reuse. + + Returns + ------- + BenchmarkVlmBatchOutcome + Result, prompts, auxiliaries, and placeholder ``pred`` tensor. + """ + dm_kw: dict[str, Any] = {"dataloader_args": {"batch_size": 1}} + dm_kw.update(_CATEGORY_DEFAULTS.get(benchmark_key, {})) + dm = PrunaDataModule.from_string(benchmark_key, **dm_kw) + dm.limit_datasets(1) + prompts, auxiliaries = next(iter(dm.test_dataloader())) + pred = make_random_pred_images(len(prompts)) + metric = build_vlm_benchmark_metric( + metric_name, + benchmark_key, + vlm_type=vlm_type, + model_name=model_name, + device=device, + vlm=vlm, + ) + metric.update(prompts, auxiliaries, pred) + mr = metric.compute() + return BenchmarkVlmBatchOutcome(result=mr, prompts=prompts, auxiliaries=auxiliaries, pred=pred) + + +def run_benchmark_metric_batch( + benchmark_key: str, + metric_name: str, + *, + vlm_type: str = "transformers", + model_name: str = DEFAULT_SMOL, + device: str = "cpu", + vlm: BaseVLM | None = None, +) -> MetricResult: + """ + Load one test batch from the benchmark, run one VLM metric, return :class:`MetricResult`. + + Uses random ``pred`` tensors as placeholder generations (same as the ``mine`` store script). + + Parameters + ---------- + benchmark_key : str + Dataset name for :meth:`PrunaDataModule.from_string`. + metric_name : str + Metric to run. + vlm_type : str + Backend when ``vlm`` is not provided. + model_name : str + Checkpoint or litellm id when ``vlm`` is not provided. + device : str + Torch device string. + vlm : BaseVLM | None + Optional shared VLM instance for faster multi-benchmark runs. + + Returns + ------- + MetricResult + Aggregated score from :meth:`~pruna.evaluation.metrics.metric_stateful.StatefulMetric.compute`. + """ + return run_benchmark_vlm_batch_full( + benchmark_key, + metric_name, + vlm_type=vlm_type, + model_name=model_name, + device=device, + vlm=vlm, + ).result + + +def _short(obj: Any, max_len: int = 400) -> Any: + if isinstance(obj, str) and len(obj) > max_len: + return obj[:max_len] + "…" + return obj + + +def _aux_for_record(aux: dict[str, Any]) -> dict[str, Any]: + out: dict[str, Any] = {} + for k, v in aux.items(): + if k == "questions" and isinstance(v, dict): + out[k] = {qk: _short(str(qt), 200) for qk, qt in list(v.items())[:24]} + if len(v) > 24: + out["_truncated_questions"] = len(v) - 24 + else: + out[k] = _short(v) if isinstance(v, str) else v + return out + + +def _safe_json(obj: Any) -> Any: + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, dict): + return {str(k): _safe_json(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_safe_json(x) for x in obj] + if isinstance(obj, torch.Tensor): + return {"tensor_shape": list(obj.shape), "dtype": str(obj.dtype)} + return str(obj) + + +def _metric_result_record(mr: MetricResult) -> dict[str, Any]: + return { + "name": mr.name, + "result": float(mr.result), + "higher_is_better": mr.higher_is_better, + "metric_units": mr.metric_units, + } + + +def vlm_benchmark_batch_to_json_record( + outcome: BenchmarkVlmBatchOutcome, + *, + benchmark_key: str, + benchmark_name: str, + metric_name: str, + vlm_type: str, + model_name: str, + device: str, + pred_note: str | None = "random noise placeholder", +) -> dict[str, Any]: + """ + Build a JSON-serializable snapshot of one benchmark batch, preds, and metric output. + + Parameters + ---------- + outcome : BenchmarkVlmBatchOutcome + Batch prompts, auxiliaries, ``pred`` tensor, and computed :class:`MetricResult`. + benchmark_key : str + Registry / datamodule lookup key (e.g. ``GenEval``). + benchmark_name : str + Human-readable benchmark name. + metric_name : str + Metric id used for this run. + vlm_type : str + Backend id (e.g. ``transformers``). + model_name : str + Model id or litellm route. + device : str + Torch device string. + pred_note : str | None, optional + Short note stored next to ``pred`` shape (placeholder generations in integration). + + Returns + ------- + dict[str, Any] + Nested dict safe for ``json.dumps`` (strings truncated; tensors summarized). + + Examples + -------- + >>> from pruna.evaluation.metrics.result import MetricResult + >>> import torch + >>> mr = MetricResult(name="m", params={}, result=1.0, higher_is_better=True) + >>> bo = BenchmarkVlmBatchOutcome( + ... result=mr, + ... prompts=["hi"], + ... auxiliaries=[{}], + ... pred=torch.zeros(1, 3, 2, 2), + ... ) + >>> rec = vlm_benchmark_batch_to_json_record( + ... bo, + ... benchmark_key="K", + ... benchmark_name="K", + ... metric_name="m", + ... vlm_type="transformers", + ... model_name="x", + ... device="cpu", + ... ) + >>> rec["metric_result"]["result"] + 1.0 + """ + a0 = outcome.auxiliaries[0] if outcome.auxiliaries and isinstance(outcome.auxiliaries[0], dict) else {} + pred_payload: dict[str, Any] = { + "shape": list(outcome.pred.shape), + "dtype": str(outcome.pred.dtype), + } + if pred_note is not None: + pred_payload["note"] = pred_note + record: dict[str, Any] = { + "benchmark_lookup_key": benchmark_key, + "benchmark_name": benchmark_name, + "metric_name": metric_name, + "dataset_name": benchmark_key, + "vlm_type": vlm_type, + "model_name": model_name, + "device": device, + "inputs": { + "prompts": [_short(p, 500) for p in outcome.prompts], + "auxiliary_0": _aux_for_record(a0) if isinstance(a0, dict) else _safe_json(a0), + }, + "pred": pred_payload, + "metric_result": _metric_result_record(outcome.result), + } + return _safe_json(record) diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e52ae463..34b5444a 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import builtins from dataclasses import dataclass, field from pruna.data import base_datasets @@ -31,7 +34,10 @@ class Benchmark: description : str Description of what the benchmark evaluates. metrics : list[str] - List of metric names used for evaluation. + Metric names from ``MetricRegistry`` that the ``reference`` paper + explicitly names for that benchmark (not speculative proxies). Entries + with no matching registered name stay empty; pass metrics explicitly to + ``Task`` when running other evaluations. task_type : str Type of task the benchmark evaluates (e.g., 'text_to_image'). reference : str | None @@ -62,24 +68,17 @@ class BenchmarkRegistry: """ Registry for benchmarks. - Metrics per benchmark are set to those explicitly used in the reference - paper (see reference URL). All entries verified from paper evaluation - sections (ar5iv/HTML or PDF) as of verification pass: + Each entry's ``metrics`` lists only ``MetricRegistry`` names that have a + **directly named** counterpart in the ``reference`` paper (e.g. CLIPScore → + ``clip_score``, VQAScore → ``vqa``, Fréchet inception distance → ``fid``). + If the paper cites a method with no registered metric (HPS v2, Mask2Former, + mPLUG-large adjudication, …), the list is empty. + + See ``.mine/benchmark-paper-alignment/01-arxiv-literature-vs-pruna-metrics.md`` + for paper-by-paper notes and Pruna implementation gaps. - - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. - - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. - - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). - - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. - - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. - - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. - - WikiText (1609.07843 §5): perplexity on validation/test. - - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. - - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). - - ImgEdit (2505.20275 §4.2): GPT-4o 1–5 ratings and ImgEdit-Judge. - - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). - - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). - - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). - - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. + OneIG is split into six subset benchmarks (plus full ``OneIG``); see + ``.mine/benchmark-paper-alignment/02-oneig-subset-metrics-verification.md`` for §4.1 mapping. """ _registry: dict[str, Benchmark] = {} @@ -88,9 +87,7 @@ class BenchmarkRegistry: def _register(cls, benchmark: Benchmark) -> None: missing = [m for m in benchmark.metrics if not MetricRegistry.has_metric(m)] if missing: - raise ValueError( - f"Benchmark '{benchmark.name}' references metrics not in MetricRegistry: {missing}." - ) + raise ValueError(f"Benchmark '{benchmark.name}' references metrics not in MetricRegistry: {missing}.") if benchmark.lookup_key not in base_datasets: available = ", ".join(base_datasets.keys()) raise ValueError( @@ -125,7 +122,7 @@ def get(cls, name: str) -> Benchmark: return cls._registry[key] @classmethod - def list(cls, task_type: str | None = None) -> list[str]: + def list(cls, task_type: str | None = None) -> builtins.list[str]: """ List available benchmark names. @@ -174,7 +171,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning " "(counting, comparison, logic/negation) with over 24k human ratings." ), - metrics=[], # Paper uses VQAScore only; not in Pruna + metrics=["vqa", "clip_score"], # VQAScore + CLIPScore both named (arXiv:2406.13743) task_type="text_to_image", reference="https://arxiv.org/abs/2406.13743", ), @@ -195,7 +192,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " "FID for fidelity and CLIP score for image-text alignment." ), - metrics=["fid", "clip_score"], # §4.1: FID + CLIP score + metrics=["fid", "clip_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2205.11487", ), @@ -223,10 +220,12 @@ def list(cls, task_type: str | None = None) -> list[str]: name="GenEval", description=( "Compositional text-to-image benchmark with 6 categories: single object, two object, " - "counting, colors, position, color attributes. Evaluates fine-grained alignment " - "between prompts and generated images via VQA-style questions." + "counting, colors, position, color attributes. Uses atomic yes/no questions per prompt; " + "``Task.from_benchmark`` wires ``qa_accuracy`` with strict per-image aggregation " + "(all questions must pass) plus ``clip_score``. For holistic VQAScore-style scoring " + "use GenAI Bench with ``vqa``." ), - metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna + metrics=["qa_accuracy", "clip_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2310.11513", ), @@ -246,17 +245,19 @@ def list(cls, task_type: str | None = None) -> list[str]: "Image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, " "style, background, compose. Evaluates instruction-following for inpainting and editing." ), - metrics=[], # Paper uses GPT-4o/ImgEdit-Judge; not in Pruna + metrics=["img_edit_score"], # Paper: GPT-4o rubric scores, FakeShield; no matching MetricRegistry name task_type="text_to_image", reference="https://arxiv.org/abs/2505.20275", ), Benchmark( name="Long Text Bench", description=( - "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " - "handle complex multi-clause descriptions and maintain coherence across long instructions." + "Text rendering benchmark evaluating whether T2I models correctly render specific text strings " + "specified in prompts. Provides ``text_content`` ground truth for OCR comparison via ``text_score`` " + "(default: mean character error rate; optional raw Levenshtein via ``text_distance='levenshtein'``). " + "Not to be confused with text-to-image alignment for long descriptive prompts." ), - metrics=[], # Paper uses text_score/TIT-Score; not in Pruna + metrics=["text_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2507.22058", ), @@ -265,20 +266,57 @@ def list(cls, task_type: str | None = None) -> list[str]: description=( "General image editing benchmark with 11 task types: background change, color alter, " "material alter, motion change, style change, subject add/remove/replace, text change, " - "tone transfer, and human retouching." + "tone transfer, and human retouching. " + "When using VieScoreMetric with this benchmark, pass ``task_type='image_editing'`` to apply " + "the paper's 2-criterion SC scoring (execution success + overediting) instead of the default " + "text-to-image single-criterion SC. " + "The default metric implementation scores the edited image and instruction only; " + "full parity with reference VIEScore pipelines that condition on a source image may require " + "dataset fields and metric extensions not included here." ), - metrics=[], # Paper uses VIEScore; not in Pruna + metrics=["vie_score"], # VIEScore named in GEdit-Bench section task_type="text_to_image", reference="https://arxiv.org/abs/2504.17761", ), Benchmark( - name="OneIG", - description=( - "Omni-dimensional benchmark for text-to-image evaluation. Six dataset categories " - "(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, " - "Text_Rendering) plus fine-grained style classes. Includes alignment questions." - ), - metrics=[], # Paper uses dimension-specific metrics; not in Pruna + name="OneIG Anime Stylization", + description="OneIG subset: anime and stylized imagery.", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG General Object", + description="OneIG subset: everyday objects and scenes.", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Knowledge Reasoning", + description="OneIG subset: knowledge- and reasoning-heavy prompts.", + metrics=["oneig_reasoning"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Multilingualism", + description="OneIG subset: multilingual prompts (incl. Chinese splits).", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Portrait", + description="OneIG subset: people and portraits.", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Text Rendering", + description="OneIG subset: text and graphics painted into the image.", + metrics=["oneig_text_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2506.07977", ), diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 5b713dea..3e20e4a5 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -112,8 +112,8 @@ def from_benchmark( Examples -------- - >>> agent = EvaluationAgent.from_benchmark("Parti Prompts", model) - >>> agent = EvaluationAgent.from_benchmark("HPS", model, category="anime", fraction=0.1) + >>> agent = EvaluationAgent.from_benchmark("Parti Prompts") + >>> agent = EvaluationAgent.from_benchmark("HPS", category="anime", fraction=0.1) """ task = Task.from_benchmark( benchmark_name, diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 1a12f623..bf0a5ef0 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -15,16 +15,43 @@ from pruna.evaluation.metrics.registry import MetricRegistry # isort:skip from pruna.evaluation.metrics.aesthetic_laion import AestheticLAION +from pruna.evaluation.metrics.metric_alignment_score import AlignmentScoreMetric from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_dino_score import DinoScore -from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric -from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric +from pruna.evaluation.metrics.metric_elapsed_time import ( + LatencyMetric, + ThroughputMetric, + TotalTimeMetric, +) +from pruna.evaluation.metrics.metric_energy import ( + CO2EmissionsMetric, + EnergyConsumedMetric, +) from pruna.evaluation.metrics.metric_evalharness import LMEvalMetric -from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric -from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_memory import ( + DiskMemoryMetric, + InferenceMemoryMetric, + TrainingMemoryMetric, +) +from pruna.evaluation.metrics.metric_model_architecture import ( + TotalMACsMetric, + TotalParamsMetric, +) +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + LitellmVLM, + TransformersVLM, + get_vlm, +) __all__ = [ "MetricRegistry", @@ -45,4 +72,16 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", + "VQAMetric", + "AlignmentScoreMetric", + "ImageEditScoreMetric", + "QAAccuracyMetric", + "OneIGAlignmentMetric", + "TextScoreMetric", + "OneIGTextScoreMetric", + "VieScoreMetric", + "BaseVLM", + "LitellmVLM", + "TransformersVLM", + "get_vlm", ] diff --git a/src/pruna/evaluation/metrics/metric_alignment_score.py b/src/pruna/evaluation/metrics/metric_alignment_score.py new file mode 100644 index 00000000..c54e8197 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_alignment_score.py @@ -0,0 +1,141 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Alignment Score metric using VLM for image-text alignment evaluation.""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images + + +@MetricRegistry.register("alignment_score") +class AlignmentScoreMetric(StatefulMetric): + """ + Alignment Score metric using VLM. + + Assesses how well generated images match text prompts through structured questioning. + Higher scores indicate better alignment. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "alignment_score" + runs_on: List[str] = ["cuda", "cpu"] + + def __init__( + self, + *args, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"?' + score = self.vlm.score([image], [question], ["Yes"], response_format=self.response_format)[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the alignment score. + + Returns + ------- + MetricResult + The mean alignment score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index c3689446..ccfc413c 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -198,9 +198,11 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] | # Measurement list_elapsed_times = [] with tqdm(total=self.n_iterations, desc="Measuring inference time", unit="iter") as pbar: + def measure_with_progress(m, x): list_elapsed_times.append(self._time_inference(m, x)) pbar.update(1) + self._measure(model, dataloader, self.n_iterations, measure_with_progress) total_elapsed_time = sum(list_elapsed_times) diff --git a/src/pruna/evaluation/metrics/metric_img_edit_score.py b/src/pruna/evaluation/metrics/metric_img_edit_score.py new file mode 100644 index 00000000..c21c5643 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_img_edit_score.py @@ -0,0 +1,151 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image Edit Score metric. + +VLM-based instruction-following score for image editing. Evaluates how well an edited image +follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909), +ADIEE (ICCV 2025). +""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, _process_images, get_score_from_response + + +@MetricRegistry.register("img_edit_score") +class ImageEditScoreMetric(StatefulMetric): + """ + Image Edit Score metric. + + VLM-based instruction-following score for image editing. Evaluates how well an edited image + follows the given editing instruction. Higher scores indicate better editing quality. + + Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025). + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "img_edit_score" + + def __init__( + self, + *args, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = FloatOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (editing instructions). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output (edited) images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = ( + f'On a scale of 0 to 10, how well does this edited image follow the instruction "{prompt}"? ' + "0 = instruction not followed at all, 10 = perfectly executed. Reply with a single number." + ) + responses = self.vlm.generate([image], [question], response_format=self.response_format) + self.scores.append(get_score_from_response(responses[0])) + + def compute(self) -> MetricResult: + """ + Compute the image edit score. + + Returns + ------- + MetricResult + The mean image edit score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_oneig_alignment.py b/src/pruna/evaluation/metrics/metric_oneig_alignment.py new file mode 100644 index 00000000..177cf148 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_oneig_alignment.py @@ -0,0 +1,196 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OneIG alignment scoring with dependency masking (parent ``No`` gates children).""" + +from __future__ import annotations + +from typing import Any, Mapping + +import torch + +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.utils import metric_data_processor +from pruna.evaluation.metrics.vlm_utils import _process_images + + +def _int_dict_keys(mapping: Mapping[Any, Any]) -> dict[int, Any]: + return {int(k): v for k, v in mapping.items()} + + +def _normalize_dependencies(deps: Any) -> dict[int, list[int]]: + if not isinstance(deps, Mapping): + return {} + out: dict[int, list[int]] = {} + for k, v in deps.items(): + key = int(k) + if isinstance(v, list): + out[key] = [int(p) for p in v] + else: + out[key] = [] + return out + + +def apply_oneig_dependency_mask( + raw_scores: Mapping[int, float], + dependencies: Mapping[int, list[int]], +) -> dict[int, float]: + """ + Apply OneIG ``filter_score`` logic per dependency graph (single grid cell). + + Parents with semantic answer ``No`` (score ``0``) force dependent question + scores to ``0``. Parent id ``0`` is ignored, matching the reference script. + + Parameters + ---------- + raw_scores : Mapping[int, float] + Map question id → VLM score in ``{0, 1}`` (or float) before masking. + dependencies : Mapping[int, list[int]] + Map child question id → list of parent question ids (use ``[0]`` for roots). + + Returns + ------- + dict[int, float] + Copy of scores with dependent questions zeroed when any non-zero parent + scored ``0``. + """ + filtered = {int(k): float(v) for k, v in raw_scores.items()} + deps = _normalize_dependencies(dependencies) + raw = dict(filtered) + for child_id, parent_ids in deps.items(): + if child_id not in filtered: + continue + any_parent_no = False + for parent_id in parent_ids: + if parent_id == 0: + continue + if parent_id not in raw: + continue + if raw[parent_id] == 0.0: + any_parent_no = True + break + if any_parent_no: + filtered[child_id] = 0.0 + return filtered + + +def aggregate_oneig_alignment_per_cell(filtered_scores: Mapping[int, float], question_ids: list[int]) -> float: + """ + Mean filtered score over all questions in the prompt (one grid cell). + + Parameters + ---------- + filtered_scores : Mapping[int, float] + Post-mask scores for each question id. + question_ids : list[int] + Ordered ids (typically sorted ascending) defining the denominator. + + Returns + ------- + float + Average score in ``[0, 1]`` if inputs are binary; ``0.0`` if ``question_ids`` is empty. + """ + if not question_ids: + return 0.0 + s = sum(float(filtered_scores[qid]) for qid in question_ids) + return s / float(len(question_ids)) + + +@MetricRegistry.register("oneig_alignment") +class OneIGAlignmentMetric(QAAccuracyMetric): + """ + OneIG alignment with dependency-aware aggregation. + + Reuses :class:`QAAccuracyMetric` VLM Yes/No scoring but aggregates like + ``OneIG-Benchmark`` ``alignment_score.py`` for a **single** grid cell (no + ``split_mxn_grid``): question ids are sorted numerically, raw scores are + masked when any non-root parent is ``No``, then the mean over all questions + is stored per image. + + Numerical parity with upstream also depends on the VLM (e.g. ``openai/gpt-4o`` via + litellm vs reference Qwen2.5-VL). + + Parameters + ---------- + *args : Any + Additional positional arguments for :class:`QAAccuracyMetric`. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is ``"litellm"``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments for :class:`QAAccuracyMetric`. + """ + + metric_name: str = "oneig_alignment" + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Score each question with the VLM, apply dependency masking, append per-cell mean. + + Parameters + ---------- + x : list[Any] | torch.Tensor + Unused batch metadata (kept for metric interface). + gt : torch.Tensor + Ground-truth slot holding per-sample aux dicts with ``questions`` and + optionally ``dependencies``. + outputs : torch.Tensor + Model outputs (images) evaluated against the questions. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + aux_list = inputs[1] if len(inputs) > 1 else [] + if isinstance(aux_list, torch.Tensor): + aux_list = aux_list.tolist() + for i, image in enumerate(images): + aux = aux_list[i] if i < len(aux_list) else {} + if not isinstance(aux, dict): + raise ValueError( + "oneig_alignment requires aux[{}] to be a dict with 'questions'. Got: {!r}.".format(i, type(aux)) + ) + qs = aux.get("questions") + if not isinstance(qs, dict) or not qs: + raise ValueError( + f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}." + ) + qmap = _int_dict_keys(qs) + qids = sorted(qmap) + question_texts = [str(qmap[qi]) for qi in qids] + deps = _normalize_dependencies(aux.get("dependencies", {})) + raw_scores_list = self.vlm.score( + [image] * len(question_texts), + question_texts, + ["Yes"] * len(question_texts), + response_format=self.response_format, + ) + raw_map = {qid: float(raw_scores_list[j]) for j, qid in enumerate(qids)} + filtered = apply_oneig_dependency_mask(raw_map, deps) + self.scores.append(aggregate_oneig_alignment_per_cell(filtered, qids)) diff --git a/src/pruna/evaluation/metrics/metric_oneig_reasoning.py b/src/pruna/evaluation/metrics/metric_oneig_reasoning.py new file mode 100644 index 00000000..cf1b83a5 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_oneig_reasoning.py @@ -0,0 +1,350 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OneIG reasoning score via LLM2CLIP text-image similarity. + +Llama-derived checkpoints may require ``HF_TOKEN`` and ``huggingface-cli login``. + +Hugging Face download tuning (optional): + +- ``PRUNA_ONEIG_HF_VERBOSE=1`` or ``HF_DEBUG=1`` — hub **debug** logging and tqdm + progress bars (helps when stderr is piped; pair with ``python -u`` or + ``PYTHONUNBUFFERED=1`` for line-buffered output). +- ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1`` — enable **hf_transfer** multi-part downloads + (requires ``pruna[evaluation]``, which lists ``hf_transfer``). Alternatively, set + ``HF_HUB_ENABLE_HF_TRANSFER=1`` **before** starting Python so the hub picks it up at + import time. +""" + +from __future__ import annotations + +import os +from typing import Any + +import torch + +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_utils import _process_images +from pruna.logging.logger import pruna_logger + + +def _env_truthy(raw: str | None) -> bool: + if raw is None: + return False + return raw.strip().upper() in {"1", "ON", "YES", "TRUE"} + + +def _prepare_huggingface_hub_for_oneig_downloads() -> None: + """ + Apply Hugging Face Hub verbosity and optional fast downloads before checkpoints load. + + ``HF_HUB_ENABLE_HF_TRANSFER`` is read when ``huggingface_hub`` loads; if it was + false, we flip the in-module flag after importing ``hf_transfer`` when + ``PRUNA_ONEIG_HF_FAST_DOWNLOAD=1``. + """ + if _env_truthy(os.environ.get("PRUNA_ONEIG_HF_VERBOSE")) or _env_truthy(os.environ.get("HF_DEBUG")): + from huggingface_hub.utils import enable_progress_bars + from huggingface_hub.utils.logging import set_verbosity_debug + + set_verbosity_debug() + enable_progress_bars() + + if not _env_truthy(os.environ.get("PRUNA_ONEIG_HF_FAST_DOWNLOAD")): + return + + import hf_transfer # noqa: F401 # type: ignore[import-not-found] + import huggingface_hub.constants as hf_constants + + hf_constants.HF_HUB_ENABLE_HF_TRANSFER = True + pruna_logger.info("oneig_reasoning: enabled hf_transfer downloads (PRUNA_ONEIG_HF_FAST_DOWNLOAD=1).") + + +def _to_pil_list(images: list) -> list: + """Convert images to list of PIL.Image (RGB).""" + import numpy as np + from PIL import Image + + out: list = [] + for img in images: + if isinstance(img, Image.Image): + out.append(img.convert("RGB")) + elif isinstance(img, torch.Tensor): + if img.ndim == 4: + img = img[0] + if img.max() > 1: + img = img / 255.0 + np_img = (img.cpu().numpy() * 255).astype("uint8") + if np_img.shape[0] == 3: + np_img = np_img.transpose(1, 2, 0) + out.append(Image.fromarray(np_img)) + elif hasattr(img, "__array__"): + out.append(Image.fromarray(np.asarray(img)).convert("RGB")) + else: + out.append(img) + return out + + +class _LLM2CLIPScorer: + """ + Thin wrapper around LLM2CLIP text-image similarity. + + Accepts PIL images and a single answer string; returns per-image scores. + Best-effort alignment with OneIG-Benchmark scripts (CUDA + bfloat16). + """ + + def __init__( + self, + processor_model: str = "openai/clip-vit-large-patch14-336", + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", + device: str = "cuda", + ) -> None: + self.processor_model = processor_model + self.model_name = model_name + self.llm_model_name = llm_model_name + self.device = device + self._processor = None + self._clip_model = None + self._l2v = None + + def _load_models(self) -> None: + if self._clip_model is not None: + return + _prepare_huggingface_hub_for_oneig_downloads() + from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor + + from pruna.evaluation.metrics.vendor.oneig_llm2vec import LLM2Vec + from pruna.evaluation.metrics.vendor.oneig_llm2vec.modeling_llama_encoder import LlamaEncoderModel + + pruna_logger.info( + "oneig_reasoning: downloading or loading LLM2CLIP checkpoints " + "(%s, %s). First run can take many minutes and several gigabytes; " + "Hugging Face download progress may look idle when logs are piped.", + self.model_name, + self.llm_model_name, + ) + dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 + self._processor = CLIPImageProcessor.from_pretrained(self.processor_model) + self._clip_model = AutoModel.from_pretrained( + self.model_name, + dtype=dtype, + trust_remote_code=True, + ).to(self.device) + self._clip_model.train(mode=False) + + config = AutoConfig.from_pretrained(self.llm_model_name, trust_remote_code=True) + dev_str = str(self.device) + attn_impl = "sdpa" if dev_str == "cuda" or dev_str.startswith("cuda:") else "eager" + config.attn_implementation = attn_impl + if hasattr(config, "_attn_implementation"): + config._attn_implementation = attn_impl + llm_model = LlamaEncoderModel.from_pretrained( + self.llm_model_name, + dtype=dtype, + config=config, + trust_remote_code=True, + ) + llm_model.config._name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(self.llm_model_name) + self._l2v = LLM2Vec(llm_model, tokenizer, pooling_mode="mean", max_length=512, doc_max_length=512) + + def score(self, images: list, text_prompt: str) -> list[float] | None: + """ + Compute similarity scores between images and text. + + Parameters + ---------- + images : list + List of PIL.Image.Image. + text_prompt : str + Reference text (e.g. ground-truth answer). + + Returns + ------- + list[float] | None + Per-image scores, or None on failure. + """ + self._load_models() + pil_images = _to_pil_list(images) + if not pil_images: + return None + input_pixels = self._processor(images=pil_images, return_tensors="pt").pixel_values.to(self.device) + captions = [text_prompt] + text_features = self._l2v.encode(captions, convert_to_tensor=True, device=self.device).to(self.device) + text_features = self._clip_model.get_text_features(text_features) + + with torch.no_grad(): + if self.device == "cuda": + with torch.amp.autocast(device_type="cuda"): + image_features = self._clip_model.get_image_features(input_pixels) + else: + image_features = self._clip_model.get_image_features(input_pixels.float()) + + image_features = image_features.float() + text_features = text_features.float() + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + + text_probs = (image_features @ text_features.T).cpu().tolist() + return [p[0] for p in text_probs] + + +@MetricRegistry.register("oneig_reasoning") +class OneIGReasoningMetric(StatefulMetric): + """ + OneIG reasoning score: LLM2CLIP similarity between GT answer text and generated image. + + Uses ``reasoning_gt_answer`` from aux (populated by OneIG Knowledge_Reasoning loader; + language is chosen at dataset load via ``reasoning_language``). MVP: 1×1 grid (whole + image as single cell). Llama-derived checkpoints may require + ``HF_TOKEN`` and ``huggingface-cli login``. + + Parameters + ---------- + processor_model : str, optional + CLIP processor model ID. + model_name : str, optional + LLM2CLIP model ID. + llm_model_name : str, optional + LLM2Vec model ID. + device : str | torch.device | None, optional + Device for inference. + scorer : _LLM2CLIPScorer | None, optional + Optional scorer instance for testing (injected mock). + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments for :class:`StatefulMetric`. + + Notes + ----- + Prompt benchmarks yield ``(prompts, aux_list)``. With default ``call_type`` + ``y_gt``, ``aux_list`` is the list (or tensor coerced to a list) of per-sample + dicts parallel to generated images. Each dict must include a non-empty + ``reasoning_gt_answer`` for Knowledge/Reasoning samples. Missing GT, scorer + failures, or :meth:`compute` with no scored samples raise ``ValueError`` or + ``RuntimeError`` instead of returning a placeholder score. + """ + + metric_name: str = "oneig_reasoning" + default_call_type: str = "y_gt" + higher_is_better: bool = True + runs_on: list[str] = ["cuda", "cpu"] + + def __init__( + self, + processor_model: str = "openai/clip-vit-large-patch14-336", + model_name: str = "microsoft/LLM2CLIP-Openai-L-14-336", + llm_model_name: str = "microsoft/LLM2CLIP-Llama-3-8B-Instruct-CC-Finetuned", + device: str | torch.device | None = None, + scorer: _LLM2CLIPScorer | None = None, + call_type: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(device=device, **kwargs) + self.call_type = get_call_type_for_single_metric( + call_type if call_type is not None else SINGLE, self.default_call_type + ) + self.processor_model = processor_model + self.model_name = model_name + self.llm_model_name = llm_model_name + self._scorer = scorer + self.add_state("scores", default=[]) + + def _get_scorer(self) -> _LLM2CLIPScorer: + if self._scorer is not None: + return self._scorer + return _LLM2CLIPScorer( + processor_model=self.processor_model, + model_name=self.model_name, + llm_model_name=self.llm_model_name, + device=self.device, + ) + + def _get_gt_text(self, aux: dict) -> str: + val = aux.get("reasoning_gt_answer") + if val is None or (isinstance(val, str) and not val.strip()): + raise ValueError( + "oneig_reasoning requires 'reasoning_gt_answer' in aux for Knowledge_Reasoning rows. " + f"Got keys: {list(aux.keys())}." + ) + return str(val).strip() + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Score each image against its GT answer text via LLM2CLIP similarity. + + Parameters + ---------- + x : list[Any] | torch.Tensor + Unused batch metadata. + gt : torch.Tensor + Ground-truth slot with per-sample aux dicts containing ``reasoning_gt_answer``. + outputs : torch.Tensor + Model outputs (generated images). + + Raises + ------ + ValueError + If a per-sample aux entry is not a dict or lacks a non-empty + ``reasoning_gt_answer``. + RuntimeError + If the LLM2CLIP scorer returns no scores for a sample. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + aux_list = inputs[1] if len(inputs) > 1 else [] + if isinstance(aux_list, torch.Tensor): + aux_list = aux_list.tolist() + + scorer = self._get_scorer() + + for i, image in enumerate(images): + aux = aux_list[i] if i < len(aux_list) else {} + if not isinstance(aux, dict): + raise ValueError(f"oneig_reasoning requires aux[{i}] to be a dict. Got: {type(aux)}.") + text = self._get_gt_text(aux) + result = scorer.score([image], text) + if result is None or len(result) == 0: + raise RuntimeError(f"oneig_reasoning: LLM2CLIP scorer returned no scores for sample {i}.") + self.scores.append(float(sum(result) / len(result))) + + def compute(self) -> MetricResult: + """ + Compute the mean reasoning score across all samples. + + Returns + ------- + MetricResult + Mean LLM2CLIP similarity. + + Raises + ------ + RuntimeError + If :meth:`update` was not called or scored no samples. + """ + if not self.scores: + raise RuntimeError( + "oneig_reasoning: no samples were scored; call update() with valid " + "batches and non-empty reasoning_gt_answer before compute()." + ) + mean_score = sum(self.scores) / len(self.scores) + return MetricResult(self.metric_name, self.__dict__, float(mean_score)) diff --git a/src/pruna/evaluation/metrics/metric_qa_accuracy.py b/src/pruna/evaluation/metrics/metric_qa_accuracy.py new file mode 100644 index 00000000..5207bce8 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_qa_accuracy.py @@ -0,0 +1,168 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QA Accuracy metric using VLM for image understanding evaluation.""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images + + +@MetricRegistry.register("qa_accuracy") +class QAAccuracyMetric(StatefulMetric): + """ + QA Accuracy metric. + + Uses VLM to answer questions about images. + Higher scores indicate better image understanding. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. Supports ``aggregation`` (e.g. ``"all_or_nothing"`` for GenEval-style + wiring); stored on the metric instance. + """ + + scores: List[float] + default_call_type: str = "y_gt" + higher_is_better: bool = True + metric_name: str = "qa_accuracy" + + def __init__( + self, + *args, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + self.aggregation = kwargs.pop("aggregation", "mean") + + def _extract_questions(self, gt: Any, n: int) -> List[List[str]]: + if isinstance(gt, (list, tuple)) and len(gt) >= n: + out = [] + for i in range(n): + v = gt[i] + if isinstance(v, dict) and "questions" in v: + qs = v["questions"] + out.append(list(qs.values()) if isinstance(qs, dict) else list(qs)) + else: + out.append([]) + return out + return [[]] * n + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data. + gt : torch.Tensor + The ground truth (questions per image). + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 else [] + questions_per_image = self._extract_questions(auxiliaries, len(images)) + for i, image in enumerate(images): + questions = questions_per_image[i] if i < len(questions_per_image) else [] + if not questions: + aux = auxiliaries[i] if i < len(auxiliaries) else {} + raise ValueError( + "qa_accuracy requires 'questions' in auxiliaries. " + "Use a benchmark that provides it (e.g. GenEval, DPG, OneIG). " + f"Got aux keys: {list(aux.keys()) if isinstance(aux, dict) else 'not a dict'}." + ) + scores = self.vlm.score( + [image] * len(questions), + questions, + ["Yes"] * len(questions), + response_format=self.response_format, + ) + score = float(np.mean(scores)) + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the QA accuracy score. + + Returns + ------- + MetricResult + The mean QA accuracy across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_text_score.py b/src/pruna/evaluation/metrics/metric_text_score.py new file mode 100644 index 00000000..d2308f3e --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score.py @@ -0,0 +1,346 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``). + +OneIG composite: ``oneig_text_score`` / ``ocr_text_score``. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_text_score_utils import ( + levenshtein, + normalize_text_simple, + oneig_mean_text_score, + oneig_per_sample_contributions, +) +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response + +OCR_PROMPT = ( + "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " + "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " + "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " + "If no text is recognized, reply with exactly: No text recognized" +) + + +class _BaseVLMOCRTextMetric(StatefulMetric): + """ + Shared VLM OCR over rendered images with ground truth in ``text_content``. + + Subclasses implement how OCR and GT strings are scored and aggregated. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + """ + + default_call_type: str = "y_gt" + + def __init__( + self, + *args: Any, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = TextOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + + @abstractmethod + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + """Update metric state from one ground-truth / OCR pair.""" + + @abstractmethod + def _compute_result_value(self) -> float: + """Return the scalar reported as ``MetricResult.result``.""" + + def update(self, x: List[Any] | torch.Tensor, gt: List[str], outputs: torch.Tensor) -> None: + """ + Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries. + + Parameters + ---------- + x : List[Any] | torch.Tensor + Batch prompts or metadata. + gt : list of dict or list of str + Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with + newlines), or plain strings per batch item. + outputs : torch.Tensor + Rendered images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) + for i, image in enumerate(images): + responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) + raw = responses[0] if responses else "" + ocr_text = get_text_from_response(raw) + aux = auxiliaries[i] if i < len(auxiliaries) else {} + text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) + if isinstance(text_gt, list): + text_gt = "\n".join(str(x) for x in text_gt) + if text_gt is None: + raise ValueError( + f"{self.metric_name} requires 'text_content' in auxiliaries. " + "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." + ) + self._accumulate_sample(text_gt, ocr_text) + + def compute(self) -> MetricResult: + """ + Aggregate batched contributions into a single metric value. + + Returns + ------- + MetricResult + Named result with ``higher_is_better`` taken from the class. + """ + value = self._compute_result_value() + return MetricResult(self.metric_name, self.__dict__, float(value)) + + +@MetricRegistry.register("ocr_levenshtein") +@MetricRegistry.register("text_score") +class TextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then mean Levenshtein distance to ground truth (lower is better). + + Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy). + + Uses light normalization only (not the full OneIG preprocess). See + :class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + scores: List[float] + higher_is_better: bool = False + metric_name: str = "text_score" + + def __init__( + self, + *args: Any, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict[str, Any]] = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.add_state("scores", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + norm_gt = normalize_text_simple(text_gt) + norm_ocr = normalize_text_simple(ocr_text) + self.scores.append(levenshtein(norm_ocr, norm_gt)) + + def _compute_result_value(self) -> float: + if not self.scores: + return 0.0 + return float(np.mean(self.scores)) + + +@MetricRegistry.register("ocr_text_score") +@MetricRegistry.register("oneig_text_score") +class OneIGTextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then OneIG-style composite text score (higher is better). + + Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol). + + Aggregates edit distance, completion rate, and word/char accuracy like + ``OneIG-Benchmark/scripts/text/text_score.py``. + + Parameters + ---------- + *args : Any + Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`). + language_mode : {'EN', 'ZH'}, optional + Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + edit_distances: List[float] + completion_ratios: List[float] + match_counts: List[int] + gt_totals: List[int] + + higher_is_better: bool = True + metric_name: str = "oneig_text_score" + + def __init__( + self, + *args: Any, + language_mode: Literal["EN", "ZH"] = "EN", + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict[str, Any]] = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.language_mode = language_mode + self.add_state("edit_distances", []) + self.add_state("completion_ratios", []) + self.add_state("match_counts", []) + self.add_state("gt_totals", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text) + self.edit_distances.append(ed) + self.completion_ratios.append(cr) + self.match_counts.append(mcount) + self.gt_totals.append(gtot) + + def _compute_result_value(self) -> float: + *_, text_score = oneig_mean_text_score( + self.edit_distances, + self.completion_ratios, + self.match_counts, + self.gt_totals, + self.language_mode, + ) + return text_score diff --git a/src/pruna/evaluation/metrics/metric_text_score_utils.py b/src/pruna/evaluation/metrics/metric_text_score_utils.py new file mode 100644 index 00000000..8aa7d850 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score_utils.py @@ -0,0 +1,274 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for text rendering metrics (simple Levenshtein vs OneIG-style composite). + +OneIG-style preprocessing and aggregation follow +`OneIG-Benchmark/scripts/text/text_utils.py` and `text_score.py` (Apache-2.0). +""" + +from __future__ import annotations + +import re +from collections import Counter +from typing import Literal + +_OCR_HALLUCINATION_KEYWORDS = ("addCriterion", "No text recognized.", "No text recognized") + + +def normalize_text_simple(s: str) -> str: + """ + Normalize text for the legacy ``text_score`` metric (light cleanup + spacing). + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Normalized string. + """ + cleaned = re.sub( + r"[^\u4e00-\u9fa5a-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + "", + s or "", + ) + return re.sub(r"\s+", " ", cleaned).strip() + + +def levenshtein(s1: str, s2: str) -> float: + """ + Symmetric Levenshtein edit distance. + + Parameters + ---------- + s1 : str + First string. + s2 : str + Second string. + + Returns + ------- + float + Edit distance. + """ + if len(s1) < len(s2): + return levenshtein(s2, s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append(min(prev[j] + (c1 != c2), prev[j + 1] + 1, curr[-1] + 1)) + prev = curr + return float(prev[-1]) + + +def contains_chinese(text: str) -> bool: + """ + Return True if ``text`` contains CJK unified ideographs. + + Parameters + ---------- + text : str + Input text. + + Returns + ------- + bool + Whether Chinese characters are present. + """ + return bool(re.search(r"[\u4e00-\u9fff]", text)) + + +def preprocess_string_oneig(s: str) -> str: + """ + OneIG ``preprocess_string``: charset filter, Chinese vs whitespace normalization. + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Preprocessed string (ground truth or OCR). + """ + raw = s or "" + cleaned = re.sub( + r"[^\u4e00-\u9fa5a-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + "", + raw, + ) + if contains_chinese(cleaned): + pattern = re.compile( + r"[\u4e00-\u9fa5a-zA-Z0-9àâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + ) + return "".join(pattern.findall(raw)).strip() + return re.sub(r"\s+", " ", cleaned).strip() + + +def clean_oneig_ocr_hallucinations(text: str) -> str: + """ + Remove known OCR boilerplate substrings (OneIG ``clean_and_remove_hallucinations``). + + Parameters + ---------- + text : str + Raw OCR output. + + Returns + ------- + str + Cleaned OCR text. + """ + out = text or "" + for keyword in _OCR_HALLUCINATION_KEYWORDS: + out = out.replace(keyword, "").replace(f"\n{keyword}", "").replace(f"{keyword}\n", "") + return out + + +def calculate_char_match_ratio( + text_gt: str, + ocr_str: str, +) -> tuple[int, float, int]: + """ + OneIG overlap stats: character multiset for ZH, word multiset for EN. + + Parameters + ---------- + text_gt : str + Preprocessed ground truth. + ocr_str : str + Preprocessed OCR. + + Returns + ------- + total_match_count : int + Overlap count used in WAC numerator aggregation. + ratio : float + Per-sample ratio (mean of ratios is not used in the official aggregate). + gt_total : int + Denominator term: ``sum(gt_counter.values())`` for WAC aggregation. + """ + if contains_chinese(text_gt): + gt_counter: Counter[str] = Counter(text_gt) + ocr_counter: Counter[str] = Counter(ocr_str) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + ratio = total_match_count / len(text_gt) if len(text_gt) > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + words_gt = text_gt.split() + words_ocr = ocr_str.split() + gt_counter = Counter(words_gt) + ocr_counter = Counter(words_ocr) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + total_gt_count = len(words_gt) + ratio = total_match_count / total_gt_count if total_gt_count > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + +def max_edit_distance_for_language(language_mode: Literal["EN", "ZH"]) -> int: + """ + OneIG ``MAX_EDIT_DISTANCE`` (100 for English, 50 for Chinese benchmark split). + + Parameters + ---------- + language_mode : {'EN', 'ZH'} + Benchmark language mode. + + Returns + ------- + int + Cap used in the composite text score. + """ + return 50 if language_mode == "ZH" else 100 + + +def oneig_per_sample_contributions(text_gt: str, ocr_raw: str) -> tuple[float, float, int, int]: + """ + Per-sample terms for OneIG aggregation (ED, CR, WAC numerator/denominator parts). + + Parameters + ---------- + text_gt : str + Ground-truth text (dataset field). + ocr_raw : str + Raw OCR string from the VLM. + + Returns + ------- + edit_distance : float + Levenshtein distance after OneIG preprocess. + completion_ratio : float + 1.0 if distance is zero, else 0.0. + match_count : int + Overlap count for WAC. + gt_total : int + Ground-truth token count term for WAC denominator. + """ + ocr_clean = clean_oneig_ocr_hallucinations(ocr_raw) + gt_pre = preprocess_string_oneig(text_gt) + ocr_pre = preprocess_string_oneig(ocr_clean) + ed = levenshtein(ocr_pre, gt_pre) + cr = 1.0 if ed == 0.0 else 0.0 + match_count, _, gt_total = calculate_char_match_ratio(gt_pre, ocr_pre) + return ed, cr, match_count, gt_total + + +def oneig_mean_text_score( + edit_distances: list[float], + completion_ratios: list[float], + match_counts: list[int], + gt_totals: list[int], + language_mode: Literal["EN", "ZH"], +) -> tuple[float, float, float, float]: + """ + Aggregate OneIG ED, CR, WAC and composite text score (higher is better). + + Parameters + ---------- + edit_distances : list of float + Per-sample edit distances. + completion_ratios : list of float + Per-sample completion indicators. + match_counts : list of int + Per-sample WAC numerators. + gt_totals : list of int + Per-sample WAC denominator terms. + language_mode : {'EN', 'ZH'} + Selects ``MAX_EDIT_DISTANCE``. + + Returns + ------- + ed_mean : float + Mean edit distance. + cr_mean : float + Mean completion ratio. + wac : float + Micro-averaged WAC: ``sum(match_counts) / sum(gt_totals)``. + text_score : float + Composite: ``1 - min(MAX_ED, ED) * (1 - CR) * (1 - WAC) / MAX_ED``. + """ + cap = float(max_edit_distance_for_language(language_mode)) + if not edit_distances: + return 0.0, 0.0, 0.0, 0.0 + ed_mean = float(sum(edit_distances) / len(edit_distances)) + cr_mean = float(sum(completion_ratios) / len(completion_ratios)) + denom = float(sum(gt_totals)) + wac = float(sum(match_counts) / denom) if denom > 0.0 else 0.0 + text_score = 1.0 - min(cap, ed_mean) * (1.0 - cr_mean) * (1.0 - wac) / cap + return ed_mean, cr_mean, wac, text_score diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..ea2365fa 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -246,6 +264,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +278,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/src/pruna/evaluation/metrics/metric_vie_score.py b/src/pruna/evaluation/metrics/metric_vie_score.py new file mode 100644 index 00000000..75ec7e57 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vie_score.py @@ -0,0 +1,178 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VIEScore metric for evaluating conditional image synthesis (semantic + quality). + +Reference: VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation +(ACL 2024) - https://arxiv.org/abs/2312.14867, https://github.com/TIGER-AI-Lab/VIEScore +""" + +from __future__ import annotations + +import math +import re +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import FloatOutput, _process_images + + +@MetricRegistry.register("vie_score") +class VieScoreMetric(StatefulMetric): + """ + VIEScore metric for evaluating conditional image synthesis (semantic + quality). + + Uses VLM to assess both semantic alignment and visual quality. + Higher scores indicate better overall quality. + + Computes: + - Semantic score: How well image follows prompt + - Quality score: Naturalness and artifacts + - Overall: Geometric mean of semantic and quality + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + References + ---------- + VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation (ACL 2024) + https://arxiv.org/abs/2312.14867 + https://github.com/TIGER-AI-Lab/VIEScore + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "vie_score" + + def __init__( + self, + *args, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = FloatOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + + sem_prompt = ( + f'On a scale of 0 to 10, how well does this image match the prompt "{prompt}"? ' + "0 = no match, 10 = perfect match. Reply with a single number." + ) + sem_resp = self.vlm.generate([image], [sem_prompt], response_format=self.response_format)[0] + sem_score = self._parse_score(sem_resp) + + qual_prompt = ( + "On a scale of 0 to 10, rate this image's naturalness and absence of artifacts. " + "0 = unnatural, heavy artifacts; 10 = natural, no artifacts. Reply with a single number." + ) + qual_resp = self.vlm.generate([image], [qual_prompt], response_format=self.response_format)[0] + qual_score = self._parse_score(qual_resp) + + score = math.sqrt(sem_score * qual_score) / 10.0 + self.scores.append(score) + + def _parse_score(self, response: str) -> float: + if isinstance(response, str): + numbers = re.findall(r"\d+", response) + return min(float(numbers[0]), 10.0) if numbers else 0.0 + return 0.0 + + def compute(self) -> MetricResult: + """ + Compute the VIEScore metric. + + Returns + ------- + MetricResult + The mean VIEScore across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/metric_vqa.py b/src/pruna/evaluation/metrics/metric_vqa.py new file mode 100644 index 00000000..53d03f6e --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vqa.py @@ -0,0 +1,165 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VQA (Visual Question Answering) metric. + +Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation +https://arxiv.org/abs/2404.01291 + +Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. With litellm, +use_probability=True (default) requests logprobs for soft scores when the provider supports it. +Set use_probability=False for binary 0/1. TransformersVLM always uses binary. +""" + +from __future__ import annotations + +from typing import Any, List, Literal, Optional + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images + + +@MetricRegistry.register("vqa") +class VQAMetric(StatefulMetric): + """ + VQA (Visual Question Answering) metric. + + Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment. + Higher scores indicate better image-text alignment. + + VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default use_probability=True + with litellm requests logprobs for soft scores when supported. + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend to use. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation for stable outputs (litellm pydantic; transformers outlines + when a string format is used). Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + use_probability : bool, optional + If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1. + Default is True for paper alignment. + **kwargs : Any + Additional arguments. + """ + + scores: List[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "vqa" + + def __init__( + self, + *args, + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: Optional[dict] = None, + structured_output: bool = True, + device=None, + api_key: Optional[str] = None, + call_type: str = SINGLE, + use_probability: bool = True, + **kwargs, + ): + super().__init__(device=device) + self.device = set_to_best_available_device(device) + self.structured_output = structured_output + self.use_probability = use_probability + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = VQAnswer if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + + def update(self, x: List[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth / cached images. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], list) else [""] * len(images) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"?' + score = self.vlm.score( + [image], + [question], + ["Yes"], + response_format=self.response_format, + use_probability=self.use_probability, + )[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the VQA score. + + Returns + ------- + MetricResult + The mean VQA score across all updates. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index 5efd721a..e5d404e1 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib from functools import partial from inspect import isclass from typing import Any, Callable, Dict, Iterable, List @@ -32,6 +33,7 @@ class MetricRegistry: """ _registry: Dict[str, Callable[..., Any]] = {} + _lazy_metrics: frozenset[str] = frozenset({"oneig_reasoning"}) @classmethod def register(cls, name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -104,7 +106,7 @@ def has_metric(cls, name: str) -> bool: bool True if the metric is registered, False otherwise. """ - return name in cls._registry + return name in cls._registry or name in cls._lazy_metrics @classmethod def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: @@ -122,6 +124,9 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: ------- The metric instance. """ + if name in cls._lazy_metrics and name not in cls._registry: + importlib.import_module("pruna.evaluation.metrics.metric_oneig_reasoning") + if name not in cls._registry: raise ValueError(f"Metric '{name}' is not registered.") diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..c6813872 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -56,13 +56,17 @@ def metric_data_processor( This function determines the order and selection of inputs to be passed to various metrics. The function supports different input arrangements through the 'call_type' configuration: - - 'x_y': Uses input data (x) and model outputs - - 'gt_y': Uses ground truth (gt) and model outputs - - 'y_x': Uses model outputs and input data (x) - - 'y_gt': Uses model outputs and ground truth (gt) - - 'pairwise_gt_y': Uses cached base model outputs (gt) and smashed model outputs (y). - - 'pairwise_y_gt': Uses smashed model outputs (y) and cached base model outputs (gt). - The evaluation agent is expected to pass the cached base model outputs as gt. + + - 'y_gt': Model's output first, then ground truth. Returns [outputs, gt]. + - 'gt_y': Ground truth first, then model's output. Returns [gt, outputs]. + - 'y_x': Model's output first, then input data. Returns [outputs, x]. + Used by CLIPScore, AlignmentScore, VQA, ImageEditScore, VIEScore. + - 'x_y': Input data first, then model's output. Returns [x, outputs]. + - 'x_gt': Input data first, then ground truth. Returns [x, gt]. + - 'gt_x': Ground truth first, then input data. Returns [gt, x]. + - 'pairwise_y_gt': Base model's output first, then subsequent model's output. + - 'pairwise_gt_y': Subsequent model's output first, then base model's output. + - 'y': Only the output is used; the metric has an internal dataset. Returns [outputs]. Parameters ---------- @@ -85,7 +89,8 @@ def metric_data_processor( Raises ------ ValueError - If the specified call_type is not one of: 'x_y', 'gt_y', 'y_x', 'y_gt', 'pairwise'. + If the specified call_type is not one of: 'y_gt', 'gt_y', 'y_x', 'x_y', + 'x_gt', 'gt_x', 'pairwise_y_gt', 'pairwise_gt_y', 'y'. Examples -------- @@ -106,11 +111,15 @@ def metric_data_processor( return [outputs, x] elif call_type == "y_gt": return [outputs, gt] + elif call_type == "x_gt": + return [x, gt] + elif call_type == "gt_x": + return [gt, x] elif call_type == "pairwise_gt_y": return [gt, outputs] elif call_type == "pairwise_y_gt": return [outputs, gt] - elif call_type == "y": # IQA metrics that have an internal dataset + elif call_type == "y": return [outputs] else: raise ValueError(f"Invalid call type: {call_type}") diff --git a/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec new file mode 100644 index 00000000..01654bd4 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec @@ -0,0 +1,12 @@ +LLM2Vec (llm2vec package) vendored from OneIG-Benchmark. + +Source: https://github.com/OneIG-Bench/OneIG-Benchmark +Commit: 41b49831e79e6dde5323618c164da1c4cf0f699d +Path: scripts/utils/llm2clip/llm2vec/ + +OneIG-Benchmark is licensed under the Apache License 2.0. +See the project repository for full license text. + +``oneig_llm2vec/modeling_llama_encoder.py`` is derived from +McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp (Hugging Face Hub); +Pruna relaxes the upstream flash-attention-only constraint for CPU use. diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py new file mode 100644 index 00000000..c1fb56c8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py @@ -0,0 +1,391 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). +# See NOTICE.oneig_llm2vec in parent directory. + +import json +import logging +import pathlib +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +from peft import PeftModel +from torch import Tensor, device, nn +from tqdm import trange +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + LlamaConfig, + PretrainedConfig, +) + +from pruna.evaluation.metrics.vendor.oneig_llm2vec.models.bidirectional_llama import LlamaBiModel + +logger = logging.getLogger(__name__) + + +def batch_to_device(batch, target_device: device | str): + """Send a pytorch batch to a device (CPU/GPU).""" + for key in batch: + if isinstance(batch[key], Tensor): + batch[key] = batch[key].to(target_device) + return batch + + +class LLM2Vec(nn.Module): + """Bidirectional LLM wrapper with configurable pooling for dense embeddings.""" + + def __init__( + self, + model: AutoModel, + tokenizer: AutoTokenizer, + pooling_mode: str = "mean", + max_length: int = 512, + doc_max_length: int = 512, + skip_instruction: bool = True, + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.pooling_mode = pooling_mode + self.skip_instruction = skip_instruction + self.max_length = max_length + self.doc_max_length = 512 + self.config = model.config + + @classmethod + def _get_model_class(cls, config_class_name, enable_bidirectional): + if not enable_bidirectional: + return AutoModel + elif config_class_name == "LlamaConfig": + return LlamaBiModel + else: + raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.") + + @classmethod + def from_pretrained( + cls, + base_model_name_or_path, + peft_model_name_or_path=None, + merge_peft=False, + enable_bidirectional=True, + extra_model_name_or_path=None, + **kwargs, + ): + """Load tokenizer and encoder from Hub or a local path and return ``LLM2Vec``. + + Supports optional PEFT adapters, bidirectional Llama, and extra adapter paths; + keyword args are forwarded to Hugging Face ``from_pretrained`` calls. + """ + keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] + encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None} + + tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + config = AutoConfig.from_pretrained(base_model_name_or_path) + config_class_name = config.__class__.__name__ + + model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional) + model = model_class.from_pretrained(base_model_name_or_path, **kwargs) + + base_path = pathlib.Path(base_model_name_or_path) + config_json = base_path / "config.json" + if base_path.is_dir() and config_json.exists(): + with open(config_json, encoding="utf-8") as config_file: + config_dict = json.load(config_file) + config = PretrainedConfig.from_dict(config_dict) + model.config._name_or_path = config._name_or_path + + if hasattr(model, "peft_config"): + model = PeftModel.from_pretrained( + model, + base_model_name_or_path, + ) + model = model.merge_and_unload() + + if peft_model_name_or_path is not None: + model = PeftModel.from_pretrained( + model, + peft_model_name_or_path, + ) + if merge_peft: + model = model.merge_and_unload() + if extra_model_name_or_path is not None: + logger.info(f"Loading extra model from {extra_model_name_or_path}") + if not merge_peft: + model = model.merge_and_unload() + if isinstance(extra_model_name_or_path, str): + model = PeftModel.from_pretrained( + model, + extra_model_name_or_path, + ) + peft_model_name_or_path = extra_model_name_or_path + model = model.merge_and_unload() + elif isinstance(extra_model_name_or_path, list): + for extra_model in extra_model_name_or_path: + model = PeftModel.from_pretrained( + model, + extra_model, + ) + peft_model_name_or_path = extra_model + model = model.merge_and_unload() + else: + raise ValueError("extra_model_name_or_path should be a string or a list of strings.") + config = {} + config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path + llm2vec_config_path = pathlib.Path(config_addr) / "llm2vec_config.json" + if llm2vec_config_path.exists(): + with open(llm2vec_config_path, encoding="utf-8") as config_file: + llm2vec_config = json.load(config_file) + config.update(llm2vec_config) + logger.info(f"LLM2Vec config: {config}") + for key, value in encoder_args.items(): + config[key] = value + + return cls(model=model, tokenizer=tokenizer, **config) + + def prepare_for_tokenization(self, text): + """Apply model-specific chat or EOS wrappers so tokenization matches training.""" + if "Llama-3" in self.model.config._name_or_path and "Instruct" in self.model.config._name_or_path: + text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + return text + if self.model.config._name_or_path == "microsoft/Phi-3.5-mini-instruct": + text = "<|user|>\n" + text.strip() + "<|end|>\n" + return text + if self.pooling_mode == "eos_token": + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(self.model.config, LlamaConfig): + text = text.strip() + " " + return text + + def tokenize(self, texts): + """Tokenize texts with optional embed-region markers for instruction/document split.""" + texts_2 = [] + original_texts = [] + for text in texts: + t = text.split("!@#$%^&*()") + texts_2.append(t[1] if len(t) > 1 else "") + original_texts.append("".join(t)) + + original = self.tokenizer( + original_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) + embed_mask = None + for t_i, t in enumerate(texts_2): + ids = self.tokenizer( + [t], + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + if embed_mask is None: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = e_m.unsqueeze(0) + else: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) + + original["embed_mask"] = embed_mask + return original + + def _skip_instruction(self, sentence_feature): + assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape + sentence_feature["attention_mask"] = sentence_feature["embed_mask"] + + def forward(self, sentence_feature: Dict[str, Tensor]): + """Run the encoder and return pooled sentence embeddings.""" + embed_mask = None + if "embed_mask" in sentence_feature: + embed_mask = sentence_feature.pop("embed_mask") + reps = self.model(**sentence_feature) + if embed_mask is not None: + sentence_feature["embed_mask"] = embed_mask + + return self.get_pooling(sentence_feature, reps.last_hidden_state) + + def get_pooling(self, features, last_hidden_states): + """Pool token hidden states according to ``pooling_mode``.""" + assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left." + if self.skip_instruction: + self._skip_instruction(features) + seq_lengths = features["attention_mask"].sum(dim=-1) + if self.pooling_mode == "mean": + return torch.stack( + [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)], + dim=0, + ) + elif self.pooling_mode == "weighted_mean": + bs, seq_len, _ = last_hidden_states.shape + complete_weights = torch.zeros(bs, seq_len, device=last_hidden_states.device) + for i, seq_l in enumerate(seq_lengths): + if seq_l > 0: + complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 + complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9) + return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) + elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": + return last_hidden_states[:, -1] + elif self.pooling_mode == "bos_token": + return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id] + else: + raise ValueError(f"{self.pooling_mode} is not implemented yet.") + + def _convert_to_str(self, instruction, text): + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + while tokenized_q_length > self.doc_max_length: + reduction_ratio = self.doc_max_length / tokenized_q_length + reduced_length = int(len(text.split()) * reduction_ratio) + text = " ".join(text.split()[:reduced_length]) + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}" + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = True, + convert_to_numpy: bool = False, + convert_to_tensor: bool = True, + device: Optional[str] = None, + ): + """Encode sentences (optionally instruction + document) to embedding tensors.""" + seq: Any = sentences + if isinstance(seq[0], str) and isinstance(seq[-1], int): + seq = [seq] + if isinstance(seq[0], str): + seq = [[""] + [sentence] for sentence in seq] + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + concatenated_input_texts = [] + for sentence in seq: + assert isinstance(sentence[0], str) + assert isinstance(sentence[1], str) + concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1])) + sentences = concatenated_input_texts + + self.train(mode=False) + + if convert_to_tensor: + convert_to_numpy = False + + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + all_embeddings = [] + + self.to(device) + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=True, + ): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy) + all_embeddings.append(embeddings) + + all_embeddings = torch.cat(all_embeddings, dim=0) + all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] + all_embeddings = all_embeddings.to(torch.float32) + return all_embeddings + + def save(self, output_path, merge_before_save=False, save_config=True): + """Persist model, tokenizer, and optional ``llm2vec_config.json`` to ``output_path``.""" + if merge_before_save and isinstance(self.model, PeftModel): + self.model = self.model.merge_and_unload() + if hasattr(self.model, "_hf_peft_config_loaded"): + setattr(self.model, "_hf_peft_config_loaded", False) + + self.model.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + llm2vec_config = { + "pooling_mode": self.pooling_mode, + "max_length": self.max_length, + "doc_max_length": self.doc_max_length, + "skip_instruction": self.skip_instruction, + } + + if save_config: + pathlib.Path(output_path).mkdir(exist_ok=True, parents=True) + config_out = pathlib.Path(output_path) / "llm2vec_config.json" + with open(config_out, "w", encoding="utf-8") as config_file: + json.dump(llm2vec_config, config_file, indent=4) + + def _encode( + self, + sentences_batch, + device: Optional[str] = None, + convert_to_numpy: bool = False, + multiprocessing=False, + ): + if multiprocessing: + rank = mp.current_process()._identity[0] + if device is None and torch.cuda.is_available(): + device = f"cuda:{rank % torch.cuda.device_count()}" + + use_device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") + self.to(use_device) + features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch]) + features = batch_to_device(features, use_device) + + with torch.no_grad(): + embeddings = self.forward(features) + return embeddings + + def _text_length(self, text: Union[List[int], List[List[int]]]): + if isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0: + return len(text) + if isinstance(text, dict): + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): + return 1 + else: + return sum(len(t) if not isinstance(t, int) else 1 for t in text) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """Resize the underlying model token embedding matrix.""" + return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """Enable gradient checkpointing on the wrapped model.""" + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py new file mode 100644 index 00000000..cf9b4df8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py @@ -0,0 +1,107 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Derived from McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp ``modeling_llama_encoder.py`` +# (Hugging Face Hub). Upstream requires ``flash_attention_2`` only; this copy allows ``eager`` +# and ``sdpa`` so ``oneig_reasoning`` can run on CPU without ``flash_attn``. See +# ``NOTICE.oneig_llm2vec`` in the parent ``vendor`` directory. + +import importlib.metadata + +from packaging import version +from torch import nn +from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_56_2() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.56.2. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.56.2; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.56.2") + + +class ModifiedLlamaAttention(LlamaAttention): + """ + Llama self-attention with ``is_causal`` disabled for encoder-style use. + + Parameters + ---------- + *args, **kwargs + Forwarded to :class:`~transformers.models.llama.modeling_llama.LlamaAttention`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder block using :class:`ModifiedLlamaAttention` for bidirectional encoding. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + GradientCheckpointingLayer.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaEncoderModel(LlamaModel): + """ + Bidirectional Llama stack for LLM2Vec-style encoding (eager, SDPA, or flash attention). + + Parameters + ---------- + config : LlamaConfig + Model configuration (requires transformers >= 4.56.2 layout). + """ + + def __init__(self, config: LlamaConfig) -> None: + if not is_transformers_attn_greater_or_equal_4_56_2(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py " + "of transformers version >= 4.56.2" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + attn_impl = getattr(config, "_attn_implementation", getattr(config, "attn_implementation", "eager")) + self._use_sdpa = attn_impl == "sdpa" + self._use_flash_attention_2 = attn_impl == "flash_attention_2" + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py new file mode 100644 index 00000000..6e081ca8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py @@ -0,0 +1,237 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). + +import importlib.metadata +from typing import cast + +import torch +from packaging import version +from peft import PeftModel +from torch import nn +from transformers import ( + LlamaConfig, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_38() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.38.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.38.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0") + + +def is_transformers_attn_greater_or_equal_4_40() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.40.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.40.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.40.0") + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder layer with non-causal self-attention when supported by the attention module. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + if hasattr(self.self_attn, "is_causal"): + self.self_attn.is_causal = False + + +class LlamaBiModel(LlamaModel): + """ + Bidirectional Llama backbone for MNTP-style training (transformers >= 4.38). + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + _no_split_modules = ["ModifiedLlamaDecoderLayer"] + + def __init__(self, config: LlamaConfig): + if not is_transformers_attn_greater_or_equal_4_38(): + raise ValueError( + "The current implementation of LlamaBiModel follows modeling_llama.py " + "of transformers version >= 4.38.0" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.post_init() + + def _update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_seen_tokens=None, + output_attentions=False, + ): + attn_impl = getattr( + self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager") + ) + if attn_impl == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): + target_length = self.config.max_position_embeddings + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else ( + cache_position[-1] + 1 + if not is_transformers_attn_greater_or_equal_4_40() + else past_seen_tokens + sequence_length + 1 + ) + ) + + causal_mask = torch.zeros((sequence_length, target_length), dtype=dtype, device=device) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + + if attention_mask is not None: + causal_mask = causal_mask.clone() + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + offset = ( + cache_position[0] + if attention_mask.shape[-2] < cache_position[0] + sequence_length + else 0 + ) + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + attn_impl = getattr( + self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager") + ) + if ( + attn_impl == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + causal_mask = AttentionMaskConverter._unmask_unattended( + cast(torch.FloatTensor, causal_mask.to(dtype=torch.float32)), + min_dtype, + ) + + return causal_mask + + +class LlamaBiForMNTP(LlamaForCausalLM): + """ + Causal LM wrapper around :class:`LlamaBiModel` for MNTP with optional PEFT. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + def __init__(self, config: LlamaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.model = LlamaBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_model_for_peft(self) -> LlamaBiModel | PeftModel: + """ + Return the inner model for PEFT wrapping (base or wrapped). + + Returns + ------- + LlamaBiModel or PeftModel + ``self.model``, either a :class:`LlamaBiModel` or a :class:`peft.PeftModel`. + """ + return self.model + + def set_model_for_peft(self, model: PeftModel) -> None: + """ + Replace the inner model with a PEFT-wrapped model. + + Parameters + ---------- + model : PeftModel + PEFT model whose base matches the expected backbone. + """ + self.model = model + + def save_peft_model(self, path: str) -> None: + """ + Save the (possibly PEFT-wrapped) inner model to disk. + + Parameters + ---------- + path : str + Directory path passed to ``save_pretrained`` on the inner model. + """ + self.model.save_pretrained(path) diff --git a/src/pruna/evaluation/metrics/vlm_base.py b/src/pruna/evaluation/metrics/vlm_base.py new file mode 100644 index 00000000..011bc8ae --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_base.py @@ -0,0 +1,611 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +VLM (Vision-Language Model) base classes for metrics. + +This module provides two VLM implementations: +1. LitellmVLM - Uses litellm for API-based VLM calls (supports 100+ providers) +2. TransformersVLM - Uses local VLM models from HuggingFace Transformers + +Both support structured generation for stable outputs: +- LitellmVLM: Uses pydantic models with response_format +- TransformersVLM: Uses outlines for constrained decoding. +""" + +from __future__ import annotations + +import base64 +import io +import math +import os +from abc import ABC, abstractmethod +from typing import Any, List, Literal, Optional, Type, TypeVar, Union + +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.logging.logger import pruna_logger + +T = TypeVar("T", bound=BaseModel) + +VLM_METRIC_REGISTRY_NAMES: frozenset[str] = frozenset( + ( + "vqa", + "qa_accuracy", + "alignment_score", + "img_edit_score", + "text_score", + "ocr_levenshtein", + "ocr_text_score", + "oneig_text_score", + "oneig_alignment", + "vie_score", + ) +) + + +def get_vlm( + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + *, + model_name: Optional[str] = None, + device: Optional[str | torch.device] = None, + api_key: Optional[str] = None, + structured_output: bool = True, + **vlm_kwargs: Any, +) -> BaseVLM: + """ + Create or return a VLM instance. + + Parameters + ---------- + vlm : BaseVLM | None + If provided, returned as-is. Otherwise a VLM is created. + vlm_type : {"litellm", "transformers"} + Backend when creating a VLM. + model_name : str | None + Model name for litellm (e.g. ``openai/gpt-4o``) or HuggingFace ``from_pretrained`` id. + **Required** when ``vlm`` is not provided. Ignored when ``vlm`` is provided. + device : str | torch.device | None + Device for transformers VLM. + api_key : str | None + API key for litellm. + structured_output : bool + When True, litellm uses pydantic ``response_format`` from the metric; for + ``transformers``, enables outlines-based constrained decoding when a string + format is passed to ``generate``/``score``. + **vlm_kwargs : Any + Same dict as ``vlm_kwargs`` on VLM metrics: forwarded to the backend chosen by + ``vlm_type``. For ``"litellm"``, kwargs go to ``LitellmVLM`` (e.g. provider-specific + options). For ``"transformers"``, use ``model_load_kwargs`` for + ``AutoModelForImageTextToText.from_pretrained``; any other keys are passed to + ``TransformersVLM`` after ``model_load_kwargs`` is popped. + + Returns + ------- + BaseVLM + The VLM instance. + """ + if vlm is not None: + return vlm + if not model_name: + raise ValueError( + "get_vlm requires model_name when vlm is not provided " + '(pass model_name explicitly, e.g. model_name="openai/gpt-4o").' + ) + if vlm_type == "litellm": + return LitellmVLM(model_name=model_name, api_key=api_key, **vlm_kwargs) + model_load_kwargs = vlm_kwargs.pop("model_load_kwargs", {}) + return TransformersVLM( + model_name=model_name, + device=device, + use_outlines=structured_output, + model_load_kwargs=model_load_kwargs, + **vlm_kwargs, + ) + + +class BaseVLM(ABC): + """Base class for Vision-Language Models.""" + + @abstractmethod + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model (litellm) or format string: "integer", "yes_no", "json" (transformers/outlines). + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[str] + Generated responses. + """ + pass + + @abstractmethod + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True and supported, return P(expected answer) instead of binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format. When set, uses generate() with this format and + extracts the answer field for comparison instead of raw string matching. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + pass + + +class LitellmVLM(BaseVLM): + """ + VLM using litellm for API-based inference. + + Supports 100+ LLM providers (OpenAI, Anthropic, Azure, etc.) + + Parameters + ---------- + model_name : str + Model name (e.g. ``openai/gpt-4o`` for litellm). Passed from :func:`get_vlm`. + api_key : str | None, optional + API key for the provider. Uses LITELLM_API_KEY or OPENAI_API_KEY env if None. + **kwargs : Any + Additional arguments passed to litellm. + """ + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.api_key = api_key or os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.extra_kwargs = kwargs + try: + import litellm + + litellm.drop_params = True + self._litellm = litellm + except ImportError: + pruna_logger.error("litellm not installed. Install with: pip install litellm") + raise + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model for structured output (litellm uses BaseModel). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[str] + Generated responses. + """ + results = [] + for image, prompt in zip(images, prompts): + try: + # Prepare message content + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + # Prepare completion kwargs + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + **self.extra_kwargs, + **kwargs, + } + # Add structured generation if requested (litellm uses pydantic models only) + if response_format is not None and isinstance(response_format, type): + completion_kwargs["response_format"] = response_format + # Use synchronous completion + response = self._litellm.completion(**completion_kwargs) + content_result = response.choices[0].message.content + # If using pydantic, content is already parsed + use_pydantic = ( + response_format is not None + and isinstance(response_format, type) + and isinstance(content_result, response_format) + ) + if use_pydantic: + # Return JSON string representation + results.append(content_result.model_dump_json()) + else: + results.append(content_result) + except Exception as e: + pruna_logger.error(f"Litellm generation failed: {e}") + results.append("") + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When use_probability=True, requests logprobs from the API and returns P(expected). + When response_format is set, uses structured generation and extracts the answer field. + Falls back to binary 0/1 if logprobs not available. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return P(expected) from logprobs when available. Default is False. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + if use_probability: + score = self._score_with_logprobs(image, prompt, answer, **kwargs) + elif response_format is not None: + raw = self.generate([image], [prompt], response_format=response_format, **kwargs)[0] + response_answer = get_answer_from_response(raw) + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + else: + response = self.generate([image], [prompt], **kwargs)[0].lower() + score = 1.0 if answer.lower() in response else 0.0 + scores.append(score) + return scores + + def _score_with_logprobs(self, image: Image.Image, prompt: str, expected: str, **kwargs: Any) -> float: + """ + Get P(expected) from logprobs when available. + + Parameters + ---------- + image : Image.Image + PIL Image to score. + prompt : str + Question prompt. + expected : str + Expected answer (e.g., "Yes"). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + float + Probability of expected answer (0-1), or binary 0/1 on fallback. + """ + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + "logprobs": True, + "top_logprobs": 5, + **self.extra_kwargs, + **kwargs, + } + try: + response = self._litellm.completion(**completion_kwargs) + choice = response.choices[0] + logprobs = getattr(choice, "logprobs", None) or getattr(choice.message, "logprobs", None) + if logprobs and hasattr(logprobs, "content"): + for tok in logprobs.content or []: + top = getattr(tok, "top_logprobs", None) or [] + for t in top: + token_str = getattr(t, "token", "") or str(t).lower() + if token_str and expected.lower() in token_str.lower(): + logprob = float(getattr(t, "logprob", -1e9) or -1e9) + return min(1.0, max(0.0, math.exp(logprob))) + content_str = (choice.message.content or "").lower() + if expected.lower() in content_str: + return 1.0 + return 0.0 + except Exception: + response = self.generate([image], [prompt], **kwargs)[0].lower() + return 1.0 if expected.lower() in response else 0.0 + + def _image_to_data_url(self, image: Image.Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + b64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +class TransformersVLM(BaseVLM): + """ + VLM using HuggingFace Transformers for local inference. + + Supports models like BLIP, LLaVA, SmolVLM, etc. + + Parameters + ---------- + model_name : str, optional + HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b". + device : str | torch.device | None, optional + Device for inference. Auto-detected if None. + use_outlines : bool, optional + Whether to use outlines for constrained decoding when the caller passes a string + ``response_format``. Usually set from ``structured_output`` via :func:`get_vlm`. + model_load_kwargs : dict, optional + Kwargs passed to from_pretrained (e.g. dtype, attn_implementation). + **kwargs : Any + Additional arguments passed to model.generate. + """ + + def __init__( + self, + model_name: str = "Salesforce/blip2-opt-2.7b", + device: Optional[str | torch.device] = None, + use_outlines: bool = False, + model_load_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.use_outlines = use_outlines + self.model_load_kwargs = model_load_kwargs or {} + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + self.extra_kwargs = kwargs + self._model = None + self._processor = None + + def _load_model(self) -> None: + if self._model is not None: + return + try: + from transformers import AutoModelForImageTextToText, AutoProcessor + except ImportError: + pruna_logger.error("transformers not installed. Install with: pip install transformers") + raise + pruna_logger.info(f"Loading VLM model: {self.model_name}") + self._processor = AutoProcessor.from_pretrained(self.model_name) + self._model = AutoModelForImageTextToText.from_pretrained(self.model_name, **self.model_load_kwargs) + device = self.device + self._model.to(device) # type: ignore[invalid-argument-type] + self._model.eval() + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses using local VLM. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Format constraint for outlines ("integer", "yes_no") or None. + **kwargs : Any + Additional arguments passed to model generate. + + Returns + ------- + List[str] + Generated responses. + """ + self._load_model() + results = [] + max_new_tokens = kwargs.get("max_new_tokens", 128) + format_str = response_format if isinstance(response_format, str) else None + if self.use_outlines and format_str: + results = self._generate_with_outlines(images, prompts, format_str, max_new_tokens) + else: + results = self._generate_standard(images, prompts, max_new_tokens) + return results + + def _generate_with_outlines( + self, + images: List[Image.Image], + prompts: List[str], + format_type: str, + max_new_tokens: int, + ) -> List[str]: + """Generate using outlines for constrained decoding.""" + try: + import outlines + except ImportError: + pruna_logger.warning("outlines not installed, using standard generation") + return self._generate_standard(images, prompts, max_new_tokens) + results = [] + # Define format constraints + if format_type == "json": + generator = outlines.generate.json(self._model) + elif format_type == "integer": + generator = outlines.generate.format(self._model, r"\d+") + elif format_type == "yes_no": + generator = outlines.generate.format(self._model, r"(Yes|No)") + else: + return self._generate_standard(images, prompts, max_new_tokens) + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + try: + inputs = self._prepare_inputs(image, prompt) + output = generator(**inputs, max_tokens=max_new_tokens) + response = self._decode_output(output[0]) + results.append(response) + except Exception as e: + pruna_logger.warning(f"Outlines generation failed: {e}, using standard") + results.append("") + return results + + def _prepare_inputs(self, image: Image.Image, prompt: str) -> dict: + """Prepare model inputs, supporting both BLIP-style and chat-template processors.""" + try: + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + except (ValueError, TypeError): + conversation = [ + {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]} + ] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _decode_output(self, output_ids: torch.Tensor) -> str: + """Decode model output to text.""" + if hasattr(self._processor, "batch_decode"): + return self._processor.batch_decode([output_ids], skip_special_tokens=True)[0] + return self._processor.decode(output_ids, skip_special_tokens=True) + + def _generate_standard( + self, + images: List[Image.Image], + prompts: List[str], + max_new_tokens: int, + ) -> List[str]: + """Standard generation without outlines.""" + results = [] + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._prepare_inputs(image, prompt) + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + response = self._decode_output(output[0]) + results.append(response) + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + use_probability is not supported for TransformersVLM; uses binary 0/1. + When response_format is set, uses structured generation and extracts the answer field. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + Ignored; TransformersVLM always uses binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to generate. + + Returns + ------- + List[float] + Scores for each image-question pair (0 or 1). + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + responses = self.generate([image], [prompt], response_format=response_format, **kwargs) + raw = responses[0] if responses else "" + response_answer = get_answer_from_response(raw) if response_format is not None else raw.lower() + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + scores.append(score) + return scores diff --git a/src/pruna/evaluation/metrics/vlm_utils.py b/src/pruna/evaluation/metrics/vlm_utils.py new file mode 100644 index 00000000..8e010826 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_utils.py @@ -0,0 +1,181 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities and Pydantic models for VLM metrics.""" + +from __future__ import annotations + +import json +import re +from typing import Any, List + +import torch +from PIL import Image +from pydantic import BaseModel, Field + + +def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.max() > 1: + tensor = tensor / 255.0 + np_img = (tensor.cpu().numpy() * 255).astype("uint8") + return Image.fromarray(np_img.transpose(1, 2, 0)) + + +def _process_images(images: torch.Tensor) -> List[Any]: + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in images] + + +class VQAnswer(BaseModel): + """ + Structured output for VQA questions (Yes/No or open-ended). + + Parameters + ---------- + answer : str + Answer to the question. Typically "Yes" or "No" for alignment metrics, + but can be any string for open-ended questions. + """ + + answer: str = Field(description="Answer to the question") + + +class FloatOutput(BaseModel): + """ + Structured output for numeric scoring (img_edit_score, VieScoreMetric). + + Parameters + ---------- + score : float + Score from 0 to 10. + """ + + score: float = Field(ge=0, le=10, description="Score from 0 to 10") + + +class TextOutput(BaseModel): + """ + Structured output for text extraction (text_score). + + Parameters + ---------- + text : str + Extracted text from the image, or 'No text recognized' if empty. + """ + + text: str = Field(description="Extracted text from the image, or 'No text recognized' if empty") + + +def get_answer_from_response(response: str | BaseModel | dict) -> str: + """ + Extract answer string from a VLM score() response (VQAnswer, dict, or raw string). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate() or vlm.score(). + + Returns + ------- + str + Extracted answer string, or empty string. + """ + if response is None: + return "" + if isinstance(response, VQAnswer): + return response.answer + if isinstance(response, dict): + return response.get("answer", "") + raw = str(response).strip() + if raw.startswith("{"): + try: + return json.loads(raw).get("answer", raw) + except (json.JSONDecodeError, TypeError): + pass + return raw + + +def get_text_from_response(response: str | BaseModel | dict) -> str: + """ + Extract text from a VLM generate() response (str, pydantic, or dict). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + str + Extracted text, or empty string. + """ + if response is None: + return "" + if isinstance(response, TextOutput): + text = response.text + elif isinstance(response, dict): + text = response.get("text", "") + else: + text = (response or "").strip() + if text.startswith("{"): + try: + data = json.loads(text) + text = data.get("text", text) + except (json.JSONDecodeError, TypeError): + pass + for phrase in ("No text recognized", "no text recognized", "No text"): + text = text.replace(phrase, "").strip() + return (text or "").strip() + + +def get_score_from_response(response: str | BaseModel | dict) -> float: + """ + Extract numeric score (0-10) from a VLM generate() response. + + Handles: + + * ``FloatOutput`` instances (local / parsed Pydantic). + * ``dict`` with a ``"score"`` key. + * JSON **strings** (e.g. LitellmVLM returns ``model_dump_json()`` for structured output). + * Plain text with a number (first decimal or integer matched). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + float + Score in [0, 1] (normalized from 0-10). + """ + if response is None: + return 0.0 + if isinstance(response, FloatOutput): + return min(float(response.score), 10.0) / 10.0 + if isinstance(response, dict): + return min(float(response.get("score", 0)), 10.0) / 10.0 + text = str(response or "").strip() + if text.startswith("{"): + try: + data = json.loads(text) + if isinstance(data, dict) and "score" in data: + return min(float(data["score"]), 10.0) / 10.0 + except (json.JSONDecodeError, TypeError, ValueError): + pass + match = re.search(r"\d+(?:\.\d+)?", text) + if match: + return min(float(match.group(0)), 10.0) / 10.0 + return 0.0 diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 0ae4ba8a..3e0866b5 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -27,6 +27,7 @@ from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.utils import get_hyperparameters +from pruna.evaluation.metrics.vlm_base import VLM_METRIC_REGISTRY_NAMES from pruna.logging.logger import pruna_logger AVAILABLE_REQUESTS = ("image_generation_quality", "text_generation_quality") @@ -102,6 +103,20 @@ def from_benchmark( dataloader_args=dataloader_args or {}, **kwargs, ) + if benchmark.lookup_key == "GenEval": + return cls( + request=[ + MetricRegistry.get_metric( + "qa_accuracy", + aggregation="all_or_nothing", + model_name="openai/gpt-4o", + ), + MetricRegistry.get_metric("clip_score"), + ], + datamodule=datamodule, + device=device, + low_memory=low_memory, + ) return cls( request=benchmark.metrics, datamodule=datamodule, @@ -295,9 +310,16 @@ def _process_metric_names( for metric_name in request: metric_name = cast(str, metric_name) new_requests.append(cast(str, metric_name)) - return MetricRegistry.get_metrics( - names=new_requests, inference_device=inference_device, stateful_metric_device=stateful_metric_device - ) + out: List[BaseMetric | StatefulMetric] = [] + for name in new_requests: + kwargs: dict[str, Any] = { + "inference_device": inference_device, + "stateful_metric_device": stateful_metric_device, + } + if name in VLM_METRIC_REGISTRY_NAMES: + kwargs["model_name"] = "openai/gpt-4o" + out.append(MetricRegistry.get_metric(name, **kwargs)) + return out def _get_lm_eval_task_metrics(task_name: str): diff --git a/tests/common.py b/tests/common.py index 2a58b698..90e467cf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -195,8 +195,15 @@ def check_docstrings_content(file: str) -> None: file : str The import statement to check. """ + # Nested callables use ``..`` in ``__qualname__`` (numpydoc cannot load them). + # Vendored ``llm2vec`` mirrors upstream docstrings; skip strict numpydoc for that module. n_invalid, report = numpydoc_validation.validate_recursive( - file, checks={"all", "ES01", "SA01", "EX01"}, exclude=set() + file, + checks={"all", "ES01", "SA01", "EX01"}, + exclude={ + r"\.\.", + r"vendor\.oneig_llm2vec\.llm2vec", + }, ) if n_invalid != 0: raise ValueError(report) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 103cadfb..f097e87e 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,3 +1,4 @@ +import importlib.util from typing import Any, Callable import pytest @@ -59,7 +60,6 @@ def _assert_at_least_one_sample(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), - pytest.param("GenEval", dict(), marks=pytest.mark.slow), pytest.param("HPS", dict(), marks=pytest.mark.slow), pytest.param("ImgEdit", dict(), marks=pytest.mark.slow), pytest.param("LongTextBench", dict(), marks=pytest.mark.slow), @@ -104,23 +104,24 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: str, collate_fn_args: d iterate_dataloaders(datamodule) -def _benchmarks_with_category() -> list[tuple[str, str]]: - """Benchmarks that have a category param: (dataset_name, category) for every category.""" +def _benchmark_category_smoke() -> list[tuple[str, str]]: + """One (dataset, category) per benchmark that supports ``category`` (stable, small smoke set).""" result = [] - for name in base_datasets: + for name in sorted(base_datasets): + if name == "VBench" and importlib.util.find_spec("vbench") is None: + continue setup_fn = base_datasets[name][0] literal_values = get_literal_values_from_param(setup_fn, "category") if literal_values: - for cat in literal_values: - result.append((name, cat)) + result.append((name, sorted(literal_values)[0])) return result @pytest.mark.cpu @pytest.mark.slow -@pytest.mark.parametrize("dataset_name, category", _benchmarks_with_category()) +@pytest.mark.parametrize("dataset_name, category", _benchmark_category_smoke()) def test_benchmark_category_filter(dataset_name: str, category: str) -> None: - """Test dataset loading with each category filter; dataset has at least one sample.""" + """Category filter loads and batches match the chosen category (one category per dataset).""" dm = PrunaDataModule.from_string(dataset_name, category=category, dataloader_args={"batch_size": 4}) _assert_at_least_one_sample(dm) dm.limit_datasets(10) @@ -143,20 +144,17 @@ def _category_in_aux(aux: dict, cat: str) -> bool: @pytest.mark.cpu @pytest.mark.slow -@pytest.mark.parametrize( - "dataset_name, required_aux_key", - [ +def test_prompt_benchmark_auxiliaries() -> None: + """Prompt-based benchmarks expose expected aux keys.""" + for dataset_name, required_aux_key in ( ("LongTextBench", "text_content"), ("OneIG", "text_content"), - ], -) -def test_prompt_benchmark_auxiliaries(dataset_name: str, required_aux_key: str) -> None: - """Test prompt-based benchmarks load with expected auxiliaries.""" - dm = PrunaDataModule.from_string(dataset_name, dataloader_args={"batch_size": 4}) - dm.limit_datasets(10) - batch = next(iter(dm.test_dataloader())) - prompts, auxiliaries = batch - - assert len(prompts) == 4 - assert all(isinstance(p, str) for p in prompts) - assert all(required_aux_key in aux for aux in auxiliaries) + ): + dm = PrunaDataModule.from_string(dataset_name, dataloader_args={"batch_size": 4}) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(required_aux_key in aux for aux in auxiliaries) diff --git a/tests/data/test_oneig_loader.py b/tests/data/test_oneig_loader.py new file mode 100644 index 00000000..e0ca83c3 --- /dev/null +++ b/tests/data/test_oneig_loader.py @@ -0,0 +1,112 @@ +"""Tests for OneIG-Bench prompt loading (Q_D graphs and reasoning ground truth).""" + +from __future__ import annotations + +import pytest + +from pruna.data.datasets import prompt as prompt_mod + + +def test_oneig_needs_zh_multilingualism_hub() -> None: + """ZH config is pulled only for full suite or when Multilingualism is requested.""" + assert prompt_mod._oneig_needs_zh_multilingualism_hub(None) is True + assert prompt_mod._oneig_needs_zh_multilingualism_hub("Multilingualism") is True + assert prompt_mod._oneig_needs_zh_multilingualism_hub("Portrait") is False + assert prompt_mod._oneig_needs_zh_multilingualism_hub(["Portrait", "General_Object"]) is False + assert prompt_mod._oneig_needs_zh_multilingualism_hub(["Portrait", "Multilingualism"]) is True + + +def test_oneig_qd_prefix_multilingualism() -> None: + """Multilingualism maps to the only upstream stem ``multilingualism_zh``.""" + row = {"category": "Multilingualism", "id": "000", "prompt_en": "x", "class": "None"} + assert prompt_mod._oneig_qd_prefix(row) == "multilingualism_zh" + + +def test_oneig_qd_prefix_anime_zh_hint() -> None: + """Rows marked Chinese use ``anime_zh`` when category is anime/stylization.""" + row = { + "category": "Anime_Stylization", + "id": "001", + "prompt_en": "hello", + "class": "None", + "language": "zh", + } + assert prompt_mod._oneig_qd_prefix(row) == "anime_zh" + + +def test_to_oneig_record_multilingualism_fills_questions() -> None: + """Synthetic Multilingualism row resolves Q_D from merged index.""" + qb = {"multilingualism_zh_000": {"questions": {"1": "现场是不是颁奖典礼?"}, "dependencies": {"1": [0]}}} + row = {"category": "Multilingualism", "id": "000", "prompt_en": " awards ", "class": "None"} + rec = prompt_mod._to_oneig_record(row, qb, {}, {}) + assert rec["questions"]["1"] == "现场是不是颁奖典礼?" + assert rec["dependencies"]["1"] == [0] + + +def test_to_oneig_record_knowledge_reasoning_gt() -> None: + """Knowledge_Reasoning rows attach official-style gt strings by id.""" + row = { + "category": "Knowledge_Reasoning", + "id": "000", + "prompt_en": "Peaks chart", + "class": "geography", + } + gt_en = {"000": "The world's five tallest peaks are Mount Everest"} + gt_zh = {"000": "中文答案"} + rec = prompt_mod._to_oneig_record(row, {}, gt_en, gt_zh, "EN") + assert rec["reasoning_gt_answer"] == gt_en["000"] + assert rec["questions"] == {} + rec_zh = prompt_mod._to_oneig_record(row, {}, gt_en, gt_zh, "ZH") + assert rec_zh["reasoning_gt_answer"] == gt_zh["000"] + + +def test_to_oneig_record_prefers_prompt_over_prompt_en() -> None: + """When ``prompt`` is set it wins for the unified ``text`` field.""" + row = { + "category": "General_Object", + "id": "000", + "prompt": "native", + "prompt_en": "english", + "class": "None", + } + rec = prompt_mod._to_oneig_record(row, {}, {}, {}) + assert rec["text"] == "native" + + +def test_to_oneig_record_uses_prompt_cn_for_zh_hub_rows() -> None: + """``OneIG-Bench-ZH`` Multilingualism rows expose Chinese text as ``prompt_cn``.""" + row = {"category": "Multilingualism", "id": "000", "prompt_cn": "中文提示", "class": "None"} + rec = prompt_mod._to_oneig_record(row, {}, {}, {}) + assert rec["text"] == "中文提示" + + +@pytest.mark.slow +def test_setup_oneig_lazyloads_zh_hub_only_when_needed(monkeypatch: pytest.MonkeyPatch) -> None: + """Portrait-only loads ``OneIG-Bench``; Multilingualism also loads ``OneIG-Bench-ZH``.""" + from datasets import load_dataset as real_load_dataset + + loaded: list[str] = [] + + def tracking_load(*args: object, **kwargs: object): + name = args[1] if len(args) > 1 else kwargs.get("name") + loaded.append(str(name)) + return real_load_dataset(*args, **kwargs) + + monkeypatch.setattr(prompt_mod, "load_dataset", tracking_load) + + prompt_mod.setup_oneig_dataset(category="Portrait", test_sample_size=1) + assert loaded == ["OneIG-Bench"] + + loaded.clear() + prompt_mod.setup_oneig_dataset(category="Multilingualism", test_sample_size=1) + assert loaded == ["OneIG-Bench", "OneIG-Bench-ZH"] + + +@pytest.mark.slow +def test_setup_oneig_knowledge_reasoning_loads_remote_gt() -> None: + """Integration: first reasoning sample has non-empty gt from the hub JSON.""" + _train, _val, test = prompt_mod.setup_oneig_dataset(category="Knowledge_Reasoning", test_sample_size=1) + row = test[0] + assert row["reasoning_gt_answer"] + assert isinstance(row["reasoning_gt_answer"], str) + assert len(row["reasoning_gt_answer"]) > 20 diff --git a/tests/evaluation/test_geneval_task_metrics.py b/tests/evaluation/test_geneval_task_metrics.py new file mode 100644 index 00000000..b898fa7a --- /dev/null +++ b/tests/evaluation/test_geneval_task_metrics.py @@ -0,0 +1,22 @@ +"""GenEval task wires strict multi-question QA plus CLIP.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from pruna.evaluation.task import Task + + +@pytest.mark.cpu +@patch("pruna.evaluation.task.PrunaDataModule.from_string") +def test_geneval_from_benchmark_uses_qa_accuracy_all_or_nothing(mock_from_string: MagicMock) -> None: + """GenEval uses strict per-image QA aggregation and CLIP.""" + mock_dm = MagicMock() + mock_dm.test_dataloader.return_value = iter([]) + mock_from_string.return_value = mock_dm + task = Task.from_benchmark("GenEval", dataloader_args={"batch_size": 1}) + qa = next(m for m in task.metrics if getattr(m, "metric_name", None) == "qa_accuracy") + assert qa.aggregation == "all_or_nothing" + assert any(getattr(m, "metric_name", None) == "clip_score" for m in task.metrics) diff --git a/tests/evaluation/test_oneig_alignment.py b/tests/evaluation/test_oneig_alignment.py new file mode 100644 index 00000000..1029e955 --- /dev/null +++ b/tests/evaluation/test_oneig_alignment.py @@ -0,0 +1,62 @@ +"""Tests for OneIG alignment dependency masking and metric wiring.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from pruna.evaluation.metrics.metric_oneig_alignment import ( + OneIGAlignmentMetric, + aggregate_oneig_alignment_per_cell, + apply_oneig_dependency_mask, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM + + +def test_apply_oneig_dependency_mask_parent_no_zeros_child() -> None: + """Parent ``No`` forces dependent question score to zero.""" + raw = {1: 0.0, 2: 1.0} + deps = {1: [0], 2: [1]} + out = apply_oneig_dependency_mask(raw, deps) + assert out[1] == 0.0 + assert out[2] == 0.0 + assert aggregate_oneig_alignment_per_cell(out, [1, 2]) == 0.0 + + +def test_apply_oneig_dependency_mask_parent_yes_keeps_child() -> None: + """All ``Yes`` yields nonzero child and mean 1.0 over two questions.""" + raw = {1: 1.0, 2: 1.0} + deps = {1: [0], 2: [1]} + out = apply_oneig_dependency_mask(raw, deps) + assert out == {1: 1.0, 2: 1.0} + assert aggregate_oneig_alignment_per_cell(out, [1, 2]) == 1.0 + + +def test_apply_oneig_dependency_mask_uses_raw_parent_not_filtered_for_chain() -> None: + r"""Grandchild may stay 1 when parent's **raw** VLM score is Yes even if parent was masked to 0.""" + raw = {1: 0.0, 2: 1.0, 3: 1.0} + deps = {1: [0], 2: [1], 3: [2]} + out = apply_oneig_dependency_mask(raw, deps) + assert out[1] == 0.0 + assert out[2] == 0.0 + assert out[3] == 1.0 + + +@pytest.mark.cpu +def test_oneig_alignment_metric_respects_question_id_order() -> None: + """Questions are scored in numeric id order; masking uses aligned raw scores.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.0, 1.0] + + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = { + "questions": {"2": "second", "1": "first"}, + "dependencies": {"1": [0], "2": [1]}, + } + metric.update(["p"], [aux], images) + result = metric.compute() + assert result.name == "oneig_alignment" + assert result.result == 0.0 + call = mock_vlm.score.call_args + assert call[0][1] == ["first", "second"] diff --git a/tests/evaluation/test_oneig_reasoning.py b/tests/evaluation/test_oneig_reasoning.py new file mode 100644 index 00000000..ab06e934 --- /dev/null +++ b/tests/evaluation/test_oneig_reasoning.py @@ -0,0 +1,106 @@ +"""Tests for OneIG reasoning metric (LLM2CLIP text-image similarity).""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from pruna.evaluation.metrics.metric_oneig_reasoning import ( + OneIGReasoningMetric, + _LLM2CLIPScorer, +) + + +def _make_mock_scorer(return_value: float = 0.5) -> MagicMock: + mock = MagicMock(spec=_LLM2CLIPScorer) + mock.score.return_value = [return_value] + return mock + + +@pytest.mark.cpu +def test_oneig_reasoning_uses_gt_answer_from_aux() -> None: + """Metric reads reasoning_gt_answer from aux.""" + mock_scorer = _make_mock_scorer(0.7) + metric = OneIGReasoningMetric(scorer=mock_scorer, device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = {"reasoning_gt_answer": "A blue circle"} + metric.update(["p"], [aux], images) + result = metric.compute() + assert result.name == "oneig_reasoning" + assert result.result == 0.7 + mock_scorer.score.assert_called_once() + call_args = mock_scorer.score.call_args + assert call_args[0][1] == "A blue circle" + + +@pytest.mark.cpu +def test_oneig_reasoning_averages_per_sample_scores() -> None: + """Compute returns mean of per-sample scores.""" + mock_scorer = _make_mock_scorer(0.5) + metric = OneIGReasoningMetric(scorer=mock_scorer, device="cpu") + images = torch.rand(2, 3, 64, 64) + aux_list = [ + {"reasoning_gt_answer": "First answer"}, + {"reasoning_gt_answer": "Second answer"}, + ] + metric.update(["p1", "p2"], aux_list, images) + result = metric.compute() + assert result.result == 0.5 + assert mock_scorer.score.call_count == 2 + + +@pytest.mark.cpu +def test_oneig_reasoning_missing_gt_raises() -> None: + """Missing GT answer raises ValueError.""" + mock_scorer = _make_mock_scorer(0.8) + metric = OneIGReasoningMetric(scorer=mock_scorer, device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = {} + with pytest.raises(ValueError, match="reasoning_gt_answer"): + metric.update(["p"], [aux], images) + mock_scorer.score.assert_not_called() + + +@pytest.mark.cpu +def test_oneig_reasoning_scorer_none_raises() -> None: + """When scorer returns None, metric raises RuntimeError.""" + mock_scorer = _make_mock_scorer() + mock_scorer.score.return_value = None + metric = OneIGReasoningMetric(scorer=mock_scorer, device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = {"reasoning_gt_answer": "Some answer"} + with pytest.raises(RuntimeError, match="no scores"): + metric.update(["p"], [aux], images) + + +@pytest.mark.cpu +def test_oneig_reasoning_compute_without_update_raises() -> None: + """Compute with no updates raises RuntimeError.""" + mock_scorer = _make_mock_scorer() + metric = OneIGReasoningMetric(scorer=mock_scorer, device="cpu") + with pytest.raises(RuntimeError, match="no samples were scored"): + metric.compute() + + +@pytest.mark.cpu +def test_oneig_reasoning_has_metric_registered() -> None: + """oneig_reasoning is available via MetricRegistry (lazy).""" + from pruna.evaluation.metrics.registry import MetricRegistry + + assert MetricRegistry.has_metric("oneig_reasoning") + + +@pytest.mark.slow +@pytest.mark.skip(reason="Requires HF model download; run manually") +def test_oneig_reasoning_smoke_with_real_scorer() -> None: + """Optional: full LLM2CLIP scorer on one sample (slow).""" + from pruna.data.datasets.prompt import setup_oneig_knowledge_reasoning_dataset + + metric = OneIGReasoningMetric(device="cpu") + _train, _val, test = setup_oneig_knowledge_reasoning_dataset(test_sample_size=1) + row = test[0] + aux = {k: v for k, v in row.items() if k != "text"} + images = torch.rand(1, 3, 224, 224) + metric.update([row["text"]], [aux], images) + result = metric.compute() + assert 0 <= result.result <= 1 diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 281c6b7e..dbd84493 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -36,10 +36,10 @@ def make_mock_metric(metric_class): with patch.object(TorchMetrics, '_member_map_', {**TorchMetrics._member_map_, **mock_metrics}): yield -@pytest.mark.parametrize("metric_name", MetricRegistry()._registry) +@pytest.mark.parametrize("metric_name", sorted(MetricRegistry._registry)) def test_metric_initialization_from_metric_name(metric_name): datamodule = PrunaDataModule.from_string("LAION256") - Task(request=[metric_name], datamodule=datamodule) + Task(request=[metric_name], datamodule=datamodule, device="cpu") @device_parametrized diff --git a/tests/evaluation/test_vlm_metrics.py b/tests/evaluation/test_vlm_metrics.py new file mode 100644 index 00000000..a9a6036e --- /dev/null +++ b/tests/evaluation/test_vlm_metrics.py @@ -0,0 +1,220 @@ +"""Tests for VLM metrics (VQA, AlignmentScore, ImageEditScore, QAAccuracy, TextScore, VieScore).""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.metric_alignment_score import AlignmentScoreMetric +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + +_ALL_VLM = ( + VQAMetric, + AlignmentScoreMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + OneIGAlignmentMetric, + TextScoreMetric, + OneIGTextScoreMetric, + VieScoreMetric, +) + +_SLOW_SMOL_SUBSET = ( + VQAMetric, + OneIGAlignmentMetric, + ImageEditScoreMetric, + VieScoreMetric, +) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + if isinstance(metric, OneIGAlignmentMetric): + metric.update( + prompts, + [ + { + "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, + "dependencies": {"1": [0], "2": [1]}, + } + ], + images, + ) + elif isinstance(metric, QAAccuracyMetric): + metric.update( + prompts, + [{"questions": {"1": "Is there a cat?"}}], + images, + ) + elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: + """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + assert result.name == metric.metric_name + assert isinstance(result.result, float) + if metric.higher_is_better: + assert 0.0 <= result.result <= 1.0 + else: + assert result.result >= 0.0 + + +@pytest.mark.cpu +@pytest.mark.parametrize("metric_cls", _ALL_VLM) +def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: + """Each VLM metric runs end-to-end with mocked litellm.""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (AlignmentScoreMetric, VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): + mock_response.choices[0].message.content = '{"answer": "Yes"}' + else: + mock_response.choices[0].message.content = '{"score": 8}' + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + assert result.name == metric.metric_name + assert isinstance(result.result, float) + assert mock_completion.called + + +@pytest.mark.cpu +def test_vlm_metrics_empty_compute_returns_zero() -> None: + """No updates → compute is 0.0 (same for all stateful VLM metrics).""" + metric = VQAMetric( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + assert metric.compute().result == 0.0 + + +@pytest.mark.cpu +def test_vlm_metrics_custom_vlm() -> None: + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric( + vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + assert metric.compute().result == 1.0 + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls, expected_name, expected_result", + [ + (TextScoreMetric, "text_score", 0.0), + (OneIGTextScoreMetric, "oneig_text_score", 1.0), + ], +) +def test_text_metrics_list_str_gt( + metric_cls: type, expected_name: str, expected_result: float +) -> None: + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == expected_result + assert result.name == expected_name + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +def test_text_score_registry_aliases() -> None: + from pruna.evaluation.metrics.registry import MetricRegistry + + lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") + comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") + assert type(lev).__name__ == "TextScoreMetric" + assert type(comp).__name__ == "OneIGTextScoreMetric" + assert lev.metric_name == "text_score" + assert comp.metric_name == "oneig_text_score" + + +@pytest.mark.cpu +def test_oneig_text_score_utils_golden_composite() -> None: + from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score + + ed, cr, wac, composite = oneig_mean_text_score( + edit_distances=[10.0], + completion_ratios=[0.0], + match_counts=[2], + gt_totals=[4], + language_mode="EN", + ) + assert ed == 10.0 + assert cr == 0.0 + assert wac == 0.5 + assert composite == pytest.approx(0.95) + + _, _, _, zh = oneig_mean_text_score( + edit_distances=[30.0], + completion_ratios=[0.0], + match_counts=[0], + gt_totals=[1], + language_mode="ZH", + ) + assert zh == pytest.approx(0.4) diff --git a/tests/evaluation/test_vlm_utils.py b/tests/evaluation/test_vlm_utils.py new file mode 100644 index 00000000..7057d626 --- /dev/null +++ b/tests/evaluation/test_vlm_utils.py @@ -0,0 +1,21 @@ +"""Unit tests for vlm_utils score parsing.""" + +import pytest + +from pruna.evaluation.metrics.vlm_utils import FloatOutput, get_score_from_response + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + assert get_score_from_response(raw) == pytest.approx(expected)