Skip to content

[feat] Add support for Qwen3.5 and Qwen3-Next to ATOM-plugined SGLang#532

Open
wanzhenchn wants to merge 8 commits intoROCm:mainfrom
wanzhenchn:feat/qwen3.5-sgl-plugin-final
Open

[feat] Add support for Qwen3.5 and Qwen3-Next to ATOM-plugined SGLang#532
wanzhenchn wants to merge 8 commits intoROCm:mainfrom
wanzhenchn:feat/qwen3.5-sgl-plugin-final

Conversation

@wanzhenchn
Copy link
Copy Markdown

@wanzhenchn wanzhenchn commented Apr 9, 2026

Motivation

Background: ROCm/ATOM#355 and ROCm/ATOM#359.

PR #355 integrated ATOM with upstream SGLang through the SGLANG_EXTERNAL_MODEL_PACKAGE out-of-tree mechanism, replacing a fork-based workflow and establishing atom.plugin.sglang.models as the external entry package for ATOM-backed architectures.

Building on that foundation, this PR extends the SGLang plugin path so that two major ATOM model families—Qwen3-next (Qwen3NextForCausalLM) and Qwen3.5 (Qwen3_5ForConditionalGeneration/Qwen3_5MoeForConditionalGeneration)—can run as first-class external models inside SGLang. The goal is parity with prior ATOM-in-SGLang accuracy while improving end-to-end inference performance on the supported paths (e.g. ATOM’s fused kernels, quantization, and MLA / MoE handling tuned for ROCm), without requiring a patched SGLang tree—users continue to point SGLANG_EXTERNAL_MODEL_PACKAGE at atom.plugin.sglang.models and launch with standard upstream sglang.launch_server.

Technical Details

  • Qwen3-next

    • Qwen3NextForCausalLM is registered under atom.plugin.sglang.models and subclasses _AtomCausalLMBaseForSglang, reusing the same SGLang-facing contract as other OOT entry points: the wrapper calls prepare_model(..., engine="sglang") to build the ATOM weight stack, runs the language model forward with pipeline-parallel state mapped from pp_proxy_tensors, applies LogitsProcessor on the last PP rank, and loads weights via load_model_in_plugin_mode.
    • The linear-attention (GDN) path is wrapped by Qwen3NextSglangModel plus sglang_gdn_bridge so GDN layers see the SGLang forward_batch context they expect. At prepare time, apply_qwen3_next_sglang_model_patch swaps atom.models.qwen3_next.Qwen3NextModel to that bridged implementation; the shared prepare hook defaults ATOM_SGLANG_USE_NATIVE_AITER_ATTN_BACKEND for Qwen3NextForCausalLM before register_ops_to_sglang.
  • Qwen3.5

    • Qwen3.5 text / MoE / multimodal stacks reuse SGLang’s in-tree Qwen3_5* container classes while the language tower is still constructed through ATOM prepare_model. apply_prepare_model_adaptations applies Qwen3.5-only config fixes (e.g. MoE text fields, quant remaps for ROCm) and _apply_qwen35_sglang_model_patch rebinds atom.models.qwen3_5.Qwen3_5Model to the SGLang-specific implementation.
    • Weights are not applied through SGLang’s default load_weights iterator; instead load_model_in_plugin_mode drives ATOM’s loader, with _QWEN35_OOT_PACKED_MODULES_MAPPING (and related mapping) kept aligned to SGLang’s packed/fused parameter layout so FP8 and fused checkpoints behave like the in-tree model.
    • The same prepare hook defaults native Aiter attention backend selection for Qwen3_5* (and Qwen3-next) unless the user has already set the env var.

How to Run

export SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models
python3 -m sglang.launch_server --model-path <model_path> ...

The following models have been supported:

Accuracy

# launch server
#! /usr/bin/bash

set -euxo pipefail

export PYTHONPATH=/opt/sglang/python
export SGLANG_DISABLE_CUDNN_CHECK=1

