35
35
ShardedEmbeddingTable ,
36
36
)
37
37
from torchrec .distributed .sharding .sequence_sharding import SequenceShardingContext
38
+ from torchrec .distributed .types import ShardedTensorMetadata , ShardMetadata
38
39
from torchrec .modules .embedding_configs import DataType , PoolingType
39
40
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
40
41
41
- WORLD_SIZE = 4
42
+ WORLD_SIZE = 2
42
43
43
44
44
45
class TestGetWeightedAverageCacheLoadFactor (unittest .TestCase ):
@@ -546,24 +547,16 @@ class TestECBucketMetadata(unittest.TestCase):
546
547
data_type = st .sampled_from ([DataType .FP16 , DataType .FP32 ]),
547
548
embedding_dim = st .sampled_from (list (range (160 , 320 , 40 ))),
548
549
total_bucket = st .sampled_from ([14 , 20 , 32 , 40 ]),
549
- my_rank = st .integers (min_value = 0 , max_value = WORLD_SIZE ),
550
+ my_rank = st .integers (min_value = 0 , max_value = WORLD_SIZE - 1 ),
550
551
)
551
552
@settings (max_examples = 10 , deadline = 10000 )
552
553
def test_bucket_metadata_calculation_util (
553
554
self , data_type : DataType , embedding_dim : int , total_bucket : int , my_rank : int
554
555
) -> None :
555
- compute_kernels = [
556
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE ,
557
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE ,
558
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE ,
559
- EmbeddingComputeKernel .SSD_VIRTUAL_TABLE ,
560
- ]
556
+ compute_kernels = [EmbeddingComputeKernel .SSD_VIRTUAL_TABLE ] * WORLD_SIZE
561
557
fused_params_groups = [
562
558
{"cache_load_factor" : 0.5 },
563
- {"cache_load_factor" : 0.5 },
564
- {"cache_load_factor" : 0.5 },
565
- {"cache_load_factor" : 0.5 },
566
- ]
559
+ ] * WORLD_SIZE
567
560
tables = [
568
561
ShardedEmbeddingTable (
569
562
name = f"table_{ i } " ,
@@ -579,8 +572,27 @@ def test_bucket_metadata_calculation_util(
579
572
num_embeddings = 10000 * (2 * i + 1 ),
580
573
total_num_buckets = total_bucket ,
581
574
use_virtual_table = True ,
575
+ local_metadata = ShardMetadata (
576
+ shard_offsets = [i * (10000 * (2 * i + 1 ) // WORLD_SIZE ), 0 ],
577
+ shard_sizes = [10000 * (2 * i + 1 ) // WORLD_SIZE , embedding_dim ],
578
+ placement = f"rank:{ i } /cuda:{ i } " ,
579
+ ),
580
+ global_metadata = ShardedTensorMetadata (
581
+ shards_metadata = [
582
+ ShardMetadata (
583
+ shard_offsets = [j * (10000 * (2 * i + 1 ) // WORLD_SIZE ), 0 ],
584
+ shard_sizes = [
585
+ 10000 * (2 * i + 1 ) // WORLD_SIZE ,
586
+ embedding_dim ,
587
+ ],
588
+ placement = f"rank:{ j } /cuda:{ j } " ,
589
+ )
590
+ for j in range (WORLD_SIZE )
591
+ ],
592
+ size = torch .Size ([10000 * (2 * i + 1 ), embedding_dim ]),
593
+ ),
582
594
)
583
- for i in range (len ( compute_kernels ) )
595
+ for i in range (WORLD_SIZE )
584
596
]
585
597
586
598
# since we don't have access to _group_tables_per_rank
0 commit comments