Skip to content

Commit 0c3c165

Browse files
authored
Fix fp8 in dtype_byte_size
1 parent 5c482b6 commit 0c3c165

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

paddlenlp/transformers/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,8 @@ def dtype_byte_size(dtype):
850850
"""
851851
if dtype == paddle.bool:
852852
return 1 / 8
853+
if dtype == paddle.float8_e4m3fn or dtype == paddle.float8_e5m2:
854+
return 1
853855
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
854856
if bit_search is None:
855857
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")

0 commit comments

Comments
 (0)