Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup-uv-project/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ runs:
python-version: ${{ inputs.python-version }}

- shell: bash
run: uv sync --extra dev --extra lmharness --extra vllm
run: uv sync --extra dev
14 changes: 13 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ jobs:
strategy:
matrix:
python-version: ["3.11"]
name: ["base", 'lmharness']
include:
- name: base
extras: ""
mark_filter: "cpu and not slow and not style and not requires_intel and not requires_lmharness"
- name: lmharness
extras: "--extra lmharness"
mark_filter: "requires_lmharness"

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -80,6 +88,10 @@ jobs:
- uses: ./.github/actions/setup-uv-project
with:
python-version: ${{ matrix.python-version }}

- name: Install additional extras
if: matrix.extras != ''
run: uv sync --extra dev ${{ matrix.extras }}

- name: Cache Hugging Face datasets and models
uses: actions/cache@v5
Expand Down Expand Up @@ -123,4 +135,4 @@ jobs:
- name: Run tests with pytest-rerunfailures
run: |
echo "Running tests with up to 3 reruns on failure using $PYTEST_WORKERS workers..."
uv run pytest -n $PYTEST_WORKERS -m "not (slow or style or high_cpu or cuda or distributed)" --reruns 3 --reruns-delay 10 --maxfail=1
uv run pytest -n $PYTEST_WORKERS -m "${{ matrix.mark_filter }}" --reruns 3 --reruns-delay 10 --maxfail=1
37 changes: 19 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,14 @@ explicit = true
index-strategy = "first-index"

conflicts = [
[
{ extra = "awq" },
{ extra = "vbench" },
],
[
{ extra = "vllm" },
{ extra = "vbench" },
],
[
{ extra = "stable-fast-extraindex" },
{ extra = "stable-fast" },
],
[
{ extra = "stable-fast-extraindex" },
{ extra = "full" },
],
[{ extra = "awq" }, { extra = "vbench" }],
[{ extra = "vllm" }, { extra = "vbench" }],
[{ extra = "intel" }, { extra = "awq" }],
[{ extra = "gptq" }, { extra = "awq" }],
# intel is incompatible with all stable-fast variants and vllm
[{ extra = "intel" }, { extra = "stable-fast" }, { extra = "stable-fast-extraindex" }],
[{ extra = "intel" }, { extra = "full" }, { extra = "stable-fast-extraindex" }],
[{ extra = "intel" }, { extra = "vllm" }],
]

