diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index 367fd1c4f..86848d890 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -722,7 +722,8 @@ def forward_impl_plugin_mode( ) else: # Qwen only uses gluon pa decode when bs=64 - if num_decodes == _QWEN_GLUON_PA_DECODE_BS: + if False: + # if num_decodes == _QWEN_GLUON_PA_DECODE_BS: self.paged_attention_triton_plugin_mode( q=query[:num_decode_tokens], k_cache=new_key_cache, diff --git a/atom/plugin/vllm_omni/__init__.py b/atom/plugin/vllm_omni/__init__.py new file mode 100644 index 000000000..9009e4b4a --- /dev/null +++ b/atom/plugin/vllm_omni/__init__.py @@ -0,0 +1,5 @@ +"""vLLM-Omni plugin integration for ATOM.""" + +from .register import register_omni_model, register_omni_platform + +__all__ = ["register_omni_platform", "register_omni_model"] diff --git a/atom/plugin/vllm_omni/diffusion/attention_backend/flash_attn.py b/atom/plugin/vllm_omni/diffusion/attention_backend/flash_attn.py new file mode 100644 index 000000000..5cf05be8f --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/attention_backend/flash_attn.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from vllm.logger import init_logger + +from vllm_omni.diffusion.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) + +logger = init_logger(__name__) + + +class AiterFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @classmethod + def supports_attention_mask(cls) -> bool: + return True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 96, 128, 192, 256] + + @staticmethod + def get_name() -> str: + return "AITER_DIFFUSION_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> type["AiterDiffusionFlashAttentionImpl"]: + return AiterDiffusionFlashAttentionImpl + + +class AiterDiffusionFlashAttentionImpl(AttentionImpl): + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.causal = causal + self.softmax_scale = softmax_scale + + @staticmethod + def _unwrap_flash_output(out: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor: + # FA3 may return (out, lse), FA2 returns out + return out[0] if isinstance(out, tuple) else out + + def _forward_varlen_masked( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + ) -> torch.Tensor: + from aiter import flash_attn_varlen_func + from vllm_omni.diffusion.attention.backends.utils.fa import ( + _pad_input, + _unpad_input, + _upad_input, + ) + + assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)" + query_length = query.size(1) + q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( + query, key, value, attention_mask, query_length, _unpad_input + ) + + out_unpad = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **{ + "causal": self.causal, + "softmax_scale": self.softmax_scale, + }, + ) + out_unpad = self._unwrap_flash_output(out_unpad) + return _pad_input(out_unpad, indices_q, query.size(0), query_length) + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + from aiter import flash_attn_func + + attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None + + if attention_mask is not None and torch.any(~attention_mask): + return self._forward_varlen_masked( + query, + key, + value, + attention_mask, + ) + + out = flash_attn_func( + query, + key, + value, + causal=self.causal, + softmax_scale=self.softmax_scale, + ) + return self._unwrap_flash_output(out) diff --git a/atom/plugin/vllm_omni/diffusion/models/README.md b/atom/plugin/vllm_omni/diffusion/models/README.md new file mode 100644 index 000000000..275a46a2c --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/README.md @@ -0,0 +1,174 @@ +# ATOM vLLM-Omni Diffusion Model Plugin + +Models under this directory run with the **vLLM-Omni plugin** — they cannot run standalone with native ATOM. For native ATOM models, see `atom/models/` instead. + +## What the Plugin Does + +The ATOM plugin replaces vLLM's linear layers (`vllm.model_executor.layers.linear`) with ATOM's AITER-accelerated equivalents (`atom.model_ops.linear`), enabling ROCm-optimized quantized GEMM kernels for diffusion model inference. + +The plugin hooks into vllm-omni at startup via `register_omni_model()` in `atom/plugin/vllm_omni/register.py`. It uses **monkey-patching** rather than registering new pipeline classes: the stock vllm-omni pipelines are left in place, but the transformer class they instantiate is swapped out before any model is loaded. + +--- + +## How to Add a New Model + +Follow the pattern used for Wan2.2 in `wan2_2/wan2_2_transformer.py`. + +### Step 1: Identify what to replace + +Open the stock vllm-omni transformer file for your model (e.g. `vllm_omni/diffusion/models//`). Look for uses of: + +```python +from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +``` + +These are the layers to replace with their `atom.model_ops.linear` equivalents. + +### Step 2: Create an ATOM transformer file + +Create `atom/plugin/vllm_omni/diffusion/models//` and add a `_transformer.py`. + +**Import pattern:** + +```python +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm_omni.diffusion.models.._transformer import ( + StockSelfAttention, + StockCrossAttention, + StockFeedForward, + StockTransformerBlock, + StockTransformerModel, + # any helper functions needed in forward() overrides +) +``` + +**For each layer class** that uses vLLM linears, create an ATOM subclass: + +```python +class ATOMStockSelfAttention(StockSelfAttention): + def __init__(self, ...): + super().__init__(...) + # Replace linear layers after super().__init__() creates the vllm ones + self.to_qkv = QKVParallelLinear(hidden_size=dim, head_size=head_dim, + total_num_heads=num_heads, bias=True) + self.num_heads = self.to_qkv.num_heads # refresh from atom layer + self.num_kv_heads = self.to_qkv.num_kv_heads + self.to_out = RowParallelLinear(inner_dim, dim, bias=True) +``` + +**Check if `forward()` needs an override.** Two cases require it: + +| Situation | What to do | +|-----------|-----------| +| Stock `forward()` does `out, _ = self.layer(x)` (tuple unpack) | Override `forward()` — atom layers return a plain tensor, not `(tensor, None)` | +| Stock `forward()` does `out = self.layer(x)` | No override needed — atom and vllm (with `return_bias=False`) both return plain tensors | + +The `QKVParallelLinear` case always requires an override because vLLM returns a tuple: + +```python + def forward(self, hidden_states, ...): + # atom returns plain tensor; vllm returns (tensor, None) + qkv = self.to_qkv(hidden_states) # NOT: qkv, _ = self.to_qkv(hidden_states) + ... +``` + +**For feedforward layers** that wrap `ColumnParallelLinear` inside a helper (e.g. `ColumnParallelGELU`), replace the inner `.proj` attribute: + +```python +class ATOMStockFeedForward(StockFeedForward): + def __init__(self, dim, inner_dim, dim_out=None, bias=True): + super().__init__(dim=dim, inner_dim=inner_dim, dim_out=dim_out, bias=bias) + dim_out = dim_out or dim + self.net_0.proj = ColumnParallelLinear(dim, inner_dim, bias=bias) + self.net_2 = RowParallelLinear(inner_dim, dim_out, bias=bias) + # forward() inherited — helper's forward() calls self.proj(x) → plain tensor ✓ +``` + +**Compose into a block and top-level model:** + +```python +class ATOMStockTransformerBlock(StockTransformerBlock): + def __init__(self, dim, ffn_dim, num_heads, eps=1e-6, ...): + super().__init__(...) + head_dim = dim // num_heads + self.attn1 = ATOMStockSelfAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.attn2 = ATOMStockCrossAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.ffn = ATOMStockFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) + # forward() inherited from StockTransformerBlock unchanged + +class ATOMStockTransformerModel(StockTransformerModel): + def __init__(self, ..., num_layers=N, ...): + super().__init__(...) # builds rope, embeddings, norm, proj_out + inner_dim = num_attention_heads * attention_head_dim + # Replace all blocks after super() creates the stock ones + self.blocks = nn.ModuleList([ + ATOMStockTransformerBlock(inner_dim, ffn_dim, num_attention_heads, eps, ...) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan all inherited from StockTransformerModel +``` + +### Step 3: Register via monkey-patch in `register.py` + +Open `atom/plugin/vllm_omni/register.py` and add to the monkey-patch block at the end of `register_omni_model()`: + +```python +import vllm_omni.diffusion.models..pipeline_ as __pipeline +from atom.plugin.vllm_omni.diffusion.models.._transformer import ATOMTransformerModel +__pipeline. = ATOMTransformerModel +``` + +Python resolves module-level names at call time, so patching the name in the pipeline module's namespace causes all subsequent `create_transformer_from_config()` calls to instantiate the ATOM model — no pipeline file copies needed. + +**You only need to patch the base pipeline module.** If variant pipelines (e.g. i2v, ti2v) import `create_transformer_from_config` *from* the base pipeline rather than defining their own, they will automatically pick up the patch — patching the same name twice in different modules would be redundant. Check the variant pipeline's imports to confirm: + +```python +# If you see this in pipeline__i2v.py, one patch covers all variants: +from vllm_omni.diffusion.models..pipeline_ import create_transformer_from_config +``` + +**Do not copy pipeline files.** If the stock pipeline needs no changes beyond the transformer class swap, patching is sufficient. Only create a new pipeline class if you need to change the pipeline's own logic (e.g. different preprocessing, scheduler, or VAE). + +### Step 4: Update `__init__.py` + +Add your model's ATOM transformer class to `atom/plugin/vllm_omni/diffusion/models//__init__.py` (if the directory needs one). Re-export stock pipeline helpers from `vllm_omni` directly rather than copying them. + +--- + +## API Compatibility Notes + +### `atom.model_ops.linear` vs `vllm.model_executor.layers.linear` + +| vLLM class | ATOM equivalent | Notes | +|---|---|---| +| `ColumnParallelLinear(in, out, bias, gather_output=False, return_bias=False)` | `ColumnParallelLinear(in, out, bias)` | Extra kwargs absorbed via `**kwargs`, silently ignored | +| `RowParallelLinear(in, out, bias, input_is_parallel=True, return_bias=False)` | `RowParallelLinear(in, out, bias)` | `reduce_results=True` by default — matches vLLM behavior | +| `QKVParallelLinear(hidden_size, head_size, total_num_heads, bias)` | `QKVParallelLinear(hidden_size, head_size, total_num_heads, bias)` | Same constructor; **different return type** (see below) | + +### Critical: `QKVParallelLinear` return type difference + +```python +# vLLM: returns (tensor, None) tuple +qkv, _ = self.to_qkv(hidden_states) + +# ATOM: returns plain tensor — must NOT unpack +qkv = self.to_qkv(hidden_states) +``` + +`ColumnParallelLinear` and `RowParallelLinear` forward signatures are compatible — both return a plain tensor when vLLM's `return_bias=False` (the standard config for diffusion models). + +### `atom.model_ops.linear` forward signature + +```python +def forward(self, x: Tensor, x_scale: Tensor | None = None, otype=bf16) -> Tensor +``` + +Calling `layer(x)` works as expected; `x_scale` and `otype` are used for quantized inference and default safely to unquantized bfloat16. + +--- + +## Current Models + +| Model | Transformer file | Registered via | +|-------|-----------------|----------------| +| Wan2.2 (T2V / I2V / TI2V) | `wan2_2/wan2_2_transformer.py` | monkey-patch in `register.py` | diff --git a/atom/plugin/vllm_omni/diffusion/models/qwen_image/__init__.py b/atom/plugin/vllm_omni/diffusion/models/qwen_image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py new file mode 100644 index 000000000..0f6a45464 --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +from vllm.logger import init_logger + +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + FeedForward, + QwenImageCrossAttention, + QwenImageTransformerBlock, + QwenImageTransformer2DModel, +) + +logger = init_logger(__name__) + + +class ATOMFeedForward(FeedForward): + + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + quant_config=None, + prefix: str = "", + ): + super().__init__( + dim=dim, dim_out=dim_out, mult=mult, activation_fn=activation_fn, + inner_dim=inner_dim, bias=bias, quant_config=quant_config, prefix=prefix, + ) + inner_dim_val = inner_dim or int(dim * mult) + dim_out_val = dim_out or dim + # Replace ColumnParallelApproxGELU's inner proj with ATOM ColumnParallelLinear. + # ColumnParallelApproxGELU.forward() calls self.proj(x) → plain tensor ✓ + self.net[0].proj = ColumnParallelLinear(dim, inner_dim_val, bias=bias) + # Replace net[2] (RowParallelLinear) with ATOM version. + self.net[2] = RowParallelLinear(inner_dim_val, dim_out_val, bias=bias) + # forward() inherited: iterates self.net ✓ + + +class ATOMQwenImageCrossAttention(QwenImageCrossAttention): + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), + out_bias: bool = True, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, + context_pre_only: bool = False, + out_dim: int | None = None, + quant_config=None, + ): + super().__init__( + dim=dim, num_heads=num_heads, head_dim=head_dim, + added_kv_proj_dim=added_kv_proj_dim, window_size=window_size, + out_bias=out_bias, qk_norm=qk_norm, eps=eps, pre_only=pre_only, + context_pre_only=context_pre_only, out_dim=out_dim, quant_config=quant_config, + ) + # Replace vLLM QKVParallelLinear with ATOM versions; refresh head counts. + self.to_qkv = QKVParallelLinear( + hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + inner_dim = out_dim if out_dim is not None else head_dim * num_heads + # Replace vLLM RowParallelLinear with ATOM versions. + self.to_out = RowParallelLinear(inner_dim, dim, bias=out_bias) + self.to_add_out = RowParallelLinear(inner_dim, dim, bias=out_bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + vid_freqs: torch.Tensor, + txt_freqs: torch.Tensor, + hidden_states_mask: torch.Tensor | None = None, + encoder_hidden_states_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + print("ATOMQwenImageCrossAttention forward") + # CRITICAL: ATOM QKVParallelLinear returns a plain tensor; vLLM returns (tensor, None). + img_qkv = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) + + txt_qkv = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) + + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten( -1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) + + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten( -1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + img_cos = vid_freqs.real.to(img_query.dtype) + img_sin = vid_freqs.imag.to(img_query.dtype) + txt_cos = txt_freqs.real.to(txt_query.dtype) + txt_sin = txt_freqs.imag.to(txt_query.dtype) + + img_query = self.rope(img_query, img_cos, img_sin) + img_key = self.rope(img_key, img_cos, img_sin) + txt_query = self.rope(txt_query, txt_cos, txt_sin) + txt_key = self.rope(txt_key, txt_cos, txt_sin) + + seq_len_txt = encoder_hidden_states.shape[1] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + if ( + self.parallel_config is not None + and self.parallel_config.sequence_parallel_size > 1 + and not get_forward_context().split_text_embed_in_sp + ): + attn_metadata = AttentionMetadata( + joint_query=txt_query, + joint_key=txt_key, + joint_value=txt_value, + joint_strategy="front", + ) + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + if encoder_hidden_states_mask is not None: + attn_metadata.joint_attn_mask = encoder_hidden_states_mask + + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) + else: + attn_metadata = None + if hidden_states_mask is not None or encoder_hidden_states_mask is not None: + mask_list: list[torch.Tensor] = [] + if encoder_hidden_states_mask is not None: + mask_list.append(encoder_hidden_states_mask) + else: + mask_list.append( + torch.ones( + encoder_hidden_states.shape[:2], + dtype=torch.bool, + device=encoder_hidden_states.device, + ) + ) + if hidden_states_mask is not None: + mask_list.append(hidden_states_mask) + else: + mask_list.append( + torch.ones( + hidden_states.shape[:2], + dtype=torch.bool, + device=hidden_states.device, + ) + ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] + attn_metadata = AttentionMetadata(attn_mask=joint_mask) + + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) + + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] + + # ATOM RowParallelLinear returns plain tensor + performs all-reduce ✓ + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class ATOMQwenImageTransformerBlock(QwenImageTransformerBlock): + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + zero_cond_t: bool = False, + quant_config=None, + ): + super().__init__( + dim=dim, num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, qk_norm=qk_norm, + eps=eps, zero_cond_t=zero_cond_t, quant_config=quant_config, + ) + # Replace joint cross-attention with ATOM version (QKV + Row parallel layers). + # img_mod and txt_mod use ReplicatedLinear — not replaced (broadcast, not sharded). + self.attn = ATOMQwenImageCrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + added_kv_proj_dim=dim, + context_pre_only=False, + ) + # Replace feedforward layers with ATOM versions. + self.img_mlp = ATOMFeedForward(dim=dim, dim_out=dim) + self.txt_mlp = ATOMFeedForward(dim=dim, dim_out=dim) + # forward() inherited from QwenImageTransformerBlock unchanged ✓ + + +class ATOMQwenImageTransformer2DModel(QwenImageTransformer2DModel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Collect block constructor args from the already-built first block to stay DRY. + num_attention_heads = self.transformer_blocks[0].num_attention_heads + attention_head_dim = self.transformer_blocks[0].attention_head_dim + zero_cond_t = self.transformer_blocks[0].zero_cond_t + num_layers = len(self.transformer_blocks) + # Replace all QwenImageTransformerBlocks with ATOM versions. + # img_in, txt_in, time_text_embed, norm_out.linear, proj_out use ReplicatedLinear — kept. + import torch.nn as nn + self.transformer_blocks = nn.ModuleList([ + ATOMQwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + zero_cond_t=zero_cond_t, + ) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan, _repeated_blocks all inherited ✓ diff --git a/atom/plugin/vllm_omni/diffusion/models/wan2_2/__init__.py b/atom/plugin/vllm_omni/diffusion/models/wan2_2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py new file mode 100644 index 000000000..47c9dfbe3 --- /dev/null +++ b/atom/plugin/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from atom.model_ops.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear + +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import ( + ColumnParallelGELU, + WanCrossAttention, + WanFeedForward, + WanSelfAttention, + WanTransformerBlock, + WanTransformer3DModel, + apply_rotary_emb_wan, +) + +logger = init_logger(__name__) + + +class ATOMWanCrossAttention(WanCrossAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Replace vllm ColumnParallelLinear with atom versions + self.to_q = ColumnParallelLinear(self.dim, self.inner_dim, bias=True) + self.to_k = ColumnParallelLinear(self.dim, self.kv_inner_dim, bias=True) + self.to_v = ColumnParallelLinear(self.dim, self.kv_inner_dim, bias=True) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = ColumnParallelLinear(self.added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = ColumnParallelLinear(self.added_kv_proj_dim, self.inner_dim, bias=True) + else: + self.add_k_proj = None + self.add_v_proj = None + self.norm_added_k = None + + # Replace vllm RowParallelLinear with atom version + self.to_out = RowParallelLinear(self.inner_dim, self.dim, bias=True) + # Inherited forward() works: atom Col/RowParallelLinear.forward() returns plain tensor, + # same as vllm with return_bias=False. + + +class ATOMWanSelfAttention(WanSelfAttention): + + def __init__(self, dim: int, num_heads: int, head_dim: int, eps: float = 1e-5, dropout: float = 0.0): + super().__init__(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps, dropout=dropout) + # Replace vllm QKVParallelLinear with atom version + self.to_qkv = QKVParallelLinear( + hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, bias=True, + ) + # Refresh head counts from the atom layer + self.num_heads = self.to_qkv.num_heads + self.num_kv_heads = self.to_qkv.num_kv_heads + # Replace vllm RowParallelLinear with atom version + self.to_out = RowParallelLinear(self.inner_dim, dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + print("ATOMWanSelfAttention forward") + # CRITICAL: atom QKVParallelLinear returns a plain tensor; + # the stock WanSelfAttention.forward() does `qkv, _ = self.to_qkv(x)` (tuple unpack). + qkv = self.to_qkv(hidden_states) + + q_size = self.num_heads * self.head_dim + kv_size = self.num_kv_heads * self.head_dim + query, key, value = qkv.split([q_size, kv_size, kv_size], dim=-1) + + query = self.norm_q(query) + key = self.norm_k(key) + query = query.unflatten(2, (self.num_heads, self.head_dim)) + key = key.unflatten( 2, (self.num_kv_heads, self.head_dim)) + value = value.unflatten(2, (self.num_kv_heads, self.head_dim)) + + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + query = apply_rotary_emb_wan(query, freqs_cos, freqs_sin) + key = apply_rotary_emb_wan(key, freqs_cos, freqs_sin) + + attn_metadata = AttentionMetadata(attn_mask=attn_mask) if attn_mask is not None else None + hidden_states = self.attn(query, key, value, attn_metadata) + hidden_states = hidden_states.flatten(2, 3).type_as(query) + hidden_states = self.to_out(hidden_states) # atom RowParallelLinear: tensor + all-reduce + return self.dropout(hidden_states) + + +class ATOMWanFeedForward(WanFeedForward): + + def __init__(self, dim: int, inner_dim: int, dim_out: int | None = None, bias: bool = True): + super().__init__(dim=dim, inner_dim=inner_dim, dim_out=dim_out, bias=bias) + dim_out = dim_out or dim + # Replace net_0.proj (inside ColumnParallelGELU) with atom ColumnParallelLinear. + # ColumnParallelGELU.forward() calls self.proj(x) expecting a plain tensor — + # atom ColumnParallelLinear.forward() satisfies this (no tuple). + self.net_0.proj = ColumnParallelLinear(dim, inner_dim, bias=bias) + # Replace net_2 with atom RowParallelLinear. + self.net_2 = RowParallelLinear(inner_dim, dim_out, bias=bias) + # forward() inherited from WanFeedForward: net_0 → net_1 (Identity) → net_2 + + +class ATOMWanTransformerBlock(WanTransformerBlock): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + cross_attn_norm: bool = False, + ): + super().__init__( + dim=dim, ffn_dim=ffn_dim, num_heads=num_heads, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, cross_attn_norm=cross_attn_norm, + ) + head_dim = dim // num_heads + self.attn1 = ATOMWanSelfAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps) + self.attn2 = ATOMWanCrossAttention( + dim=dim, num_heads=num_heads, head_dim=head_dim, eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + self.ffn = ATOMWanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) + # forward() inherited from WanTransformerBlock unchanged + + +class ATOMWanTransformer3DModel(WanTransformer3DModel): + + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + eps: float = 1e-6, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: int | None = None, + ): + super().__init__( + patch_size=patch_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + text_dim=text_dim, + freq_dim=freq_dim, + ffn_dim=ffn_dim, + num_layers=num_layers, + cross_attn_norm=cross_attn_norm, + eps=eps, + image_dim=image_dim, + added_kv_proj_dim=added_kv_proj_dim, + rope_max_seq_len=rope_max_seq_len, + pos_embed_seq_len=pos_embed_seq_len, + ) + inner_dim = num_attention_heads * attention_head_dim + # Replace all WanTransformerBlocks with ATOMWanTransformerBlocks. + # rope, patch_embedding, condition_embedder, norm_out, proj_out are kept from super(). + self.blocks = nn.ModuleList([ + ATOMWanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, eps, added_kv_proj_dim, cross_attn_norm + ) + for _ in range(num_layers) + ]) + # forward(), load_weights(), _sp_plan, _repeated_blocks all inherited from WanTransformer3DModel diff --git a/atom/plugin/vllm_omni/platform.py b/atom/plugin/vllm_omni/platform.py new file mode 100644 index 000000000..de1ecedf0 --- /dev/null +++ b/atom/plugin/vllm_omni/platform.py @@ -0,0 +1,65 @@ +"""ATOM vLLM-Omni platform integration. + +This module contains the vLLM-Omni `OmniPlatform` subclass used in ATOM's +vLLM-Omni plugin mode. Overrides both AR and diffusion attention backend +selection to use ATOM implementations. +""" + +import logging + +from atom.utils import envs + +logger = logging.getLogger("atom") +# This flag is used to enable the vLLM-Omni plugin mode. +disable_vllm_plugin = envs.ATOM_DISABLE_VLLM_PLUGIN +disable_vllm_plugin_attention = envs.ATOM_DISABLE_VLLM_PLUGIN_ATTENTION + +if not disable_vllm_plugin: + from vllm_omni.platforms.rocm.platform import RocmOmniPlatform + + class ATOMOmniPlatform(RocmOmniPlatform): + @classmethod + def get_attn_backend_cls( + cls, selected_backend, attn_selector_config, num_heads + ) -> str: + if disable_vllm_plugin_attention: + logger.info("Fallback to original vLLM attention backend") + return super().get_attn_backend_cls( + selected_backend, attn_selector_config, num_heads + ) + + logger.info("Use atom attention backend") + if attn_selector_config.use_mla: + return "atom.model_ops.attentions.aiter_mla.AiterMLABackend" + return "atom.model_ops.attentions.aiter_attention.AiterBackend" + + @classmethod + def get_diffusion_attn_backend_cls( + cls, selected_backend: str | None, head_size: int + ) -> str: + if disable_vllm_plugin_attention: + logger.info( + "Fallback to original vLLM-Omni diffusion attention backend" + ) + return super().get_diffusion_attn_backend_cls( + selected_backend, head_size + ) + + # Respect env var override for non-FLASH_ATTN backends + # (TORCH_SDPA, SAGE_ATTN, etc.) + if ( + selected_backend is not None + and selected_backend.upper() != "FLASH_ATTN" + ): + return super().get_diffusion_attn_backend_cls( + selected_backend, head_size + ) + + logger.info("Use atom diffusion attention backend") + return ( + "atom.plugin.vllm_omni.diffusion.attention_backend.flash_attn" + ".AiterFlashAttentionBackend" + ) + +else: + ATOMOmniPlatform = None diff --git a/atom/plugin/vllm_omni/register.py b/atom/plugin/vllm_omni/register.py new file mode 100644 index 000000000..770d80397 --- /dev/null +++ b/atom/plugin/vllm_omni/register.py @@ -0,0 +1,156 @@ +from typing import Optional +import logging +import torch + +from atom.plugin.prepare import _set_framework_backbone +from atom.utils import envs + + +logger = logging.getLogger("atom") + +# this flag is used to enable the vllm-omni plugin mode +disable_vllm_plugin = envs.ATOM_DISABLE_VLLM_PLUGIN + + +_VLLM_OMNI_DIFFUSION_MODEL_REGISTRY_OVERRIDES = { + +} + +def _ensure_atom_config_for_diffusion(od_config) -> None: + """Set a minimal ATOM config if not already set, so LinearBase.__init__ can read torch_dtype. + + In the vLLM OOT LLM plugin, generate_atom_config_for_plugin_mode(vllm_config) sets this + inside ATOMModelBase.__init__. For diffusion models, no full VllmConfig exists, so we + construct a lightweight stand-in from od_config.dtype. + + Only torch_dtype is accessed from the config in the diffusion construction path + (LinearBase.__init__ line 263, bias tensor allocation). A SimpleNamespace suffices. + """ + import atom.config as _atom_cfg + if _atom_cfg._current_atom_config is not None: + return # Already set (e.g. vLLM OOT LLM plugin ran first) + + import types + torch_dtype = getattr(od_config, "dtype", torch.bfloat16) + _atom_cfg.set_current_atom_config(types.SimpleNamespace(torch_dtype=torch_dtype)) + logger.info(f"ATOM: set minimal diffusion atom config (torch_dtype={torch_dtype})") + + +def _ensure_aiter_tp_initialized() -> None: + """Reuse vLLM's TP group for aiter if not already initialized. + + Mirrors init_aiter_dist() in the vLLM OOT plugin (called from ATOMModelBase.__init__). + Called lazily at model-load time via the wrapped initialize_model, so vLLM's TP + group is guaranteed to be ready. One central call covers all diffusion models. + """ + from aiter.dist import parallel_state as aiter_ps + if aiter_ps._TP is not None: + return # Already initialized (e.g. regular vLLM plugin path ran first) + + import vllm.distributed.parallel_state as vllm_ps + tp_size = vllm_ps.get_tensor_model_parallel_world_size() + + from atom.plugin.vllm.tp_group_reuse import init_aiter_tp_from_vllm + if init_aiter_tp_from_vllm(tp_size): + return # TP>1: reused vLLM's group + aiter ca_comm (optimal path) + + # Fallback for TP=1 or no ca_comm: minimal adapter backed by vLLM's ProcessGroups. + # LinearBase.forward() never calls all_reduce when tp_size==1 (guarded by tp_size>1). + from aiter.dist.parallel_state import GroupCoordinator as AiterGroupCoordinator, _register_group + vllm_tp = vllm_ps.get_tp_group() + + class _AiterTPFromVllm(AiterGroupCoordinator): + def __init__(self): + # Skip GroupCoordinator.__init__ to avoid creating new ProcessGroups. + self.unique_name = "tp:0" + _register_group(self) + self.rank = vllm_tp.rank + self.local_rank = vllm_tp.local_rank + self.ranks = vllm_tp.ranks + self.world_size = vllm_tp.world_size + self.rank_in_group = vllm_tp.rank_in_group + self.cpu_group = vllm_tp.cpu_group + self.device_group = vllm_tp.device_group + self.device = vllm_tp.device + self.use_device_communicator = False + self.device_communicator = None + self.mq_broadcaster = None + + aiter_ps._TP = _AiterTPFromVllm() + logger.info( + "ATOM: initialized aiter TP group from vLLM " + f"(world_size={vllm_tp.world_size}, rank={vllm_tp.rank_in_group})" + ) + + +def register_omni_platform() -> Optional[str]: + + if disable_vllm_plugin: + logger.info("Disable ATOM OOT plugin platforms (vllm-omni)") + return None + + _set_framework_backbone("vllm") + + # return the ATOM omni platform to vllm-omni + return "atom.plugin.vllm_omni.platform.ATOMOmniPlatform" + + +def register_omni_model() -> None: + if disable_vllm_plugin: + logger.info("Disable ATOM model register (vllm-omni)") + return + + try: + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image as _qwen_t2i + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit as _qwen_edit + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus as _qwen_edit_plus + import vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered as _qwen_layered + from atom.plugin.vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( + ATOMQwenImageTransformer2DModel, + ) + # Each pipeline has already captured QwenImageTransformer2DModel as a local binding + # (via vllm_omni/diffusion/models/qwen_image/__init__.py eager import). Patching the + # source transformer module is too late — we must patch each pipeline's local binding. + for _m in [_qwen_t2i, _qwen_edit, _qwen_edit_plus, _qwen_layered]: + _m.QwenImageTransformer2DModel = ATOMQwenImageTransformer2DModel + logger.info("Patched QwenImageTransformer2DModel → ATOMQwenImageTransformer2DModel in qwen_image pipelines") + except ImportError as e: + logger.warning(f"Could not patch qwen_image pipelines with ATOM transformer: {e}") + + try: + from atom.plugin.vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import ( + ATOMWanTransformer3DModel, + ) + + # Approach 1: works + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as pipeline_wan2_2 + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v as pipeline_wan2_2_i2v + import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v as pipeline_wan2_2_ti2v + pipeline_wan2_2.WanTransformer3DModel = ATOMWanTransformer3DModel + pipeline_wan2_2_i2v.WanTransformer3DModel = ATOMWanTransformer3DModel + pipeline_wan2_2_ti2v.WanTransformer3DModel = ATOMWanTransformer3DModel + + # Approach 2: doesn't work + # import vllm_omni.diffusion.models.wan2_2.wan2_2_transformer as _wan2_2_transformer + # _wan2_2_transformer.WanTransformer3DModel = ATOMWanTransformer3DModel # doesn work + logger.info("Patched WanTransformer3DModel → ATOMWanTransformer3DModel in wan2_2 pipelines") + except ImportError as e: + logger.warning(f"Could not patch wan2_2 pipelines with ATOM transformer: {e}") + + # Wrap initialize_model to call aiter TP init before every diffusion model is loaded. + # Mirrors ATOMModelBase.__init__ → _prepare_env() in the vLLM OOT plugin: + # one central point covers all diffusion models, no per-model initialization needed. + # + # Must patch diffusers_loader (the call site), not registry (the definition site): + # diffusers_loader does `from vllm_omni.diffusion.registry import initialize_model`, + # creating a local binding that is unaffected by patching the registry module. + import vllm_omni.diffusion.model_loader.diffusers_loader as _diffusers_loader + _orig_initialize_model = _diffusers_loader.initialize_model + + def _atom_initialize_model(od_config): + _ensure_aiter_tp_initialized() + _ensure_atom_config_for_diffusion(od_config) + return _orig_initialize_model(od_config) + + _diffusers_loader.initialize_model = _atom_initialize_model + logger.info("Wrapped vllm_omni initialize_model with ATOM aiter TP initialization") diff --git a/pyproject.toml b/pyproject.toml index 434c21ddf..714539ad3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,3 +37,9 @@ atom = "atom.plugin.vllm.register:register_platform" # but the plugin mode for models can be disabled by # ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1 atom_model_registry = "atom.plugin.vllm.register:register_model" + +[project.entry-points."vllm_omni.platform_plugins"] +atom = "atom.plugin.vllm_omni.register:register_omni_platform" + +[project.entry-points."vllm_omni.general_plugins"] +atom_model_registry = "atom.plugin.vllm_omni.register:register_omni_model" diff --git a/recipes/atom_vllmomni/Qwen-Image.md b/recipes/atom_vllmomni/Qwen-Image.md new file mode 100644 index 000000000..04e3fe5e0 --- /dev/null +++ b/recipes/atom_vllmomni/Qwen-Image.md @@ -0,0 +1,42 @@ +# Qwen-Image with ATOM vLLM-Omni Plugin Backend Usage Guide + +[Qwen-Image](https://huggingface.co/Qwen/Qwen-Image) is an image generation foundation model in the Qwen series developed by Alibaba. It achieves significant advances in complex text rendering and precise image editing. The model demonstrates strong general capabilities in both image generation and editing, with exceptional performance in text rendering. + +## Launching server + +### BF16 on 1xMI300X/MI355X GPUs + +```bash +vllm serve Qwen/Qwen-Image --omni \ + --host localhost \ + --port 8091 \ + --tensor-parallel-size 1 +``` + +### Interact with the model + +The command is extracted from https://github.com/vllm-project/vllm-omni/tree/main/examples/online_serving/text_to_image + +```python +from openai import OpenAI +import base64 + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="none") + +response = client.chat.completions.create( + model="Qwen/Qwen-Image", + messages=[{"role": "user", "content": "A beautiful landscape painting"}], + extra_body={ + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "true_cfg_scale": 4.0, + "seed": 42, + }, +) + +img_url = response.choices[0].message.content[0]["image_url"]["url"] +_, b64_data = img_url.split(",", 1) +with open("output.png", "wb") as f: + f.write(base64.b64decode(b64_data)) +``` \ No newline at end of file