Skip to content

Commit 9d68828

Browse files
committed
Get weights and weight gradients as 1d
1 parent 21c5707 commit 9d68828

File tree

4 files changed

+26
-141
lines changed

4 files changed

+26
-141
lines changed

src/nf/nf_dense_layer.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ end function get_params
100100

101101
module subroutine get_params_ptr(self, w_ptr, b_ptr)
102102
class(dense_layer), intent(in), target :: self
103-
real, pointer :: w_ptr(:,:)
104-
real, pointer :: b_ptr(:)
103+
real, pointer, intent(out) :: w_ptr(:)
104+
real, pointer, intent(out) :: b_ptr(:)
105105
end subroutine get_params_ptr
106106

107107
module function get_gradients(self) result(gradients)
@@ -115,8 +115,8 @@ end function get_gradients
115115

116116
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
117117
class(dense_layer), intent(in), target :: self
118-
real, pointer :: dw_ptr(:,:)
119-
real, pointer :: db_ptr(:)
118+
real, pointer, intent(out) :: dw_ptr(:)
119+
real, pointer, intent(out) :: db_ptr(:)
120120
end subroutine get_gradients_ptr
121121

122122
module subroutine set_params(self, params)

src/nf/nf_dense_layer_submodule.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ end function get_params
7979

8080
module subroutine get_params_ptr(self, w_ptr, b_ptr)
8181
class(dense_layer), intent(in), target :: self
82-
real, pointer :: w_ptr(:,:)
83-
real, pointer :: b_ptr(:)
84-
w_ptr => self % weights
82+
real, pointer, intent(out) :: w_ptr(:)
83+
real, pointer, intent(out) :: b_ptr(:)
84+
w_ptr(1:size(self % weights)) => self % weights
8585
b_ptr => self % biases
8686
end subroutine get_params_ptr
8787

@@ -104,9 +104,9 @@ end function get_gradients
104104

105105
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
106106
class(dense_layer), intent(in), target :: self
107-
real, pointer :: dw_ptr(:,:)
108-
real, pointer :: db_ptr(:)
109-
dw_ptr => self % dw
107+
real, pointer, intent(out) :: dw_ptr(:)
108+
real, pointer, intent(out) :: db_ptr(:)
109+
dw_ptr(1:size(self % dw)) => self % dw
110110
db_ptr => self % db
111111
end subroutine get_gradients_ptr
112112

src/nf/nf_network_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ module subroutine update(self, optimizer, batch_size)
649649
integer, intent(in), optional :: batch_size
650650
integer :: batch_size_
651651
real, allocatable :: params(:)
652-
real, pointer :: weights(:,:), biases(:), dw(:,:), db(:)
652+
real, pointer :: weights(:), biases(:), dw(:), db(:)
653653
integer :: n
654654

655655
! Passing the optimizer instance is optional. If not provided, and if the

src/nf/nf_optimizers.f90

Lines changed: 15 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ module nf_optimizers
1919
real :: learning_rate = 0.01
2020
contains
2121
procedure(init), deferred :: init
22-
procedure(minimize_1d), deferred :: minimize_1d
23-
procedure(minimize_2d), deferred :: minimize_2d
24-
generic :: minimize => minimize_1d, minimize_2d
22+
procedure(minimize), deferred :: minimize
2523
end type optimizer_base_type
2624

2725
abstract interface
@@ -32,19 +30,12 @@ impure elemental subroutine init(self, num_params)
3230
integer, intent(in) :: num_params
3331
end subroutine init
3432

35-
pure subroutine minimize_1d(self, param, gradient)
33+
pure subroutine minimize(self, param, gradient)
3634
import :: optimizer_base_type
3735
class(optimizer_base_type), intent(inout) :: self
3836
real, intent(inout) :: param(:)
3937
real, intent(in) :: gradient(:)
40-
end subroutine minimize_1d
41-
42-
pure subroutine minimize_2d(self, param, gradient)
43-
import :: optimizer_base_type
44-
class(optimizer_base_type), intent(inout) :: self
45-
real, intent(inout) :: param(:,:)
46-
real, intent(in) :: gradient(:,:)
47-
end subroutine minimize_2d
38+
end subroutine minimize
4839

4940
end interface
5041

