The model should run without error.
Traceback (most recent call last):
File "/my/scripts/dgmr_channel_issue_test_case.py", line 20, in <module>
output = model(dummy_input)
^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/dgmr/dgmr.py", line 136, in forward
x = self.generator(x)
^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/dgmr/generators.py", line 216, in forward
x = self.sampler(conditioning_states, latent_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/dgmr/generators.py", line 154, in forward
hidden_states = self.convGRU1(hidden_states, init_states[3])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/dgmr/layers/ConvGRU.py", line 107, in forward
output, hidden_state = self.cell(x[step], hidden_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/dgmr/layers/ConvGRU.py", line 72, in forward
read_gate = F.sigmoid(self.read_gate_conv(xh))
^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 554, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/my/python/env/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
return F.conv2d(
^^^^^^^^^
RuntimeError: Given groups=1, weight of size [384, 1152, 3, 3], expected input[1, 1536, 4, 4] to have 1152 channels, but got 1536 channels instead
Describe the bug
I am unable to use the DGMR model with
input_channelsset to anything other than 1.To Reproduce
Steps to reproduce the behavior:
Expected behavior
The model should run without error.
Stack trace
Additional context
While digging into this I noticed that the
Sampleris instantiated without being passed theoutput_channelsparameter. I tried adding this in, it didn't help, but I assume this would also need to be fixed.