|
26 | 26 | # See the License for the specific language governing permissions and |
27 | 27 | # limitations under the License. |
28 | 28 |
|
| 29 | +from ......utils.device import parse_device |
| 30 | +from ......utils.env import get_paddle_cuda_version |
29 | 31 | from ....common.vlm.transformers import PretrainedConfig |
30 | 32 |
|
31 | 33 |
|
@@ -120,6 +122,8 @@ def __init__( |
120 | 122 | vision_config=None, |
121 | 123 | **kwargs, |
122 | 124 | ): |
| 125 | + import paddle |
| 126 | + |
123 | 127 | # Set default for tied embeddings if not specified. |
124 | 128 | super().__init__( |
125 | 129 | pad_token_id=pad_token_id, |
@@ -165,13 +169,13 @@ def __init__( |
165 | 169 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
166 | 170 |
|
167 | 171 | # Currently, these configuration items are hard-coded |
168 | | - from ......utils.env import get_paddle_cuda_version |
169 | 172 |
|
170 | | - cuda_version = get_paddle_cuda_version() |
171 | | - if cuda_version and cuda_version[0] > 11: |
172 | | - self.fuse_rms_norm = True |
173 | | - else: |
174 | | - self.fuse_rms_norm = False |
| 173 | + self.fuse_rms_norm = False |
| 174 | + device_type, _ = parse_device(paddle.device.get_device()) |
| 175 | + if device_type == "gpu": |
| 176 | + cuda_version = get_paddle_cuda_version() |
| 177 | + if cuda_version and cuda_version[0] > 11: |
| 178 | + self.fuse_rms_norm = True |
175 | 179 | self.use_sparse_flash_attn = True |
176 | 180 | self.use_var_len_flash_attn = False |
177 | 181 | self.scale_qk_coeff = 1.0 |
|
0 commit comments