From eb7a67cf051d031264fecdac85ea73cad8431e45 Mon Sep 17 00:00:00 2001 From: Xing-lil Date: Fri, 30 May 2025 15:43:08 +0800 Subject: [PATCH 1/3] remove redundant reshard when mesh==1 --- .../auto_parallel/reshard/p_to_r_reshard_function.cc | 4 ++++ .../auto_parallel/reshard/p_to_s_reshard_function.cc | 4 ++++ .../auto_parallel/reshard/r_to_p_reshard_function.cc | 5 +++++ .../auto_parallel/reshard/r_to_s_reshard_function.cc | 4 ++++ .../auto_parallel/reshard/s_to_r_reshard_function.cc | 9 ++++++++- .../auto_parallel/reshard/s_to_s_reshard_function.cc | 4 ++++ 6 files changed, 29 insertions(+), 1 deletion(-) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index 31a302bd2990cb..ee1ca15ee30f41 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -50,6 +50,10 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_dist_attr = in.dist_attr(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); + if (in_process_ids.size() == 1) { + *out = in; + return; + } const auto& in_partial_status = in_dist_attr.partial_status(); auto in_reduce_type = in_partial_status.at(0); bool reduce_mean = false; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 85eaf0a9b4072b..4d3da5a7ece13e 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -121,6 +121,10 @@ void PToSReshardFunction::Eval(DeviceContext* dev_ctx, int out_split_axis = GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; int64_t num_of_process = in_process_mesh.size(); + if (num_of_process == 1) { + *out = in; + return; + } int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process; bool is_balanced_split = (num_of_padding == 0); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index dc19c04e6c2102..3adf0fe2c48a80 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -48,6 +48,11 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, const TensorDistAttr& out_dist_attr, DistTensor* out) { VLOG(3) << "Call " << Name(); + const auto& in_process_ids = in.dist_attr().process_mesh().process_ids(); + if (in_process_ids.size() == 1) { + *out = in; + return; + } const auto& out_process_mesh = out_dist_attr.process_mesh(); int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0]; const auto& in_reduce_type = out_dist_attr.partial_status().at(0); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index 235598974c31de..95dad59f00e35c 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -59,6 +59,10 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second; int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; + if (num_of_process == 1) { + *out = in; + return; + } VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis << ". Split will use axis " << mesh_axis << " of process_mesh." << " There will have " << num_of_process diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index 67c23fa2901186..d09db3031c480b 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -37,6 +37,10 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx, int64_t padding_nums, DenseTensor* out) { int64_t num_of_process = process_ids.size(); + if (num_of_process == 1) { + *out = in; + return; + } auto dtype = in.dtype(); // For balanced split to replicate, we need to do all gather first. @@ -109,7 +113,10 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_dims_mapping = in_dist_attr.dims_mapping(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); - + if (in_process_ids.size() == 1) { + *out = in; + return; + } int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t num_of_process = in_process_mesh.size(); int64_t num_of_padding = in.dims()[split_axis] % num_of_process; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc index 74851f3df90ebe..60a347c1693e64 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc @@ -54,6 +54,10 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, VLOG(3) << "Call " << Name(); const auto& in_process_mesh = in.dist_attr().process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); + if (in_process_ids.size() == 1) { + *out = in; + return; + } auto dtype = in.dtype(); const auto& logical_ddim = in.dims(); int64_t nranks = static_cast(in_process_ids.size()); From 0e10fb10d12e6e720ce585ed0a662554be715321 Mon Sep 17 00:00:00 2001 From: Xing-lil Date: Tue, 3 Jun 2025 11:37:56 +0800 Subject: [PATCH 2/3] update --- .../auto_parallel/reshard/p_to_r_reshard_function.cc | 2 +- .../auto_parallel/reshard/p_to_s_reshard_function.cc | 2 +- .../auto_parallel/reshard/r_to_p_reshard_function.cc | 2 +- .../auto_parallel/reshard/r_to_s_reshard_function.cc | 2 +- .../auto_parallel/reshard/s_to_r_reshard_function.cc | 3 ++- .../auto_parallel/reshard/s_to_s_reshard_function.cc | 2 +- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index ee1ca15ee30f41..baee328843bcf2 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -51,7 +51,7 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { - *out = in; + SetValue(out, in.value()); return; } const auto& in_partial_status = in_dist_attr.partial_status(); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 4d3da5a7ece13e..8f182c4820631b 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -122,7 +122,7 @@ void PToSReshardFunction::Eval(DeviceContext* dev_ctx, GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; int64_t num_of_process = in_process_mesh.size(); if (num_of_process == 1) { - *out = in; + SetValue(out, in.value()); return; } int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index 3adf0fe2c48a80..03e490876dcbdb 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -50,7 +50,7 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, VLOG(3) << "Call " << Name(); const auto& in_process_ids = in.dist_attr().process_mesh().process_ids(); if (in_process_ids.size() == 1) { - *out = in; + SetValue(out, in.value()); return; } const auto& out_process_mesh = out_dist_attr.process_mesh(); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index 95dad59f00e35c..41fcd2b3369e9f 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -60,7 +60,7 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; if (num_of_process == 1) { - *out = in; + SetValue(out, in.value()); return; } VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index d09db3031c480b..e61d63b8414e14 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -114,9 +114,10 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { - *out = in; + SetValue(out, in.value()); return; } + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; int64_t num_of_process = in_process_mesh.size(); int64_t num_of_padding = in.dims()[split_axis] % num_of_process; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc index 60a347c1693e64..7bdc940c708cad 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc @@ -55,7 +55,7 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, const auto& in_process_mesh = in.dist_attr().process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { - *out = in; + SetValue(out, in.value()); return; } auto dtype = in.dtype(); From 7b731a27bcdaed1261239a4ac80a8c33f8e00749 Mon Sep 17 00:00:00 2001 From: Xing-lil Date: Thu, 5 Jun 2025 13:48:52 +0800 Subject: [PATCH 3/3] update --- .../auto_parallel/reshard/p_to_r_reshard_function.cc | 1 + .../auto_parallel/reshard/p_to_s_reshard_function.cc | 1 + .../auto_parallel/reshard/r_to_p_reshard_function.cc | 1 + .../auto_parallel/reshard/r_to_s_reshard_function.cc | 1 + .../auto_parallel/reshard/s_to_r_reshard_function.cc | 5 +---- .../auto_parallel/reshard/s_to_s_reshard_function.cc | 1 + 6 files changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc index baee328843bcf2..3ba7b075fbe2f5 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.cc @@ -52,6 +52,7 @@ void PToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } const auto& in_partial_status = in_dist_attr.partial_status(); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 8f182c4820631b..5bab28e0b87cf7 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -123,6 +123,7 @@ void PToSReshardFunction::Eval(DeviceContext* dev_ctx, int64_t num_of_process = in_process_mesh.size(); if (num_of_process == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc index 03e490876dcbdb..303f7655b7d3c8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_p_reshard_function.cc @@ -51,6 +51,7 @@ void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx, const auto& in_process_ids = in.dist_attr().process_mesh().process_ids(); if (in_process_ids.size() == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } const auto& out_process_mesh = out_dist_attr.process_mesh(); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index 41fcd2b3369e9f..36b9d0ca1387b8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -61,6 +61,7 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, int64_t num_of_process = out_process_mesh.shape()[mesh_axis]; if (num_of_process == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index e61d63b8414e14..83ec3806134be5 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -37,10 +37,6 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx, int64_t padding_nums, DenseTensor* out) { int64_t num_of_process = process_ids.size(); - if (num_of_process == 1) { - *out = in; - return; - } auto dtype = in.dtype(); // For balanced split to replicate, we need to do all gather first. @@ -115,6 +111,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc index 7bdc940c708cad..ae2c3d79792c80 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_s_reshard_function.cc @@ -56,6 +56,7 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, const auto& in_process_ids = in_process_mesh.process_ids(); if (in_process_ids.size() == 1) { SetValue(out, in.value()); + SetDistProps(out, in.dims(), out_dist_attr); return; } auto dtype = in.dtype();