diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index 31a302bd2990cb..3ba7b075fbe2f5 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -50,6 +50,11 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_dist_attr = in.dist_attr(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); + if (in_process_ids.size() == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } const auto& in_partial_status = in_dist_attr.partial_status(); auto in_reduce_type = in_partial_status.at(0); bool reduce_mean = false; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 85eaf0a9b4072b..5bab28e0b87cf7 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -121,6 +121,11 @@ void PToSReshardFunction::Eval(DeviceContext* dev_ctx, int out_split_axis = GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; int64_t num_of_process = in_process_mesh.size(); + if (num_of_process == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process; bool is_balanced_split = (num_of_padding == 0); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index dc19c04e6c2102..303f7655b7d3c8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -48,6 +48,12 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, const TensorDistAttr& out_dist_attr, DistTensor* out) { VLOG(3) << "Call " << Name(); + const auto& in_process_ids = in.dist_attr().process_mesh().process_ids(); + if (in_process_ids.size() == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } const auto& out_process_mesh = out_dist_attr.process_mesh(); int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0]; const auto& in_reduce_type = out_dist_attr.partial_status().at(0); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index 235598974c31de..36b9d0ca1387b8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -59,6 +59,11 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; + if (num_of_process == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis << ". Split will use axis " << mesh_axis << " of process_mesh." << " There will have " << num_of_process diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index 67c23fa2901186..83ec3806134be5 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -109,6 +109,11 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_dims_mapping = in_dist_attr.dims_mapping(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); + if (in_process_ids.size() == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t num_of_process = in_process_mesh.size(); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc index 74851f3df90ebe..ae2c3d79792c80 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc @@ -54,6 +54,11 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, VLOG(3) << "Call " << Name(); const auto& in_process_mesh = in.dist_attr().process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); + if (in_process_ids.size() == 1) { + SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); + return; + } auto dtype = in.dtype(); const auto& logical_ddim = in.dims(); int64_t nranks = static_cast(in_process_ids.size());