Skip to content

Commit 80f2af6

Browse files
kausvfacebook-github-bot
authored andcommitted
Handle Virtual Table sizing in planner
Summary: Virtual Embedding Tables is a new feature in ZCH v.Next. These do not allocate the embedding table memory on init. This allows us to have larger embedding tables and paired with Embedding Offloading kernels to DRAM and SSD. Since the memory is kernel dependent, we check for the specific kernel and override the hbm and dram sizes accordingly. https://docs.google.com/document/d/1NjtP01PSOHKwyRxAicV7wrdWPVI1NN9A3aPYuBaOpsg/edit?tab=t.0 Differential Revision: D74557713
1 parent 67ebc8c commit 80f2af6

File tree

4 files changed

+251
-19
lines changed

4 files changed

+251
-19
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def plan(
422422
sharders=sharders,
423423
debug=self._debug,
424424
)
425+
logger.info(f"Found sharding plan {sharding_plan}")
425426
return sharding_plan
426427
else:
427428
global_storage_capacity = reduce(

torchrec/distributed/planner/shard_estimators.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torchrec.optim as trec_optim
16+
from libfb.py.pyre import none_throws
1617
from torch import nn
1718
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1819
from torchrec.distributed.planner.constants import (
@@ -41,6 +42,7 @@
4142
from torchrec.distributed.types import (
4243
CacheStatistics,
4344
CommOp,
45+
KeyValueParams,
4446
ModuleSharder,
4547
PipelineType,
4648
ShardingType,
@@ -998,23 +1000,29 @@ def estimate(
9981000
if hasattr(sharder, "fused_params") and sharder.fused_params
9991001
else None
10001002
)
1001-
1002-
num_poolings = (
1003-
cast(List[float], self._constraints[sharding_option.name].num_poolings)
1003+
constraints: Optional[ParameterConstraints] = (
1004+
self._constraints.get(sharding_option.name, None)
10041005
if self._constraints
1005-
and self._constraints.get(sharding_option.name)
1006-
and self._constraints[sharding_option.name].num_poolings
1006+
else None
1007+
)
1008+
num_poolings = (
1009+
constraints.num_poolings
1010+
if constraints and constraints.num_poolings
10071011
else [1.0] * sharding_option.num_inputs
10081012
)
10091013
assert len(num_poolings) == sharding_option.num_inputs
10101014
batch_sizes = (
1011-
cast(List[int], self._constraints[sharding_option.name].batch_sizes)
1012-
if self._constraints
1013-
and self._constraints.get(sharding_option.name)
1014-
and self._constraints[sharding_option.name].batch_sizes
1015+
constraints.batch_sizes
1016+
if constraints and constraints.batch_sizes
10151017
else [sharding_option.batch_size] * sharding_option.num_inputs
10161018
)
10171019

1020+
key_value_params: Optional[KeyValueParams] = (
1021+
constraints.key_value_params
1022+
if constraints and constraints.key_value_params
1023+
else None
1024+
)
1025+
10181026
# hardcoded as 8 bytes
10191027
# input indices can be of int32, but in TBE they get converted to int64 anyway
10201028
input_data_type_size = BIGINT_DTYPE
@@ -1057,6 +1065,7 @@ def estimate(
10571065
count_ephemeral_storage_cost=self._run_embedding_at_peak_memory,
10581066
is_inference=self._is_inference,
10591067
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
1068+
key_value_params=key_value_params,
10601069
)
10611070
for shard, storage in zip(sharding_option.shards, shard_storages):
10621071
shard.storage = storage
@@ -1125,6 +1134,7 @@ def calculate_shard_storages(
11251134
count_ephemeral_storage_cost: bool = False,
11261135
is_inference: bool = False,
11271136
multipass_prefetch_max_pass: Optional[int] = None,
1137+
key_value_params: Optional[KeyValueParams] = None,
11281138
) -> List[Storage]:
11291139
"""
11301140
Calculates estimated storage sizes for each sharded tensor, comprised of input,
@@ -1151,6 +1161,7 @@ def calculate_shard_storages(
11511161
output_data_type_size (int): number of bytes of output data type.
11521162
pipeline_type: PipelineType: pipeline type if for training.
11531163
is_inference: bool, whether the model is for inference.
1164+
key_value_params (Optional[KeyValueParams]): fused params for SSD/DRAM KV cache.
11541165
11551166
Returns:
11561167
List[Storage]: storage object for each device in topology.
@@ -1184,13 +1195,6 @@ def calculate_shard_storages(
11841195
# TODO(wangj): for ssd/dram kv, most likely we use absolute L1 cache size instead of caching ratio, as denominator is huge
11851196
hbm_storage = round(ddr_storage * caching_ratio)
11861197
table_cached = True
1187-
if compute_kernel in {
1188-
EmbeddingComputeKernel.KEY_VALUE.value,
1189-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1190-
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1191-
}:
1192-
# TODO(wangj): update this to the L2 cache usage and add SSD usage
1193-
ddr_storage = 0
11941198

11951199
optimizer_class = getattr(tensor, "_optimizer_classes", [None])[0]
11961200

@@ -1212,6 +1216,36 @@ def calculate_shard_storages(
12121216
is_inference=is_inference,
12131217
)
12141218

1219+
if compute_kernel in {
1220+
EmbeddingComputeKernel.KEY_VALUE.value,
1221+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
1222+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
1223+
}:
1224+
assert (
1225+
key_value_params is not None
1226+
), "key_value_params cannot be None in ParameterConstraints of planner for embedding compute kernel: {}".format(
1227+
compute_kernel
1228+
)
1229+
assert (
1230+
key_value_params.max_l1_cache_size is not None
1231+
), "key_value_params.max_l1_cache_size cannot be None in ParameterConstraints of planner for embedding compute kernel: {}".format(
1232+
compute_kernel
1233+
)
1234+
assert (
1235+
key_value_params.l2_cache_size is not None
1236+
), "key_value_params.l2_cache_size cannot be None in ParameterConstraints of planner for embedding compute kernel: {}".format(
1237+
compute_kernel
1238+
)
1239+
# TODO(wangj): is this expected?
1240+
hbm_specific_sizes = [
1241+
none_throws(key_value_params.max_l1_cache_size) * 1024 * 1024
1242+
for _ in hbm_specific_sizes
1243+
]
1244+
ddr_specific_sizes = [
1245+
none_throws(key_value_params.l2_cache_size) * 1024 * 1024 * 1024
1246+
for _ in ddr_specific_sizes
1247+
]
1248+
12151249
hbm_sizes: List[int] = [
12161250
(
12171251
hbm_specific_size

torchrec/distributed/planner/tests/test_planners.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212

1313
import torch
1414
from torch import nn
15+
from torchrec import EmbeddingConfig
16+
from torchrec.distributed.embedding import EmbeddingCollectionSharder
1517
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1618
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
1719
from torchrec.distributed.planner import ParameterConstraints
1820
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
1921
from torchrec.distributed.planner.proposers import EmbeddingOffloadScaleupProposer
22+
from torchrec.distributed.planner.stats import EmbeddingStats
2023
from torchrec.distributed.planner.types import (
2124
PlannerError,
2225
PlannerErrorType,
@@ -31,6 +34,7 @@
3134
CacheParams,
3235
DataType,
3336
EmbeddingModuleShardingPlan,
37+
KeyValueParams,
3438
ModuleSharder,
3539
ShardingPlan,
3640
ShardingType,
@@ -359,3 +363,149 @@ def test_auto_sharder_solution(self) -> None:
359363
self.assertSetEqual(
360364
{EmbeddingComputeKernel.FUSED_UVM_CACHING.value}, compute_kernels
361365
)
366+
367+
def test_planner_with_virtual_table(self) -> None:
368+
table_count = 4
369+
tables = [
370+
EmbeddingConfig(
371+
num_embeddings=1_125_899_902_955_520,
372+
embedding_dim=64,
373+
name="table_" + str(i),
374+
feature_names=["feature_" + str(i)],
375+
use_virtual_table=True,
376+
total_num_buckets=3_991_680,
377+
)
378+
for i in range(table_count // 2)
379+
] + [
380+
EmbeddingConfig(
381+
num_embeddings=100_000,
382+
embedding_dim=64,
383+
name="table_" + str(i),
384+
feature_names=["feature_" + str(i)],
385+
)
386+
for i in range(table_count // 2, table_count)
387+
]
388+
print(tables)
389+
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
390+
391+
constraints = {
392+
**{
393+
f"table_{i}": ParameterConstraints(
394+
sharding_types=["row_wise"],
395+
compute_kernels=["dram_virtual_table"],
396+
)
397+
for i in range(table_count // 2)
398+
},
399+
**{
400+
f"table_{i}": ParameterConstraints(
401+
enforce_hbm=False,
402+
)
403+
for i in range(table_count // 2, table_count)
404+
},
405+
}
406+
planner = EmbeddingShardingPlanner(
407+
topology=self.topology,
408+
proposer=EmbeddingOffloadScaleupProposer(),
409+
constraints=constraints,
410+
)
411+
412+
self.assertRaisesRegex(
413+
AssertionError,
414+
"key_value_params cannot be None in ParameterConstraints of planner for embedding compute kernel: dram_virtual_table",
415+
planner.plan,
416+
module=model,
417+
sharders=[EmbeddingCollectionSharder()],
418+
)
419+
420+
constraints = {
421+
**{
422+
f"table_{i}": ParameterConstraints(
423+
sharding_types=["row_wise"],
424+
compute_kernels=["dram_virtual_table"],
425+
key_value_params=KeyValueParams(
426+
l2_cache_size=64, max_l1_cache_size=128
427+
),
428+
)
429+
for i in range(table_count // 2)
430+
},
431+
**{
432+
f"table_{i}": ParameterConstraints(
433+
cache_params=CacheParams(algorithm=CacheAlgorithm.LRU),
434+
)
435+
for i in range(table_count // 2, table_count)
436+
},
437+
}
438+
439+
topology = Topology(
440+
world_size=2,
441+
hbm_cap=1024 * 1024 * 1024 * 2,
442+
ddr_cap=1024 * 1024 * 1024 * 256,
443+
compute_device="cuda",
444+
)
445+
446+
planner = EmbeddingShardingPlanner(
447+
topology=topology,
448+
proposer=EmbeddingOffloadScaleupProposer(),
449+
constraints=constraints,
450+
)
451+
sharding_plan = planner.plan(
452+
module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore
453+
)
454+
455+
expected_ranks = [[0, 1], [0, 1], [0, 1], [0, 1]]
456+
ranks = [
457+
cast(List[int], param_shard.ranks)
458+
for param_shard in cast(
459+
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
460+
).values()
461+
]
462+
compute_kernels = {
463+
param_shard.compute_kernel
464+
for param_shard in cast(
465+
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
466+
).values()
467+
}
468+
self.assertEqual(sorted(expected_ranks), sorted(ranks))
469+
self.assertSetEqual(
470+
{
471+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
472+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
473+
},
474+
compute_kernels,
475+
)
476+
477+
for table_index in range(4):
478+
shards = sharding_plan.plan["sparse.ec"][f"table_{table_index}"].sharding_spec.shards # pyre-ignore
479+
self.assertEqual(len(shards), 2)
480+
self.assertEqual(shards[0].shard_offsets, [0, 0])
481+
self.assertEqual(
482+
shards[0].shard_sizes,
483+
[562949951477760 if table_index < 2 else 50_000, 64],
484+
)
485+
self.assertEqual(
486+
shards[1].shard_offsets,
487+
[562949951477760 if table_index < 2 else 50_000, 0],
488+
)
489+
self.assertEqual(
490+
shards[1].shard_sizes,
491+
[562949951477760 if table_index < 2 else 50_000, 64],
492+
)
493+
stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table
494+
self.assertTrue(
495+
any(
496+
"dram_virtual_table: HBM: 0.501 GB, DDR: 256.0 GB" in line
497+
for line in stats
498+
)
499+
)
500+
self.assertTrue(
501+
any(
502+
"fused_uvm_caching: HBM: 0.011 GB, DDR: 0.048 GB" in line
503+
for line in stats
504+
)
505+
)
506+
self.assertTrue(
507+
any("Max HBM: 0.256 GB on ranks [0, 1]" in line for line in stats)
508+
)
509+
self.assertTrue(
510+
any("Min HBM: 0.256 GB on ranks [0, 1]" in line for line in stats)
511+
)

torchrec/distributed/test_utils/test_model.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
from tensordict import TensorDict
18+
from torchrec import EmbeddingCollection
1819
from torchrec.distributed.embedding import EmbeddingCollectionSharder
1920
from torchrec.distributed.embedding_tower_sharding import (
2021
EmbeddingTowerCollectionSharder,
@@ -1207,6 +1208,47 @@ def _post_sparsenn_forward(
12071208
)
12081209

12091210

1211+
class TestECSparseArch(nn.Module):
1212+
"""
1213+
Basic nn.Module for testing
1214+
1215+
Args:
1216+
tables
1217+
device
1218+
1219+
Call Args:
1220+
features
1221+
1222+
Returns:
1223+
KeyedTensor
1224+
"""
1225+
1226+
def __init__(
1227+
self,
1228+
tables: List[EmbeddingConfig],
1229+
# weighted_tables: List[EmbeddingBagConfig],
1230+
device: Optional[torch.device] = None,
1231+
# max_feature_lengths: Optional[Dict[str, int]] = None,
1232+
) -> None:
1233+
super().__init__()
1234+
if device is None:
1235+
device = torch.device("cpu")
1236+
self.ec: EmbeddingCollection = EmbeddingCollection(
1237+
tables=tables,
1238+
device=device,
1239+
)
1240+
1241+
def forward(
1242+
self,
1243+
features: KeyedJaggedTensor,
1244+
weighted_features: Optional[KeyedJaggedTensor] = None,
1245+
batch_size: Optional[int] = None,
1246+
) -> KeyedTensor:
1247+
ec = self.ec(features)
1248+
result = _post_sparsenn_forward(ec, None, None, batch_size)
1249+
return result
1250+
1251+
12101252
class TestSparseArch(nn.Module):
12111253
"""
12121254
Basic nn.Module for testing
@@ -1349,7 +1391,7 @@ class TestSparseNN(TestSparseNNBase, CopyableMixin):
13491391

13501392
def __init__(
13511393
self,
1352-
tables: List[EmbeddingBagConfig],
1394+
tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]],
13531395
num_float_features: int = 10,
13541396
weighted_tables: Optional[List[EmbeddingBagConfig]] = None,
13551397
embedding_groups: Optional[Dict[str, List[str]]] = None,
@@ -1373,14 +1415,19 @@ def __init__(
13731415
self.dense = TestDenseArch(num_float_features, dense_device)
13741416
if zch:
13751417
self.sparse: nn.Module = TestSparseArchZCH(
1376-
tables,
1418+
tables, # pyre-ignore
13771419
weighted_tables,
13781420
torch.device("meta"),
13791421
return_remapped=True,
13801422
)
1423+
elif isinstance(tables[0], EmbeddingConfig):
1424+
self.sparse = TestECSparseArch(
1425+
tables, # pyre-ignore [6]
1426+
sparse_device,
1427+
)
13811428
else:
13821429
self.sparse = TestSparseArch(
1383-
tables,
1430+
tables, # pyre-ignore
13841431
weighted_tables,
13851432
sparse_device,
13861433
max_feature_lengths,

0 commit comments

Comments
 (0)