@@ -634,3 +634,185 @@ def test_planner_with_virtual_table(self) -> None:
634
634
self .assertTrue (
635
635
any ("Min HBM: 0.256 GB on ranks [0, 1]" in line for line in stats )
636
636
)
637
+
638
+ constraints = {
639
+ ** {
640
+ f"table_{ i } " : ParameterConstraints (
641
+ sharding_types = ["row_wise" ],
642
+ compute_kernels = ["dram_virtual_table" ],
643
+ key_value_params = KeyValueParams (
644
+ l2_cache_size = 64 , max_l1_cache_size = 128
645
+ ),
646
+ )
647
+ for i in range (table_count // 2 )
648
+ },
649
+ ** {
650
+ f"table_{ i } " : ParameterConstraints (
651
+ cache_params = CacheParams (algorithm = CacheAlgorithm .LRU ),
652
+ )
653
+ for i in range (table_count // 2 , table_count )
654
+ },
655
+ }
656
+
657
+ topology = Topology (
658
+ world_size = 2 ,
659
+ hbm_cap = 1024 * 1024 * 1024 * 2 ,
660
+ ddr_cap = 1024 * 1024 * 1024 * 256 ,
661
+ compute_device = "cuda" ,
662
+ )
663
+
664
+ planner = EmbeddingShardingPlanner (
665
+ topology = topology ,
666
+ proposer = EmbeddingOffloadScaleupProposer (),
667
+ constraints = constraints ,
668
+ )
669
+ sharding_plan = planner .plan (
670
+ module = model , sharders = [EmbeddingCollectionSharder ()] # pyre-ignore
671
+ )
672
+
673
+ expected_ranks = [[0 , 1 ], [0 , 1 ], [0 , 1 ], [0 , 1 ]]
674
+ ranks = [
675
+ cast (List [int ], param_shard .ranks )
676
+ for param_shard in cast (
677
+ EmbeddingModuleShardingPlan , sharding_plan .plan ["sparse.ec" ]
678
+ ).values ()
679
+ ]
680
+ compute_kernels = {
681
+ param_shard .compute_kernel
682
+ for param_shard in cast (
683
+ EmbeddingModuleShardingPlan , sharding_plan .plan ["sparse.ec" ]
684
+ ).values ()
685
+ }
686
+ self .assertEqual (sorted (expected_ranks ), sorted (ranks ))
687
+ self .assertSetEqual (
688
+ {
689
+ EmbeddingComputeKernel .DRAM_VIRTUAL_TABLE .value ,
690
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
691
+ },
692
+ compute_kernels ,
693
+ )
694
+
695
+ tables = [
696
+ EmbeddingConfig (
697
+ num_embeddings = 10000 ,
698
+ embedding_dim = 64 ,
699
+ name = "table_" + str (i ),
700
+ feature_names = ["feature_" + str (i )],
701
+ use_virtual_table = True ,
702
+ total_num_buckets = 10 ,
703
+ )
704
+ for i in range (table_count // 2 )
705
+ ] + [
706
+ EmbeddingConfig (
707
+ num_embeddings = 100_000 ,
708
+ embedding_dim = 64 ,
709
+ name = "table_" + str (i ),
710
+ feature_names = ["feature_" + str (i )],
711
+ )
712
+ for i in range (table_count // 2 , table_count )
713
+ ]
714
+
715
+ model = TestSparseNN (tables = tables , sparse_device = torch .device ("meta" ))
716
+
717
+ planner = EmbeddingShardingPlanner (
718
+ topology = topology ,
719
+ proposer = EmbeddingOffloadScaleupProposer (),
720
+ constraints = constraints ,
721
+ )
722
+
723
+ # L1 cache size > size of embedding table * default cache load factor
724
+
725
+ sharding_plan = planner .plan (
726
+ module = model , sharders = [EmbeddingCollectionSharder ()] # pyre-ignore
727
+ )
728
+ for table_index in range (4 ):
729
+ shards = sharding_plan .plan ["sparse.ec" ][
730
+ f"table_{ table_index } "
731
+ ].sharding_spec .shards
732
+ self .assertEqual (len (shards ), 2 )
733
+ self .assertEqual (shards [0 ].shard_offsets , [0 , 0 ])
734
+ self .assertEqual (
735
+ shards [0 ].shard_sizes ,
736
+ [5000 if table_index < 2 else 50_000 , 64 ],
737
+ )
738
+ self .assertEqual (
739
+ shards [1 ].shard_offsets ,
740
+ [5000 if table_index < 2 else 50_000 , 0 ],
741
+ )
742
+ self .assertEqual (
743
+ shards [1 ].shard_sizes ,
744
+ [5000 if table_index < 2 else 50_000 , 64 ],
745
+ )
746
+ stats : List [str ] = cast (EmbeddingStats , planner ._stats [0 ])._stats_table
747
+ # L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
748
+ # L2 cache size is 128MB per shard per table
749
+ self .assertTrue (
750
+ any (
751
+ "dram_virtual_table: HBM: 0.002 GB, DDR: 256.0 GB" in line
752
+ for line in stats
753
+ )
754
+ )
755
+ self .assertTrue (
756
+ any (
757
+ "fused_uvm_caching: HBM: 0.011 GB, DDR: 0.048 GB" in line
758
+ for line in stats
759
+ )
760
+ )
761
+ self .assertTrue (
762
+ any ("Max HBM: 0.007 GB on ranks [0, 1]" in line for line in stats )
763
+ )
764
+ self .assertTrue (
765
+ any ("Min HBM: 0.007 GB on ranks [0, 1]" in line for line in stats )
766
+ )
767
+
768
+ # Override cache load factor
769
+ planner = EmbeddingShardingPlanner (
770
+ topology = topology ,
771
+ proposer = EmbeddingOffloadScaleupProposer (),
772
+ constraints = constraints ,
773
+ )
774
+ sharding_plan = planner .plan (
775
+ module = model ,
776
+ sharders = [ # pyre-ignore
777
+ EmbeddingCollectionSharder (fused_params = {"cache_load_factor" : 0.5 })
778
+ ],
779
+ )
780
+ for table_index in range (4 ):
781
+ shards = sharding_plan .plan ["sparse.ec" ][
782
+ f"table_{ table_index } "
783
+ ].sharding_spec .shards
784
+ self .assertEqual (len (shards ), 2 )
785
+ self .assertEqual (shards [0 ].shard_offsets , [0 , 0 ])
786
+ self .assertEqual (
787
+ shards [0 ].shard_sizes ,
788
+ [5000 if table_index < 2 else 50_000 , 64 ],
789
+ )
790
+ self .assertEqual (
791
+ shards [1 ].shard_offsets ,
792
+ [5000 if table_index < 2 else 50_000 , 0 ],
793
+ )
794
+ self .assertEqual (
795
+ shards [1 ].shard_sizes ,
796
+ [5000 if table_index < 2 else 50_000 , 64 ],
797
+ )
798
+ stats : List [str ] = cast (EmbeddingStats , planner ._stats [0 ])._stats_table
799
+ # L1 cache size of 64GB > size of embedding table * cache load factor. We use the smaller value.
800
+ # L2 cache size is 128MB per shard per table
801
+ self .assertTrue (
802
+ any (
803
+ "dram_virtual_table: HBM: 0.005 GB, DDR: 256.0 GB" in line
804
+ for line in stats
805
+ )
806
+ )
807
+ self .assertTrue (
808
+ any (
809
+ "fused_uvm_caching: HBM: 0.027 GB, DDR: 0.048 GB" in line
810
+ for line in stats
811
+ )
812
+ )
813
+ self .assertTrue (
814
+ any ("Max HBM: 0.016 GB on ranks [0, 1]" in line for line in stats )
815
+ )
816
+ self .assertTrue (
817
+ any ("Min HBM: 0.016 GB on ranks [0, 1]" in line for line in stats )
818
+ )
0 commit comments