[tool.uv.sources]
Expand Down Expand Up @@ -147,8 +139,6 @@ dependencies = [
"timm",
"bitsandbytes; sys_platform != 'darwin' or platform_machine != 'arm64'",
"optimum-quanto>=0.2.5",
"ctranslate2==4.6.0",
"whisper-s2t==1.3.1",
"hqq==0.2.7.post1",
"torchao>=0.12.0,<0.16.0", # 0.16.0 breaks diffusers 0.36.0, torch+torch: https://github.com/pytorch/ao/issues/2919#issue-3375688762
"gliner; python_version >= '3.11'",
Expand All @@ -165,6 +155,13 @@ dependencies = [
]

[project.optional-dependencies]
# whisper-s2t and ctranslate2 are isolated because importing whisper_s2t
# at test time causes side-effects that break unrelated tests.

whisper = [
"ctranslate2==4.6.0",
"whisper-s2t==1.3.1",
]
vllm = [
"vllm>=0.16.0",
"ray",
Expand Down Expand Up @@ -230,8 +227,12 @@ cpu = []
lmharness = [
"lm-eval>=0.4.0"
]

# Intel extension is tightly coupled with the torch version
intel = [
"intel-extension-for-pytorch>=2.7.0",
"torch>=2.7.0,<2.9.0",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that it's tightly coupled with torch, but if intel-extension already declares it's torch dependency we shouldn't need to declare it here no? Same goes for torchvision

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean and I softly disagree. yes the version already defines which torch version it's supposed to use, but it doesn't define it in its dependencies. same for the torchvision version. So someone could install ipex with the absolutely wrong torch stack and get really cryptic errors. I think it's better to be explicit here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, if the torch version is not in the package dependency then absolutely let's pin torch so that the intel install works

"torchvision>=0.22.0,<0.24.0",
]

[build-system]
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/llm_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LLMCompressor(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda"]
compatible_before: Iterable[str] = ["moe_kernel_tuner"]
compatible_after: Iterable[str] = ["sage_attn", "moe_kernel_tuner"]
required_install = "``uv pip install 'pruna[awq]'``"

def get_hyperparameters(self) -> list:
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/algorithms/test_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None:
@pytest.mark.parametrize(
"model_fixture, algorithms, allow_pickle_files, metric",
[
("sd_tiny_random", ["deepcache", "stable_fast"], False, 'cmmd'),
pytest.param("sd_tiny_random", ["deepcache", "stable_fast"], False, 'cmmd', marks=pytest.mark.requires_stable_fast),
("mobilenet_v2", ["torch_unstructured", "half"], True, 'latency'),
("sd_tiny_random", ["hqq_diffusers", "torch_compile"], False, 'cmmd'),
("flux_tiny_random", ["hqq_diffusers", "torch_compile"], False, 'cmmd'),
("sd_tiny_random", ["diffusers_int8", "torch_compile"], False, 'cmmd'),
("tiny_llama", ["gptq", "torch_compile"], True, 'perplexity'),
pytest.param("tiny_llama", ["gptq", "torch_compile"], True, 'perplexity', marks=pytest.mark.requires_gptq),
("llama_3_tiny_random_as_pipeline", ["llm_int8", "torch_compile"], True, 'perplexity'),
("flux_tiny_random", ["pab", "hqq_diffusers"], False, 'cmmd'),
("flux_tiny_random", ["pab", "diffusers_int8"], False, 'cmmd'),
Expand All @@ -58,9 +58,9 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None:
("flux_tiny_random", ["fora", "hqq_diffusers"], False, 'cmmd'),
("flux_tiny_random", ["fora", "diffusers_int8"], False, 'cmmd'),
("flux_tiny_random", ["fora", "torch_compile"], False, 'cmmd'),
("flux_tiny_random", ["fora", "stable_fast"], False, 'cmmd'),
pytest.param("flux_tiny_random", ["fora", "stable_fast"], False, 'cmmd', marks=pytest.mark.requires_stable_fast),
("tiny_janus", ["hqq", "torch_compile"], False, 'cmmd'),
pytest.param("flux_tiny", ["fora", "flash_attn3", "torch_compile"], False, 'cmmd', marks=pytest.mark.high),
pytest.param("flux_tiny", ["fora", "flash_attn3", "torch_compile"], False, 'cmmd', marks=pytest.mark.high_gpu),
],
indirect=["model_fixture"],
ids=lambda val: "+".join(val) if isinstance(val, list) else None,
Expand Down
2 changes: 2 additions & 0 deletions tests/algorithms/test_compatibility_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pruna.algorithms import AlgorithmRegistry

pytestmark = pytest.mark.cpu


def test_compatibility_symmetry():
pruna_algorithms = AlgorithmRegistry._registry
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.slow
@pytest.mark.requires_awq
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not slow anymore?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought slow was a wrong tag here, as awq is not time intensive like gptq for instance, but how do you feel?

class TestLLMCompressor(AlgorithmTesterBase):
"""Test the LLM Compressor quantizer."""

Expand Down
3 changes: 3 additions & 0 deletions tests/algorithms/testers/cgenerate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest

from pruna import PrunaModel
from pruna.algorithms.c_translate import CGenerate

from .base_tester import AlgorithmTesterBase


@pytest.mark.requires_whisper
class TestCGenerate(AlgorithmTesterBase):
"""Test the c_generate algorithm."""

Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/flash_attn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.high
@pytest.mark.high_gpu
class TestFlashAttn3(AlgorithmTesterBase):
"""Test the flash attention 3 kernel."""

Expand Down
3 changes: 2 additions & 1 deletion tests/algorithms/testers/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@


@pytest.mark.slow
@pytest.mark.high
@pytest.mark.high_gpu
@pytest.mark.requires_gptq
class TestGPTQ(AlgorithmTesterBase):
"""Test the GPTQ quantizer."""

Expand Down
7 changes: 3 additions & 4 deletions tests/algorithms/testers/ipex_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from .base_tester import AlgorithmTesterBase


# this prevents the test from running on GitHub Actions, which does not reliably provide Intel CPUs
@pytest.mark.high_cpu
@pytest.mark.requires_intel
class TestIPEXLLM(AlgorithmTesterBase):
"""Test the IPEX LLM algorithm."""

models = ["opt_tiny_random"]
models = ["opt_125m"]
reject_models = ["sd_tiny_random"]
allow_pickle_files = False
algorithm_class = IPEXLLM
metrics = ["latency"]
metrics = []
3 changes: 2 additions & 1 deletion tests/algorithms/testers/moe_kernel_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from pathlib import Path

import pytest
import torch

from pruna import PrunaModel, SmashConfig
Expand All @@ -11,6 +12,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.requires_vllm
class TestMoeKernelTuner(AlgorithmTesterBase):
"""Test the MoeKernelTuner."""

Expand All @@ -34,7 +36,6 @@ def post_smash_hook(self, model: PrunaModel) -> None:

def _resolve_hf_cache_config_path(self) -> Path:
"""Read the saved artifact and compute the expected HF cache config path."""

imported_packages = MoeKernelTuner().import_algorithm_packages()

smash_cfg = SmashConfig()
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/ring_distributer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.distributed
@pytest.mark.multi_gpu
class TestRingAttn(AlgorithmTesterBase):
"""Test the RingAttn algorithm."""

Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.high
@pytest.mark.high_gpu
class TestSageAttn(AlgorithmTesterBase):
"""Test the sage attention kernel."""

Expand Down
3 changes: 3 additions & 0 deletions tests/algorithms/testers/stable_fast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest

from pruna.algorithms.stable_fast import StableFast

from .base_tester import AlgorithmTesterBase


@pytest.mark.requires_stable_fast
class TestStableFast(AlgorithmTesterBase):
"""Test the stable_fast algorithm."""

Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/tti_inplace_perp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def assert_no_nan_values(module: Any) -> None:

# Our nightlies machine does not support efficient attention mechanisms and causes OOM errors with this test.
# This test do pass on modern architectures.
@pytest.mark.high
@pytest.mark.high_gpu
@pytest.mark.slow
class TestTTIInPlacePerp(AlgorithmTesterBase):
"""Test the TTI InPlace Perp recovery algorithm."""
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/tti_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def assert_no_nan_values(module: Any) -> None:

# Our nightlies machine does not support efficient attention mechanisms and causes OOM errors with this test.
# This test do pass on modern architectures.
@pytest.mark.high
@pytest.mark.high_gpu
@pytest.mark.slow
class TestTTILoRA(AlgorithmTesterBase):
"""Test the TTI LoRA recovery algorithm."""
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/testers/tti_perp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def assert_no_nan_values(module: Any) -> None:


@pytest.mark.slow
@pytest.mark.high
@pytest.mark.high_gpu
class TestTTIPerp(AlgorithmTesterBase):
"""Test the TTI Perp recovery algorithm."""

Expand Down
9 changes: 6 additions & 3 deletions tests/algorithms/testers/upscale.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import pytest

from pruna.algorithms.upscale import RealESRGAN

from .base_tester import AlgorithmTesterBase


@pytest.mark.cuda
# Takes too long to run on CPU, so we explicitly exclude it
class TestUpscale(AlgorithmTesterBase):
"""Test the Upscale algorithm."""

Expand All @@ -14,3 +12,8 @@ class TestUpscale(AlgorithmTesterBase):
allow_pickle_files = False
algorithm_class = RealESRGAN
metrics = ["cmmd"]

@classmethod
def compatible_devices(cls) -> list[str]:
"""Exclude CPU (too slow)."""
return [d for d in super().compatible_devices() if d != "cpu"]
2 changes: 1 addition & 1 deletion tests/algorithms/testers/whispers2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base_tester import AlgorithmTesterBase


@pytest.mark.skip(reason="This test / the importing of whisper_s2t is affecting other tests.")
@pytest.mark.requires_whisper
@pytest.mark.slow
class TestWhisperS2T(AlgorithmTesterBase):
"""Test the WhisperS2T batcher."""
Expand Down
3 changes: 3 additions & 0 deletions tests/algorithms/testers/x_fast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest

from pruna.algorithms.x_fast import XFast

from .base_tester import AlgorithmTesterBase


@pytest.mark.requires_stable_fast
class TestXFast(AlgorithmTesterBase):
"""Test the X-Fast algorithm."""

Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def device_parametrized(cls: Any) -> Any:
pytest.param("cuda", marks=pytest.mark.cuda),
pytest.param(
"accelerate",
marks=pytest.mark.distributed,
marks=pytest.mark.multi_gpu,
),
pytest.param("cpu", marks=pytest.mark.cpu),
],
Expand Down
Loading
Loading