Skip to content

Commit 145441b

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Base planner class to decouple EmbeddingShardingPlanner from common planner components (#3228)
Summary: Pull Request resolved: #3228 Context: New planner implementation are extending from OSS planner (EmebddingShardingPlanner) to use common features present in the OSS planner. This is particularly dangers as it limits the development of OSS planner as we as it adds unnecessary dependency on extended planners. This diff: Introducing EmbeddingPlannerBase which adds common components and utils needed to support OSS planner as well as new planner implementations. Reviewed By: mserturk Differential Revision: D77712273 fbshipit-source-id: a883a82f78190786b994875e45de3a82895518de
1 parent fc08e73 commit 145441b

File tree

1 file changed

+121
-34
lines changed

1 file changed

+121
-34
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 121 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-strict
99

1010
import copy
11-
import hashlib
1211
import logging
1312
import time
1413
from functools import reduce
@@ -143,33 +142,24 @@ def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
143142
return merged_plan
144143

145144

146-
class EmbeddingShardingPlanner(ShardingPlanner):
145+
class EmbeddingPlannerBase(ShardingPlanner):
147146
"""
148-
Provides an optimized sharding plan for a given module with shardable parameters
149-
according to the provided sharders, topology, and constraints.
147+
Base class for embedding sharding planners that provides common initialization
148+
and shared functionality.
150149
151150
Args:
152151
topology (Optional[Topology]): the topology of the current process group.
153152
batch_size (Optional[int]): the batch size of the model.
154153
enumerator (Optional[Enumerator]): the enumerator to use
155154
storage_reservation (Optional[StorageReservation]): the storage reservation to use
156-
proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
157-
partitioner (Optional[Partitioner]): the partitioner to use
158-
performance_model (Optional[PerfModel]): the performance model to use
159155
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
160156
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
161157
for sharding.
162158
debug (bool): whether to print debug information.
163-
164-
Example::
165-
166-
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
167-
planner = EmbeddingShardingPlanner()
168-
plan = planner.plan(
169-
module=ebc,
170-
sharders=[EmbeddingBagCollectionSharder()],
171-
)
172-
159+
callbacks (Optional[List[Callable[[List[ShardingOption]], List[ShardingOption]]]):
160+
callback functions to apply to plans.
161+
timeout_seconds (Optional[int]): timeout for planning in seconds.
162+
heuristical_storage_reservation_percentage (float): percentage of storage to reserve for sparse archs.
173163
"""
174164

175165
def __init__(
@@ -178,16 +168,14 @@ def __init__(
178168
batch_size: Optional[int] = None,
179169
enumerator: Optional[Enumerator] = None,
180170
storage_reservation: Optional[StorageReservation] = None,
181-
proposer: Optional[Union[Proposer, List[Proposer]]] = None,
182-
partitioner: Optional[Partitioner] = None,
183-
performance_model: Optional[PerfModel] = None,
184171
stats: Optional[Union[Stats, List[Stats]]] = None,
185172
constraints: Optional[Dict[str, ParameterConstraints]] = None,
186173
debug: bool = True,
187174
callbacks: Optional[
188175
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
189176
] = None,
190177
timeout_seconds: Optional[int] = None,
178+
heuristical_storage_reservation_percentage: float = 0.15,
191179
) -> None:
192180
if topology is None:
193181
topology = Topology(
@@ -210,7 +198,116 @@ def __init__(
210198
self._storage_reservation: StorageReservation = (
211199
storage_reservation
212200
if storage_reservation
213-
else HeuristicalStorageReservation(percentage=0.15)
201+
else HeuristicalStorageReservation(
202+
percentage=heuristical_storage_reservation_percentage
203+
)
204+
)
205+
206+
if stats is not None:
207+
self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats
208+
else:
209+
self._stats = [EmbeddingStats()]
210+
211+
self._debug = debug
212+
self._callbacks: List[
213+
Callable[[List[ShardingOption]], List[ShardingOption]]
214+
] = ([] if callbacks is None else callbacks)
215+
if timeout_seconds is not None:
216+
assert timeout_seconds > 0, "Timeout must be positive"
217+
self._timeout_seconds = timeout_seconds
218+
219+
def collective_plan(
220+
self,
221+
module: nn.Module,
222+
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
223+
pg: Optional[dist.ProcessGroup] = None,
224+
) -> ShardingPlan:
225+
"""
226+
Call self.plan(...) on rank 0 and broadcast
227+
228+
Args:
229+
module (nn.Module): the module to shard.
230+
sharders (Optional[List[ModuleSharder[nn.Module]]]): the sharders to use for sharding
231+
pg (Optional[dist.ProcessGroup]): the process group to use for collective operations
232+
233+
Returns:
234+
ShardingPlan: the sharding plan for the module.
235+
"""
236+
if pg is None:
237+
assert dist.is_initialized(), (
238+
"The default process group is not yet initialized. "
239+
"Please call torch.distributed.init_process_group() first before invoking this. "
240+
"If you are not within a distributed environment, use the single rank version plan() instead."
241+
)
242+
pg = none_throws(dist.GroupMember.WORLD)
243+
244+
if sharders is None:
245+
sharders = get_default_sharders()
246+
return invoke_on_rank_and_broadcast_result(
247+
pg,
248+
0,
249+
self.plan,
250+
module,
251+
sharders,
252+
)
253+
254+
255+
class EmbeddingShardingPlanner(EmbeddingPlannerBase):
256+
"""
257+
Provides an optimized sharding plan for a given module with shardable parameters
258+
according to the provided sharders, topology, and constraints.
259+
260+
Args:
261+
topology (Optional[Topology]): the topology of the current process group.
262+
batch_size (Optional[int]): the batch size of the model.
263+
enumerator (Optional[Enumerator]): the enumerator to use
264+
storage_reservation (Optional[StorageReservation]): the storage reservation to use
265+
proposer (Optional[Union[Proposer, List[Proposer]]]): the proposer(s) to use
266+
partitioner (Optional[Partitioner]): the partitioner to use
267+
performance_model (Optional[PerfModel]): the performance model to use
268+
stats (Optional[Union[Stats, List[Stats]]]): the stats to use
269+
constraints (Optional[Dict[str, ParameterConstraints]]): per table constraints
270+
for sharding.
271+
debug (bool): whether to print debug information.
272+
273+
Example::
274+
275+
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device("meta"))
276+
planner = EmbeddingShardingPlanner()
277+
plan = planner.plan(
278+
module=ebc,
279+
sharders=[EmbeddingBagCollectionSharder()],
280+
)
281+
282+
"""
283+
284+
def __init__(
285+
self,
286+
topology: Optional[Topology] = None,
287+
batch_size: Optional[int] = None,
288+
enumerator: Optional[Enumerator] = None,
289+
storage_reservation: Optional[StorageReservation] = None,
290+
proposer: Optional[Union[Proposer, List[Proposer]]] = None,
291+
partitioner: Optional[Partitioner] = None,
292+
performance_model: Optional[PerfModel] = None,
293+
stats: Optional[Union[Stats, List[Stats]]] = None,
294+
constraints: Optional[Dict[str, ParameterConstraints]] = None,
295+
debug: bool = True,
296+
callbacks: Optional[
297+
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
298+
] = None,
299+
timeout_seconds: Optional[int] = None,
300+
) -> None:
301+
super().__init__(
302+
topology=topology,
303+
batch_size=batch_size,
304+
enumerator=enumerator,
305+
storage_reservation=storage_reservation,
306+
stats=stats,
307+
constraints=constraints,
308+
debug=debug,
309+
callbacks=callbacks,
310+
timeout_seconds=timeout_seconds,
214311
)
215312
self._partitioner: Partitioner = (
216313
partitioner if partitioner else GreedyPerfPartitioner()
@@ -227,24 +324,14 @@ def __init__(
227324
UniformProposer(),
228325
]
229326
self._perf_model: PerfModel = (
230-
performance_model if performance_model else NoopPerfModel(topology=topology)
327+
performance_model
328+
if performance_model
329+
else NoopPerfModel(topology=self._topology)
231330
)
232331

233-
if stats is not None:
234-
self._stats: List[Stats] = [stats] if not isinstance(stats, list) else stats
235-
else:
236-
self._stats = [EmbeddingStats()]
237-
238-
self._debug = debug
239332
self._num_proposals: int = 0
240333
self._num_plans: int = 0
241334
self._best_plan: Optional[List[ShardingOption]] = None
242-
self._callbacks: List[
243-
Callable[[List[ShardingOption]], List[ShardingOption]]
244-
] = ([] if callbacks is None else callbacks)
245-
if timeout_seconds is not None:
246-
assert timeout_seconds > 0, "Timeout must be positive"
247-
self._timeout_seconds = timeout_seconds
248335

249336
def collective_plan(
250337
self,

0 commit comments

Comments
 (0)