We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 93500ae commit 7706e76Copy full SHA for 7706e76
dacapo/experiments/model.py
@@ -185,7 +185,9 @@ def __get_output_shape(
185
The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions.
186
187
"""
188
- dummy_data = torch.zeros((1, in_channels) + input_shape, device=self.get_device())
+ dummy_data = torch.zeros(
189
+ (1, in_channels) + input_shape, device=self.get_device()
190
+ )
191
with torch.no_grad():
192
out = self.forward(dummy_data)
193
return out.shape[1], Coordinate(out.shape[2:])
0 commit comments