From be9a09b33c10e807d426b05619c1be7f54e5649e Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Wed, 13 Aug 2025 06:45:49 -0700 Subject: [PATCH 1/2] Moving hash function to base planner class (#3238) Summary: Moving planner hash function into EmbeddingPlannerBase to give all planner implementations access to this feature Reviewed By: aporialiao Differential Revision: D78996456 --- torchrec/distributed/planner/planners.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 6084d08f9..06a9aea97 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -251,6 +251,22 @@ def collective_plan( sharders, ) + def hash_planner_context_inputs(self) -> int: + """ + Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. + These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. + + Returns: + Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. + """ + return hash_planner_context_inputs( + self._topology, + self._batch_size, + self._enumerator, + self._storage_reservation, + self._constraints, + ) + class EmbeddingShardingPlanner(EmbeddingPlannerBase): """ @@ -368,22 +384,6 @@ def collective_plan( sharders, ) - def hash_planner_context_inputs(self) -> int: - """ - Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. - These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. - - Returns: - Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. - """ - return hash_planner_context_inputs( - self._topology, - self._batch_size, - self._enumerator, - self._storage_reservation, - self._constraints, - ) - def plan( self, module: nn.Module, From d03cd98ef060f965cfe69651705d79f9e84a76dd Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Wed, 13 Aug 2025 06:45:49 -0700 Subject: [PATCH 2/2] Using EmbeddingPlannerBase for all planner implementations (#3232) Summary: As title Reviewed By: aporialiao Differential Revision: D78887917 --- torchrec/distributed/planner/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/__init__.py b/torchrec/distributed/planner/__init__.py index efd06bf02..3dd8289e2 100644 --- a/torchrec/distributed/planner/__init__.py +++ b/torchrec/distributed/planner/__init__.py @@ -21,6 +21,9 @@ - automatically building and selecting an optimized sharding plan. """ -from torchrec.distributed.planner.planners import EmbeddingShardingPlanner # noqa +from torchrec.distributed.planner.planners import ( # noqa # noqa + EmbeddingPlannerBase, + EmbeddingShardingPlanner, +) from torchrec.distributed.planner.types import ParameterConstraints, Topology # noqa from torchrec.distributed.planner.utils import bytes_to_gb, sharder_name # noqa