Skip to content

Unable to use model with input_channels != 1 #85

@lopsided

Description

@lopsided

Describe the bug

I am unable to use the DGMR model with input_channels set to anything other than 1.

To Reproduce

Steps to reproduce the behavior:

import torch
from dgmr import DGMR

input_channels = 2  # Works for 1
forecast_steps = 4
output_shape = 128
latent_channels = 768
context_channels = 384

model = DGMR(
    forecast_steps=forecast_steps,
    input_channels=input_channels,
    output_shape=output_shape,
    latent_channels=latent_channels,
    context_channels=context_channels,
)

dummy_input = torch.randn(1, 4, input_channels, output_shape, output_shape)

output = model(dummy_input)

Expected behavior

The model should run without error.

Stack trace

Traceback (most recent call last):
  File "/my/scripts/dgmr_channel_issue_test_case.py", line 20, in <module>
    output = model(dummy_input)
             ^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/dgmr/dgmr.py", line 136, in forward
    x = self.generator(x)
        ^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/dgmr/generators.py", line 216, in forward
    x = self.sampler(conditioning_states, latent_dim)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/dgmr/generators.py", line 154, in forward
    hidden_states = self.convGRU1(hidden_states, init_states[3])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/dgmr/layers/ConvGRU.py", line 107, in forward
    output, hidden_state = self.cell(x[step], hidden_state)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/dgmr/layers/ConvGRU.py", line 72, in forward
    read_gate = F.sigmoid(self.read_gate_conv(xh))
                          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 554, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
    return F.conv2d(
           ^^^^^^^^^
RuntimeError: Given groups=1, weight of size [384, 1152, 3, 3], expected input[1, 1536, 4, 4] to have 1152 channels, but got 1536 channels instead

Additional context

While digging into this I noticed that the Sampler is instantiated without being passed the output_channels parameter. I tried adding this in, it didn't help, but I assume this would also need to be fixed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions