Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 55 additions & 60 deletions xtuner/v1/model/compose/qwen3_vl/modeling_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,66 +369,61 @@ def fully_shard(
)
return self

def fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]

for t, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)

h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side

indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]

weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]

for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())

idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
weight_tensor = torch.tensor(
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])

patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
# copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474
def fast_pos_embed_interpolate(self, grid_thw: torch.Tensor) -> torch.Tensor:
num_grid_per_side = self.num_grid_per_side
m_size = self.spatial_merge_size
hidden_dim = self.pos_embed.embedding_dim
device = self.pos_embed.weight.device

outputs = []
for t, h, w in grid_thw:
h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device)
w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device)

h_floor = h_idxs.to(torch.long)
w_floor = w_idxs.to(torch.long)
h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

dh = h_idxs - h_floor
dw = w_idxs - w_floor

# Create meshgrid view for all h, w vars
dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij')
h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij')
h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij')

# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11 = dh_grid * dw_grid
w10 = dh_grid - w11
w01 = dw_grid - w11
w00 = 1 - dh_grid - w01

h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid])
w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid])
h_grid_idx = h_grid * num_grid_per_side

indices = (h_grid_idx + w_grid).reshape(4, -1)
weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device)

embeds = self.pos_embed(indices)
embeds *= weights
combined = embeds.sum(dim=0)

combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim)
combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim)
repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim)
outputs.append(repeated)

return torch.cat(outputs, dim=0)

def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size
Expand Down
3 changes: 2 additions & 1 deletion xtuner/v1/module/rope/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def compute_default_rope_parameters(
attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
inv_freq = inv_freq.to(device=device)
return inv_freq, attention_factor


Expand Down
Loading