You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Right now it is very messy and the AMD CI is broken.
The ViT changes also does not take into account the other model.py files, it only changes it for qwen_2_5_vl.py which potentially breaking all other models.py files.
vLLM currently have refactor to introduce the use of --mm-encoder-attn-backend to select the attention backend.
The PR is vllm-project#27061 , and a bugfix PR vllm-project#27124 .
Since the introduction of torch.compile into the ViT, currently only starting with qwen vl model in PR vllm-project#23207 , the AMD ViT Code path are broken. Multiple bugfix PR attempts are not working:
Make sure that the ViT attention is a platform specific. We should determine platform interface. We also perform override in the platform interface. We should avoid doing that in the model.py files
In the platform interface, we should only return _MHA_Backend, we should not return the functions. The functions should only be returned through maybe_get_vit_flash_attn_backend .
Honor --mm-encoder-attn-backend so that we can write unit tests to test all different backends. AMD Instinct GPU is able to test all backends. Radeon GPUs only are able to use the TORCH_SDPA code path.
Make sure that the ViT attention is a platform specific. We should determine platform interface. We also perform override in the platform interface. We should avoid doing that in the model.py files
get_vit_attn_backend in the platform interface has to be able to access the --mm-encoder-attn-backend.
We need to deprecate this line https://github.com/vllm-project/vllm/blob/33a0ea5f3264b5b2f571b8a53357e10efcc94670/vllm/model_executor/models/vision.py#L96 it is using VLLM_ATTENTION_BACKEND which is for Text Backbone. The ViT should not use this environment variable.
In the platform interface, we should only return _MHA_Backend, we should not return the functions. The functions should only be returned through maybe_get_vit_flash_attn_backend .
Added a logger.info_once so that users know which _MHA_Backend is selected in the end.
Clean up cuda code path. Since vllm.vllm_flash_attn is just a wrapper for flash_attn library, on cuda, we always use vllm.vllm_flash_attn instead of flash_attn.
Write unit tests to test all different backends. Since there are large model sizes, we will check the VRAM size, if it is large enough, we run it. We provide such a unit test so that developers can run locally.
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Motivation.
Right now it is very messy and the AMD CI is broken.
The ViT changes also does not take into account the other model.py files, it only changes it for qwen_2_5_vl.py which potentially breaking all other models.py files.
vllm/model_executor/models/dots_ocr.pyvllm/model_executor/models/ernie45_vl.pyvllm/model_executor/models/glm4_1v.pyvllm/model_executor/models/qwen2_vl.pyvllm/model_executor/models/siglip2navit.pyvLLM currently have refactor to introduce the use of
--mm-encoder-attn-backendto select the attention backend.The PR is vllm-project#27061 , and a bugfix PR vllm-project#27124 .
Since the introduction of torch.compile into the ViT, currently only starting with qwen vl model in PR vllm-project#23207 , the AMD ViT Code path are broken. Multiple bugfix PR attempts are not working:
First, we should shrink down the https://github.com/vllm-project/vllm/pull/27061/files#r2443909604 the
_Backendby introducing another_MHA_Backendregistry.Make sure that the ViT attention is a platform specific. We should determine
platforminterface. We also perform override in theplatforminterface. We should avoid doing that in themodel.pyfilesIn the
platforminterface, we should only return_MHA_Backend, we should not return the functions. The functions should only be returned throughmaybe_get_vit_flash_attn_backend.Honor
--mm-encoder-attn-backendso that we can write unit tests to test all different backends. AMD Instinct GPU is able to test all backends. Radeon GPUs only are able to use the TORCH_SDPA code path.Proposed Change.
Changes
First, we should shrink down the https://github.com/vllm-project/vllm/pull/27061/files#r2443909604 the
_Backendby introducing another_MHA_Backendregistry.Make sure that the ViT attention is a platform specific. We should determine
platforminterface. We also perform override in theplatforminterface. We should avoid doing that in themodel.pyfilesget_vit_attn_backendin theplatforminterface has to be able to access the--mm-encoder-attn-backend.We need to deprecate this line
https://github.com/vllm-project/vllm/blob/33a0ea5f3264b5b2f571b8a53357e10efcc94670/vllm/model_executor/models/vision.py#L96it is usingVLLM_ATTENTION_BACKENDwhich is for Text Backbone. The ViT should not use this environment variable.In the
platforminterface, we should only return_MHA_Backend, we should not return the functions. The functions should only be returned throughmaybe_get_vit_flash_attn_backend.Added a
logger.info_onceso that users know which_MHA_Backendis selected in the end.Clean up cuda code path. Since
vllm.vllm_flash_attnis just a wrapper forflash_attnlibrary, on cuda, we always usevllm.vllm_flash_attninstead offlash_attn.https://github.com/vllm-project/vllm/blob/ba33e8830dceb32e9b03508bbff435e3082759b8/vllm/attention/layer.py#L120-L125 .
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...