Skip to content

Commit e5022ce

Browse files
committed
support w4a8(Decode)/C8/C8+TP4EP4/PD disaggregation + compatibility fixes
Squashed from 6 feature commits + 2 compatibility fix commits: - support w4a8(Decode) - support C8 KV cache quantization - support C8+TP4EP4 - bugfix C8 - bugfix pd+C8 - bugfix pd+mtp - fix: make weight_need_transpose conditional and remove hardcoded layer_id - fix: comprehensive compatibility fixes (Iluvatar platform, moe cast bug, mutable default, hardcoded magic number, unconditional XPU import, etc.)
1 parent f02b138 commit e5022ce

10 files changed

Lines changed: 239 additions & 28 deletions

File tree

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,14 @@ def allocated_slots(self, request: Request):
235235
return len(request.block_tables) * self.config.cache_config.block_size
236236

237237
def get_new_block_nums(self, request: Request, num_new_tokens: int):
238+
# Account for preallocated blocks that haven't been added to block_tables yet
239+
preallocated_count = len(getattr(request, "preallocated_blocks", []))
238240
block_num = (
239-
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
240-
) // self.config.cache_config.block_size - len(request.block_tables)
241+
(request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1)
242+
// self.config.cache_config.block_size
243+
- len(request.block_tables)
244+
- preallocated_count
245+
)
241246

