Skip to content

mark_forward_method does not work with ModelParallelStrategy #20710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tonyf opened this issue Apr 12, 2025 · 0 comments · May be fixed by #20711
Open

mark_forward_method does not work with ModelParallelStrategy #20710

tonyf opened this issue Apr 12, 2025 · 0 comments · May be fixed by #20711
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@tonyf
Copy link

tonyf commented Apr 12, 2025

Bug description

When using the ModelParallelStrategy, methods annotated with mark_forward_method raise an exception if the function signature does not match that of the module's forward method. This fails specifically when the number of args/kwargs differ between the functions.

For calling generate here would fail in an FSDP2 setting with the error TypeError: Model.forward got an unexpected keyword argument cfg

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return x

    def generate(self, x, y, cfg: int = 0.5):
        z_1 = self.forward(x, y)
        z_2 = self.foward(x, torch.zeros_like(y))
        ...

What version are you seeing the problem on?

v2.5

Error messages and logs

        │
[rank0]: │   473 │   │   ):                                                                                                                                      │
[rank0]: │   474 │   │   │   self.callbacks.on_validation_step_start(self, batch_idx)                                                                            │
[rank0]: │   475 │   │   │                                                                                                                                       │
[rank0]: │ ❱ 476 │   │   │   result = self.validation_step(batch, batch_idx)                                                                                     │
[rank0]: │   477 │   │   │   self.callbacks.on_validation_step_end(self, result, batch_idx)                                                                      │
[rank0]: │   478 │   │                                                                                                                                           │
[rank0]: │   479 │   │   result = self.on_validation_epoch_end()                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/models/flow_matching/stage_1_train.py:112 in validation_step                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │   109 │   │   B, _, T, H, W = samples.shape                                                                                                           │
[rank0]: │   110 │   │   ct, ch, cw = self.autoencoder.compression                                                                                               │
[rank0]: │   111 │   │                                                                                                                                           │
[rank0]: │ ❱ 112 │   │   samples = self.model.sample(                                                                                                            │
[rank0]: │   113 │   │   │   shape=(B, (T - 1) // ct + 1, H // ch, W // cw, self.autoencoder.latent_dim),                                                        │
[rank0]: │   114 │   │   │   text=text_embeds,                                                                                                                   │
[rank0]: │   115 │   │   │   sample_steps=self.config.sample_steps,                                                                                              │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:197 in call_forward_module                                │
[rank0]: │                                                                                                                                                       │
[rank0]: │   194 │   │   def call_forward_module(*args: Any, **kwargs: Any) -> Any:                                                                              │
[rank0]: │   195 │   │   │   # Patch the original_module's forward, so we can redirect the arguments back                                                        │
[rank0]: │   196 │   │   │   self._original_module.forward = wrapped_forward                                                                                     │
[rank0]: │ ❱ 197 │   │   │   return self.forward(*args, **kwargs)                                                                                                │
[rank0]: │   198 │   │                                                                                                                                           │
[rank0]: │   199 │   │   return call_forward_module                                                                                                              │
[rank0]: │   200                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/lightning/fabric/wrappers.py:136 in forward                                            │
[rank0]: │                                                                                                                                                       │
[rank0]: │   133 │   │   args, kwargs = precision.convert_input((args, kwargs))                                                                                  │
[rank0]: │   134 │   │                                                                                                                                           │
[rank0]: │   135 │   │   with precision.forward_context():                                                                                                       │
[rank0]: │ ❱ 136 │   │   │   output = self._forward_module(*args, **kwargs)                                                                                      │
[rank0]: │   137 │   │                                                                                                                                           │
[rank0]: │   138 │   │   output = precision.convert_output(output)                                                                                               │
[rank0]: │   139                                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:574 in _fn                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │    571 │   │   │   )                                                                                                                                  │
[rank0]: │    572 │   │   │                                                                                                                                      │
[rank0]: │    573 │   │   │   try:                                                                                                                               │
[rank0]: │ ❱  574 │   │   │   │   return fn(*args, **kwargs)                                                                                                     │
[rank0]: │    575 │   │   │   finally:                                                                                                                           │
[rank0]: │    576 │   │   │   │   # Restore the dynamic layer stack depth if necessary.                                                                          │
[rank0]: │    577 │   │   │   │   torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739 in _wrapped_call_impl                                  │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1736 │   │   if self._compiled_call_impl is not None:                                                                                               │
[rank0]: │   1737 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                             │
[rank0]: │   1738 │   │   else:                                                                                                                                  │
[rank0]: │ ❱ 1739 │   │   │   return self._call_impl(*args, **kwargs)                                                                                            │
[rank0]: │   1740 │                                                                                                                                              │
[rank0]: │   1741 │   # torchrec tests the code consistency with the following code                                                                              │
[rank0]: │   1742 │   # fmt: off                                                                                                                                 │
[rank0]: │                                                                                                                                                       │
[rank0]: │ /home/tony/workspace/models/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750 in _call_impl                                          │
[rank0]: │                                                                                                                                                       │
[rank0]: │   1747 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks                                                        │
[rank0]: │   1748 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                        │
[rank0]: │   1749 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                        │
[rank0]: │ ❱ 1750 │   │   │   return forward_call(*args, **kwargs)                                                                                               │
[rank0]: │   1751 │   │                                                                                                                                          │
[rank0]: │   1752 │   │   result = None                                                                                                                          │
[rank0]: │   1753 │   │   called_always_called_hooks = set()                                                                                                     │
[rank0]: ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
[rank0]: TypeError: Rem.forward() got an unexpected keyword argument 'shape'

Environment

Current environment
#- PyTorch Lightning Version: 2.5.0.post
#- PyTorch Version: 2.6.0+cu124
#- Python version: 3.11
#- OS: Linux
#- CUDA/cuDNN version: 12.4
#- GPU models and configuration: 8xH100
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

@tonyf tonyf added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant