From 951adeb26a351aba5bf5c582a55d25b304a4570e Mon Sep 17 00:00:00 2001 From: nubot Date: Tue, 24 Mar 2026 23:49:02 +0800 Subject: [PATCH] [bugfix] fallback to eager attn for MPS VL loaders --- swift/model/models/qwen.py | 7 +++++++ swift/model/models/stepfun.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/swift/model/models/qwen.py b/swift/model/models/qwen.py index b76b2c20b2..92c3cee956 100644 --- a/swift/model/models/qwen.py +++ b/swift/model/models/qwen.py @@ -1055,6 +1055,13 @@ def _check_qwen_vl_utils(self): require_version('qwen_vl_utils>=0.0.14') compat_qwen_vl_utils(image_patch_size=16) + def get_config(self, model_dir: str): + # torch SDPA on MPS currently mis-handles Qwen3-VL GQA during generation. + if self.attn_impl is None and self.model_kwargs.get('device_map') == 'mps': + self.attn_impl = 'eager' + logger.info('Setting attn_impl=eager for Qwen3-VL on MPS.') + return super().get_config(model_dir) + def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: from transformers import Qwen3VLForConditionalGeneration self.auto_model_cls = self.auto_model_cls or Qwen3VLForConditionalGeneration diff --git a/swift/model/models/stepfun.py b/swift/model/models/stepfun.py index b112029d4e..801980730c 100644 --- a/swift/model/models/stepfun.py +++ b/swift/model/models/stepfun.py @@ -133,6 +133,8 @@ def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: class Step3VLLoader(ModelLoader): def get_config(self, model_dir: str) -> PretrainedConfig: + if self.attn_impl is None and self.model_kwargs.get('device_map') == 'mps': + self.attn_impl = 'eager' config = super().get_config(model_dir) config.vocab_size = config.text_config.vocab_size return config