if [ $# -lt 4 ]; then
  echo "Usage: $0 <port> <model_path> <device_id> <enable_atom>"
  exit 1
fi

port=$1
model_path=$2
device_id=$3
enable_atom=$4 # true or false

if [ "${enable_atom}" == true ]; then
  export SGLANG_EXTERNAL_MODEL_PACKAGE=atom.plugin.sglang.models
fi

tp=$(echo "$device_id" |grep -o "[0-9]" |grep -c "")

HIP_VISIBLE_DEVICES=${device_id} python3 -m sglang.launch_server \
  --model-path ${model_path} --port ${port} \
  --tensor-parallel-size ${tp} --mem-fraction-static 0.9 \
  --reasoning-parser qwen3 --disable-radix-cache
model_path=$1
port=$2
task=$3

models_url="http://localhost:${port}/v1/models"
echo "Waiting for OpenAI-compatible server at ${models_url} ..." >&2
until curl -sf --connect-timeout 2 --max-time 10 "${models_url}" >/dev/null; do
  sleep 2
done
echo "Server is up; starting lm_eval." >&2

lm_eval --model local-completions \
        --model_args model=${model_path},base_url=http://localhost:${port}/v1/completions,num_concurrent=64,max_retries=3,max_gen_toks=2048,tokenized_requests=False \
        --tasks ${task} \
        --batch_size auto \
        --trust_remote_code
  • Qwen3.5-397B-A17B-FP8

    • ATOM + SGLang
    image
    • SGLang only
    image
  • Qwen3-Next-80B-A3B-Instruct

    • ATOM + SGLang
    image
    • SGLang only
    image

Inference Perf.

  • Qwen3.5-397B-A17B-FP8 on MI355X
image
  • Qwen3.5-397B-A17B-FP8 on MI308X
image

Comment thread atom/models/qwen3_next.py Outdated
Comment thread atom/models/qwen3_5.py Outdated
@wuhuikx wuhuikx requested a review from ZhiweiYan-96 April 9, 2026 08:04
@wanzhenchn wanzhenchn force-pushed the feat/qwen3.5-sgl-plugin-final branch 3 times, most recently from c856071 to a8e9882 Compare April 13, 2026 23:55
Comment thread atom/plugin/sglang/models/qwen3_5.py Outdated
Comment thread atom/model_ops/attention_gdn.py Outdated
Comment thread atom/model_ops/attention_gdn.py Outdated
@wanzhenchn wanzhenchn force-pushed the feat/qwen3.5-sgl-plugin-final branch from a8e9882 to 9295de9 Compare April 15, 2026 08:48
@wuhuikx wuhuikx marked this pull request as draft April 15, 2026 09:20
@wuhuikx
Copy link
Copy Markdown
Collaborator

wuhuikx commented Apr 15, 2026

I turn it to draft and will let it go through CI after we finishing the code review.

@ganyi1996ppo @Yuechguo please help review this PR.

@wanzhenchn wanzhenchn force-pushed the feat/qwen3.5-sgl-plugin-final branch from 99ece56 to 66194d5 Compare April 16, 2026 03:19
@wanzhenchn wanzhenchn force-pushed the feat/qwen3.5-sgl-plugin-final branch 2 times, most recently from 87e8531 to 9ae4b3d Compare April 17, 2026 09:26
@wanzhenchn wanzhenchn force-pushed the feat/qwen3.5-sgl-plugin-final branch from 9ae4b3d to a4b019b Compare April 17, 2026 09:28
@wanzhenchn wanzhenchn marked this pull request as ready for review April 17, 2026 09:29
Comment thread atom/plugin/sglang/models/qwen3_next.py
Comment thread atom/plugin/sglang/models/qwen3_5.py
Comment thread atom/model_ops/linear.py
Comment thread atom/model_ops/linear.py
fwd_ctx: ForwardContext = get_forward_context()
gdn_metadata: GDNAttentionMetadata = fwd_ctx.attn_metadata.gdn_metadata

gdn_metadata, conv_state, ssm_state = self._resolve_runtime_state(fwd_ctx)
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why abstract this part? former one looks more straight forward. and I don't see any reuse of it.

Copy link
Copy Markdown
Author

@wanzhenchn wanzhenchn Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • SGLang exposes the mamba/GDN conv cache as a row-major tensor shaped
    [slot, D, W].
  • ATOM's causal_conv1d_* kernels consume the same logical shape, but they
    require the feature dimension D to be contiguous in memory
    (stride(-2) == 1).

So we should have tensor layout transformation for conv state.

It has been called atom/plugin/sglang/attention_backend/attention_gdn.py

class GatedDeltaNet(AtomGatedDeltaNet):
"""SGLang adapter over the shared ATOM GDN implementation."""

def _resolve_runtime_state(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant defination seems?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is used to achieve the conv state tensor layout transformation after the GatedDeltaNet is initialized.

from atom.utils.decorators import TorchCompileWrapperWithCustomDispatcher


class Qwen3NextSglangAttention(_CoreQwen3NextAttention):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can trust the modeling already have in atom for the execution part. if there anything varying from want already have, you can just add a if branch, or abstract into a function.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't add this sglang/models/ folder, you can patch the qwen3_5.py or qwen3_next.py as you need

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants