@@ -19,9 +19,7 @@ module nf_optimizers
19
19
real :: learning_rate = 0.01
20
20
contains
21
21
procedure (init), deferred :: init
22
- procedure (minimize_1d), deferred :: minimize_1d
23
- procedure (minimize_2d), deferred :: minimize_2d
24
- generic :: minimize = > minimize_1d, minimize_2d
22
+ procedure (minimize), deferred :: minimize
25
23
end type optimizer_base_type
26
24
27
25
abstract interface
@@ -32,19 +30,12 @@ impure elemental subroutine init(self, num_params)
32
30
integer , intent (in ) :: num_params
33
31
end subroutine init
34
32
35
- pure subroutine minimize_1d (self , param , gradient )
33
+ pure subroutine minimize (self , param , gradient )
36
34
import :: optimizer_base_type
37
35
class(optimizer_base_type), intent (inout ) :: self
38
36
real , intent (inout ) :: param(:)
39
37
real , intent (in ) :: gradient(:)
40
- end subroutine minimize_1d
41
-
42
- pure subroutine minimize_2d (self , param , gradient )
43
- import :: optimizer_base_type
44
- class(optimizer_base_type), intent (inout ) :: self
45
- real , intent (inout ) :: param(:,:)
46
- real , intent (in ) :: gradient(:,:)
47
- end subroutine minimize_2d
38
+ end subroutine minimize
48
39
49
40
end interface
50
41
@@ -55,8 +46,7 @@ end subroutine minimize_2d
55
46
real , allocatable , private :: velocity(:)
56
47
contains
57
48
procedure :: init = > init_sgd
58
- procedure :: minimize_1d = > minimize_sgd_1d
59
- procedure :: minimize_2d = > minimize_sgd_2d
49
+ procedure :: minimize = > minimize_sgd
60
50
end type sgd
61
51
62
52
type, extends(optimizer_base_type) :: rmsprop
@@ -71,8 +61,7 @@ end subroutine minimize_2d
71
61
real , allocatable , private :: rms_gradient(:)
72
62
contains
73
63
procedure :: init = > init_rmsprop
74
- procedure :: minimize_1d = > minimize_rmsprop_1d
75
- procedure :: minimize_2d = > minimize_rmsprop_2d
64
+ procedure :: minimize = > minimize_rmsprop
76
65
end type rmsprop
77
66
78
67
type, extends(optimizer_base_type) :: adam
@@ -95,8 +84,7 @@ end subroutine minimize_2d
95
84
integer , private :: t = 0
96
85
contains
97
86
procedure :: init = > init_adam
98
- procedure :: minimize_1d = > minimize_adam_1d
99
- procedure :: minimize_2d = > minimize_adam_2d
87
+ procedure :: minimize = > minimize_adam
100
88
end type adam
101
89
102
90
type, extends(optimizer_base_type) :: adagrad
@@ -113,8 +101,7 @@ end subroutine minimize_2d
113
101
integer , private :: t = 0
114
102
contains
115
103
procedure :: init = > init_adagrad
116
- procedure :: minimize_1d = > minimize_adagrad_1d
117
- procedure :: minimize_2d = > minimize_adagrad_2d
104
+ procedure :: minimize = > minimize_adagrad
118
105
end type adagrad
119
106
120
107
contains
@@ -129,7 +116,7 @@ impure elemental subroutine init_sgd(self, num_params)
129
116
end subroutine init_sgd
130
117
131
118
132
- pure subroutine minimize_sgd_1d (self , param , gradient )
119
+ pure subroutine minimize_sgd (self , param , gradient )
133
120
! ! Concrete implementation of a stochastic gradient descent optimizer
134
121
! ! update rule.
135
122
class(sgd), intent (inout ) :: self
@@ -152,33 +139,7 @@ pure subroutine minimize_sgd_1d(self, param, gradient)
152
139
param = param - self % learning_rate * gradient
153
140
end if
154
141
155
- end subroutine minimize_sgd_1d
156
-
157
-
158
- pure subroutine minimize_sgd_2d (self , param , gradient )
159
- ! ! Concrete implementation of a stochastic gradient descent optimizer
160
- ! ! update rule for 2D arrays.
161
- class(sgd), intent (inout ) :: self
162
- real , intent (inout ) :: param(:,:)
163
- real , intent (in ) :: gradient(:,:)
164
-
165
- if (self % momentum > 0 ) then
166
- ! Apply momentum update
167
- self % velocity = self % momentum * self % velocity &
168
- - self % learning_rate * reshape (gradient, [size (gradient)])
169
- if (self % nesterov) then
170
- ! Apply Nesterov update
171
- param = param + reshape (self % momentum * self % velocity &
172
- - self % learning_rate * reshape (gradient, [size (gradient)]), shape (param))
173
- else
174
- param = param + reshape (self % velocity, shape (param))
175
- end if
176
- else
177
- ! Apply regular update
178
- param = param - self % learning_rate * gradient
179
- end if
180
-
181
- end subroutine minimize_sgd_2d
142
+ end subroutine minimize_sgd
182
143
183
144
184
145
impure elemental subroutine init_rmsprop(self, num_params)
@@ -191,7 +152,7 @@ impure elemental subroutine init_rmsprop(self, num_params)
191
152
end subroutine init_rmsprop
192
153
193
154
194
- pure subroutine minimize_rmsprop_1d (self , param , gradient )
155
+ pure subroutine minimize_rmsprop (self , param , gradient )
195
156
! ! Concrete implementation of a RMSProp optimizer update rule.
196
157
class(rmsprop), intent (inout ) :: self
197
158
real , intent (inout ) :: param(:)
@@ -205,24 +166,7 @@ pure subroutine minimize_rmsprop_1d(self, param, gradient)
205
166
param = param - self % learning_rate &
206
167
/ sqrt (self % rms_gradient + self % epsilon) * gradient
207
168
208
- end subroutine minimize_rmsprop_1d
209
-
210
-
211
- pure subroutine minimize_rmsprop_2d (self , param , gradient )
212
- ! ! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
213
- class(rmsprop), intent (inout ) :: self
214
- real , intent (inout ) :: param(:,:)
215
- real , intent (in ) :: gradient(:,:)
216
-
217
- ! Compute the RMS of the gradient using the RMSProp rule
218
- self % rms_gradient = self % decay_rate * self % rms_gradient &
219
- + (1 - self % decay_rate) * reshape (gradient, [size (gradient)])** 2
220
-
221
- ! Update the network parameters based on the new RMS of the gradient
222
- param = param - self % learning_rate &
223
- / sqrt (reshape (self % rms_gradient, shape (param)) + self % epsilon) * gradient
224
-
225
- end subroutine minimize_rmsprop_2d
169
+ end subroutine minimize_rmsprop
226
170
227
171
228
172
impure elemental subroutine init_adam(self, num_params)
@@ -236,7 +180,7 @@ impure elemental subroutine init_adam(self, num_params)
236
180
end subroutine init_adam
237
181
238
182
239
- pure subroutine minimize_adam_1d (self , param , gradient )
183
+ pure subroutine minimize_adam (self , param , gradient )
240
184
! ! Concrete implementation of an Adam optimizer update rule.
241
185
class(adam), intent (inout ) :: self
242
186
real , intent (inout ) :: param(:)
@@ -264,38 +208,7 @@ pure subroutine minimize_adam_1d(self, param, gradient)
264
208
265
209
end associate
266
210
267
- end subroutine minimize_adam_1d
268
-
269
-
270
- pure subroutine minimize_adam_2d (self , param , gradient )
271
- ! ! Concrete implementation of an Adam optimizer update rule for 2D arrays.
272
- class(adam), intent (inout ) :: self
273
- real , intent (inout ) :: param(:,:)
274
- real , intent (in ) :: gradient(:,:)
275
-
276
- self % t = self % t + 1
277
-
278
- ! If weight_decay_l2 > 0, use L2 regularization;
279
- ! otherwise, default to regular Adam.
280
- associate(g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]))
281
- self % m = self % beta1 * self % m + (1 - self % beta1) * g
282
- self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
283
- end associate
284
-
285
- ! Compute bias-corrected first and second moment estimates.
286
- associate( &
287
- m_hat = > self % m / (1 - self % beta1** self % t), &
288
- v_hat = > self % v / (1 - self % beta2** self % t) &
289
- )
290
-
291
- ! Update parameters.
292
- param = param &
293
- - self % learning_rate * reshape (m_hat / (sqrt (v_hat) + self % epsilon), shape (param)) &
294
- - self % learning_rate * self % weight_decay_decoupled * param
295
-
296
- end associate
297
-
298
- end subroutine minimize_adam_2d
211
+ end subroutine minimize_adam
299
212
300
213
301
214
impure elemental subroutine init_adagrad(self, num_params)
@@ -308,7 +221,7 @@ impure elemental subroutine init_adagrad(self, num_params)
308
221
end subroutine init_adagrad
309
222
310
223
311
- pure subroutine minimize_adagrad_1d (self , param , gradient )
224
+ pure subroutine minimize_adagrad (self , param , gradient )
312
225
! ! Concrete implementation of an Adagrad optimizer update rule.
313
226
class(adagrad), intent (inout ) :: self
314
227
real , intent (inout ) :: param(:)
@@ -333,34 +246,6 @@ pure subroutine minimize_adagrad_1d(self, param, gradient)
333
246
334
247
end associate
335
248
336
- end subroutine minimize_adagrad_1d
337
-
338
-
339
- pure subroutine minimize_adagrad_2d (self , param , gradient )
340
- ! ! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341
- class(adagrad), intent (inout ) :: self
342
- real , intent (inout ) :: param(:,:)
343
- real , intent (in ) :: gradient(:,:)
344
-
345
- ! Update the current time step
346
- self % t = self % t + 1
347
-
348
- associate( &
349
- ! If weight_decay_l2 > 0, use L2 regularization;
350
- ! otherwise, default to regular Adagrad.
351
- g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]), &
352
- ! Amortize the learning rate as function of the current time step.
353
- learning_rate = > self % learning_rate &
354
- / (1 + (self % t - 1 ) * self % learning_rate_decay) &
355
- )
356
-
357
- self % sum_squared_gradient = self % sum_squared_gradient + g** 2
358
-
359
- param = param - learning_rate * reshape (g / (sqrt (self % sum_squared_gradient) &
360
- + self % epsilon), shape (param))
361
-
362
- end associate
363
-
364
- end subroutine minimize_adagrad_2d
249
+ end subroutine minimize_adagrad
365
250
366
251
end module nf_optimizers
0 commit comments