@@ -55,8 +46,7 @@ end subroutine minimize_2d
5546
real, allocatable, private :: velocity(:)
5647
contains
5748
procedure :: init => init_sgd
58-
procedure :: minimize_1d => minimize_sgd_1d
59-
procedure :: minimize_2d => minimize_sgd_2d
49+
procedure :: minimize => minimize_sgd
6050
end type sgd
6151

6252
type, extends(optimizer_base_type) :: rmsprop
@@ -71,8 +61,7 @@ end subroutine minimize_2d
7161
real, allocatable, private :: rms_gradient(:)
7262
contains
7363
procedure :: init => init_rmsprop
74-
procedure :: minimize_1d => minimize_rmsprop_1d
75-
procedure :: minimize_2d => minimize_rmsprop_2d
64+
procedure :: minimize => minimize_rmsprop
7665
end type rmsprop
7766

7867
type, extends(optimizer_base_type) :: adam
@@ -95,8 +84,7 @@ end subroutine minimize_2d
9584
integer, private :: t = 0
9685
contains
9786
procedure :: init => init_adam
98-
procedure :: minimize_1d => minimize_adam_1d
99-
procedure :: minimize_2d => minimize_adam_2d
87+
procedure :: minimize => minimize_adam
10088
end type adam
10189

10290
type, extends(optimizer_base_type) :: adagrad
@@ -113,8 +101,7 @@ end subroutine minimize_2d
113101
integer, private :: t = 0
114102
contains
115103
procedure :: init => init_adagrad
116-
procedure :: minimize_1d => minimize_adagrad_1d
117-
procedure :: minimize_2d => minimize_adagrad_2d
104+
procedure :: minimize => minimize_adagrad
118105
end type adagrad
119106

120107
contains
@@ -129,7 +116,7 @@ impure elemental subroutine init_sgd(self, num_params)
129116
end subroutine init_sgd
130117

131118

132-
pure subroutine minimize_sgd_1d(self, param, gradient)
119+
pure subroutine minimize_sgd(self, param, gradient)
133120
!! Concrete implementation of a stochastic gradient descent optimizer
134121
!! update rule.
135122
class(sgd), intent(inout) :: self
@@ -152,33 +139,7 @@ pure subroutine minimize_sgd_1d(self, param, gradient)
152139
param = param - self % learning_rate * gradient
153140
end if
154141

155-
end subroutine minimize_sgd_1d
156-
157-
158-
pure subroutine minimize_sgd_2d(self, param, gradient)
159-
!! Concrete implementation of a stochastic gradient descent optimizer
160-
!! update rule for 2D arrays.
161-
class(sgd), intent(inout) :: self
162-
real, intent(inout) :: param(:,:)
163-
real, intent(in) :: gradient(:,:)
164-
165-
if (self % momentum > 0) then
166-
! Apply momentum update
167-
self % velocity = self % momentum * self % velocity &
168-
- self % learning_rate * reshape(gradient, [size(gradient)])
169-
if (self % nesterov) then
170-
! Apply Nesterov update
171-
param = param + reshape(self % momentum * self % velocity &
172-
- self % learning_rate * reshape(gradient, [size(gradient)]), shape(param))
173-
else
174-
param = param + reshape(self % velocity, shape(param))
175-
end if
176-
else
177-
! Apply regular update
178-
param = param - self % learning_rate * gradient
179-
end if
180-
181-
end subroutine minimize_sgd_2d
142+
end subroutine minimize_sgd
182143

183144

184145
impure elemental subroutine init_rmsprop(self, num_params)
@@ -191,7 +152,7 @@ impure elemental subroutine init_rmsprop(self, num_params)
191152
end subroutine init_rmsprop
192153

193154

194-
pure subroutine minimize_rmsprop_1d(self, param, gradient)
155+
pure subroutine minimize_rmsprop(self, param, gradient)
195156
!! Concrete implementation of a RMSProp optimizer update rule.
196157
class(rmsprop), intent(inout) :: self
197158
real, intent(inout) :: param(:)
@@ -205,24 +166,7 @@ pure subroutine minimize_rmsprop_1d(self, param, gradient)
205166
param = param - self % learning_rate &
206167
/ sqrt(self % rms_gradient + self % epsilon) * gradient
207168

