|
36 | 36 | ShardingPlan,
|
37 | 37 | ShardingType,
|
38 | 38 | )
|
| 39 | +from torchrec.distributed.utils import none_throws |
39 | 40 | from torchrec.modules.embedding_configs import (
|
40 | 41 | DataType,
|
41 | 42 | EmbeddingBagConfig,
|
@@ -100,15 +101,17 @@ def initialize_and_test_parameters(
|
100 | 101 | )
|
101 | 102 | elif isinstance(model.state_dict()[key], ShardedTensor):
|
102 | 103 | if ctx.rank == 0:
|
103 |
| - gathered_tensor = torch.empty_like(embedding_tables.state_dict()[key]) |
| 104 | + gathered_tensor = torch.empty_like( |
| 105 | + embedding_tables.state_dict()[key], device=ctx.device |
| 106 | + ) |
104 | 107 | else:
|
105 | 108 | gathered_tensor = None
|
106 | 109 |
|
107 | 110 | model.state_dict()[key].gather(dst=0, out=gathered_tensor)
|
108 | 111 |
|
109 | 112 | if ctx.rank == 0:
|
110 | 113 | torch.testing.assert_close(
|
111 |
| - gathered_tensor, |
| 114 | + none_throws(gathered_tensor).to("cpu"), |
112 | 115 | embedding_tables.state_dict()[key],
|
113 | 116 | )
|
114 | 117 | elif isinstance(model.state_dict()[key], torch.Tensor):
|
@@ -160,7 +163,6 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
|
160 | 163 |
|
161 | 164 | # Initialize embedding table on non-meta device, in this case cuda:0
|
162 | 165 | embedding_tables = EmbeddingCollection(
|
163 |
| - device=torch.device("cuda:0"), |
164 | 166 | tables=[
|
165 | 167 | EmbeddingConfig(
|
166 | 168 | name=table_name,
|
@@ -210,7 +212,6 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
|
210 | 212 |
|
211 | 213 | # Initialize embedding bag on non-meta device, in this case cuda:0
|
212 | 214 | embedding_tables = EmbeddingBagCollection(
|
213 |
| - device=torch.device("cuda:0"), |
214 | 215 | tables=[
|
215 | 216 | EmbeddingBagConfig(
|
216 | 217 | name=table_name,
|
|
0 commit comments