Skip to content

Commit a872a5e

Browse files
kausvfacebook-github-bot
authored andcommitted
Fix Init Param Test
Summary: This test was failing due to error on CUDA release https://www.internalfb.com/intern/test/281475075501969?ref_report_id=0 ```invalid device pointer:``` MultiProcessTest assigns the devices on its own. So I removed the specific device from EmbeddingConfig of the tests. Then distributed.gather() failed on tensor at pos 0 was expected to be device CUDA but found CPU. So I matched gathered_tensor device to the PG. Which caused assertclose to fail since the devices did not match any more, so I copied it to CPU. Reviewed By: iamzainhuda Differential Revision: D80547120
1 parent f9d4bbf commit a872a5e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torchrec/distributed/tests/test_init_parameters.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ShardingPlan,
3737
ShardingType,
3838
)
39+
from torchrec.distributed.utils import none_throws
3940
from torchrec.modules.embedding_configs import (
4041
DataType,
4142
EmbeddingBagConfig,
@@ -100,15 +101,17 @@ def initialize_and_test_parameters(
100101
)
101102
elif isinstance(model.state_dict()[key], ShardedTensor):
102103
if ctx.rank == 0:
103-
gathered_tensor = torch.empty_like(embedding_tables.state_dict()[key])
104+
gathered_tensor = torch.empty_like(
105+
embedding_tables.state_dict()[key], device=ctx.device
106+
)
104107
else:
105108
gathered_tensor = None
106109

107110
model.state_dict()[key].gather(dst=0, out=gathered_tensor)
108111

109112
if ctx.rank == 0:
110113
torch.testing.assert_close(
111-
gathered_tensor,
114+
none_throws(gathered_tensor).to("cpu"),
112115
embedding_tables.state_dict()[key],
113116
)
114117
elif isinstance(model.state_dict()[key], torch.Tensor):
@@ -160,7 +163,6 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
160163

161164
# Initialize embedding table on non-meta device, in this case cuda:0
162165
embedding_tables = EmbeddingCollection(
163-
device=torch.device("cuda:0"),
164166
tables=[
165167
EmbeddingConfig(
166168
name=table_name,
@@ -210,7 +212,6 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210212

211213
# Initialize embedding bag on non-meta device, in this case cuda:0
212214
embedding_tables = EmbeddingBagCollection(
213-
device=torch.device("cuda:0"),
214215
tables=[
215216
EmbeddingBagConfig(
216217
name=table_name,

0 commit comments

Comments
 (0)