|
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | from torch import nn
|
| 15 | +from torchrec import EmbeddingConfig |
| 16 | +from torchrec.distributed.embedding import EmbeddingCollectionSharder |
15 | 17 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
16 | 18 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
17 | 19 | from torchrec.distributed.planner import ParameterConstraints
|
18 | 20 | from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
|
19 | 21 | from torchrec.distributed.planner.proposers import EmbeddingOffloadScaleupProposer
|
| 22 | +from torchrec.distributed.planner.stats import EmbeddingStats |
20 | 23 | from torchrec.distributed.planner.types import (
|
21 | 24 | PlannerError,
|
22 | 25 | PlannerErrorType,
|
|
31 | 34 | CacheParams,
|
32 | 35 | DataType,
|
33 | 36 | EmbeddingModuleShardingPlan,
|
| 37 | + KeyValueParams, |
34 | 38 | ModuleSharder,
|
35 | 39 | ShardingPlan,
|
36 | 40 | ShardingType,
|
@@ -359,3 +363,152 @@ def test_auto_sharder_solution(self) -> None:
|
359 | 363 | self.assertSetEqual(
|
360 | 364 | {EmbeddingComputeKernel.FUSED_UVM_CACHING.value}, compute_kernels
|
361 | 365 | )
|
| 366 | + |
| 367 | + def test_planner_with_virtual_table(self) -> None: |
| 368 | + table_count = 4 |
| 369 | + tables = [ |
| 370 | + EmbeddingConfig( |
| 371 | + num_embeddings=1_125_899_902_955_520, |
| 372 | + embedding_dim=64, |
| 373 | + name="table_" + str(i), |
| 374 | + feature_names=["feature_" + str(i)], |
| 375 | + use_virtual_table=True, |
| 376 | + total_num_buckets=3_991_680, |
| 377 | + ) |
| 378 | + for i in range(table_count // 2) |
| 379 | + ] + [ |
| 380 | + EmbeddingConfig( |
| 381 | + num_embeddings=100_000, |
| 382 | + embedding_dim=64, |
| 383 | + name="table_" + str(i), |
| 384 | + feature_names=["feature_" + str(i)], |
| 385 | + ) |
| 386 | + for i in range(table_count // 2, table_count) |
| 387 | + ] |
| 388 | + print(tables) |
| 389 | + model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) |
| 390 | + |
| 391 | + constraints = { |
| 392 | + **{ |
| 393 | + f"table_{i}": ParameterConstraints( |
| 394 | + sharding_types=["row_wise"], |
| 395 | + compute_kernels=["dram_virtual_table"], |
| 396 | + ) |
| 397 | + for i in range(table_count // 2) |
| 398 | + }, |
| 399 | + **{ |
| 400 | + f"table_{i}": ParameterConstraints( |
| 401 | + enforce_hbm=False, |
| 402 | + ) |
| 403 | + for i in range(table_count // 2, table_count) |
| 404 | + }, |
| 405 | + } |
| 406 | + planner = EmbeddingShardingPlanner( |
| 407 | + topology=self.topology, |
| 408 | + proposer=EmbeddingOffloadScaleupProposer(), |
| 409 | + constraints=constraints, |
| 410 | + ) |
| 411 | + |
| 412 | + self.assertRaisesRegex( |
| 413 | + AssertionError, |
| 414 | + "key_value_params cannot be None in ParameterConstraints of planner for embedding compute kernel: dram_virtual_table", |
| 415 | + planner.plan, |
| 416 | + module=model, |
| 417 | + sharders=[EmbeddingCollectionSharder()], |
| 418 | + ) |
| 419 | + |
| 420 | + constraints = { |
| 421 | + **{ |
| 422 | + f"table_{i}": ParameterConstraints( |
| 423 | + sharding_types=["row_wise"], |
| 424 | + compute_kernels=["dram_virtual_table"], |
| 425 | + key_value_params=KeyValueParams( |
| 426 | + l2_cache_size=64, max_l1_cache_size=128 |
| 427 | + ), |
| 428 | + ) |
| 429 | + for i in range(table_count // 2) |
| 430 | + }, |
| 431 | + **{ |
| 432 | + f"table_{i}": ParameterConstraints( |
| 433 | + cache_params=CacheParams(algorithm=CacheAlgorithm.LRU), |
| 434 | + ) |
| 435 | + for i in range(table_count // 2, table_count) |
| 436 | + }, |
| 437 | + } |
| 438 | + |
| 439 | + topology = Topology( |
| 440 | + world_size=2, |
| 441 | + hbm_cap=1024 * 1024 * 1024 * 2, |
| 442 | + ddr_cap=1024 * 1024 * 1024 * 256, |
| 443 | + compute_device="cuda", |
| 444 | + ) |
| 445 | + |
| 446 | + planner = EmbeddingShardingPlanner( |
| 447 | + topology=topology, |
| 448 | + proposer=EmbeddingOffloadScaleupProposer(), |
| 449 | + constraints=constraints, |
| 450 | + ) |
| 451 | + sharding_plan = planner.plan( |
| 452 | + module=model, sharders=[EmbeddingCollectionSharder()] # pyre-ignore |
| 453 | + ) |
| 454 | + |
| 455 | + expected_ranks = [[0, 1], [0, 1], [0, 1], [0, 1]] |
| 456 | + ranks = [ |
| 457 | + cast(List[int], param_shard.ranks) |
| 458 | + for param_shard in cast( |
| 459 | + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"] |
| 460 | + ).values() |
| 461 | + ] |
| 462 | + compute_kernels = { |
| 463 | + param_shard.compute_kernel |
| 464 | + for param_shard in cast( |
| 465 | + EmbeddingModuleShardingPlan, sharding_plan.plan["sparse.ec"] |
| 466 | + ).values() |
| 467 | + } |
| 468 | + self.assertEqual(sorted(expected_ranks), sorted(ranks)) |
| 469 | + self.assertSetEqual( |
| 470 | + { |
| 471 | + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, |
| 472 | + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, |
| 473 | + }, |
| 474 | + compute_kernels, |
| 475 | + ) |
| 476 | + |
| 477 | + for table_index in range(4): |
| 478 | + # pyre-ignore |
| 479 | + shards = sharding_plan.plan["sparse.ec"][ |
| 480 | + f"table_{table_index}" |
| 481 | + ].sharding_spec.shards |
| 482 | + self.assertEqual(len(shards), 2) |
| 483 | + self.assertEqual(shards[0].shard_offsets, [0, 0]) |
| 484 | + self.assertEqual( |
| 485 | + shards[0].shard_sizes, |
| 486 | + [562949951477760 if table_index < 2 else 50_000, 64], |
| 487 | + ) |
| 488 | + self.assertEqual( |
| 489 | + shards[1].shard_offsets, |
| 490 | + [562949951477760 if table_index < 2 else 50_000, 0], |
| 491 | + ) |
| 492 | + self.assertEqual( |
| 493 | + shards[1].shard_sizes, |
| 494 | + [562949951477760 if table_index < 2 else 50_000, 64], |
| 495 | + ) |
| 496 | + stats: List[str] = cast(EmbeddingStats, planner._stats[0])._stats_table |
| 497 | + self.assertTrue( |
| 498 | + any( |
| 499 | + "dram_virtual_table: HBM: 0.501 GB, DDR: 256.0 GB" in line |
| 500 | + for line in stats |
| 501 | + ) |
| 502 | + ) |
| 503 | + self.assertTrue( |
| 504 | + any( |
| 505 | + "fused_uvm_caching: HBM: 0.011 GB, DDR: 0.048 GB" in line |
| 506 | + for line in stats |
| 507 | + ) |
| 508 | + ) |
| 509 | + self.assertTrue( |
| 510 | + any("Max HBM: 0.256 GB on ranks [0, 1]" in line for line in stats) |
| 511 | + ) |
| 512 | + self.assertTrue( |
| 513 | + any("Min HBM: 0.256 GB on ranks [0, 1]" in line for line in stats) |
| 514 | + ) |
0 commit comments