Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,14 @@ def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size

def get_new_block_nums(self, request: Request, num_new_tokens: int):
# Account for preallocated blocks that haven't been added to block_tables yet
preallocated_count = len(getattr(request, "preallocated_blocks", []))
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables)
(request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
- len(request.block_tables)
- preallocated_count
)

if self.config.speculative_config.method is not None:
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
Expand Down Expand Up @@ -800,8 +805,14 @@ def get_enough_request(request, scheduled_reqs):
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
# First, consume any preallocated blocks before allocating new ones
preallocated = getattr(request, "preallocated_blocks", [])
if preallocated:
request.block_tables.extend(preallocated)
request.preallocated_blocks = []
scheduled_reqs.append(self._prepare_decode_task(request))
# Allocation for next decoding blocks
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
elif self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
llm_logger.debug(
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
)
Expand Down Expand Up @@ -911,6 +922,12 @@ def _allocate_decode_and_extend():
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Merge preallocated blocks (from PD disaggregation) into block_tables
# so the attention kernel can access all reserved blocks.
preallocated = getattr(request, "preallocated_blocks", [])
if preallocated:
request.block_tables.extend(preallocated)
request.preallocated_blocks = []
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption
Expand All @@ -920,6 +937,11 @@ def _allocate_decode_and_extend():
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Merge preallocated blocks (from PD disaggregation) into block_tables
preallocated = getattr(request, "preallocated_blocks", [])
if preallocated:
request.block_tables.extend(preallocated)
request.preallocated_blocks = []
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
Expand Down Expand Up @@ -1403,9 +1425,10 @@ def preallocate_resource_in_d(self, request: Request):
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
actual_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num
) // self.config.cache_config.block_size
need_prealloc_prefill_blocks = actual_prefill_blocks + self.config.cache_config.enc_dec_block_num

with self.lock:
if len(self.waiting) > 0:
Expand All @@ -1416,11 +1439,14 @@ def preallocate_resource_in_d(self, request: Request):
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
return False

request.block_tables = self.cache_manager.allocate_gpu_blocks(
need_prealloc_prefill_blocks, request.request_id
)
all_blocks = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id)
# Only put the blocks that will actually contain prefilled KV data into block_tables.
# The extra enc_dec_block_num blocks are pre-reserved for future decode tokens and
# stored separately to avoid the attention kernel reading uninitialized KV cache data.
request.block_tables = all_blocks[:actual_prefill_blocks]
request.preallocated_blocks = all_blocks[actual_prefill_blocks:]
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
request.disaggregate_info["block_tables"] = all_blocks
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
Expand Down Expand Up @@ -1470,6 +1496,12 @@ def add_prefilled_request(self, request_output: RequestOutput):
self.running.append(request)

def _free_blocks(self, request: Request):
# Also free any preallocated blocks that haven't been consumed yet
preallocated = getattr(request, "preallocated_blocks", [])
if preallocated:
self.cache_manager.recycle_gpu_blocks(preallocated, request.request_id)
request.preallocated_blocks = []

if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode":
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def _validate_split_kv_size(value: int) -> int:
"FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""),
# Whether to enable low latency in mixed scenario
"FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))),
# Whether to use yiyan model
"FD_XPU_USE_YIYAN_MODEL": lambda: bool(int(os.getenv("FD_XPU_USE_YIYAN_MODEL", "0"))),
# Whether to use phi FP8 quantization,if 1,use paddle default.
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
def __init__(self, quant_config):
super().__init__(quant_config)
if quant_config is None:
self.quant_config = WeightOnlyConfig(algo="weight_only_int8", is_checkpoint_bf16=True)
self.quant_config = WeightOnlyConfig(algo="weight_only_int8")
else:
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
Expand Down Expand Up @@ -480,21 +480,18 @@ def _process_quantize(weight_idx):
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]), False)
getattr(layer, scale_name).copy_(scale, False)

if self.quant_config.is_checkpoint_bf16:
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"

if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
return
weight_type = "down"

if self.model_format == "torch":
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])

def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
layer.hidden_size,
]
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
is_checkpoint_bf16 = self.quant_config.is_checkpoint_bf16 if self.quant_config is not None else True
if is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if layer.fd_config.load_config.load_choices == "default_v1":
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
Expand Down Expand Up @@ -184,10 +183,6 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
@paddle.no_grad()
def process_weights_after_loading(self, layer):
""" """
is_checkpoint_bf16 = self.quant_config.is_checkpoint_bf16 if self.quant_config is not None else True
if not is_checkpoint_bf16:
return

if self.quant_config is not None:
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
Expand Down
8 changes: 4 additions & 4 deletions fastdeploy/model_executor/layers/backends/xpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def forward_mixed(
cache_v_scale = getattr(layer, "cache_v_scale", None)
cache_k_out_scale = getattr(layer, "cache_k_out_scale", None)
cache_v_out_scale = getattr(layer, "cache_v_out_scale", None)
cache_k_zp = getattr(self, "cache_k_zp", None)
cache_v_zp = getattr(self, "cache_v_zp", None)
cache_k_zp = getattr(layer, "cache_k_zp", None)
cache_v_zp = getattr(layer, "cache_v_zp", None)

if layer.use_qk_norm:
q_norm_weight = layer.q_norm_weight
Expand Down Expand Up @@ -220,8 +220,8 @@ def forward_mixed(
cache_v_scale,
cache_k_out_scale,
cache_v_out_scale,
cache_k_zp,
cache_v_zp,
cache_k_zp.astype("bfloat16") if cache_k_zp is not None else None, # for C8
cache_v_zp.astype("bfloat16") if cache_v_zp is not None else None, # for C8
None, # shift
None, # smooth
q_norm_weight,
Expand Down
54 changes: 52 additions & 2 deletions fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
set_weight_attrs(
getattr(layer, self.added_scale_attrs[0]),
{
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)
setattr(
layer,
self.added_scale_attrs[1],
Expand All @@ -277,6 +285,31 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
set_weight_attrs(
getattr(layer, self.added_scale_attrs[1]),
{
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)

set_weight_attrs(
layer.up_gate_proj_weight,
{
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)

if self.moe_quant_type in ["w8a8", "w4a8"]:
for in_scale_name in self.added_in_scale_attrs:
Expand All @@ -289,6 +322,25 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
default_initializer=paddle.nn.initializer.Constant(0),
),
)
set_weight_attrs(
layer.down_proj_in_scale,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)

set_weight_attrs(
layer.up_gate_proj_in_scale,
{
"SHARD_ID_TO_SHARDED_DIM": {"gate": None, "up": None, "down": None},
"weight_loader": extra_weight_attrs.get(
"weight_loader", default_weight_loader(layer.fd_config)
),
},
)

def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
Expand Down Expand Up @@ -616,8 +668,6 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):

def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
Expand Down
Loading
Loading