-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpu_model_runner.py
More file actions
4999 lines (4465 loc) · 212 KB
/
gpu_model_runner.py
File metadata and controls
4999 lines (4465 loc) · 212 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import itertools
import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from tqdm import tqdm
import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
MultipleOf,
)
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationMode,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
update_config,
)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group,
get_tp_group,
graph_capture,
is_global_first_rank,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
is_mixture_of_experts,
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
supports_transcription,
)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling,
is_pooling_model,
is_text_generation_model,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import (
get_dtype_size,
kv_cache_dtype_str_to_dtype,
supports_dynamo,
)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
split_attn_metadata,
)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
ECConnectorOutput,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
PoolerOutput,
SamplerOutput,
make_empty_encoder_model_runner_output,
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
UBatchSlices,
check_ubatch_thresholds,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from .utils import (
AttentionGroup,
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
import requests
from vllm.v1.hiddenstate.classifier import HiddenstateClassifier
import json
from pathlib import Path
# TOKEN_FILES_DIR = Path("/work/nvme/bcjw/bhuang4/DeepSeek-R1-0528-Qwen3-8B")
# with (TOKEN_FILES_DIR / "tokens_with_double_newline.json").open() as fp:
# DOUBLE_NEWLINE_TOKEN_IDS = json.load(fp)
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
):
self._model_runner_output = model_runner_output
self._invalid_req_indices = invalid_req_indices
# Event on the copy stream so we can synchronize the non-blocking copy.
self.async_copy_ready_event = torch.cuda.Event()
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
self._logprobs_tensors = logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
"cpu", non_blocking=True
)
self._logprobs_tensors_cpu = (
self._logprobs_tensors.to_cpu_nonblocking()
if self._logprobs_tensors
else None
)
self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
This function blocks until the copy is finished.
"""
self.async_copy_ready_event.synchronize()
# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
return output
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None
class GPUModelRunner(
LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
cache_config.cache_dtype, self.model_config
)
self.is_pooling_model = model_config.runner_type == "pooling"
self.enable_prompt_embeds = model_config.enable_prompt_embeds
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model
)
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend == "external_launcher"
and len(get_pp_group().ranks) > 0
)
# Model-related.
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = model_config.uses_alibi
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
if self.model_config.is_encoder_decoder:
# Maximum length of the encoder input, only for encoder-decoder
# models.
self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens
else:
self.max_encoder_len = 0
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None
"""
State of the expert parallelism load balancer.
Will be lazily initialized when the model is loaded.
"""
# Lazy initializations
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: (
NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
)
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
)
else:
raise ValueError(
"Unknown speculative decoding method: "
f"{self.speculative_config.method}"
)
self.rejection_sampler = RejectionSampler(self.sampler)
# Request states.
self.requests: dict[str, CachedRequestState] = {}
self.comm_stream = torch.cuda.Stream()
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
custom_logitsprocs = model_config.logits_processors
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
kernel_block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config,
self.device,
self.pin_memory,
self.is_pooling_model,
custom_logitsprocs,
),
# We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Separate cuda stream for overlapping transfer of sampled token ids from
# GPU to CPU when async scheduling is enabled.
self.async_output_copy_stream: torch.cuda.Stream | None = None
# cuda event to synchronize use of reused CPU tensors between steps
# when async scheduling is enabled.
self.prepare_inputs_event: torch.cuda.Event | None = None
if self.use_async_scheduling:
self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.cuda.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
if (
self.compilation_config.cudagraph_capture_sizes
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
self.cudagraph_batch_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
# Cache the device properties.
self._init_device_properties()
# Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
# Because inputs_embeds may be bfloat16 and we don't need a numpy
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
self.inputs_embeds = self._make_buffer(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
)
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
self.discard_request_indices = self._make_buffer(
self.max_num_reqs, dtype=torch.int64
)
self.num_discarded_requests = 0
self.num_decode_draft_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
self.num_accepted_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int64
)
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
dtype=np.int64,
)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
self.kv_sharing_fast_prefill_logits_indices = None
if self.cache_config.kv_sharing_fast_prefill:
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device
)
self.uniform_decode_query_len = (
1
if not self.speculative_config
else 1 + self.speculative_config.num_speculative_tokens
)
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (
MultiModalBudget(
self.model_config,
self.scheduler_config,
self.mm_registry,
)
if self.supports_mm_inputs
else None
)
self.reorder_batch_threshold: int | None = None
# Attention layers that are only in the KVCacheConfig of the runner
# (e.g., KV sharing, encoder-only attention), but not in the
# KVCacheConfig of the scheduler.
self.runner_only_attn_layers: set[str] = set()
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_num_reqs, 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
self._last_hidden_states_by_req: dict[str, torch.Tensor] = {}
self._capture_hidden_states_next_forward: set[str] = set()
#hidden state classifier
self.step_scorer: HiddenstateClassifier | None = None
self.load_step_scorer_checkpoint(
model_path=self.vllm_config.STEP_config.step_scorer_path
)
def load_step_scorer_checkpoint(self, model_path: str) -> None:
model = HiddenstateClassifier(input_dim=self.hidden_size).to(self.device)
ckpt = torch.load(model_path, map_location=self.device)
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
state_dict = ckpt["model_state_dict"]
else:
state_dict = ckpt
model.load_state_dict(state_dict)
self.step_scorer = model
self.step_scorer.eval()
def step_scorer_evaluate(self, req_ids, num_scheduled_tokens):
hs = self.get_hidden_states_by_req(req_ids, num_scheduled_tokens)
if self.step_scorer is None:
raise RuntimeError("classifier not loaded")
with torch.inference_mode():
return {
rid: torch.sigmoid(
self.step_scorer(
t.to(device=self.device, dtype=torch.float32, non_blocking=True)
)
).flatten()
.tolist()
for rid, t in hs.items()
}
def enable_hidden_states_capture(self) -> None:
self.vllm_config.STEP_config.enable = True
def disable_hidden_states_capture(self) -> None:
self.vllm_config.STEP_config.enable = False
self._capture_hidden_states_next_forward.clear()
self._last_hidden_states_by_req.clear()
def clear_hidden_states(self) -> None:
return self._last_hidden_states_by_req.clear()
def get_hidden_states_debug_info(self) -> dict[str, int]:
return {
"enabled": int(self.vllm_config.STEP_config.enable),
"num_requests": len(self._last_hidden_states_by_req),
}
def get_hidden_states_by_req(
self, req_ids: list[str], num_scheduled_tokens: dict[str, int]
) -> dict[str, torch.Tensor]:
# Slice from the most recent cached hidden states for the requested ids.
result: dict[str, torch.Tensor] = {}
for req_id in req_ids:
if req_id in self._last_hidden_states_by_req:
result[req_id] = self._last_hidden_states_by_req[req_id]
return result
def fetch_hidden_states(self, req_ids, num_scheduled_tokens):
return self.get_hidden_states_by_req(req_ids, num_scheduled_tokens)
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:
return self.mrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory,
with_numpy=numpy,
)
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
if not self.is_pooling_model:
return model_kwargs
num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.get_pooling_params()
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
if (
param.extra_kwargs is not None
and (token_types := param.extra_kwargs.get("compressed_token_type_ids"))
is not None
):
token_type_id_requests[i] = token_types
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens.gpu[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
pos = token_type_id_requests.get(i, seq_lens[i])
ids = (torch.arange(seq_lens[i]) >= pos).int()
token_type_ids.append(ids)
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
device=self.device
)
return model_kwargs
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
assert self.reorder_batch_threshold == 1, (
"DCP not support reorder_batch_threshold > 1 now."
)
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold,
)
# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
# Note: used for model runner override.
def _sync_device(self) -> None:
torch.cuda.synchronize()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
The updated states are used by the `_prepare_inputs` function to create
the input GPU tensors for the model.
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
# them from the persistent batch but keep their cached states since
# they will be scheduled again sometime in the future.
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
cached_req_ids = self.input_batch.req_id_to_index.keys()
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if (
sampling_params
and sampling_params.sampling_type == SamplingType.RANDOM_SEED
):
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
if self.is_pooling_model:
assert pooling_params is not None
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_features=new_req_data.mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_index = self.input_batch.req_id_to_index.get(req_id)
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
if req_index is not None:
end_idx = (
self.input_batch.num_prompt_tokens[req_index]
+ num_output_tokens
)
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
# Update the block IDs.
if not resumed_from_preemption:
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
assert req_index is None
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
if self.use_async_scheduling and num_output_tokens > 0:
# We must recover the output token ids for resumed requests in the
# async scheduling case, so that correct input_ids are obtained.
resumed_token_ids = req_data.all_token_ids[req_id]
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
reqs_to_add.append(req_state)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
if new_block_ids is not None:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index
] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, []
)
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index
] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
self.input_batch.spec_token_ids[req_index] = spec_token_ids
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
self.input_batch.add_request(request)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor
) -> None:
"""Update the cached states after model execution.
This is used for MTP/EAGLE for hybrid models, as in linear attention,
only the last token's state is kept. In MTP/EAGLE, for draft tokens
the state are kept util we decide how many tokens are accepted for
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
if not self.model_config.is_hybrid or not self.speculative_config:
return
# Find the number of accepted tokens for each sequence.
num_accepted_tokens = (
(
torch.cat(
[
output_token_ids,
torch.full(
(output_token_ids.size(0), 1),
-1,
device=output_token_ids.device,
),
],
dim=1,
)
== -1
)
.int()
.argmax(-1)
.cpu()
.numpy()
)
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
assert supports_mrope(model), "M-RoPE support is not implemented."
req_state.mrope_positions, req_state.mrope_position_delta = (
model.get_mrope_input_positions(
req_state.prompt_token_ids,
req_state.mm_features,
)
)
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
) -> BatchedTensorInputs:
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {}
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
for feature in req.mm_features:
if feature.data is not None:
mm_kwargs.append(feature.data)
# Input all modalities at once
model = cast(SupportsMultiModal, self.model)