242247
if self.config.speculative_config.method is not None:
243248
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
@@ -800,8 +805,14 @@ def get_enough_request(request, scheduled_reqs):
800805
self.allocated_slots(request) - request.num_total_tokens
801806
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
802807
):
808+
# First, consume any preallocated blocks before allocating new ones
809+
preallocated = getattr(request, "preallocated_blocks", [])
810+
if preallocated:
811+
request.block_tables.extend(preallocated)
812+
request.preallocated_blocks = []
813+
scheduled_reqs.append(self._prepare_decode_task(request))
803814
# Allocation for next decoding blocks
804-
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
815+
elif self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
805816
llm_logger.debug(
806817
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
807818
)
@@ -911,6 +922,12 @@ def _allocate_decode_and_extend():
911922
request.block_tables.extend(
912923
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
913924
)
925+
# Merge preallocated blocks (from PD disaggregation) into block_tables
926+
# so the attention kernel can access all reserved blocks.
927+
preallocated = getattr(request, "preallocated_blocks", [])
928+
if preallocated:
929+
request.block_tables.extend(preallocated)
930+
request.preallocated_blocks = []
914931
# Prepare prefill task
915932
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
916933
else: # Not enough blocks to allocate, trigger preemption
@@ -920,6 +937,11 @@ def _allocate_decode_and_extend():
920937
request.block_tables.extend(
921938
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
922939
)
940+
# Merge preallocated blocks (from PD disaggregation) into block_tables
941+
preallocated = getattr(request, "preallocated_blocks", [])
942+
if preallocated:
943+
request.block_tables.extend(preallocated)
944+
request.preallocated_blocks = []
923945
# Prepare prefill task
924946
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
925947
token_budget -= num_new_tokens
@@ -1403,9 +1425,10 @@ def preallocate_resource_in_d(self, request: Request):
14031425
"""
14041426
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
14051427
request.need_prefill_tokens = len(request.prompt_token_ids)
1406-
need_prealloc_prefill_blocks = (
1428+
actual_prefill_blocks = (
14071429
request.need_prefill_tokens + self.config.cache_config.block_size - 1
1408-
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num
1430+
) // self.config.cache_config.block_size
1431+
need_prealloc_prefill_blocks = actual_prefill_blocks + self.config.cache_config.enc_dec_block_num
14091432

14101433
with self.lock:
14111434
if len(self.waiting) > 0:
@@ -1416,11 +1439,14 @@ def preallocate_resource_in_d(self, request: Request):
14161439
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
14171440
return False
14181441

1419-
request.block_tables = self.cache_manager.allocate_gpu_blocks(
1420-
need_prealloc_prefill_blocks, request.request_id
1421-
)
1442+
all_blocks = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id)
1443+
# Only put the blocks that will actually contain prefilled KV data into block_tables.
1444+
# The extra enc_dec_block_num blocks are pre-reserved for future decode tokens and
1445+
# stored separately to avoid the attention kernel reading uninitialized KV cache data.
1446+
request.block_tables = all_blocks[:actual_prefill_blocks]
1447+
request.preallocated_blocks = all_blocks[actual_prefill_blocks:]
14221448
request.num_computed_tokens = request.need_prefill_tokens
1423-
request.disaggregate_info["block_tables"] = request.block_tables
1449+
request.disaggregate_info["block_tables"] = all_blocks
14241450
allocated_position = self.get_available_position()
14251451
request.idx = allocated_position
14261452
self.tasks_list[request.idx] = request
@@ -1470,6 +1496,12 @@ def add_prefilled_request(self, request_output: RequestOutput):
14701496
self.running.append(request)
14711497

14721498
def _free_blocks(self, request: Request):
1499+
# Also free any preallocated blocks that haven't been consumed yet
1500+
preallocated = getattr(request, "preallocated_blocks", [])
1501+
if preallocated:
1502+
self.cache_manager.recycle_gpu_blocks(preallocated, request.request_id)
1503+
request.preallocated_blocks = []
1504+
14731505
if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode":
14741506
self.cache_manager.release_block_ids(request)
14751507
self.cache_manager.recycle_gpu_blocks(

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def _validate_split_kv_size(value: int) -> int:
210210
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
211211
# Whether to enable low latency in mixed scenario
212212
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
213+
# Whether to use yiyan model
214+
"FD_XPU_USE_YIYAN_MODEL": lambda: bool(int(os.getenv("FD_XPU_USE_YIYAN_MODEL", "0"))),
213215
# Whether to use phi FP8 quantization,if 1,use paddle default.
214216
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
215217
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,

fastdeploy/model_executor/layers/backends/xpu/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def forward_mixed(
181181
cache_v_scale = getattr(layer, "cache_v_scale", None)
182182
cache_k_out_scale = getattr(layer, "cache_k_out_scale", None)
183183
cache_v_out_scale = getattr(layer, "cache_v_out_scale", None)
184-
cache_k_zp = getattr(self, "cache_k_zp", None)
185-
cache_v_zp = getattr(self, "cache_v_zp", None)
184+
cache_k_zp = getattr(layer, "cache_k_zp", None)
185+
cache_v_zp = getattr(layer, "cache_v_zp", None)
186186

187187
if layer.use_qk_norm:
188188
q_norm_weight = layer.q_norm_weight
@@ -220,8 +220,8 @@ def forward_mixed(
220220
cache_v_scale,
221221
cache_k_out_scale,
222222
cache_v_out_scale,
223-
cache_k_zp,
224-
cache_v_zp,
223+
cache_k_zp.astype("bfloat16") if cache_k_zp is not None else None, # for C8
224+
cache_v_zp.astype("bfloat16") if cache_v_zp is not None else None, # for C8
225225
None, # shift
226226
None, # smooth
227227
q_norm_weight,

fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
268268
default_initializer=paddle.nn.initializer.Constant(0),
269269
),
270270
)
271+
set_weight_attrs(
272+
getattr(layer, self.added_scale_attrs[0]),
273+
{
274+
"weight_loader": extra_weight_attrs.get(
275+
"weight_loader", default_weight_loader(layer.fd_config)
276+
),
277+
},
278+
)
271279
setattr(
272280
layer,
273281
self.added_scale_attrs[1],
@@ -277,6 +285,31 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
277285
default_initializer=paddle.nn.initializer.Constant(0),
278286
),
279287
)
288+
set_weight_attrs(
289+
getattr(layer, self.added_scale_attrs[1]),
290+
{
291+
"weight_loader": extra_weight_attrs.get(
292+
"weight_loader", default_weight_loader(layer.fd_config)
293+
),
294+
},
295+
)
296+
297+
set_weight_attrs(
298+
layer.up_gate_proj_weight,
299+
{
300+
"weight_loader": extra_weight_attrs.get(
301+
"weight_loader", default_weight_loader(layer.fd_config)
302+
),
303+
},
304+
)
305+
set_weight_attrs(
306+
layer.down_proj_weight,
307+
{
308+
"weight_loader": extra_weight_attrs.get(
309+
"weight_loader", default_weight_loader(layer.fd_config)
310+
),
311+
},
312+
)
280313

281314
if self.moe_quant_type in ["w8a8", "w4a8"]:
282315
for in_scale_name in self.added_in_scale_attrs:
@@ -289,6 +322,25 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
289322
default_initializer=paddle.nn.initializer.Constant(0),
290323
),
291324
)
325+
set_weight_attrs(
326+
layer.down_proj_in_scale,
327+
{
328+
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
329+
"weight_loader": extra_weight_attrs.get(
330+
"weight_loader", default_weight_loader(layer.fd_config)
331+
),
332+
},
333+
)
334+
335+
set_weight_attrs(
336+
layer.up_gate_proj_in_scale,
337+
{
338+
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
339+
"weight_loader": extra_weight_attrs.get(
340+
"weight_loader", default_weight_loader(layer.fd_config)
341+
),
342+
},
343+
)
292344

293345
def process_loaded_weights(self, layer: nn.Layer, state_dict):
294346
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)

fastdeploy/model_executor/layers/backends/xpu/quantization/kv_cache.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle
2020
from paddle import nn
2121

22+
from fastdeploy import envs
2223
from fastdeploy.model_executor.layers.quantization.kv_cache import (
2324
KvCacheQuantzationTypes,
2425
)
@@ -42,6 +43,7 @@ def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_poi
4243
super().__init__()
4344
self.kv_cache_quant_type = kv_cache_quant_type
4445
self.is_channel_wise = is_channel_wise
46+
self.has_zero_point = has_zero_point
4547

4648
try:
4749
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
@@ -139,6 +141,62 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
139141
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
140142
if self.cache_quant_config.is_channel_wise:
141143
scale_shape = [layer.kv_num_heads * layer.head_dim]
144+
# Custom weight_loader for C8+TP: the safetensors scale/zp shape is
145+
# [1, num_kv_heads, 1, head_dim]. We must split along the kv_heads
146+
# dimension (dim=1), not the last dimension. The default_weight_loader
147+
# treats output_dim as boolean and always splits along dim=-1, which
148+
# is incorrect for 4D tensors where we need to split along dim=1.
149+
fd_config = layer.fd_config
150+
total_kv_heads = fd_config.model_config.num_key_value_heads
151+
tp_size = fd_config.parallel_config.tensor_parallel_size
152+
tp_rank = fd_config.parallel_config.tensor_parallel_rank
153+
max_bound = self.cache_quant_config.max_bound
154+
155+
def _kv_scale_weight_loader(
156+
param,
157+
loaded_weight,
158+
shard_id=None,
159+
_total_kv_heads=total_kv_heads,
160+
_tp_size=tp_size,
161+
_tp_rank=tp_rank,
162+
_max_bound=max_bound,
163+
):
164+
loaded_weight = get_tensor(loaded_weight).cast("float32")
165+
# TP split along kv_heads dimension
166+
if _tp_size > 1 and not fd_config.load_config.is_pre_sharded:
167+
head_dim = loaded_weight.numel() // _total_kv_heads
168+
loaded_weight = loaded_weight.reshape([_total_kv_heads, head_dim])
169+
assert (
170+
_total_kv_heads % _tp_size == 0
171+
), f"num_kv_heads ({_total_kv_heads}) must be divisible by tp_size ({_tp_size})"
172+
kv_heads_per_rank = _total_kv_heads // _tp_size
173+
start = _tp_rank * kv_heads_per_rank
174+
end = start + kv_heads_per_rank
175+
loaded_weight = loaded_weight[start:end, :]
176+
loaded_weight = paddle.clip(loaded_weight, min=1e-8)
177+
loaded_weight = (_max_bound / loaded_weight).reshape(param.shape).cast(param.dtype)
178+
param.copy_(loaded_weight, False)
179+
180+
def _kv_zp_weight_loader(
181+
param, loaded_weight, shard_id=None, _total_kv_heads=total_kv_heads, _tp_size=tp_size, _tp_rank=tp_rank
182+
):
183+
loaded_weight = get_tensor(loaded_weight).cast(param.dtype)
184+
# TP split along kv_heads dimension
185+
if _tp_size > 1 and not fd_config.load_config.is_pre_sharded:
186+
head_dim = loaded_weight.numel() // _total_kv_heads
187+
loaded_weight = loaded_weight.reshape([_total_kv_heads, head_dim])
188+
kv_heads_per_rank = _total_kv_heads // _tp_size
189+
start = _tp_rank * kv_heads_per_rank
190+
end = start + kv_heads_per_rank
191+
loaded_weight = loaded_weight[start:end, :]
192+
loaded_weight = loaded_weight.reshape(param.shape)
193+
param.copy_(loaded_weight, False)
194+
195+
scale_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_scale_weight_loader}
196+
zp_weight_attrs = {**extra_weight_attrs, "weight_loader": _kv_zp_weight_loader}
197+
else:
198+
scale_weight_attrs = extra_weight_attrs
199+
zp_weight_attrs = extra_weight_attrs
142200

143201
layer.cache_k_scale = layer.create_parameter(
144202
shape=scale_shape,
@@ -154,13 +212,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
154212
set_weight_attrs(
155213
layer.cache_k_scale,
156214
{
157-
**extra_weight_attrs,
215+
**scale_weight_attrs,
158216
},
159217
)
160218
set_weight_attrs(
161219
layer.cache_v_scale,
162220
{
163-
**extra_weight_attrs,
221+
**scale_weight_attrs,
164222
},
165223
)
166224

@@ -189,13 +247,13 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
189247
set_weight_attrs(
190248
layer.cache_k_zp,
191249
{
192-
**extra_weight_attrs,
250+
**zp_weight_attrs,
193251
},
194252
)
195253
set_weight_attrs(
196254
layer.cache_v_zp,
197255
{
198-
**extra_weight_attrs,
256+
**zp_weight_attrs,
199257
},
200258
)
201259

@@ -219,10 +277,20 @@ def process_weights_after_loading(self, layer: nn.Layer):
219277
use for loader v1
220278
"""
221279
# cache_k_out_scale is the reciprocal of cache_k_scale
222-
if layer.cache_k_scale._is_initialized():
223-
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale) # cache_k_out_scale
224-
if layer.cache_v_scale._is_initialized():
225-
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
280+
if envs.FD_XPU_USE_YIYAN_MODEL:
281+
if layer.cache_k_scale._is_initialized():
282+
layer.cache_k_out_scale.set_value(
283+
self.cache_quant_config.max_bound / layer.cache_k_scale.cast("float32").reshape_([-1])
284+
)
285+
if layer.cache_v_scale._is_initialized():
286+
layer.cache_v_out_scale.set_value(
287+
self.cache_quant_config.max_bound / layer.cache_v_scale.cast("float32").reshape_([-1])
288+
)
289+
else:
290+
if layer.cache_k_scale._is_initialized():
291+
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale)
292+
if layer.cache_v_scale._is_initialized():
293+
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
226294

227295
def apply(self, layer):
228296
"""

0 commit comments

Comments
 (0)