Skip to content

Commit 58b3f4a

Browse files
MJ10harsha-simhadri
authored andcommitted
FastGRNNCUDA: Fixes (#136)
* fixes for installation and fastgrnncuda * ensure input tensors are on device * ensure tensors on device for fastgrnncudacell * add batch_first support * fix forward params
1 parent 5176ca3 commit 58b3f4a

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
330330
self._zetaInit = zetaInit
331331
self._nuInit = nuInit
332332
self._name = name
333-
333+
self.device = torch.device("cuda")
334+
334335
if wRank is not None:
335336
self._num_W_matrices += 1
336337
self._num_weight_matrices[0] = self._num_W_matrices
@@ -340,29 +341,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
340341
self._name = name
341342

342343
if wRank is None:
343-
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
344+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device))
344345
self.W1 = torch.empty(0)
345346
self.W2 = torch.empty(0)
346347
else:
347348
self.W = torch.empty(0)
348-
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size]))
349-
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank]))
349+
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device))
350+
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device))
350351

351352
if uRank is None:
352-
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
353+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device))
353354
self.U1 = torch.empty(0)
354355
self.U2 = torch.empty(0)
355356
else:
356357
self.U = torch.empty(0)
357-
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
358-
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
358+
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device))
359+
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device))
359360

360361
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
361362

362-
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
363-
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
364-
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
365-
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
363+
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device))
364+
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device))
365+
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device))
366+
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device))
366367

367368
@property
368369
def name(self):
@@ -374,7 +375,11 @@ def cellType(self):
374375

375376
def forward(self, input, state):
376377
# Calls the custom autograd function while invokes the CUDA implementation
377-
return FastGRNNFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state,
378+
if not input.is_cuda:
379+
input.to(self.device)
380+
if not state.is_cuda:
381+
state.to(self.device)
382+
return FastGRNNFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, state,
378383
self.W, self.U, self.W1, self.W2, self.U1, self.U2, self._gate_non_linearity)
379384

380385
def getVars(self):
@@ -1103,7 +1108,7 @@ class FastGRNNCUDA(nn.Module):
11031108
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11041109
update_nonlinearity="tanh", wRank=None, uRank=None,
11051110
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
1106-
name="FastGRNNCUDACell"):
1111+
batch_first=False, name="FastGRNNCUDA"):
11071112
super(FastGRNNCUDA, self).__init__()
11081113
if utils.findCUDA() is None:
11091114
raise Exception('FastGRNNCUDA is supported only on GPU devices.')
@@ -1113,7 +1118,17 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11131118
self._zetaInit = zetaInit
11141119
self._nuInit = nuInit
11151120
self._name = name
1116-
1121+
self._num_W_matrices = 1
1122+
self._num_U_matrices = 1
1123+
self._num_biases = 2
1124+
self._num_weight_matrices = [self._num_W_matrices, self._num_U_matrices, self._num_biases]
1125+
self._wRank = wRank
1126+
self._uRank = uRank
1127+
self._wSparsity = wSparsity
1128+
self._uSparsity = uSparsity
1129+
self.oldmats = []
1130+
self.device = torch.device("cuda")
1131+
self.batch_first = batch_first
11171132
if wRank is not None:
11181133
self._num_W_matrices += 1
11191134
self._num_weight_matrices[0] = self._num_W_matrices
@@ -1123,33 +1138,42 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
11231138
self._name = name
11241139

11251140
if wRank is None:
1126-
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
1141+
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device))
11271142
self.W1 = torch.empty(0)
11281143
self.W2 = torch.empty(0)
11291144
else:
11301145
self.W = torch.empty(0)
1131-
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size]))
1132-
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank]))
1146+
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device))
1147+
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device))
11331148

11341149
if uRank is None:
1135-
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size]))
1150+
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device))
11361151
self.U1 = torch.empty(0)
11371152
self.U2 = torch.empty(0)
11381153
else:
11391154
self.U = torch.empty(0)
1140-
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
1141-
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
1155+
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device))
1156+
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device))
11421157

11431158
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
11441159

1145-
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
1146-
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
1147-
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
1148-
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
1160+
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
1161+
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
1162+
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device))
1163+
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device))
11491164

1150-
def forward(self, input, h_state, cell_state=None):
1165+
def forward(self, input, hiddenState, cell_state=None):
11511166
# input: [timesteps, batch, features, state_size]
1152-
return FastGRNNUnrollFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, h_state,
1167+
if self.batch_first:
1168+
input = input.transpose(0, 1)
1169+
if not input.is_cuda:
1170+
input = input.to(self.device)
1171+
if hiddenState is None:
1172+
hiddenState = torch.zeros(
1173+
[input.shape[1], self.hidden_size]).to(self.device)
1174+
if not hiddenState.is_cuda:
1175+
hiddenState = hiddenState.to(self.device)
1176+
return FastGRNNUnrollFunction.apply(input, self.bias_gate, self.bias_update, self.zeta, self.nu, hiddenState,
11531177
self.W, self.U, self.W1, self.W2, self.U1, self.U2, self._gate_non_linearity)
11541178

11551179
def getVars(self):

pytorch/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
version='0.3.0',
2323
description='PyTorch code for ML algorithms for edge devices developed at Microsoft Research India.',
2424
author_email="edgeml@microsoft.com",
25-
packages=['edgeml_pytorch'],
25+
packages=['edgeml_pytorch', 'edgeml_pytorch.trainer', 'edgeml_pytorch.graph'],
2626
license='MIT License',
2727
long_description=open('README.md').read(),
2828
url='https://github.com/Microsoft/EdgeML',

0 commit comments

Comments
 (0)