diff --git a/pyproject.toml b/pyproject.toml index 5b1eb704..db759302 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +llamacpp = [ + "llama-cpp-python>=0.2.78", # Required for running and inferencing Llama.cpp models + "gguf>=0.6.0", # Required for converting HF models to GGUF format +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", @@ -187,6 +191,8 @@ awq = [ full = [ "xformers>=0.0.30", "stable-fast-pruna==1.0.8", + "llama-cpp-python>=0.2.78", # Required for running and inferencing Llama.cpp models + "gguf>=0.6.0", # Required for converting HF models to GGUF format ] vbench = [ "vbench-pruna; sys_platform != 'darwin'", diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..4d585eda 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -28,6 +28,7 @@ SAVE_FUNCTIONS, save_pruna_model, ) +from pruna.engine.utils import get_fn_name from pruna.logging.logger import pruna_logger @@ -365,7 +366,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: # if the registered save function is None, the original saving function remains if self.save_fn is not None and self.save_fn != SAVE_FUNCTIONS.reapply: - smash_config.save_fns.append(self.save_fn.name) + fn_name = get_fn_name(self.save_fn) + smash_config.save_fns.append(fn_name) prefix = self.algorithm_name + "_" wrapped_config = SmashConfigPrefixWrapper(smash_config, prefix) diff --git a/src/pruna/algorithms/llama_cpp.py b/src/pruna/algorithms/llama_cpp.py new file mode 100644 index 00000000..657166f5 --- /dev/null +++ b/src/pruna/algorithms/llama_cpp.py @@ -0,0 +1,321 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import shutil +import subprocess # nosec B404 +import sys +import tempfile +import urllib.request +import weakref +from pathlib import Path +from typing import Any, Dict + +from ConfigSpace import OrdinalHyperparameter + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import Int +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import ( + is_causal_lm, + is_transformers_pipeline_with_causal_lm, +) +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.utils import verify_sha256 +from pruna.logging.logger import pruna_logger + +# SHA256 hash for the pinned version (b3600) of convert_hf_to_gguf.py +LLAMA_CPP_CONVERSION_SCRIPT_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/b3600/convert_hf_to_gguf.py" +LLAMA_CPP_CONVERSION_SCRIPT_SHA256 = "f62ab712618231b3e76050f94e45dcf94567312c209b4b99bfc142229360b018" +LLAMA_CPP_CACHE_DIR = Path.home() / ".cache" / "pruna" / "scripts" / "llama_cpp" + + +class LlamaCpp(PrunaAlgorithmBase): + """ + Implement Llama.cpp as a quantizer. + + Converts Hugging Face models to GGUF format and quantizes them using the llama.cpp tools. + """ + + algorithm_name: str = "llama_cpp" + group_tags: list[tags] = [tags.QUANTIZER] + references: dict[str, str] = { + "GitHub": "https://github.com/ggml-org/llama.cpp", + "Python Bindings": "https://github.com/abetlen/llama-cpp-python", + } + save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.llama_cpp + tokenizer_required: bool = False + processor_required: bool = False + dataset_required: bool = False + runs_on: list[str] = ["cpu", "cuda", "mps"] + compatible_before: list[str] = [] + compatible_after: list[str] = [] + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "quantization_method", + sequence=[ + "q4_k_m", + "q4_k_s", + "q5_k_m", + "q8_0", + "f16" + ], + default_value="q4_k_m", + meta={"desc": "Quantization method for llama.cpp. Examples: q4_k_m, q8_0, f16."}, + ), + OrdinalHyperparameter( + "n_gpu_layers", + sequence=[0, 1, 4, 8, 16, 32, 999], + default_value=0, + meta={"desc": "Number of layers to offload to GPU. Use 999 for all layers."}, + ), + Int( + "main_gpu", + default=0, + meta={"desc": "The GPU to use for the main model tensors."}, + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is supported. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is supported, False otherwise. + """ + return is_causal_lm(model) or is_transformers_pipeline_with_causal_lm(model) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Quantize the model with Llama.cpp by converting to GGUF. + + Parameters + ---------- + model : Any + The model to quantize. + smash_config : SmashConfigPrefixWrapper + The configuration for the quantization. + + Returns + ------- + Any + The quantized Llama object. + """ + imported_modules = self.import_algorithm_packages() + llama_cpp = imported_modules["llama_cpp"] + + # Ensure we have the causal lm if it's a pipeline + model_to_export = model.model if is_transformers_pipeline_with_causal_lm(model) else model + + quantization_method = self._get_quantization_method(model_to_export, smash_config["quantization_method"]) + pruna_logger.info(f"Quantizing model with llama.cpp using method {quantization_method}") + + _, f16_gguf_path, quant_gguf_path = self._get_cache_paths( + model_to_export, smash_config, quantization_method + ) + + # Create a temp directory to hold HF model if needed + temp_dir = Path(tempfile.mkdtemp()) + # Ensure cleanup even if save() is not called + weakref.finalize(self, shutil.rmtree, str(temp_dir), ignore_errors=True) + + try: + # Convert to F16 GGUF if needed + if not f16_gguf_path.exists(): + self._convert_to_gguf(model_to_export, f16_gguf_path, temp_dir, smash_config) + else: + pruna_logger.info(f"Using cached F16 GGUF model at {f16_gguf_path}") + + # Quantize GGUF if needed + if quantization_method != "f16": + if not quant_gguf_path.exists(): + self._quantize_gguf(llama_cpp, f16_gguf_path, quant_gguf_path, quantization_method) + else: + pruna_logger.info(f"Using cached quantized model at {quant_gguf_path}") + else: + quant_gguf_path = f16_gguf_path + + return self._load_quantized_model(llama_cpp, quant_gguf_path, smash_config, temp_dir) + + except Exception as e: + pruna_logger.error(f"Error during llama.cpp quantization: {e}") + shutil.rmtree(temp_dir, ignore_errors=True) + raise + + def _get_quantization_method(self, model: Any, default_method: str) -> str: + """Get the quantization method, defaulting to f16 for tiny models.""" + if ( + hasattr(model, "config") + and hasattr(model.config, "hidden_size") + and model.config.hidden_size < 32 + ): + pruna_logger.info("Tiny model detected. Bypassing quantized block sizes and using f16.") + return "f16" + return default_method + + def _load_quantized_model(self, llama_cpp: Any, quant_gguf_path: Path, smash_config: Any, temp_dir: Path) -> Any: + pruna_logger.info(f"Loading quantized model from {quant_gguf_path}") + n_gpu_layers = smash_config["n_gpu_layers"] + if n_gpu_layers == 999: + n_gpu_layers = -1 # llama-cpp-python uses -1 for all layers + quantized_model = llama_cpp.Llama( + model_path=str(quant_gguf_path), + n_gpu_layers=n_gpu_layers, + main_gpu=smash_config["main_gpu"], + ) + quantized_model.model_path = str(quant_gguf_path) + quantized_model._pruna_device = smash_config["device"] + return quantized_model + + def _get_cache_paths( + self, model: Any, smash_config: SmashConfigPrefixWrapper, q_method: str + ) -> tuple[Path, Path, Path]: + """Generate cache paths for the models.""" + llama_cpp_cache = Path(smash_config.cache_dir) / "llama_cpp" + llama_cpp_cache.mkdir(parents=True, exist_ok=True) + + model_id = "model" + if hasattr(model, "config") and hasattr(model.config, "_name_or_path"): + model_id = Path(model.config._name_or_path).name + + f16_gguf_path = llama_cpp_cache / f"{model_id}-f16.gguf" + quant_gguf_path = llama_cpp_cache / f"{model_id}-{q_method}.gguf" + return llama_cpp_cache, f16_gguf_path, quant_gguf_path + + def _convert_to_gguf( + self, + model: Any, + outfile: Path, + temp_dir: Path, + smash_config: SmashConfigPrefixWrapper + ) -> None: + """Save HF model and convert it to GGUF format.""" + with tempfile.TemporaryDirectory(dir=str(temp_dir)) as hf_model_dir: + model.save_pretrained(hf_model_dir) + if hasattr(smash_config, "tokenizer") and smash_config.tokenizer: + smash_config.tokenizer.save_pretrained(hf_model_dir) + + script_path = self._get_conversion_script() + pruna_logger.info(f"Converting Hugging Face model to GGUF format at {outfile}...") + + # Ensure inputs are properly sanitized and validated to prevent arg injection. + for param in (script_path, hf_model_dir, outfile): + param_str = str(param) + if any(c in param_str for c in ("\0", "\n", "\r", ";", "&", "|", "`", "$")): + raise ValueError(f"Unsafe characters detected in subprocess argument: {param_str}") + + convert_cmd = [ + sys.executable, str(script_path), + hf_model_dir, + "--outfile", str(outfile), + "--outtype", "f16" + ] + try: + # subprocess needed because convert_hf_to_gguf.py is a standalone CLI script + subprocess.run(convert_cmd, check=True, capture_output=True, text=True) # nosec B603 + except subprocess.CalledProcessError as e: + pruna_logger.error(f"Conversion script failed with error: {e.stderr}") + raise + + def _quantize_gguf( + self, + llama_cpp: Any, + infile: Path, + outfile: Path, + method: str + ) -> None: + """Quantize a GGUF file using llama-cpp-python API.""" + pruna_logger.info(f"Quantizing GGUF model to {method} at {outfile}...") + + if not hasattr(llama_cpp, "llama_model_quantize"): + raise RuntimeError("llama_model_quantize API not available in llama-cpp-python.") + + params = llama_cpp.llama_model_quantize_default_params() + ftype_name = f"LLAMA_FTYPE_MOSTLY_{method.upper()}" + + if hasattr(llama_cpp, ftype_name): + params.ftype = getattr(llama_cpp, ftype_name) + else: + raise ValueError(f"Unknown quantization method: {method}") + + llama_cpp.llama_model_quantize( + str(infile).encode("utf-8"), + str(outfile).encode("utf-8"), + params, + ) + + def _get_conversion_script(self) -> Path: + """ + Get the conversion script from cache or download it. + + Returns + ------- + Path + The path to the conversion script. + """ + LLAMA_CPP_CACHE_DIR.mkdir(parents=True, exist_ok=True) + script_path = LLAMA_CPP_CACHE_DIR / "convert_hf_to_gguf.py" + + # Validate URL scheme for security + if not LLAMA_CPP_CONVERSION_SCRIPT_URL.startswith("https://"): + raise ValueError(f"Insecure conversion script URL: {LLAMA_CPP_CONVERSION_SCRIPT_URL}") + + if not script_path.exists() or not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256): + pruna_logger.info(f"Downloading conversion script from {LLAMA_CPP_CONVERSION_SCRIPT_URL}") + urllib.request.urlretrieve(LLAMA_CPP_CONVERSION_SCRIPT_URL, script_path) # nosec B310 + + if not verify_sha256(script_path, LLAMA_CPP_CONVERSION_SCRIPT_SHA256): + script_path.unlink(missing_ok=True) + raise ValueError( + f"Integrity verification failed for {LLAMA_CPP_CONVERSION_SCRIPT_URL}. " + "The downloaded script may have been tampered with or the pinned version has changed." + ) + + return script_path + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Provide algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + try: + import llama_cpp + return dict(llama_cpp=llama_cpp) + except ImportError: + raise ImportError( + "Could not import llama_cpp. Please install it with `pip install llama-cpp-python`." + ) diff --git a/src/pruna/config/hyperparameters.py b/src/pruna/config/hyperparameters.py index d42ea506..928a6c81 100644 --- a/src/pruna/config/hyperparameters.py +++ b/src/pruna/config/hyperparameters.py @@ -16,10 +16,50 @@ from typing import Any -from ConfigSpace import CategoricalHyperparameter, Constant +from ConfigSpace import CategoricalHyperparameter, Constant, UniformIntegerHyperparameter from typing_extensions import override +class Int(UniformIntegerHyperparameter): + """ + Represents an integer hyperparameter. + + Parameters + ---------- + name : str + The name of the hyperparameter. + lower : int + The lower bound of the hyperparameter. + upper : int + The upper bound of the hyperparameter. + default : int + The default value of the hyperparameter. + meta : Any + The metadata for the hyperparameter. + """ + + def __init__( + self, + name: str, + lower: int = 0, + upper: int = 2**31 - 1, + default: int = 0, + meta: Any = None, + ) -> None: + super().__init__(name, lower=lower, upper=upper, default_value=default, meta=meta) + + def __new__( + cls, + name: str, + lower: int = 0, + upper: int = 2**31 - 1, + default: int = 0, + meta: Any = None, + ) -> UniformIntegerHyperparameter: + """Create a new integer hyperparameter.""" + return UniformIntegerHyperparameter(name, lower=lower, upper=upper, default_value=default, meta=meta) + + class Boolean(CategoricalHyperparameter): """ Represents a boolean hyperparameter with choices True and False. diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 74b04b56..c55ce370 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -506,6 +506,39 @@ def load_quantized_model(quantized_path: str | Path) -> Any: ) +def load_llama_cpp(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: + """ + Load a model quantized with llama.cpp from the given model path. + + Parameters + ---------- + path : str | Path + The path to the model directory. + smash_config : SmashConfig + The SmashConfig object containing the device and device_map. + **kwargs : Any + Additional keyword arguments to pass to the model loading function. + + Returns + ------- + Any + The loaded llama.cpp model. + """ + from pruna.algorithms.llama_cpp import LlamaCpp + + algorithm_packages = LlamaCpp().import_algorithm_packages() + llama_cpp = algorithm_packages["llama_cpp"] + + model_path = Path(path) / "model.gguf" + if not model_path.exists(): + raise FileNotFoundError(f"GGUF file not found at {model_path}") + + model = llama_cpp.Llama(model_path=str(model_path), **filter_load_kwargs(llama_cpp.Llama.__init__, kwargs)) + model.model_path = str(model_path) + model._pruna_device = smash_config["device"] + return model + + def load_hqq_diffusers(path: str | Path, smash_config: SmashConfig, **kwargs) -> Any: """ Load a diffusers model from the given model path. @@ -637,6 +670,7 @@ class LOAD_FUNCTIONS(Enum): # noqa: N801 pickled = member(load_pickled) hqq = member(load_hqq) hqq_diffusers = member(load_hqq_diffusers) + llama_cpp = member(load_llama_cpp) def __call__(self, *args, **kwargs) -> Any: """ diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index fa5fb763..5c4b727b 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -715,3 +715,20 @@ def is_gptq_model(model: Any) -> bool: True if the model is a GPTQ model, False otherwise. """ return "gptqmodel" in model.__class__.__module__ and "GPTQ" in model.__class__.__name__ + + +def is_llama_cpp_model(model: Any) -> bool: + """ + Check if the model is a llama.cpp Llama model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a llama.cpp Llama model, False otherwise. + """ + return model.__class__.__name__ == "Llama" and "llama_cpp" in str(model.__class__.__module__) diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index a0f34728..ce274bc6 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -178,6 +178,10 @@ def set_to_eval(self) -> None: """Set the model to evaluation mode.""" set_to_eval(self.model) + def save(self, model_path: str) -> None: + """Save the model.""" + self.save_pretrained(model_path) + def save_pretrained(self, model_path: str) -> None: """ Save the smashed model to the specified model path. diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 27101b31..2f91c31c 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -42,7 +42,7 @@ ) from pruna.engine.model_checks import get_helpers, is_janus_llamagen_ar from pruna.engine.save_artifacts import save_artifacts -from pruna.engine.utils import determine_dtype, monkeypatch +from pruna.engine.utils import determine_dtype, get_fn_name, monkeypatch from pruna.logging.logger import pruna_logger if TYPE_CHECKING: @@ -72,8 +72,7 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf pruna_logger.debug("Using model's original save function...") save_fn = original_save_fn - # if save-before-move was the last operation, we simply move the already saved files, we have delt with them before - elif smash_config.save_fns[-1] == SAVE_FUNCTIONS.save_before_apply.name: + elif len(smash_config.save_fns) > 0 and smash_config.save_fns[-1] == get_fn_name(SAVE_FUNCTIONS.save_before_apply): pruna_logger.debug("Moving saved model...") save_fn = save_before_apply @@ -470,6 +469,35 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) +def save_model_llama_cpp(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: + """ + Save the model with llama.cpp functionality. + + Parameters + ---------- + model : Any + The model to save. + model_path : str | Path + The directory to save the model to. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + model_path = Path(model_path) + + if hasattr(model, "model_path"): + gguf_file = Path(model.model_path) + if gguf_file.exists(): + target_file = model_path / "model.gguf" + if gguf_file.resolve() != target_file.resolve(): + shutil.copy(gguf_file, target_file) + model.model_path = str(target_file) + smash_config.load_fns.append(LOAD_FUNCTIONS.llama_cpp.name) + else: + raise FileNotFoundError(f"GGUF file not found at {gguf_file}") + else: + raise AttributeError("Llama object does not have model_path attribute.") + + def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. @@ -521,6 +549,7 @@ class SAVE_FUNCTIONS(Enum): # noqa: N801 pickled = member(save_pickled) hqq = member(save_model_hqq) hqq_diffusers = member(save_model_hqq_diffusers) + llama_cpp = member(save_model_llama_cpp) save_before_apply = member(save_before_apply) reapply = member(reapply) diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index a039fc24..e8e5064c 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -16,9 +16,11 @@ import contextlib import gc +import hashlib import inspect import json from contextlib import AbstractContextManager, contextmanager +from functools import partial from pathlib import Path from typing import Any @@ -38,6 +40,48 @@ def safe_memory_cleanup() -> None: torch.cuda.empty_cache() +def get_fn_name(obj: Any) -> str: + """ + Get the name of a function or a partial function. + + Parameters + ---------- + obj : Any + The function or partial function to get the name of. + + Returns + ------- + str + The name of the function. + """ + if isinstance(obj, partial): + return get_fn_name(obj.func) + return getattr(obj, "name", getattr(obj, "__name__", str(obj))) + + +def verify_sha256(file_path: str | Path, expected_hash: str) -> bool: + """ + Verify the SHA256 hash of a file. + + Parameters + ---------- + file_path : str | Path + The path to the file to verify. + expected_hash : str + The expected SHA256 hash. + + Returns + ------- + bool + True if the hash matches, False otherwise. + """ + sha256_hash = hashlib.sha256() + with Path(file_path).open("rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() == expected_hash + + def load_json_config(path: str | Path, json_name: str) -> dict: """ Load and parse a JSON configuration file. @@ -364,6 +408,12 @@ def get_device(model: Any) -> str: if safe_is_instance(model, Pipeline): return get_device(model.model) + # function scored import due to model_check's import of ModelContext + from pruna.engine.model_checks import is_llama_cpp_model + + if is_llama_cpp_model(model): + return _get_llama_cpp_device(model) + # a device map that points the whole model to the same device (only key is "") is not considered distributed # when casting a model like this with "to" the device map is not maintained, so we rely on the model.device attribute if hasattr(model, "hf_device_map") and model.hf_device_map is not None and list(model.hf_device_map.keys()) != [""]: @@ -375,6 +425,9 @@ def get_device(model: Any) -> str: model_device = next(model.parameters()).device except StopIteration: raise ValueError("Could not determine device of model, model has no device attribute.") + except AttributeError: + # Model does not use PyTorch parameters natively (e.g. llama_cpp), default to cpu string mapping + model_device = "cpu" # model_device.type ignores the device index. Added a new function to convert to string. model_device = device_to_string(model_device) @@ -382,6 +435,25 @@ def get_device(model: Any) -> str: return model_device +def _get_llama_cpp_device(model: Any) -> str: + """ + Determine device for llama.cpp models. + + Parameters + ---------- + model : Any + The llama.cpp model. + + Returns + ------- + str + The device string. + """ + if hasattr(model, "_pruna_device"): + return device_to_string(model._pruna_device) + return "cpu" # Default for now, as it's the safest. + + def get_device_map(model: Any, subset_key: str | None = None) -> dict[str, str]: """ Get the device map of the model. diff --git a/tests/algorithms/testers/llama_cpp.py b/tests/algorithms/testers/llama_cpp.py new file mode 100644 index 00000000..f107ad27 --- /dev/null +++ b/tests/algorithms/testers/llama_cpp.py @@ -0,0 +1,20 @@ +from pruna.algorithms.llama_cpp import LlamaCpp + +from .base_tester import AlgorithmTesterBase + + +class TestLlamaCpp(AlgorithmTesterBase): + """Test the LlamaCpp quantizer.""" + + __test__ = True + + models = ["llama_3_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = LlamaCpp + metrics = [] + + def pre_smash_hook(self, model): + """Skip test if llama_cpp is not installed.""" + import pytest + pytest.importorskip("llama_cpp") diff --git a/tests/algorithms/testers/moe_kernel_tuner.py b/tests/algorithms/testers/moe_kernel_tuner.py index 9a754cf3..85661a83 100644 --- a/tests/algorithms/testers/moe_kernel_tuner.py +++ b/tests/algorithms/testers/moe_kernel_tuner.py @@ -34,7 +34,6 @@ def post_smash_hook(self, model: PrunaModel) -> None: def _resolve_hf_cache_config_path(self) -> Path: """Read the saved artifact and compute the expected HF cache config path.""" - imported_packages = MoeKernelTuner().import_algorithm_packages() smash_cfg = SmashConfig()