@@ -49,6 +49,8 @@ def get_env_device():
4949 return "gcu"
5050 elif "intel_hpu" in paddle .device .get_all_custom_device_type ():
5151 return "intel_hpu"
52+ elif "iluvatar_gpu" in paddle .device .get_all_custom_device_type ():
53+ return "iluvatar_gpu"
5254 elif paddle .is_compiled_with_rocm ():
5355 return "rocm"
5456 elif paddle .is_compiled_with_xpu ():
@@ -61,7 +63,7 @@ def get_env_device():
6163except ImportError :
6264 fused_rotary_position_embedding = None
6365try :
64- if get_env_device () in ["npu" , "mlu" , "gcu" ]:
66+ if get_env_device () in ["npu" , "mlu" , "gcu" , "iluvatar_gpu" ]:
6567 from paddle .base import core
6668
6769 for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
@@ -84,7 +86,7 @@ def fusion_rope(
8486 rotary_emb ,
8587 context_parallel_degree = - 1 ,
8688):
87- if get_env_device () not in ["gcu" , "intel_hpu" ]:
89+ if get_env_device () not in ["gcu" , "intel_hpu" , "iluvatar_gpu" ]:
8890 assert past_key_value is None , "fuse rotary not support cache kv for now"
8991 batch_size , seq_length , num_heads , head_dim = query_states .shape
9092 _ , kv_seq_len , num_key_value_heads , _ = key_states .shape
@@ -93,7 +95,7 @@ def fusion_rope(
9395 get_env_device () == "gpu"
9496 ), "context parallel only support cuda device for now"
9597 kv_seq_len *= context_parallel_degree
96- if get_env_device () not in ["gcu" , "intel_hpu" ]:
98+ if get_env_device () not in ["gcu" , "intel_hpu" , "iluvatar_gpu" ]:
9799 cos , sin = rotary_emb (value_states , seq_len = kv_seq_len )
98100 if get_env_device () == "npu" :
99101 query_states = core .eager ._run_custom_op ("fused_rope" , query_states , cos , sin )[
0 commit comments