Skip to content

[Bug]: ShardedManagedCollisionEmbeddingCollection throws an IndexError when "return_remapped_features=True" #2838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rayhuang90 opened this issue Mar 20, 2025 · 1 comment
Assignees

Comments

@rayhuang90
Copy link

rayhuang90 commented Mar 20, 2025

I utilize ManagedCollisionEmbeddingCollection with DistributedModelParallel to store hashID embeddings during distributed training.

An error occurs when setting return_remapped_features=True with a single embedding table configuration, but it resolves when a second configuration is added.

The expected behavior is that return_remapped_features=True should not throw errors regardless of the number of embedding table configurations.

Below is a minimal reproducible Python code example:

mch_bug_reproduce.py.txt

torchrun --standalone --nnodes=1 --node-rank=0 --nproc-per-node=1 mch_bug_reproduce.py

Error message

[rank0]: Traceback (most recent call last):
[rank0]:   File "~/mch_bug_reproduce.py", line 72, in <module>
[rank0]:     emb_result, remapped_ids = dmp_mc_ec(mb)
[rank0]:                                ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]:     return self._dmp_wrapped_module(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 998, in forward
[rank0]:     return self.compute_and_output_dist(ctx, dist_input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/types.py", line 982, in compute_and_output_dist
[rank0]:     return self.output_dist(ctx, output)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_embedding_modules.py", line 243, in output_dist
[rank0]:     kjt_awaitable = self._managed_collision_collection.output_dist(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/mc_modules.py", line 791, in output_dist
[rank0]:     awaitables_per_sharding.append(odist(remapped_ids, sharding_ctx))
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sequence_sharding.py", line 102, in forward
[rank0]:     return self._dist(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 1469, in forward
[rank0]:     embedding_dim=local_embs.shape[1],
[rank0]:                   ~~~~~~~~~~~~~~~~^^^
[rank0]: IndexError: tuple index out of range

My current environment

fbgemm_gpu==1.1.0+cu118
numpy==2.1.2
protobuf==3.19.6
torch==2.6.0+cu118
torchrec==1.1.0+cu118
transformers==4.48.0
triton==3.2.0
@kausv
Copy link
Contributor

kausv commented Mar 21, 2025

Checking

@kausv kausv self-assigned this Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants