@@ -330,7 +330,8 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
330
330
self ._zetaInit = zetaInit
331
331
self ._nuInit = nuInit
332
332
self ._name = name
333
-
333
+ self .device = torch .device ("cuda" )
334
+
334
335
if wRank is not None :
335
336
self ._num_W_matrices += 1
336
337
self ._num_weight_matrices [0 ] = self ._num_W_matrices
@@ -340,29 +341,29 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
340
341
self ._name = name
341
342
342
343
if wRank is None :
343
- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
344
+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ], self . device ))
344
345
self .W1 = torch .empty (0 )
345
346
self .W2 = torch .empty (0 )
346
347
else :
347
348
self .W = torch .empty (0 )
348
- self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
349
- self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
349
+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ], self . device ))
350
+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ], self . device ))
350
351
351
352
if uRank is None :
352
- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
353
+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ], self . device ))
353
354
self .U1 = torch .empty (0 )
354
355
self .U2 = torch .empty (0 )
355
356
else :
356
357
self .U = torch .empty (0 )
357
- self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
358
- self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
358
+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ], self . device ))
359
+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ], self . device ))
359
360
360
361
self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
361
362
362
- self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
363
- self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
364
- self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
365
- self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
363
+ self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ], self . device ))
364
+ self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ], self . device ))
365
+ self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ], self . device ))
366
+ self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ], self . device ))
366
367
367
368
@property
368
369
def name (self ):
@@ -374,7 +375,11 @@ def cellType(self):
374
375
375
376
def forward (self , input , state ):
376
377
# Calls the custom autograd function while invokes the CUDA implementation
377
- return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
378
+ if not input .is_cuda :
379
+ input .to (self .device )
380
+ if not state .is_cuda :
381
+ state .to (self .device )
382
+ return FastGRNNFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , state ,
378
383
self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
379
384
380
385
def getVars (self ):
@@ -1103,7 +1108,7 @@ class FastGRNNCUDA(nn.Module):
1103
1108
def __init__ (self , input_size , hidden_size , gate_nonlinearity = "sigmoid" ,
1104
1109
update_nonlinearity = "tanh" , wRank = None , uRank = None ,
1105
1110
wSparsity = 1.0 , uSparsity = 1.0 , zetaInit = 1.0 , nuInit = - 4.0 ,
1106
- name = "FastGRNNCUDACell " ):
1111
+ batch_first = False , name = "FastGRNNCUDA " ):
1107
1112
super (FastGRNNCUDA , self ).__init__ ()
1108
1113
if utils .findCUDA () is None :
1109
1114
raise Exception ('FastGRNNCUDA is supported only on GPU devices.' )
@@ -1113,7 +1118,17 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
1113
1118
self ._zetaInit = zetaInit
1114
1119
self ._nuInit = nuInit
1115
1120
self ._name = name
1116
-
1121
+ self ._num_W_matrices = 1
1122
+ self ._num_U_matrices = 1
1123
+ self ._num_biases = 2
1124
+ self ._num_weight_matrices = [self ._num_W_matrices , self ._num_U_matrices , self ._num_biases ]
1125
+ self ._wRank = wRank
1126
+ self ._uRank = uRank
1127
+ self ._wSparsity = wSparsity
1128
+ self ._uSparsity = uSparsity
1129
+ self .oldmats = []
1130
+ self .device = torch .device ("cuda" )
1131
+ self .batch_first = batch_first
1117
1132
if wRank is not None :
1118
1133
self ._num_W_matrices += 1
1119
1134
self ._num_weight_matrices [0 ] = self ._num_W_matrices
@@ -1123,33 +1138,42 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
1123
1138
self ._name = name
1124
1139
1125
1140
if wRank is None :
1126
- self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ]))
1141
+ self .W = nn .Parameter (0.1 * torch .randn ([hidden_size , input_size ], device = self . device ))
1127
1142
self .W1 = torch .empty (0 )
1128
1143
self .W2 = torch .empty (0 )
1129
1144
else :
1130
1145
self .W = torch .empty (0 )
1131
- self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ]))
1132
- self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ]))
1146
+ self .W1 = nn .Parameter (0.1 * torch .randn ([wRank , input_size ], device = self . device ))
1147
+ self .W2 = nn .Parameter (0.1 * torch .randn ([hidden_size , wRank ], device = self . device ))
1133
1148
1134
1149
if uRank is None :
1135
- self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ]))
1150
+ self .U = nn .Parameter (0.1 * torch .randn ([hidden_size , hidden_size ], device = self . device ))
1136
1151
self .U1 = torch .empty (0 )
1137
1152
self .U2 = torch .empty (0 )
1138
1153
else :
1139
1154
self .U = torch .empty (0 )
1140
- self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ]))
1141
- self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ]))
1155
+ self .U1 = nn .Parameter (0.1 * torch .randn ([uRank , hidden_size ], device = self . device ))
1156
+ self .U2 = nn .Parameter (0.1 * torch .randn ([hidden_size , uRank ], device = self . device ))
1142
1157
1143
1158
self ._gate_non_linearity = NON_LINEARITY [gate_nonlinearity ]
1144
1159
1145
- self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ]))
1146
- self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ]))
1147
- self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ]))
1148
- self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ]))
1160
+ self .bias_gate = nn .Parameter (torch .ones ([1 , hidden_size ], device = self . device ))
1161
+ self .bias_update = nn .Parameter (torch .ones ([1 , hidden_size ], device = self . device ))
1162
+ self .zeta = nn .Parameter (self ._zetaInit * torch .ones ([1 , 1 ], device = self . device ))
1163
+ self .nu = nn .Parameter (self ._nuInit * torch .ones ([1 , 1 ], device = self . device ))
1149
1164
1150
- def forward (self , input , h_state , cell_state = None ):
1165
+ def forward (self , input , hiddenState , cell_state = None ):
1151
1166
# input: [timesteps, batch, features, state_size]
1152
- return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , h_state ,
1167
+ if self .batch_first :
1168
+ input = input .transpose (0 , 1 )
1169
+ if not input .is_cuda :
1170
+ input = input .to (self .device )
1171
+ if hiddenState is None :
1172
+ hiddenState = torch .zeros (
1173
+ [input .shape [1 ], self .hidden_size ]).to (self .device )
1174
+ if not hiddenState .is_cuda :
1175
+ hiddenState = hiddenState .to (self .device )
1176
+ return FastGRNNUnrollFunction .apply (input , self .bias_gate , self .bias_update , self .zeta , self .nu , hiddenState ,
1153
1177
self .W , self .U , self .W1 , self .W2 , self .U1 , self .U2 , self ._gate_non_linearity )
1154
1178
1155
1179
def getVars (self ):
0 commit comments