Skip to content

Commit 1d61b69

Browse files
kausvfacebook-github-bot
authored andcommitted
Account for cache load factor in memory estimate (pytorch#3035)
Summary: Pull Request resolved: pytorch#3035 The actual L1 cache allocated is determined by cache load factor and ceiled by max_l1_cache_size config. This diff addresses that to improve the memory estimate. Reviewed By: emlin, duduyi2013 Differential Revision: D75893983
1 parent 64030cf commit 1d61b69

File tree

3 files changed

+201
-2
lines changed

3 files changed

+201
-2
lines changed

torchrec/distributed/planner/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# with other devices such as the FE NIC.
3131
HBM_TO_DDR_MEM_BW: float = 32 * 1024 * 1024 * 1024 / 1000 # bytes/ms
3232
UVM_CACHING_RATIO: float = 0.2
33+
KV_CACHING_RATIO: float = 0.2
3334
BATCH_SIZE: int = 512
3435

3536
BATCHED_COPY_PERF_FACTOR: float = 2.455 # empirical studies

torchrec/distributed/planner/shard_estimators.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import logging
1111
import math
12+
from math import ceil
1213
from typing import cast, Dict, List, Optional, Tuple, Type
1314

1415
import torch
@@ -22,6 +23,7 @@
2223
FULL_BLOCK_EMB_DIM,
2324
HALF_BLOCK_PENALTY,
2425
kernel_bw_lookup,
26+
KV_CACHING_RATIO,
2527
QUARTER_BLOCK_PENALTY,
2628
UVM_CACHING_RATIO,
2729
WEIGHTED_KERNEL_MULTIPLIER,
@@ -1021,6 +1023,11 @@ def estimate(
10211023
if constraints and constraints.key_value_params
10221024
else None
10231025
)
1026+
kv_cache_load_factor: float = (
1027+
sharder.fused_params.get("cache_load_factor", KV_CACHING_RATIO)
1028+
if sharder.fused_params
1029+
else KV_CACHING_RATIO
1030+
)
10241031

10251032
# hardcoded as 8 bytes
10261033
# input indices can be of int32, but in TBE they get converted to int64 anyway
@@ -1065,6 +1072,7 @@ def estimate(
10651072
is_inference=self._is_inference,
10661073
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
10671074
key_value_params=key_value_params,
1075+
kv_cache_load_factor=kv_cache_load_factor,
10681076
)
10691077
for shard, storage in zip(sharding_option.shards, shard_storages):
10701078
shard.storage = storage
@@ -1134,6 +1142,7 @@ def calculate_shard_storages(
11341142
is_inference: bool = False,
11351143
multipass_prefetch_max_pass: Optional[int] = None,
11361144
key_value_params: Optional[KeyValueParams] = None,
1145+
kv_cache_load_factor: float = KV_CACHING_RATIO,
11371146
) -> List[Storage]:
11381147
"""
11391148
Calculates estimated storage sizes for each sharded tensor, comprised of input,
@@ -1191,7 +1200,6 @@ def calculate_shard_storages(
11911200
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value,
11921201
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
11931202
}:
1194-
# TODO(wangj): for ssd/dram kv, most likely we use absolute L1 cache size instead of caching ratio, as denominator is huge
11951203
hbm_storage = round(ddr_storage * caching_ratio)
11961204
table_cached = True
11971205

@@ -1225,7 +1233,15 @@ def calculate_shard_storages(
12251233
)
12261234

12271235
hbm_specific_sizes = [
1228-
(key_value_params.max_l1_cache_size or 0) * 1024 * 1024
1236+
min(
1237+
(key_value_params.max_l1_cache_size or 0) * 1024 * 1024,
1238+
ceil(
1239+
tensor.shape[0] # num_embeddings
1240+
* kv_cache_load_factor
1241+
* tensor.element_size() # size of one column
1242+
* tensor.shape[1], # number of columns in embedding
1243+
),
1244+
)
12291245
for _ in hbm_specific_sizes
12301246
]
12311247
ddr_specific_sizes = [

torchrec/distributed/planner/tests/test_planners.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,185 @@ def test_planner_with_virtual_table(self) -> None:
634634
self.assertTrue(
635635
any("Min HBM: 0.256 GB on ranks [0, 1]" in line for line in stats)
636636
)
637+
638+
constraints = {
639+
**{
640+
f"table_{i}": ParameterConstraints(
641+
sharding_types=["row_wise"],
642+
compute_kernels=["dram_virtual_table"],
643+
key_value_params=KeyValueParams(
644+
l2_cache_size=64, max_l1_cache_size=128
645+
),
646+
)
647+
for i in range(table_count // 2)
648+
},
649+
**{
650+
f"table_{i}": ParameterConstraints(
651+
cache_params=CacheParams(algorithm=CacheAlgorithm.LRU),
652+
)
653+
for i in range(table_count // 2, table_count)
654+
},
655+
}
656+
657+
topology = Topology(
658+
world_size=2,
659+
hbm_cap=1024 * 1024 * 1024 * 2,
660+
ddr_cap=1024 * 1024 * 1024 * 256,
661+
compute_device="cuda",
662+
)
663+
664+
planner = EmbeddingShardingPlanner(
665+
topology=topology,
666+
proposer=EmbeddingOffloadScaleupProposer(),
667+
constraints=constraints,
668+
)
669+
sharding_plan = planner.plan(
670+
module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore
671+
)
672+
673+
expected_ranks = [[0, 1], [0, 1], [0, 1], [0, 1]]
674+
ranks = [
675+
cast(List[int], param_shard.ranks)
676+
for param_shard in cast(
677+
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
678+
).values()
679+
]
680+
compute_kernels = {
681+
param_shard.compute_kernel
682+
for param_shard in cast(
683+
EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"]
684+
).values()
685+
}
686+
self.assertEqual(sorted(expected_ranks), sorted(ranks))
687+
self.assertSetEqual(
688+
{
689+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value,
690+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
691+
},
692+
compute_kernels,
693+
)
694+
695+
tables = [
696+
EmbeddingConfig(
697+
num_embeddings=10000,
698+
embedding_dim=64,
699+
name="table_" + str(i),
700+
feature_names=["feature_" + str(i)],
701+
use_virtual_table=True,
702+
total_num_buckets=10,
703+
)
704+
for i in range(table_count // 2)
705+
] + [
706+
EmbeddingConfig(
707+
num_embeddings=100_000,
708+
embedding_dim=64,
709+
name="table_" + str(i),
710+
feature_names=["feature_" + str(i)],
711+
)
712+
for i in range(table_count // 2, table_count)
713+
]
714+
715+
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
716+
717+
planner = EmbeddingShardingPlanner(
718+
topology=topology,
719+
proposer=EmbeddingOffloadScaleupProposer(),
720+
constraints=constraints,
721+
)
722+
723+
# L1 cache size > size of embedding table * default cache load factor
724+
725+
sharding_plan = planner.plan(
726+
module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore
727+
)
728+
for table_index in range(4):
729+
shards = sharding_plan.plan["sparse.ec"][
730+
f"table_{table_index}"
731+
].sharding_spec.shards
732+
self.assertEqual(len(shards), 2)
733+
self.assertEqual(shards[0].shard_offsets, [0, 0])
734+
self.assertEqual(
735+
shards[0].shard_sizes,
736+
[5000 if table_index < 2 else 50_000, 64],
737+
)
738+
self.assertEqual(
739+
shards[1].shard_offsets,
740+
[5000 if table_index < 2 else 50_000, 0],
741+
)
742+
self.assertEqual(
743+
shards[1].shard_sizes,
744+
[5000 if table_index < 2 else 50_000, 64],
745+
)
746+
stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table
747+
# L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
748+
# L2 cache size is 128MB per shard per table
749+
self.assertTrue(
750+
any(
751+
"dram_virtual_table: HBM: 0.002 GB, DDR: 256.0 GB" in line
752+
for line in stats
753+
)
754+
)
755+
self.assertTrue(
756+
any(
757+
"fused_uvm_caching: HBM: 0.011 GB, DDR: 0.048 GB" in line
758+
for line in stats
759+
)
760+
)
761+
self.assertTrue(
762+
any("Max HBM: 0.007 GB on ranks [0, 1]" in line for line in stats)
763+
)
764+
self.assertTrue(
765+
any("Min HBM: 0.007 GB on ranks [0, 1]" in line for line in stats)
766+
)
767+
768+
# Override cache load factor
769+
planner = EmbeddingShardingPlanner(
770+
topology=topology,
771+
proposer=EmbeddingOffloadScaleupProposer(),
772+
constraints=constraints,
773+
)
774+
sharding_plan = planner.plan(
775+
module=model,
776+
sharders=[ # pyre-ignore
777+
EmbeddingCollectionSharder(fused_params={"cache_load_factor": 0.5})
778+
],
779+
)
780+
for table_index in range(4):
781+
shards = sharding_plan.plan["sparse.ec"][
782+
f"table_{table_index}"
783+
].sharding_spec.shards
784+
self.assertEqual(len(shards), 2)
785+
self.assertEqual(shards[0].shard_offsets, [0, 0])
786+
self.assertEqual(
787+
shards[0].shard_sizes,
788+
[5000 if table_index < 2 else 50_000, 64],
789+
)
790+
self.assertEqual(
791+
shards[1].shard_offsets,
792+
[5000 if table_index < 2 else 50_000, 0],
793+
)
794+
self.assertEqual(
795+
shards[1].shard_sizes,
796+
[5000 if table_index < 2 else 50_000, 64],
797+
)
798+
stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table
799+
# L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
800+
# L2 cache size is 128MB per shard per table
801+
self.assertTrue(
802+
any(
803+
"dram_virtual_table: HBM: 0.005 GB, DDR: 256.0 GB" in line
804+
for line in stats
805+
)
806+
)
807+
self.assertTrue(
808+
any(
809+
"fused_uvm_caching: HBM: 0.027 GB, DDR: 0.048 GB" in line
810+
for line in stats
811+
)
812+
)
813+
self.assertTrue(
814+
any("Max HBM: 0.016 GB on ranks [0, 1]" in line for line in stats)
815+
)
816+
self.assertTrue(
817+
any("Min HBM: 0.016 GB on ranks [0, 1]" in line for line in stats)
818+
)

0 commit comments

Comments
 (0)