Skip to content

Commit 1e2f924

Browse files
MJ10harsha-simhadri
authored andcommitted
FastGRNNCUDA Fix (#137)
* fixes for installation and fastgrnncuda * ensure input tensors are on device * ensure tensors on device for fastgrnncudacell * add batch_first support * fix forward params * minor variable name fix
1 parent 58b3f4a commit 1e2f924

File tree

1 file changed

+1
-1
lines changed
  • pytorch/edgeml_pytorch/graph

1 file changed

+1
-1
lines changed

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,7 @@ def forward(self, input, hiddenState, cell_state=None):
11701170
input = input.to(self.device)
11711171
if hiddenState is None:
11721172
hiddenState = torch.zeros(
1173-
[input.shape[1], self.hidden_size]).to(self.device)
1173+
[input.shape[1], self._hidden_size]).to(self.device)
11741174
if not hiddenState.is_cuda:
11751175
hiddenState = hiddenState.to(self.device)
11761176
return FastGRNNUnrollFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, hiddenState,

0 commit comments

Comments
 (0)