From 3eab8104d4c16c971ce6090aed6af80783cac9d7 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 16 Mar 2026 08:42:33 +0000 Subject: [PATCH 1/6] attempt to add vllmomni plugin Signed-off-by: tjtanaa --- .../attentions/diffusion_fa_utils.py | 264 ++++++++++++++++++ .../attentions/diffusion_flash_attn.py | 129 +++++++++ atom/plugin/vllm_omni/__init__.py | 5 + atom/plugin/vllm_omni/platform.py | 65 +++++ atom/plugin/vllm_omni/register.py | 68 +++++ pyproject.toml | 6 + 6 files changed, 537 insertions(+) create mode 100644 atom/model_ops/attentions/diffusion_fa_utils.py create mode 100644 atom/model_ops/attentions/diffusion_flash_attn.py create mode 100644 atom/plugin/vllm_omni/__init__.py create mode 100644 atom/plugin/vllm_omni/platform.py create mode 100644 atom/plugin/vllm_omni/register.py diff --git a/atom/model_ops/attentions/diffusion_fa_utils.py b/atom/model_ops/attentions/diffusion_fa_utils.py new file mode 100644 index 000000000..1474598d7 --- /dev/null +++ b/atom/model_ops/attentions/diffusion_fa_utils.py @@ -0,0 +1,264 @@ +# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py +import torch +import torch.nn.functional as F + +from vllm_omni.platforms import current_omni_platform + +# Flash Attention function detection with fallback chain +flash_attn_func = None +flash_attn_varlen_func = None + +if current_omni_platform.is_rocm(): + # ROCm: try Aiter first + try: + from vllm._aiter_ops import is_aiter_found_and_supported + + if is_aiter_found_and_supported(): + from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass +elif current_omni_platform.is_xpu(): + try: + from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass +else: + # CUDA: try FA3 -> FA2 fallback chain + # Try FA3 from fa3-fwd PyPI package + try: + from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + # Fallback: Try FA3 from flash-attention source build + if flash_attn_func is None: + try: + from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + # Fallback: Try FA2 from flash-attn package (try multiple import paths) + if flash_attn_func is None: + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401 + except (ImportError, ModuleNotFoundError): + pass + + if flash_attn_func is None: + try: + from flash_attn.flash_attn_interface import ( # noqa: F401 + flash_attn_func, + flash_attn_varlen_func, + ) + except (ImportError, ModuleNotFoundError): + pass + +# If no FA backend available, SDPA backend will be selected at the platform level +# flash_attn_func and flash_attn_varlen_func will be None +HAS_FLASH_ATTN = flash_attn_func is not None or flash_attn_varlen_func is not None + + +def _index_first_axis(tensor, indices): + """ + A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, + after flattening the first two dimensions of the tensor. This is functionally equivalent to + FA2's `index_first_axis` and replaces the need to import it. + """ + # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first + # two dimensions to get (total_tokens, ...) before indexing. + reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) + return reshaped_tensor[indices] + + +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + _index_first_axis(hidden_states, indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def _pad_input(hidden_states, indices, batch, seqlen): + """ + pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, + # this might cause a graph break + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + unpad_input_func, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong + to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in + order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + unpad_input_func: + The function to use for unpadding the input tensors. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into + ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + if torch.compiler.is_compiling(): + # allow PyTorch compiler to include operations that return scalar values (like .item() + torch._dynamo.config.capture_scalar_outputs = True + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + # With static caches, the k/v states may be larger than the mask -> + # we need to slice them to avoid generating garbage + # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores + if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): + key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] + + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = _index_first_axis(key_layer, indices_k) + value_layer = _index_first_axis(value_layer, indices_k) + if query_length == kv_seq_len: + query_layer = _index_first_axis(query_layer, indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def _is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. + we have multiple increasing sequences + """ + if position_ids is None: + return False + + increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() + return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() diff --git a/atom/model_ops/attentions/diffusion_flash_attn.py b/atom/model_ops/attentions/diffusion_flash_attn.py new file mode 100644 index 000000000..ab1f19dc4 --- /dev/null +++ b/atom/model_ops/attentions/diffusion_flash_attn.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) + +logger = init_logger(__name__) + + +class ATOMDiffusionFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @classmethod + def supports_attention_mask(cls) -> bool: + return True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 96, 128, 192, 256] + + @staticmethod + def get_name() -> str: + return "ATOM_DIFFUSION_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> type["ATOMDiffusionFlashAttentionImpl"]: + return ATOMDiffusionFlashAttentionImpl + + +class ATOMDiffusionFlashAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.causal = causal + self.softmax_scale = softmax_scale + + @staticmethod + def _unwrap_flash_output(out: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor: + # FA3 may return (out, lse), FA2 returns out + return out[0] if isinstance(out, tuple) else out + + def _forward_varlen_masked( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + from atom.model_ops.attentions.diffusion_fa_utils import ( + _pad_input, + _unpad_input, + _upad_input, + flash_attn_varlen_func, + ) + + assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)" + query_length = query.size(1) + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query, key, value, attention_mask, query_length, _unpad_input + ) + + out_unpad = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **{ + "causal": self.causal, + "softmax_scale": self.softmax_scale, + }, + ) + out_unpad = self._unwrap_flash_output(out_unpad) + return _pad_input(out_unpad, indices_q, query.size(0), query_length) + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + """CUDA/ROCm flash attention implementation.""" + from atom.model_ops.attentions.diffusion_fa_utils import ( + HAS_FLASH_ATTN, + flash_attn_func, + ) + + if not HAS_FLASH_ATTN: + raise ImportError( + "FlashAttentionBackend requires Flash Attention. " + "Please install one of: fa3-fwd, flash-attention, or flash-attn. " + "Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA" + ) + + attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None + + if attention_mask is not None and torch.any(~attention_mask): + return self._forward_varlen_masked( + query, + key, + value, + attention_mask, + ) + + out = flash_attn_func( + query, + key, + value, + causal=self.causal, + softmax_scale=self.softmax_scale, + ) + return self._unwrap_flash_output(out) diff --git a/atom/plugin/vllm_omni/__init__.py b/atom/plugin/vllm_omni/__init__.py new file mode 100644 index 000000000..9009e4b4a --- /dev/null +++ b/atom/plugin/vllm_omni/__init__.py @@ -0,0 +1,5 @@ +"""vLLM-Omni plugin integration for ATOM.""" + +from .register import register_omni_model, register_omni_platform + +__all__ = ["register_omni_platform", "register_omni_model"] diff --git a/atom/plugin/vllm_omni/platform.py b/atom/plugin/vllm_omni/platform.py new file mode 100644 index 000000000..5d32172eb --- /dev/null +++ b/atom/plugin/vllm_omni/platform.py @@ -0,0 +1,65 @@ +"""ATOM vLLM-Omni platform integration. + +This module contains the vLLM-Omni `OmniPlatform` subclass used in ATOM's +vLLM-Omni plugin mode. Overrides both AR and diffusion attention backend +selection to use ATOM implementations. +""" + +import logging + +from atom.utils import envs + +logger = logging.getLogger("atom") +# This flag is used to enable the vLLM-Omni plugin mode. +disable_vllm_plugin = envs.ATOM_DISABLE_VLLM_PLUGIN +disable_vllm_plugin_attention = envs.ATOM_DISABLE_VLLM_PLUGIN_ATTENTION + +if not disable_vllm_plugin: + from vllm_omni.platforms.rocm.platform import RocmOmniPlatform + + class ATOMOmniPlatform(RocmOmniPlatform): + @classmethod + def get_attn_backend_cls( + cls, selected_backend, attn_selector_config, num_heads + ) -> str: + if disable_vllm_plugin_attention: + logger.info("Fallback to original vLLM attention backend") + return super().get_attn_backend_cls( + selected_backend, attn_selector_config, num_heads + ) + + logger.info("Use atom attention backend") + if attn_selector_config.use_mla: + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" + return "atom.model_ops.attentions.aiter_attention.AiterBackend" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, selected_backend: str | None, head_size: int + ) -> str: + if disable_vllm_plugin_attention: + logger.info( + "Fallback to original vLLM-Omni diffusion attention backend" + ) + return super().get_diffusion_attn_backend_cls( + selected_backend, head_size + ) + + # Respect env var override for non-FLASH_ATTN backends + # (TORCH_SDPA, SAGE_ATTN, etc.) + if ( + selected_backend is not None + and selected_backend.upper() != "FLASH_ATTN" + ): + return super().get_diffusion_attn_backend_cls( + selected_backend, head_size + ) + + logger.info("Use ATOM diffusion attention backend") + return ( + "atom.model_ops.attentions.diffusion_flash_attn" + ".ATOMDiffusionFlashAttentionBackend" + ) + +else: + ATOMOmniPlatform = None diff --git a/atom/plugin/vllm_omni/register.py b/atom/plugin/vllm_omni/register.py new file mode 100644 index 000000000..91d3f53a3 --- /dev/null +++ b/atom/plugin/vllm_omni/register.py @@ -0,0 +1,68 @@ +from typing import Optional +import logging + +import torch +from atom.plugin.prepare import _set_framework_backbone +from atom.utils import envs +from atom.plugin.vllm.mla_patch import patch_vllm_mla_attention +from atom.plugin.vllm.register import ( + _patch_vllm_attention_process_weights_after_loading, + _VLLM_MODEL_REGISTRY_OVERRIDES, +) + +logger = logging.getLogger("atom") + +# this flag is used to enable the vllm-omni plugin mode +disable_vllm_plugin = envs.ATOM_DISABLE_VLLM_PLUGIN + + +def register_omni_platform() -> Optional[str]: + + if disable_vllm_plugin: + logger.info("Disable ATOM OOT plugin platforms (vllm-omni)") + return None + + _set_framework_backbone("vllm") + + # return the ATOM omni platform to vllm-omni + return "atom.plugin.vllm_omni.platform.ATOMOmniPlatform" + + +def register_omni_model() -> None: + if disable_vllm_plugin: + logger.info("Disable ATOM model register (vllm-omni)") + return + + from vllm_omni.model_executor.models.registry import OmniModelRegistry + import vllm.model_executor.models.registry as vllm_model_registry + + any_updated = False + for arch, qual in _VLLM_MODEL_REGISTRY_OVERRIDES.items(): + module_name, class_name = qual.split(":", 1) + existing = OmniModelRegistry.models.get(arch) + if existing is not None: + # If already overridden to the same target, skip re-registering. + if ( + getattr(existing, "module_name", None) == module_name + and getattr(existing, "class_name", None) == class_name + ): + continue + + logger.info(f"Register model {arch} to vLLM-Omni with {qual}") + OmniModelRegistry.register_model(arch, qual) + any_updated = True + + # clear lru cache + if any_updated: + vllm_model_registry._try_load_model_cls.cache_clear() + vllm_model_registry._try_inspect_model_cls.cache_clear() + + patch_vllm_mla_attention() + # patch attention process weights after loading + try: + from vllm.attention.layer import Attention, MLAAttention + except ImportError: + from vllm.model_executor.layers.attention import Attention, MLAAttention + + _patch_vllm_attention_process_weights_after_loading(Attention) + _patch_vllm_attention_process_weights_after_loading(MLAAttention) diff --git a/pyproject.toml b/pyproject.toml index 9b3b1a809..3b27b0353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,3 +37,9 @@ atom = "atom.plugin.vllm.register:register_platform" # but the plugin mode for models can be disabled by # ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 atom_model_registry = "atom.plugin.vllm.register:register_model" + +[project.entry-points."vllm_omni.platform_plugins"] +atom = "atom.plugin.vllm_omni.register:register_omni_platform" + +[project.entry-points."vllm_omni.general_plugins"] +atom_model_registry = "atom.plugin.vllm_omni.register:register_omni_model" From 31be6ce905684aa02f93a50d9552cf3de2a3e791 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 31 Mar 2026 07:12:39 +0000 Subject: [PATCH 2/6] enable vllm omni diffusion model plugin; add qwen-image edit recipe Signed-off-by: tjtanaa --- atom/plugin/attention_mha.py | 15 ++++++- .../diffusion_attention_backend/fa_utils.py} | 0 .../flash_attn.py} | 4 +- atom/plugin/vllm_omni/platform.py | 2 +- recipes/atom_vllmomni/Qwen-Image.md | 42 +++++++++++++++++++ 5 files changed, 59 insertions(+), 4 deletions(-) rename atom/{model_ops/attentions/diffusion_fa_utils.py => plugin/vllm_omni/diffusion_attention_backend/fa_utils.py} (100%) rename atom/{model_ops/attentions/diffusion_flash_attn.py => plugin/vllm_omni/diffusion_attention_backend/flash_attn.py} (95%) create mode 100644 recipes/atom_vllmomni/Qwen-Image.md diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 367fd1c4f..e12178221 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -722,7 +722,8 @@ def forward_impl_plugin_mode( ) else: # Qwen only uses gluon pa decode when bs=64 - if num_decodes == _QWEN_GLUON_PA_DECODE_BS: + if False: + # if num_decodes == _QWEN_GLUON_PA_DECODE_BS: self.paged_attention_triton_plugin_mode( q=query[:num_decode_tokens], k_cache=new_key_cache, @@ -749,6 +750,17 @@ def forward_impl_plugin_mode( return output + def do_kv_cache_update( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + ) -> None: + return + def PagedAttentionImplDecoratorForPluginMode(cls): method_names = [ @@ -759,6 +771,7 @@ def PagedAttentionImplDecoratorForPluginMode(cls): "extend_for_sliding_window", "extend_forward", "forward_impl_plugin_mode", + "do_kv_cache_update", ] logger.info( diff --git a/atom/model_ops/attentions/diffusion_fa_utils.py b/atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py similarity index 100% rename from atom/model_ops/attentions/diffusion_fa_utils.py rename to atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py diff --git a/atom/model_ops/attentions/diffusion_flash_attn.py b/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py similarity index 95% rename from atom/model_ops/attentions/diffusion_flash_attn.py rename to atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py index ab1f19dc4..9bfa2c876 100644 --- a/atom/model_ops/attentions/diffusion_flash_attn.py +++ b/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py @@ -60,7 +60,7 @@ def _forward_varlen_masked( value: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - from atom.model_ops.attentions.diffusion_fa_utils import ( + from atom.plugin.vllm_omni.diffusion_attention_backend.fa_utils import ( _pad_input, _unpad_input, _upad_input, @@ -97,7 +97,7 @@ def forward_cuda( attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: """CUDA/ROCm flash attention implementation.""" - from atom.model_ops.attentions.diffusion_fa_utils import ( + from atom.plugin.vllm_omni.diffusion_attention_backend.fa_utils import ( HAS_FLASH_ATTN, flash_attn_func, ) diff --git a/atom/plugin/vllm_omni/platform.py b/atom/plugin/vllm_omni/platform.py index 5d32172eb..16dcc863d 100644 --- a/atom/plugin/vllm_omni/platform.py +++ b/atom/plugin/vllm_omni/platform.py @@ -57,7 +57,7 @@ def get_diffusion_attn_backend_cls( logger.info("Use ATOM diffusion attention backend") return ( - "atom.model_ops.attentions.diffusion_flash_attn" + "atom.plugin.vllm_omni.diffusion_attention_backend.flash_attn" ".ATOMDiffusionFlashAttentionBackend" ) diff --git a/recipes/atom_vllmomni/Qwen-Image.md b/recipes/atom_vllmomni/Qwen-Image.md new file mode 100644 index 000000000..04e3fe5e0 --- /dev/null +++ b/recipes/atom_vllmomni/Qwen-Image.md @@ -0,0 +1,42 @@ +# Qwen-Image with ATOM vLLM-Omni Plugin Backend Usage Guide + +[Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) is an image generation foundation model in the Qwen series developed by Alibaba. It achieves significant advances in complex text rendering and precise image editing. The model demonstrates strong general capabilities in both image generation and editing, with exceptional performance in text rendering. + +## Launching server + +### BF16 on 1xMI300X/MI355X GPUs + +```bash +vllm serve Qwen/Qwen-Image --omni \ + --host localhost \ + --port 8091 \ + --tensor-parallel-size 1 +``` + +### Interact with the model + +The command is extracted from https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/text_to_image + +```python +from openai import OpenAI +import base64 + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="none") + +response = client.chat.completions.create( + model="Qwen/Qwen-Image", + messages=[{"role": "user", "content": "A beautiful landscape painting"}], + extra_body={ + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42, + }, +) + +img_url = response.choices[0].message.content[0]["image_url"]["url"] +_, b64_data = img_url.split(",", 1) +with open("output.png", "wb") as f: + f.write(base64.b64decode(b64_data)) +``` \ No newline at end of file From ad8178534a17cae40f9a6eff37e6b652c0081322 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 31 Mar 2026 07:25:43 +0000 Subject: [PATCH 3/6] remove unnecessary changes to mha Signed-off-by: tjtanaa --- atom/plugin/attention_mha.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index e12178221..86848d890 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -750,17 +750,6 @@ def forward_impl_plugin_mode( return output - def do_kv_cache_update( - self, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - ) -> None: - return - def PagedAttentionImplDecoratorForPluginMode(cls): method_names = [ @@ -771,7 +760,6 @@ def PagedAttentionImplDecoratorForPluginMode(cls): "extend_for_sliding_window", "extend_forward", "forward_impl_plugin_mode", - "do_kv_cache_update", ] logger.info( From ac4247ce42bd8170a74145582586fcedf8ee256a Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 31 Mar 2026 07:50:08 +0000 Subject: [PATCH 4/6] clean up files Signed-off-by: tjtanaa --- .../diffusion_attention_backend/fa_utils.py | 264 ------------------ .../diffusion_attention_backend/flash_attn.py | 27 +- atom/plugin/vllm_omni/platform.py | 4 +- 3 files changed, 10 insertions(+), 285 deletions(-) delete mode 100644 atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py diff --git a/atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py b/atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py deleted file mode 100644 index 1474598d7..000000000 --- a/atom/plugin/vllm_omni/diffusion_attention_backend/fa_utils.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py -import torch -import torch.nn.functional as F - -from vllm_omni.platforms import current_omni_platform - -# Flash Attention function detection with fallback chain -flash_attn_func = None -flash_attn_varlen_func = None - -if current_omni_platform.is_rocm(): - # ROCm: try Aiter first - try: - from vllm._aiter_ops import is_aiter_found_and_supported - - if is_aiter_found_and_supported(): - from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401 - except (ImportError, ModuleNotFoundError): - pass -elif current_omni_platform.is_xpu(): - try: - from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func # noqa: F401 - except (ImportError, ModuleNotFoundError): - pass -else: - # CUDA: try FA3 -> FA2 fallback chain - # Try FA3 from fa3-fwd PyPI package - try: - from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 - except (ImportError, ModuleNotFoundError): - pass - - # Fallback: Try FA3 from flash-attention source build - if flash_attn_func is None: - try: - from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401 - except (ImportError, ModuleNotFoundError): - pass - - # Fallback: Try FA2 from flash-attn package (try multiple import paths) - if flash_attn_func is None: - try: - from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401 - except (ImportError, ModuleNotFoundError): - pass - - if flash_attn_func is None: - try: - from flash_attn.flash_attn_interface import ( # noqa: F401 - flash_attn_func, - flash_attn_varlen_func, - ) - except (ImportError, ModuleNotFoundError): - pass - -# If no FA backend available, SDPA backend will be selected at the platform level -# flash_attn_func and flash_attn_varlen_func will be None -HAS_FLASH_ATTN = flash_attn_func is not None or flash_attn_varlen_func is not None - - -def _index_first_axis(tensor, indices): - """ - A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis, - after flattening the first two dimensions of the tensor. This is functionally equivalent to - FA2's `index_first_axis` and replaces the need to import it. - """ - # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first - # two dimensions to get (total_tokens, ...) before indexing. - reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:]) - return reshaped_tensor[indices] - - -def _unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - - return ( - _index_first_axis(hidden_states, indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def _pad_input(hidden_states, indices, batch, seqlen): - """ - pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[1:] - output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) - output[indices] = hidden_states - return output.view(batch, seqlen, *dim) - - -def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. - `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, - # this might cause a graph break - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - unpad_input_func, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong - to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in - order to avoid the recomputation of the same intermediary tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - unpad_input_func: - The function to use for unpadding the input tensors. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into - ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, - `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - if torch.compiler.is_compiling(): - # allow PyTorch compiler to include operations that return scalar values (like .item() - torch._dynamo.config.capture_scalar_outputs = True - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - # With static caches, the k/v states may be larger than the mask -> - # we need to slice them to avoid generating garbage - # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores - if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]): - key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :] - - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = _index_first_axis(key_layer, indices_k) - value_layer = _index_first_axis(value_layer, indices_k) - if query_length == kv_seq_len: - query_layer = _index_first_axis(query_layer, indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def _is_packed_sequence(position_ids, batch_size): - """ - Check the position ids whether packed sequences are indicated or not - 1. Position ids exist - 2. Flattened sequences only are supported - 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. - we have multiple increasing sequences - """ - if position_ids is None: - return False - - increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min() - return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool() diff --git a/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py b/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py index 9bfa2c876..5cf05be8f 100644 --- a/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py +++ b/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py @@ -13,7 +13,7 @@ logger = init_logger(__name__) -class ATOMDiffusionFlashAttentionBackend(AttentionBackend): +class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @classmethod @@ -26,14 +26,14 @@ def get_supported_head_sizes() -> list[int]: @staticmethod def get_name() -> str: - return "ATOM_DIFFUSION_FLASH_ATTN" + return "AITER_DIFFUSION_FLASH_ATTN" @staticmethod - def get_impl_cls() -> type["ATOMDiffusionFlashAttentionImpl"]: - return ATOMDiffusionFlashAttentionImpl + def get_impl_cls() -> type["AiterDiffusionFlashAttentionImpl"]: + return AiterDiffusionFlashAttentionImpl -class ATOMDiffusionFlashAttentionImpl(AttentionImpl): +class AiterDiffusionFlashAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, @@ -60,11 +60,11 @@ def _forward_varlen_masked( value: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - from atom.plugin.vllm_omni.diffusion_attention_backend.fa_utils import ( + from aiter import flash_attn_varlen_func + from vllm_omni.diffusion.attention.backends.utils.fa import ( _pad_input, _unpad_input, _upad_input, - flash_attn_varlen_func, ) assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)" @@ -96,18 +96,7 @@ def forward_cuda( value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: - """CUDA/ROCm flash attention implementation.""" - from atom.plugin.vllm_omni.diffusion_attention_backend.fa_utils import ( - HAS_FLASH_ATTN, - flash_attn_func, - ) - - if not HAS_FLASH_ATTN: - raise ImportError( - "FlashAttentionBackend requires Flash Attention. " - "Please install one of: fa3-fwd, flash-attention, or flash-attn. " - "Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA" - ) + from aiter import flash_attn_func attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None diff --git a/atom/plugin/vllm_omni/platform.py b/atom/plugin/vllm_omni/platform.py index 16dcc863d..80dc00292 100644 --- a/atom/plugin/vllm_omni/platform.py +++ b/atom/plugin/vllm_omni/platform.py @@ -55,10 +55,10 @@ def get_diffusion_attn_backend_cls( selected_backend, head_size ) - logger.info("Use ATOM diffusion attention backend") + logger.info("Use atom diffusion attention backend") return ( "atom.plugin.vllm_omni.diffusion_attention_backend.flash_attn" - ".ATOMDiffusionFlashAttentionBackend" + ".AiterFlashAttentionBackend" ) else: From 7bc16baf8dcbc65b6ba04b896e82226b1a01531f Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 31 Mar 2026 07:52:24 +0000 Subject: [PATCH 5/6] remove unused torch import Signed-off-by: tjtanaa --- atom/plugin/vllm_omni/register.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/plugin/vllm_omni/register.py b/atom/plugin/vllm_omni/register.py index 91d3f53a3..c787f0298 100644 --- a/atom/plugin/vllm_omni/register.py +++ b/atom/plugin/vllm_omni/register.py @@ -1,7 +1,6 @@ from typing import Optional import logging -import torch from atom.plugin.prepare import _set_framework_backbone from atom.utils import envs from atom.plugin.vllm.mla_patch import patch_vllm_mla_attention From e7a86b7a3fba8785a737cf8c21643b46aa9e22d1 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Fri, 10 Apr 2026 02:08:11 +0000 Subject: [PATCH 6/6] enabled wan22 and Qwen Image Signed-off-by: tjtanaa --- .../attention_backend}/flash_attn.py | 0 .../vllm_omni/diffusion/models/README.md | 174 ++++++++++++ .../diffusion/models/qwen_image/__init__.py | 0 .../qwen_image/qwen_image_transformer.py | 249 ++++++++++++++++++ .../diffusion/models/wan2_2/__init__.py | 0 .../models/wan2_2/wan2_2_transformer.py | 179 +++++++++++++ atom/plugin/vllm_omni/platform.py | 2 +- atom/plugin/vllm_omni/register.py | 161 ++++++++--- 8 files changed, 728 insertions(+), 37 deletions(-) rename atom/plugin/vllm_omni/{diffusion_attention_backend => diffusion/attention_backend}/flash_attn.py (100%) create mode 100644 atom/plugin/vllm_omni/diffusion/models/README.md create mode 100644 atom/plugin/vllm_omni/diffusion/models/qwen_image/__init__.py create mode 100644 atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py create mode 100644 atom/plugin/vllm_omni/diffusion/models/wan2_2/__init__.py create mode 100644 atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py diff --git a/atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py b/atom/plugin/vllm_omni/diffusion/attention_backend/flash_attn.py similarity index 100% rename from atom/plugin/vllm_omni/diffusion_attention_backend/flash_attn.py rename to atom/plugin/vllm_omni/diffusion/attention_backend/flash_attn.py diff --git a/atom/plugin/vllm_omni/diffusion/models/README.md b/atom/plugin/vllm_omni/diffusion/models/README.md new file mode 100644 index 000000000..275a46a2c --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/README.md @@ -0,0 +1,174 @@ +# ATOM vLLM-Omni Diffusion Model Plugin + +Models under this directory run with the **vLLM-Omni plugin** — they cannot run standalone with native ATOM. For native ATOM models, see `atom/models/` instead. + +## What the Plugin Does + +The ATOM plugin replaces vLLM's linear layers (`vllm.model_executor.layers.linear`) with ATOM's AITER-accelerated equivalents (`atom.model_ops.linear`), enabling ROCm-optimized quantized GEMM kernels for diffusion model inference. + +The plugin hooks into vllm-omni at startup via `register_omni_model()` in `atom/plugin/vllm_omni/register.py`. It uses **monkey-patching** rather than registering new pipeline classes: the stock vllm-omni pipelines are left in place, but the transformer class they instantiate is swapped out before any model is loaded. + +--- + +## How to Add a New Model + +Follow the pattern used for Wan2.2 in `wan2_2/wan2_2_transformer.py`. + +### Step 1: Identify what to replace + +Open the stock vllm-omni transformer file for your model (e.g. `vllm_omni/diffusion/models//`). Look for uses of: + +```python +from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +``` + +These are the layers to replace with their `atom.model_ops.linear` equivalents. + +### Step 2: Create an ATOM transformer file + +Create `atom/plugin/vllm_omni/diffusion/models//` and add a `_transformer.py`. + +**Import pattern:** + +```python +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm_omni.diffusion.models.._transformer import ( + StockSelfAttention, + StockCrossAttention, + StockFeedForward, + StockTransformerBlock, + StockTransformerModel, + # any helper functions needed in forward() overrides +) +``` + +**For each layer class** that uses vLLM linears, create an ATOM subclass: + +```python +class ATOMStockSelfAttention(StockSelfAttention): + def __init__(self, ...): + super().__init__(...) + # Replace linear layers after super().__init__() creates the vllm ones + self.to_qkv = QKVParallelLinear(hidden_size=dim, head_size=head_dim, + total_num_heads=num_heads, bias=True) + self.num_heads = self.to_qkv.num_heads # refresh from atom layer + self.num_kv_heads = self.to_qkv.num_kv_heads + self.to_out = RowParallelLinear(inner_dim, dim, bias=True) +``` + +**Check if `forward()` needs an override.** Two cases require it: + +| Situation | What to do | +|-----------|-----------| +| Stock `forward()` does `out, _ = self.layer(x)` (tuple unpack) | Override `forward()` — atom layers return a plain tensor, not `(tensor, None)` | +| Stock `forward()` does `out = self.layer(x)` | No override needed — atom and vllm (with `return_bias=False`) both return plain tensors | + +The `QKVParallelLinear` case always requires an override because vLLM returns a tuple: + +```python + def forward(self, hidden_states, ...): + # atom returns plain tensor; vllm returns (tensor, None) + qkv = self.to_qkv(hidden_states) # NOT: qkv, _ = self.to_qkv(hidden_states) + ... +``` + +**For feedforward layers** that wrap `ColumnParallelLinear` inside a helper (e.g. `ColumnParallelGELU`), replace the inner `.proj` attribute: + +```python +class ATOMStockFeedForward(StockFeedForward): + def __init__(self, dim, inner_dim, dim_out=None, bias=True): + super().__init__(dim=dim, inner_dim=inner_dim, dim_out=dim_out, bias=bias) + dim_out = dim_out or dim + self.net_0.proj = ColumnParallelLinear(dim, inner_dim, bias=bias) + self.net_2 = RowParallelLinear(inner_dim, dim_out, bias=bias) + # forward() inherited — helper's forward() calls self.proj(x) → plain tensor ✓ +``` + +**Compose into a block and top-level model:** + +```python +class ATOMStockTransformerBlock(StockTransformerBlock): + def __init__(self, dim, ffn_dim, num_heads, eps=1e-6, ...): + super().__init__(...) + head_dim = dim // num_heads + self.attn1 = ATOMStockSelfAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.attn2 = ATOMStockCrossAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.ffn = ATOMStockFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) + # forward() inherited from StockTransformerBlock unchanged + +class ATOMStockTransformerModel(StockTransformerModel): + def __init__(self, ..., num_layers=N, ...): + super().__init__(...) # builds rope, embeddings, norm, proj_out + inner_dim = num_attention_heads * attention_head_dim + # Replace all blocks after super() creates the stock ones + self.blocks = nn.ModuleList([ + ATOMStockTransformerBlock(inner_dim, ffn_dim, num_attention_heads, eps, ...) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan all inherited from StockTransformerModel +``` + +### Step 3: Register via monkey-patch in `register.py` + +Open `atom/plugin/vllm_omni/register.py` and add to the monkey-patch block at the end of `register_omni_model()`: + +```python +import vllm_omni.diffusion.models..pipeline_ as __pipeline +from atom.plugin.vllm_omni.diffusion.models.._transformer import ATOMTransformerModel +__pipeline. = ATOMTransformerModel +``` + +Python resolves module-level names at call time, so patching the name in the pipeline module's namespace causes all subsequent `create_transformer_from_config()` calls to instantiate the ATOM model — no pipeline file copies needed. + +**You only need to patch the base pipeline module.** If variant pipelines (e.g. i2v, ti2v) import `create_transformer_from_config` *from* the base pipeline rather than defining their own, they will automatically pick up the patch — patching the same name twice in different modules would be redundant. Check the variant pipeline's imports to confirm: + +```python +# If you see this in pipeline__i2v.py, one patch covers all variants: +from vllm_omni.diffusion.models..pipeline_ import create_transformer_from_config +``` + +**Do not copy pipeline files.** If the stock pipeline needs no changes beyond the transformer class swap, patching is sufficient. Only create a new pipeline class if you need to change the pipeline's own logic (e.g. different preprocessing, scheduler, or VAE). + +### Step 4: Update `__init__.py` + +Add your model's ATOM transformer class to `atom/plugin/vllm_omni/diffusion/models//__init__.py` (if the directory needs one). Re-export stock pipeline helpers from `vllm_omni` directly rather than copying them. + +--- + +## API Compatibility Notes + +### `atom.model_ops.linear` vs `vllm.model_executor.layers.linear` + +| vLLM class | ATOM equivalent | Notes | +|---|---|---| +| `ColumnParallelLinear(in, out, bias, gather_output=False, return_bias=False)` | `ColumnParallelLinear(in, out, bias)` | Extra kwargs absorbed via `**kwargs`, silently ignored | +| `RowParallelLinear(in, out, bias, input_is_parallel=True, return_bias=False)` | `RowParallelLinear(in, out, bias)` | `reduce_results=True` by default — matches vLLM behavior | +| `QKVParallelLinear(hidden_size, head_size, total_num_heads, bias)` | `QKVParallelLinear(hidden_size, head_size, total_num_heads, bias)` | Same constructor; **different return type** (see below) | + +### Critical: `QKVParallelLinear` return type difference + +```python +# vLLM: returns (tensor, None) tuple +qkv, _ = self.to_qkv(hidden_states) + +# ATOM: returns plain tensor — must NOT unpack +qkv = self.to_qkv(hidden_states) +``` + +`ColumnParallelLinear` and `RowParallelLinear` forward signatures are compatible — both return a plain tensor when vLLM's `return_bias=False` (the standard config for diffusion models). + +### `atom.model_ops.linear` forward signature + +```python +def forward(self, x: Tensor, x_scale: Tensor | None = None, otype=bf16) -> Tensor +``` + +Calling `layer(x)` works as expected; `x_scale` and `otype` are used for quantized inference and default safely to unquantized bfloat16. + +--- + +## Current Models + +| Model | Transformer file | Registered via | +|-------|-----------------|----------------| +| Wan2.2 (T2V / I2V / TI2V) | `wan2_2/wan2_2_transformer.py` | monkey-patch in `register.py` | diff --git a/atom/plugin/vllm_omni/diffusion/models/qwen_image/__init__.py b/atom/plugin/vllm_omni/diffusion/models/qwen_image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py new file mode 100644 index 000000000..0f6a45464 --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +from vllm.logger import init_logger + +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + FeedForward, + QwenImageCrossAttention, + QwenImageTransformerBlock, + QwenImageTransformer2DModel, +) + +logger = init_logger(__name__) + + +class ATOMFeedForward(FeedForward): + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + quant_config=None, + prefix: str = "", + ): + super().__init__( + dim=dim, dim_out=dim_out, mult=mult, activation_fn=activation_fn, + inner_dim=inner_dim, bias=bias, quant_config=quant_config, prefix=prefix, + ) + inner_dim_val = inner_dim or int(dim * mult) + dim_out_val = dim_out or dim + # Replace ColumnParallelApproxGELU's inner proj with ATOM ColumnParallelLinear. + # ColumnParallelApproxGELU.forward() calls self.proj(x) → plain tensor ✓ + self.net[0].proj = ColumnParallelLinear(dim, inner_dim_val, bias=bias) + # Replace net[2] (RowParallelLinear) with ATOM version. + self.net[2] = RowParallelLinear(inner_dim_val, dim_out_val, bias=bias) + # forward() inherited: iterates self.net ✓ + + +class ATOMQwenImageCrossAttention(QwenImageCrossAttention): + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), + out_bias: bool = True, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, + context_pre_only: bool = False, + out_dim: int | None = None, + quant_config=None, + ): + super().__init__( + dim=dim, num_heads=num_heads, head_dim=head_dim, + added_kv_proj_dim=added_kv_proj_dim, window_size=window_size, + out_bias=out_bias, qk_norm=qk_norm, eps=eps, pre_only=pre_only, + context_pre_only=context_pre_only, out_dim=out_dim, quant_config=quant_config, + ) + # Replace vLLM QKVParallelLinear with ATOM versions; refresh head counts. + self.to_qkv = QKVParallelLinear( + hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + inner_dim = out_dim if out_dim is not None else head_dim * num_heads + # Replace vLLM RowParallelLinear with ATOM versions. + self.to_out = RowParallelLinear(inner_dim, dim, bias=out_bias) + self.to_add_out = RowParallelLinear(inner_dim, dim, bias=out_bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + vid_freqs: torch.Tensor, + txt_freqs: torch.Tensor, + hidden_states_mask: torch.Tensor | None = None, + encoder_hidden_states_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + print("ATOMQwenImageCrossAttention forward") + # CRITICAL: ATOM QKVParallelLinear returns a plain tensor; vLLM returns (tensor, None). + img_qkv = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) + + txt_qkv = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) + + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten( -1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) + + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten( -1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + img_cos = vid_freqs.real.to(img_query.dtype) + img_sin = vid_freqs.imag.to(img_query.dtype) + txt_cos = txt_freqs.real.to(txt_query.dtype) + txt_sin = txt_freqs.imag.to(txt_query.dtype) + + img_query = self.rope(img_query, img_cos, img_sin) + img_key = self.rope(img_key, img_cos, img_sin) + txt_query = self.rope(txt_query, txt_cos, txt_sin) + txt_key = self.rope(txt_key, txt_cos, txt_sin) + + seq_len_txt = encoder_hidden_states.shape[1] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + if ( + self.parallel_config is not None + and self.parallel_config.sequence_parallel_size > 1 + and not get_forward_context().split_text_embed_in_sp + ): + attn_metadata = AttentionMetadata( + joint_query=txt_query, + joint_key=txt_key, + joint_value=txt_value, + joint_strategy="front", + ) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask + + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) + else: + attn_metadata = None + if hidden_states_mask is not None or encoder_hidden_states_mask is not None: + mask_list: list[torch.Tensor] = [] + if encoder_hidden_states_mask is not None: + mask_list.append(encoder_hidden_states_mask) + else: + mask_list.append( + torch.ones( + encoder_hidden_states.shape[:2], + dtype=torch.bool, + device=encoder_hidden_states.device, + ) + ) + if hidden_states_mask is not None: + mask_list.append(hidden_states_mask) + else: + mask_list.append( + torch.ones( + hidden_states.shape[:2], + dtype=torch.bool, + device=hidden_states.device, + ) + ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] + attn_metadata = AttentionMetadata(attn_mask=joint_mask) + + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) + + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] + + # ATOM RowParallelLinear returns plain tensor + performs all-reduce ✓ + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class ATOMQwenImageTransformerBlock(QwenImageTransformerBlock): + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + zero_cond_t: bool = False, + quant_config=None, + ): + super().__init__( + dim=dim, num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, qk_norm=qk_norm, + eps=eps, zero_cond_t=zero_cond_t, quant_config=quant_config, + ) + # Replace joint cross-attention with ATOM version (QKV + Row parallel layers). + # img_mod and txt_mod use ReplicatedLinear — not replaced (broadcast, not sharded). + self.attn = ATOMQwenImageCrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + added_kv_proj_dim=dim, + context_pre_only=False, + ) + # Replace feedforward layers with ATOM versions. + self.img_mlp = ATOMFeedForward(dim=dim, dim_out=dim) + self.txt_mlp = ATOMFeedForward(dim=dim, dim_out=dim) + # forward() inherited from QwenImageTransformerBlock unchanged ✓ + + +class ATOMQwenImageTransformer2DModel(QwenImageTransformer2DModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Collect block constructor args from the already-built first block to stay DRY. + num_attention_heads = self.transformer_blocks[0].num_attention_heads + attention_head_dim = self.transformer_blocks[0].attention_head_dim + zero_cond_t = self.transformer_blocks[0].zero_cond_t + num_layers = len(self.transformer_blocks) + # Replace all QwenImageTransformerBlocks with ATOM versions. + # img_in, txt_in, time_text_embed, norm_out.linear, proj_out use ReplicatedLinear — kept. + import torch.nn as nn + self.transformer_blocks = nn.ModuleList([ + ATOMQwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + zero_cond_t=zero_cond_t, + ) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan, _repeated_blocks all inherited ✓ diff --git a/atom/plugin/vllm_omni/diffusion/models/wan2_2/__init__.py b/atom/plugin/vllm_omni/diffusion/models/wan2_2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py new file mode 100644 index 000000000..47c9dfbe3 --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import ( + ColumnParallelGELU, + WanCrossAttention, + WanFeedForward, + WanSelfAttention, + WanTransformerBlock, + WanTransformer3DModel, + apply_rotary_emb_wan, +) + +logger = init_logger(__name__) + + +class ATOMWanCrossAttention(WanCrossAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Replace vllm ColumnParallelLinear with atom versions + self.to_q = ColumnParallelLinear(self.dim, self.inner_dim, bias=True) + self.to_k = ColumnParallelLinear(self.dim, self.kv_inner_dim, bias=True) + self.to_v = ColumnParallelLinear(self.dim, self.kv_inner_dim, bias=True) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = ColumnParallelLinear(self.added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = ColumnParallelLinear(self.added_kv_proj_dim, self.inner_dim, bias=True) + else: + self.add_k_proj = None + self.add_v_proj = None + self.norm_added_k = None + + # Replace vllm RowParallelLinear with atom version + self.to_out = RowParallelLinear(self.inner_dim, self.dim, bias=True) + # Inherited forward() works: atom Col/RowParallelLinear.forward() returns plain tensor, + # same as vllm with return_bias=False. + + +class ATOMWanSelfAttention(WanSelfAttention): + + def __init__(self, dim: int, num_heads: int, head_dim: int, eps: float = 1e-5, dropout: float = 0.0): + super().__init__(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps, dropout=dropout) + # Replace vllm QKVParallelLinear with atom version + self.to_qkv = QKVParallelLinear( + hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + # Refresh head counts from the atom layer + self.num_heads = self.to_qkv.num_heads + self.num_kv_heads = self.to_qkv.num_kv_heads + # Replace vllm RowParallelLinear with atom version + self.to_out = RowParallelLinear(self.inner_dim, dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + print("ATOMWanSelfAttention forward") + # CRITICAL: atom QKVParallelLinear returns a plain tensor; + # the stock WanSelfAttention.forward() does `qkv, _ = self.to_qkv(x)` (tuple unpack). + qkv = self.to_qkv(hidden_states) + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + query = self.norm_q(query) + key = self.norm_k(key) + query = query.unflatten(2, (self.num_heads, self.head_dim)) + key = key.unflatten( 2, (self.num_kv_heads, self.head_dim)) + value = value.unflatten(2, (self.num_kv_heads, self.head_dim)) + + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + query = apply_rotary_emb_wan(query, freqs_cos, freqs_sin) + key = apply_rotary_emb_wan(key, freqs_cos, freqs_sin) + + attn_metadata = AttentionMetadata(attn_mask=attn_mask) if attn_mask is not None else None + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).type_as(query) + hidden_states = self.to_out(hidden_states) # atom RowParallelLinear: tensor + all-reduce + return self.dropout(hidden_states) + + +class ATOMWanFeedForward(WanFeedForward): + + def __init__(self, dim: int, inner_dim: int, dim_out: int | None = None, bias: bool = True): + super().__init__(dim=dim, inner_dim=inner_dim, dim_out=dim_out, bias=bias) + dim_out = dim_out or dim + # Replace net_0.proj (inside ColumnParallelGELU) with atom ColumnParallelLinear. + # ColumnParallelGELU.forward() calls self.proj(x) expecting a plain tensor — + # atom ColumnParallelLinear.forward() satisfies this (no tuple). + self.net_0.proj = ColumnParallelLinear(dim, inner_dim, bias=bias) + # Replace net_2 with atom RowParallelLinear. + self.net_2 = RowParallelLinear(inner_dim, dim_out, bias=bias) + # forward() inherited from WanFeedForward: net_0 → net_1 (Identity) → net_2 + + +class ATOMWanTransformerBlock(WanTransformerBlock): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + cross_attn_norm: bool = False, + ): + super().__init__( + dim=dim, ffn_dim=ffn_dim, num_heads=num_heads, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, cross_attn_norm=cross_attn_norm, + ) + head_dim = dim // num_heads + self.attn1 = ATOMWanSelfAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.attn2 = ATOMWanCrossAttention( + dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + self.ffn = ATOMWanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) + # forward() inherited from WanTransformerBlock unchanged + + +class ATOMWanTransformer3DModel(WanTransformer3DModel): + + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + eps: float = 1e-6, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + ): + super().__init__( + patch_size=patch_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + text_dim=text_dim, + freq_dim=freq_dim, + ffn_dim=ffn_dim, + num_layers=num_layers, + cross_attn_norm=cross_attn_norm, + eps=eps, + image_dim=image_dim, + added_kv_proj_dim=added_kv_proj_dim, + rope_max_seq_len=rope_max_seq_len, + pos_embed_seq_len=pos_embed_seq_len, + ) + inner_dim = num_attention_heads * attention_head_dim + # Replace all WanTransformerBlocks with ATOMWanTransformerBlocks. + # rope, patch_embedding, condition_embedder, norm_out, proj_out are kept from super(). + self.blocks = nn.ModuleList([ + ATOMWanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, eps, added_kv_proj_dim, cross_attn_norm + ) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan, _repeated_blocks all inherited from WanTransformer3DModel diff --git a/atom/plugin/vllm_omni/platform.py b/atom/plugin/vllm_omni/platform.py index 80dc00292..de1ecedf0 100644 --- a/atom/plugin/vllm_omni/platform.py +++ b/atom/plugin/vllm_omni/platform.py @@ -57,7 +57,7 @@ def get_diffusion_attn_backend_cls( logger.info("Use atom diffusion attention backend") return ( - "atom.plugin.vllm_omni.diffusion_attention_backend.flash_attn" + "atom.plugin.vllm_omni.diffusion.attention_backend.flash_attn" ".AiterFlashAttentionBackend" ) diff --git a/atom/plugin/vllm_omni/register.py b/atom/plugin/vllm_omni/register.py index c787f0298..770d80397 100644 --- a/atom/plugin/vllm_omni/register.py +++ b/atom/plugin/vllm_omni/register.py @@ -1,13 +1,10 @@ from typing import Optional import logging +import torch from atom.plugin.prepare import _set_framework_backbone from atom.utils import envs -from atom.plugin.vllm.mla_patch import patch_vllm_mla_attention -from atom.plugin.vllm.register import ( - _patch_vllm_attention_process_weights_after_loading, - _VLLM_MODEL_REGISTRY_OVERRIDES, -) + logger = logging.getLogger("atom") @@ -15,6 +12,77 @@ disable_vllm_plugin = envs.ATOM_DISABLE_VLLM_PLUGIN +_VLLM_OMNI_DIFFUSION_MODEL_REGISTRY_OVERRIDES = { + +} + +def _ensure_atom_config_for_diffusion(od_config) -> None: + """Set a minimal ATOM config if not already set, so LinearBase.__init__ can read torch_dtype. + + In the vLLM OOT LLM plugin, generate_atom_config_for_plugin_mode(vllm_config) sets this + inside ATOMModelBase.__init__. For diffusion models, no full VllmConfig exists, so we + construct a lightweight stand-in from od_config.dtype. + + Only torch_dtype is accessed from the config in the diffusion construction path + (LinearBase.__init__ line 263, bias tensor allocation). A SimpleNamespace suffices. + """ + import atom.config as _atom_cfg + if _atom_cfg._current_atom_config is not None: + return # Already set (e.g. vLLM OOT LLM plugin ran first) + + import types + torch_dtype = getattr(od_config, "dtype", torch.bfloat16) + _atom_cfg.set_current_atom_config(types.SimpleNamespace(torch_dtype=torch_dtype)) + logger.info(f"ATOM: set minimal diffusion atom config (torch_dtype={torch_dtype})") + + +def _ensure_aiter_tp_initialized() -> None: + """Reuse vLLM's TP group for aiter if not already initialized. + + Mirrors init_aiter_dist() in the vLLM OOT plugin (called from ATOMModelBase.__init__). + Called lazily at model-load time via the wrapped initialize_model, so vLLM's TP + group is guaranteed to be ready. One central call covers all diffusion models. + """ + from aiter.dist import parallel_state as aiter_ps + if aiter_ps._TP is not None: + return # Already initialized (e.g. regular vLLM plugin path ran first) + + import vllm.distributed.parallel_state as vllm_ps + tp_size = vllm_ps.get_tensor_model_parallel_world_size() + + from atom.plugin.vllm.tp_group_reuse import init_aiter_tp_from_vllm + if init_aiter_tp_from_vllm(tp_size): + return # TP>1: reused vLLM's group + aiter ca_comm (optimal path) + + # Fallback for TP=1 or no ca_comm: minimal adapter backed by vLLM's ProcessGroups. + # LinearBase.forward() never calls all_reduce when tp_size==1 (guarded by tp_size>1). + from aiter.dist.parallel_state import GroupCoordinator as AiterGroupCoordinator, _register_group + vllm_tp = vllm_ps.get_tp_group() + + class _AiterTPFromVllm(AiterGroupCoordinator): + def __init__(self): + # Skip GroupCoordinator.__init__ to avoid creating new ProcessGroups. + self.unique_name = "tp:0" + _register_group(self) + self.rank = vllm_tp.rank + self.local_rank = vllm_tp.local_rank + self.ranks = vllm_tp.ranks + self.world_size = vllm_tp.world_size + self.rank_in_group = vllm_tp.rank_in_group + self.cpu_group = vllm_tp.cpu_group + self.device_group = vllm_tp.device_group + self.device = vllm_tp.device + self.use_device_communicator = False + self.device_communicator = None + self.mq_broadcaster = None + + aiter_ps._TP = _AiterTPFromVllm() + logger.info( + "ATOM: initialized aiter TP group from vLLM " + f"(world_size={vllm_tp.world_size}, rank={vllm_tp.rank_in_group})" + ) + + def register_omni_platform() -> Optional[str]: if disable_vllm_plugin: @@ -32,36 +100,57 @@ def register_omni_model() -> None: logger.info("Disable ATOM model register (vllm-omni)") return - from vllm_omni.model_executor.models.registry import OmniModelRegistry - import vllm.model_executor.models.registry as vllm_model_registry - - any_updated = False - for arch, qual in _VLLM_MODEL_REGISTRY_OVERRIDES.items(): - module_name, class_name = qual.split(":", 1) - existing = OmniModelRegistry.models.get(arch) - if existing is not None: - # If already overridden to the same target, skip re-registering. - if ( - getattr(existing, "module_name", None) == module_name - and getattr(existing, "class_name", None) == class_name - ): - continue - - logger.info(f"Register model {arch} to vLLM-Omni with {qual}") - OmniModelRegistry.register_model(arch, qual) - any_updated = True - - # clear lru cache - if any_updated: - vllm_model_registry._try_load_model_cls.cache_clear() - vllm_model_registry._try_inspect_model_cls.cache_clear() - - patch_vllm_mla_attention() - # patch attention process weights after loading try: - from vllm.attention.layer import Attention, MLAAttention - except ImportError: - from vllm.model_executor.layers.attention import Attention, MLAAttention + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image as _qwen_t2i + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit as _qwen_edit + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus as _qwen_edit_plus + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered as _qwen_layered + from atom.plugin.vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + ATOMQwenImageTransformer2DModel, + ) + # Each pipeline has already captured QwenImageTransformer2DModel as a local binding + # (via vllm_omni/diffusion/models/qwen_image/__init__.py eager import). Patching the + # source transformer module is too late — we must patch each pipeline's local binding. + for _m in [_qwen_t2i, _qwen_edit, _qwen_edit_plus, _qwen_layered]: + _m.QwenImageTransformer2DModel = ATOMQwenImageTransformer2DModel + logger.info("Patched QwenImageTransformer2DModel → ATOMQwenImageTransformer2DModel in qwen_image pipelines") + except ImportError as e: + logger.warning(f"Could not patch qwen_image pipelines with ATOM transformer: {e}") - _patch_vllm_attention_process_weights_after_loading(Attention) - _patch_vllm_attention_process_weights_after_loading(MLAAttention) + try: + from atom.plugin.vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import ( + ATOMWanTransformer3DModel, + ) + + # Approach 1: works + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as pipeline_wan2_2 + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v as pipeline_wan2_2_i2v + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v as pipeline_wan2_2_ti2v + pipeline_wan2_2.WanTransformer3DModel = ATOMWanTransformer3DModel + pipeline_wan2_2_i2v.WanTransformer3DModel = ATOMWanTransformer3DModel + pipeline_wan2_2_ti2v.WanTransformer3DModel = ATOMWanTransformer3DModel + + # Approach 2: doesn't work + # import vllm_omni.diffusion.models.wan2_2.wan2_2_transformer as _wan2_2_transformer + # _wan2_2_transformer.WanTransformer3DModel = ATOMWanTransformer3DModel # doesn work + logger.info("Patched WanTransformer3DModel → ATOMWanTransformer3DModel in wan2_2 pipelines") + except ImportError as e: + logger.warning(f"Could not patch wan2_2 pipelines with ATOM transformer: {e}") + + # Wrap initialize_model to call aiter TP init before every diffusion model is loaded. + # Mirrors ATOMModelBase.__init__ → _prepare_env() in the vLLM OOT plugin: + # one central point covers all diffusion models, no per-model initialization needed. + # + # Must patch diffusers_loader (the call site), not registry (the definition site): + # diffusers_loader does `from vllm_omni.diffusion.registry import initialize_model`, + # creating a local binding that is unaffected by patching the registry module. + import vllm_omni.diffusion.model_loader.diffusers_loader as _diffusers_loader + _orig_initialize_model = _diffusers_loader.initialize_model + + def _atom_initialize_model(od_config): + _ensure_aiter_tp_initialized() + _ensure_atom_config_for_diffusion(od_config) + return _orig_initialize_model(od_config) + + _diffusers_loader.initialize_model = _atom_initialize_model + logger.info("Wrapped vllm_omni initialize_model with ATOM aiter TP initialization")