Skip to content

Map Mistral-HF models back onto Mistral format on-the-fly #20471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def __init__(self,
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
quant_config=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The quant_config is hardcoded to None. Before finalizing, replace this with a dynamic check to ensure correctness for checkpoints that may have a quantized multi_modal_projector.

prefix=f"{prefix}.linear_1")
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = RowParallelLinear(text_hidden_size,
text_hidden_size,
bias=multimodal_projector_bias,
quant_config=quant_config,
quant_config=None,
prefix=f"{prefix}.linear_2")

def forward(self, image_features: torch.Tensor,
Expand Down
88 changes: 85 additions & 3 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
import re
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
Expand All @@ -25,6 +26,7 @@

from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand Down Expand Up @@ -52,6 +54,8 @@
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs

logger = init_logger(__name__)

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
Expand Down Expand Up @@ -334,6 +338,8 @@

raise ValueError("Only image modality is supported")

packed_modules_mapping = {}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -480,6 +486,66 @@
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

# Reverse mapping from HF to original Pixtral format
MISTRAL3_REVERSE_MAPPING = {
r"^language_model\.lm_head\.weight":
r"output.weight",
r"^language_model\.model\.norm\.weight":
r"norm.weight",
r"^language_model\.model\.embed_tokens\.weight":
r"tok_embeddings.weight",
r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight":
r"layers.\1.attention_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight":
r"layers.\1.ffn_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight":
r"layers.\1.attention.w\2.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight":
r"layers.\1.feed_forward.w1.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight":
r"layers.\1.feed_forward.w2.weight",
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight":
r"layers.\1.feed_forward.w3.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight":
r"vision_encoder.transformer.layers.\1.attention_norm.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight":
r"vision_encoder.transformer.layers.\1.ffn_norm.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight":
r"vision_encoder.transformer.layers.\1.attention.w\2.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight":
r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight":
r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight",
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight":
r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight",
r"^multi_modal_projector\.linear_1":
r"vision_language_adapter.w_in",
r"^multi_modal_projector\.linear_2":
r"vision_language_adapter.w_out",
r"^vision_tower\.ln_pre\.weight":
r"vision_encoder.ln_pre.weight",
r"^vision_tower\.patch_conv\.weight":
r"vision_encoder.patch_conv.weight",
r"^multi_modal_projector\.patch_merger\.merging_layer\.weight":
r"patch_merger.merging_layer.weight",
r"^multi_modal_projector\.norm\.weight":
r"pre_mm_projector_norm.weight",
r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$":
r"layers.\1.\2.\3"
}

def maybe_remap_mistral3(self, name: str,
tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
"""Remap HF-style weight names back to original Pixtral format."""

for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items():
new_name, n_replace = re.subn(pattern, replacement, name)
if n_replace > 0:
logger.debug("remapped %s to %s for Pixtral compat", name,
new_name)
return new_name, tensor
return name, tensor # Return unchanged if no match

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
Expand All @@ -504,13 +570,28 @@
vision_lang_adapter_dict = dict(
self.vision_language_adapter.named_parameters())

def inverse_permute_for_rope(tensor, n_heads, dim1, dim2):
"""Reverse the permutation applied for ROPE in HF format."""
tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor

def llm_weights_generator():
# Single pass over weights
for name, w in weights:
remapped_weights = (self.maybe_remap_mistral3(name, w)
for name, w in weights)
for name, w in remapped_weights:
if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name:

Check failure on line 589 in vllm/model_executor/models/pixtral.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/pixtral.py:589:81: E501 Line too long (82 > 80)
n_heads = self.vision_args.num_attention_heads
dim1 = param.shape[0] # num_heads * head_dim
dim2 = param.shape[1] # hidden_size
w = inverse_permute_for_rope(w, n_heads, dim1, dim2)
logger.debug("reversed permute_for_rope for %s", name)
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
Expand Down Expand Up @@ -554,7 +635,7 @@
image_token_id: int
adapter_bias: bool = True
spatial_merge_size: int = 1
add_pre_mm_projector_layer_norm: bool = False
add_pre_mm_projector_layer_norm: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing the default value of add_pre_mm_projector_layer_norm to True is a breaking change for earlier Pixtral models. Revert this change and implement a mechanism to dynamically determine this value from the model's config.json.

Suggested change
add_pre_mm_projector_layer_norm: bool = True
add_pre_mm_projector_layer_norm: bool = False

mm_projector_id: str = ""


Expand Down Expand Up @@ -847,9 +928,10 @@
super().__init__()

mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

self.spatial_merge_size = spatial_merge_size
self.mlp_input_dim = mlp_input_dim
logger.debug("mlp_input_dim = %d (from %d * (%d ** 2))", mlp_input_dim,
vision_encoder_dim, spatial_merge_size)

self.merging_layer = nn.Linear(
mlp_input_dim,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501

Check failure on line 219 in vllm/model_executor/models/registry.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F601)

vllm/model_executor/models/registry.py:219:5: F601 Dictionary key literal `"Mistral3ForConditionalGeneration"` repeated
"QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
Expand Down
Loading