Skip to content

Commit 70fdde8

Browse files
kausvfacebook-github-bot
authored andcommitted
Fix kernel test (pytorch#3026)
Summary: Pull Request resolved: pytorch#3026 Differential Revision: D75574459
1 parent 80a306b commit 70fdde8

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def __init__( # noqa C901
303303
sharded_t._local_shards[0].tensor
304304
for sharded_t in self._sharded_embedding_weight_ids
305305
]
306-
if self._sharded_embedding_weight_ids is not None
306+
if self._sharded_embedding_weight_ids
307307
else None
308308
)
309309

@@ -1439,7 +1439,13 @@ def _init_sharded_split_embedding_weights(
14391439
)
14401440
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
14411441
for emb_table in emb_table_config_copy:
1442-
emb_table.local_metadata.placement._device = torch.device("cpu")
1442+
none_throws(
1443+
none_throws(
1444+
emb_table.local_metadata,
1445+
f"local_metadata is None for emb_table: {emb_table.name}",
1446+
).placement,
1447+
"placement is None for local_metadata of emb table: {emb_table.name}",
1448+
)._device = torch.device("cpu")
14431449

14441450
pmt_sharded_t_list = create_virtual_sharded_tensors(
14451451
emb_table_config_copy,

torchrec/distributed/embedding_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
144144
key = get_key_from_embedding_table(embedding_table)
145145
assert embedding_table.use_virtual_table
146146

147-
assert embedding_table.global_metadata is not None and pg is not None
147+
assert embedding_table.global_metadata is not None
148148
global_metadata = copy.deepcopy(embedding_table.global_metadata)
149149
create_virtual_table_global_metadata(global_metadata, my_rank, param)
150150
key_to_global_metadata[key] = global_metadata

torchrec/distributed/tests/test_embedding_sharding.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
ShardedEmbeddingTable,
3636
)
3737
from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext
38+
from torchrec.distributed.types import ShardedTensorMetadata, ShardMetadata
3839
from torchrec.modules.embedding_configs import DataType, PoolingType
3940
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
4041

41-
WORLD_SIZE = 4
42+
WORLD_SIZE = 2
4243

4344

4445
class TestGetWeightedAverageCacheLoadFactor(unittest.TestCase):
@@ -546,24 +547,16 @@ class TestECBucketMetadata(unittest.TestCase):
546547
data_type=st.sampled_from([DataType.FP16, DataType.FP32]),
547548
embedding_dim=st.sampled_from(list(range(160, 320, 40))),
548549
total_bucket=st.sampled_from([14, 20, 32, 40]),
549-
my_rank=st.integers(min_value=0, max_value=WORLD_SIZE),
550+
my_rank=st.integers(min_value=0, max_value=WORLD_SIZE - 1),
550551
)
551552
@settings(max_examples=10, deadline=10000)
552553
def test_bucket_metadata_calculation_util(
553554
self, data_type: DataType, embedding_dim: int, total_bucket: int, my_rank: int
554555
) -> None:
555-
compute_kernels = [
556-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
557-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
558-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
559-
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
560-
]
556+
compute_kernels = [EmbeddingComputeKernel.SSD_VIRTUAL_TABLE] * WORLD_SIZE
561557
fused_params_groups = [
562558
{"cache_load_factor": 0.5},
563-
{"cache_load_factor": 0.5},
564-
{"cache_load_factor": 0.5},
565-
{"cache_load_factor": 0.5},
566-
]
559+
] * WORLD_SIZE
567560
tables = [
568561
ShardedEmbeddingTable(
569562
name=f"table_{i}",
@@ -579,8 +572,27 @@ def test_bucket_metadata_calculation_util(
579572
num_embeddings=10000 * (2 * i + 1),
580573
total_num_buckets=total_bucket,
581574
use_virtual_table=True,
575+
local_metadata=ShardMetadata(
576+
shard_offsets=[i * (10000 * (2 * i + 1) // WORLD_SIZE), 0],
577+
shard_sizes=[10000 * (2 * i + 1) // WORLD_SIZE, embedding_dim],
578+
placement=f"rank:{i}/cuda:{i}",
579+
),
580+
global_metadata=ShardedTensorMetadata(
581+
shards_metadata=[
582+
ShardMetadata(
583+
shard_offsets=[j * (10000 * (2 * i + 1) // WORLD_SIZE), 0],
584+
shard_sizes=[
585+
10000 * (2 * i + 1) // WORLD_SIZE,
586+
embedding_dim,
587+
],
588+
placement=f"rank:{j}/cuda:{j}",
589+
)
590+
for j in range(WORLD_SIZE)
591+
],
592+
size=torch.Size([10000 * (2 * i + 1), embedding_dim]),
593+
),
582594
)
583-
for i in range(len(compute_kernels))
595+
for i in range(WORLD_SIZE)
584596
]
585597

586598
# since we don't have access to _group_tables_per_rank

0 commit comments

Comments
 (0)