From e3bec0b830a88a0db315e7071b7d93e1b4b23a9b Mon Sep 17 00:00:00 2001 From: LiuYinfeng01 Date: Wed, 25 Mar 2026 11:28:32 +0800 Subject: [PATCH] feat(utils): optional aiter swizzle GEMM for ROCm MI + aiter Add diffusers.utils.aiter_swizzle_gemm: can_shuffle (layout k multiple of 2*IK), apply_swizzle for eligible large nn.Linear layers using shuffle_weight + hipb_mm with bpreshuffle. Requires AMD MI-series (ROCm) with the aiter package installed; raises ImportError if aiter is missing when apply_swizzle is called. Exported from diffusers.utils alongside existing is_aiter_available checks. Made-with: Cursor --- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/aiter_swizzle_gemm.py | 105 ++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 src/diffusers/utils/aiter_swizzle_gemm.py diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..d6dd65727867 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -18,6 +18,7 @@ from packaging import version from .. import __version__ +from .aiter_swizzle_gemm import MIN_SWIZZLE_WEIGHT_ELEMS, SHUFFLE_LAYOUT, apply_swizzle, can_shuffle from .constants import ( CONFIG_NAME, DEFAULT_HF_PARALLEL_LOADING_WORKERS, diff --git a/src/diffusers/utils/aiter_swizzle_gemm.py b/src/diffusers/utils/aiter_swizzle_gemm.py new file mode 100644 index 000000000000..fa362de79b36 --- /dev/null +++ b/src/diffusers/utils/aiter_swizzle_gemm.py @@ -0,0 +1,105 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROCm: optional swizzled Linear weights + aiter hipb_mm (bpreshuffle) for eligible layers.""" + +from __future__ import annotations + +import types +from typing import Tuple + +import torch + +from .import_utils import is_aiter_available +from .logging import get_logger + + +logger = get_logger(__name__) + +SHUFFLE_LAYOUT: Tuple[int, int] = (16, 16) +MIN_SWIZZLE_WEIGHT_ELEMS = 1024 * 1024 + + +def can_shuffle(n: int, k: int, layout: tuple[int, int] = SHUFFLE_LAYOUT) -> bool: + IN, IK = layout + BK = IK * 2 + return (n % IN == 0) and (k % BK == 0) + + +def apply_swizzle(model: torch.nn.Module, model_name: str = "model") -> None: + """Patch eligible ``nn.Linear`` layers with shuffled weights and ``hipb_mm`` forward.""" + if not is_aiter_available(): + raise ImportError( + "apply_swizzle requires the `aiter` package (ROCm). Install aiter or skip this optimization." + ) + try: + from aiter import hipb_mm + from aiter.ops.shuffle import shuffle_weight + except (ImportError, OSError, RuntimeError) as e: + raise ImportError(f"aiter is required for apply_swizzle but failed to import: {e}") from e + + n_shuffled = 0 + n_skipped_shape = 0 + n_skipped_small = 0 + + for _, module in model.named_modules(): + if not isinstance(module, torch.nn.Linear): + continue + + weight = module.weight.data + n, k = weight.shape + + if n * k < MIN_SWIZZLE_WEIGHT_ELEMS: + n_skipped_small += 1 + continue + + if not can_shuffle(n, k): + n_skipped_shape += 1 + continue + + shuffled = shuffle_weight(weight, SHUFFLE_LAYOUT).t() + module.weight = torch.nn.Parameter(shuffled, requires_grad=False) + n_shuffled += 1 + + def _forward(self, x): + if x.dim() >= 3: + shape = x.shape + x_2d = x.reshape(-1, x.size(-1)) + out = hipb_mm( + x_2d, + self.weight, + solution_index=-1, + bias=self.bias, + out_dtype=torch.bfloat16, + bpreshuffle=True, + ) + return out.view(*shape[:-1], self.weight.shape[1]) + return hipb_mm( + x, + self.weight, + solution_index=-1, + bias=self.bias, + out_dtype=torch.bfloat16, + bpreshuffle=True, + ) + + module.forward = types.MethodType(_forward, module) + + logger.info( + " [%s] swizzled: %s, skipped(shape): %s, skipped(small): %s", + model_name, + n_shuffled, + n_skipped_shape, + n_skipped_small, + )