Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
import re
from functools import partial
from typing import Iterable, List, Optional, Tuple, Type

Expand Down Expand Up @@ -534,6 +535,13 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds

_lora_pattern = re.compile(
r"^model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)$"
)
Comment on lines +538 to +540
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current regular expression is slightly too broad. It allows for combinations of modules that do not exist in the model architecture, such as self_attn.down_proj. While this won't cause issues with the current model structure because named_modules() will only yield valid module names, making the regex more specific will improve its correctness and maintainability, especially for future model changes.

A more precise regex would explicitly group the allowed projections under their respective parent modules (self_attn or mlp).

Suggested change
_lora_pattern = re.compile(
r"^model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)$"
)
_lora_pattern = re.compile(
r"^model\.layers\.(\d+)\.(?:(?:self_attn\.(?:qkv_proj|o_proj))|(?:mlp\.(?:gate_up_proj|down_proj)))$"
)


def should_apply_lora(self, module_name: str) -> bool:
return bool(self._lora_pattern.match(module_name))

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
Expand Down
Loading