208-
end subroutine minimize_rmsprop_1d
209-
210-
211-
pure subroutine minimize_rmsprop_2d(self, param, gradient)
212-
!! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
213-
class(rmsprop), intent(inout) :: self
214-
real, intent(inout) :: param(:,:)
215-
real, intent(in) :: gradient(:,:)
216-
217-
! Compute the RMS of the gradient using the RMSProp rule
218-
self % rms_gradient = self % decay_rate * self % rms_gradient &
219-
+ (1 - self % decay_rate) * reshape(gradient, [size(gradient)])**2
220-
221-
! Update the network parameters based on the new RMS of the gradient
222-
param = param - self % learning_rate &
223-
/ sqrt(reshape(self % rms_gradient, shape(param)) + self % epsilon) * gradient
224-
225-
end subroutine minimize_rmsprop_2d
169+
end subroutine minimize_rmsprop
226170

227171

228172
impure elemental subroutine init_adam(self, num_params)
@@ -236,7 +180,7 @@ impure elemental subroutine init_adam(self, num_params)
236180
end subroutine init_adam
237181

238182

239-
pure subroutine minimize_adam_1d(self, param, gradient)
183+
pure subroutine minimize_adam(self, param, gradient)
240184
!! Concrete implementation of an Adam optimizer update rule.
241185
class(adam), intent(inout) :: self
242186
real, intent(inout) :: param(:)
@@ -264,38 +208,7 @@ pure subroutine minimize_adam_1d(self, param, gradient)
264208

265209
end associate
266210

267-
end subroutine minimize_adam_1d
268-
269-
270-
pure subroutine minimize_adam_2d(self, param, gradient)
271-
!! Concrete implementation of an Adam optimizer update rule for 2D arrays.
272-
class(adam), intent(inout) :: self
273-
real, intent(inout) :: param(:,:)
274-
real, intent(in) :: gradient(:,:)
275-
276-
self % t = self % t + 1
277-
278-
! If weight_decay_l2 > 0, use L2 regularization;
279-
! otherwise, default to regular Adam.
280-
associate(g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]))
281-
self % m = self % beta1 * self % m + (1 - self % beta1) * g
282-
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
283-
end associate
284-
285-
! Compute bias-corrected first and second moment estimates.
286-
associate( &
287-
m_hat => self % m / (1 - self % beta1**self % t), &
288-
v_hat => self % v / (1 - self % beta2**self % t) &
289-
)
290-
291-
! Update parameters.
292-
param = param &
293-
- self % learning_rate * reshape(m_hat / (sqrt(v_hat) + self % epsilon), shape(param)) &
294-
- self % learning_rate * self % weight_decay_decoupled * param
295-
296-
end associate
297-
298-
end subroutine minimize_adam_2d
211+
end subroutine minimize_adam
299212

300213

301214
impure elemental subroutine init_adagrad(self, num_params)
@@ -308,7 +221,7 @@ impure elemental subroutine init_adagrad(self, num_params)
308221
end subroutine init_adagrad
309222

310223

311-
pure subroutine minimize_adagrad_1d(self, param, gradient)
224+
pure subroutine minimize_adagrad(self, param, gradient)
312225
!! Concrete implementation of an Adagrad optimizer update rule.
313226
class(adagrad), intent(inout) :: self
314227
real, intent(inout) :: param(:)
@@ -333,34 +246,6 @@ pure subroutine minimize_adagrad_1d(self, param, gradient)
333246

334247
end associate
335248

336-
end subroutine minimize_adagrad_1d
337-
338-
339-
pure subroutine minimize_adagrad_2d(self, param, gradient)
340-
!! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341-
class(adagrad), intent(inout) :: self
342-
real, intent(inout) :: param(:,:)
343-
real, intent(in) :: gradient(:,:)
344-
345-
! Update the current time step
346-
self % t = self % t + 1
347-
348-
associate( &
349-
! If weight_decay_l2 > 0, use L2 regularization;
350-
! otherwise, default to regular Adagrad.
351-
g => reshape(gradient, [size(gradient)]) + self % weight_decay_l2 * reshape(param, [size(param)]), &
352-
! Amortize the learning rate as function of the current time step.
353-
learning_rate => self % learning_rate &
354-
/ (1 + (self % t - 1) * self % learning_rate_decay) &
355-
)
356-
357-
self % sum_squared_gradient = self % sum_squared_gradient + g**2
358-
359-
param = param - learning_rate * reshape(g / (sqrt(self % sum_squared_gradient) &
360-
+ self % epsilon), shape(param))
361-
362-
end associate
363-
364-
end subroutine minimize_adagrad_2d
249+
end subroutine minimize_adagrad
365250

366251
end module nf_optimizers

0 commit comments

Comments
 (0)