@@ -633,6 +633,11 @@ class TrainingArguments:
633633 },
634634 )
635635
636+ load_sharded_model_remap_parameter_name : bool = field (
637+ default = False ,
638+ metadata = {"help" : "Whether to remap parameter name when load_sharded_model = true." },
639+ )
640+
636641 tensor_parallel_degree : int = field (
637642 default = - 1 ,
638643 metadata = {
@@ -2039,6 +2044,11 @@ def _post_init_parallel_degree(self):
20392044 sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
20402045 )
20412046
2047+ if expert_parallel_degree > 1 :
2048+ assert (
2049+ self .expert_tensor_parallel_degree <= 1
2050+ ), "expert_tensor_parallel_degree > 1 is not supported when expert_parallel_degree > 1"
2051+
20422052 assert not (
20432053 self .data_parallel_degree > 1 and expert_parallel_degree > 1
20442054 ), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is { self .data_parallel_degree } ."
@@ -2227,6 +2237,17 @@ def pipeline_parallel_rank(self):
22272237 else :
22282238 return 0
22292239
2240+ @property
2241+ def expert_parallel_rank (self ):
2242+ if self .use_hybrid_parallel :
2243+ hcg = fleet .get_hybrid_communicate_group ()
2244+ if hasattr (hcg , "get_expert_parallel_rank" ):
2245+ return max (hcg .get_expert_parallel_rank (), 0 )
2246+ else :
2247+ return 0
2248+ else :
2249+ return 0
2250+
22302251 @property
22312252 def context_parallel_rank (self ):
22322253 if self .use_hybrid_parallel :
@@ -2252,7 +2273,7 @@ def optimizer_name_suffix(self):
22522273 name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
22532274 if self .sharding_parallel_degree > 1 :
22542275 name .append (self ._format_name ("shard" , self .sharding_parallel_rank , self .sharding_parallel_degree ))
2255- if self .use_expert_parallel :
2276+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
22562277 name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
22572278 return "_" .join (name )
22582279 else :
@@ -2268,7 +2289,7 @@ def weight_name_suffix(self):
22682289 name .append (self ._format_name ("tp" , self .tensor_parallel_rank , self .tensor_parallel_degree ))
22692290 if self .pipeline_parallel_degree > 1 :
22702291 name .append (self ._format_name ("pp" , self .pipeline_parallel_rank , self .pipeline_parallel_degree ))
2271- if self .use_expert_parallel :
2292+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
22722293 name .append (self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree ))
22732294 return "_" .join (name )
22742295
@@ -2277,7 +2298,9 @@ def weight_name_suffix(self):
22772298 return self ._format_name ("moe" , self .data_parallel_rank , self .data_parallel_degree )
22782299 return None
22792300
2280- def sharded_name_suffix (self , shard_id = None , pp_id = None , moe_id = None ):
2301+ def sharded_name_suffix (self , shard_id = None , pp_id = None , moe_id = None , sharding_parallel_degree = None ):
2302+ if sharding_parallel_degree is None :
2303+ sharding_parallel_degree = self .sharding_parallel_degree
22812304 if self .use_hybrid_parallel :
22822305 name = []
22832306 if self .tensor_parallel_degree > 1 :
@@ -2287,12 +2310,12 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
22872310 pp_id = self .pipeline_parallel_rank
22882311 assert isinstance (pp_id , int )
22892312 name .append (self ._format_name ("pp" , pp_id , self .pipeline_parallel_degree ))
2290- if self . sharding_parallel_degree > 1 :
2313+ if sharding_parallel_degree > 1 :
22912314 if shard_id is None :
22922315 shard_id = self .sharding_parallel_rank
22932316 assert isinstance (shard_id , int )
2294- name .append (self ._format_name ("shard" , shard_id , self . sharding_parallel_degree ))
2295- if self .use_expert_parallel :
2317+ name .append (self ._format_name ("shard" , shard_id , sharding_parallel_degree ))
2318+ if self .use_expert_parallel and self . expert_parallel_degree <= 1 :
22962319 if moe_id is None :
22972320 moe_id = self .data_parallel_rank
22982321 assert isinstance (moe_id , int )
@@ -2418,9 +2441,7 @@ def should_save_sharding_stage1_model(self):
24182441 def should_load_sharding_stage1_model (self ):
24192442 if self .enable_auto_parallel :
24202443 return False
2421- return (
2422- ShardingOption .SHARD_OP in self .sharding and self .sharding_parallel_degree > 1 and self .load_sharded_model
2423- )
2444+ return self .load_sharded_model
24242445
24252446 @property
24262447 def should_load_dataset (self ):
0 commit comments