Skip to content

Commit 406d84d

Browse files
committed
PaddleOCR-VL supports CPU and CUDA 11 (#4666)
1 parent b2ebed2 commit 406d84d

File tree

1 file changed

+10
-6
lines changed
  • paddlex/inference/models/doc_vlm/modeling/paddleocr_vl

1 file changed

+10
-6
lines changed

paddlex/inference/models/doc_vlm/modeling/paddleocr_vl/_config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
2828

29+
from ......utils.device import parse_device
30+
from ......utils.env import get_paddle_cuda_version
2931
from ....common.vlm.transformers import PretrainedConfig
3032

3133

@@ -120,6 +122,8 @@ def __init__(
120122
vision_config=None,
121123
**kwargs,
122124
):
125+
import paddle
126+
123127
# Set default for tied embeddings if not specified.
124128
super().__init__(
125129
pad_token_id=pad_token_id,
@@ -165,13 +169,13 @@ def __init__(
165169
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
166170

167171
# Currently, these configuration items are hard-coded
168-
from ......utils.env import get_paddle_cuda_version
169172

170-
cuda_version = get_paddle_cuda_version()
171-
if cuda_version and cuda_version[0] > 11:
172-
self.fuse_rms_norm = True
173-
else:
174-
self.fuse_rms_norm = False
173+
self.fuse_rms_norm = False
174+
device_type, _ = parse_device(paddle.device.get_device())
175+
if device_type == "gpu":
176+
cuda_version = get_paddle_cuda_version()
177+
if cuda_version and cuda_version[0] > 11:
178+
self.fuse_rms_norm = True
175179
self.use_sparse_flash_attn = True
176180
self.use_var_len_flash_attn = False
177181
self.scale_qk_coeff = 1.0

0 commit comments

Comments
 (0)