Skip to content

Commit edb02d0

Browse files
committed
Refactoring LeakyReLU with IBufferDerivative
1 parent 87cb3b9 commit edb02d0

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/NeuralNet/ActivationFunctions/LeakyReLU/LeakyReLU.php

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,17 @@ public function activate(NDArray $input) : NDArray
8282
* f'(x) = 1 if x > 0
8383
* f'(x) = leakage if x ≤ 0
8484
*
85-
* @param NDArray $output Output matrix
85+
* @param NDArray $input Input matrix
8686
* @return NDArray Derivative matrix
8787
*/
88-
public function differentiate(NDArray $output) : NDArray
88+
public function differentiate(NDArray $input) : NDArray
8989
{
9090
// For x > 0: 1
91-
$positivePart = NumPower::greater($output, 0);
91+
$positivePart = NumPower::greater($input, 0);
9292

9393
// For x <= 0: leakage
9494
$negativePart = NumPower::multiply(
95-
NumPower::lessEqual($output, 0),
95+
NumPower::lessEqual($input, 0),
9696
$this->leakage
9797
);
9898

tests/NeuralNet/ActivationFunctions/LeakyReLU/LeakyReLUTest.php

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public static function computeProvider() : Generator
3535
[2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
3636
]),
3737
[
38-
[2.0, 1.0, -0.004999999888241291, 0.0, 20.0, -0.09999999403953552],
38+
[2.0, 1.0, -0.0049999, 0.0, 20.0, -0.0999999],
3939
],
4040
];
4141

@@ -46,9 +46,9 @@ public static function computeProvider() : Generator
4646
[0.05, -0.52, 0.54],
4747
]),
4848
[
49-
[-0.0011999999405816197, 0.3100000023841858, -0.004900000058114529],
50-
[0.9900000095367432, 0.07999999821186066, -0.00029999998514540493],
51-
[0.05000000074505806, -0.005199999548494816, 0.5400000214576721],
49+
[-0.0011999, 0.3100000, -0.0049000],
50+
[0.9900000, 0.0799999, -0.0002999],
51+
[0.0500000, -0.0051999, 0.5400000],
5252
],
5353
];
5454
}
@@ -60,10 +60,10 @@ public static function differentiateProvider() : Generator
6060
{
6161
yield [
6262
NumPower::array([
63-
[2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
63+
[4.0, 2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
6464
]),
6565
[
66-
[1.0, 1.0, 0.009999999776482582, 0.009999999776482582, 1.0, 0.009999999776482582],
66+
[1.0, 1.0, 1.0, 0.0099999, 0.0099999, 1.0, 0.0099999],
6767
],
6868
];
6969

@@ -74,9 +74,9 @@ public static function differentiateProvider() : Generator
7474
[0.05, -0.52, 0.54],
7575
]),
7676
[
77-
[0.009999999776482582, 1.0, 0.009999999776482582],
78-
[1.0, 1.0, 0.009999999776482582],
79-
[1.0, 0.009999999776482582, 1.0],
77+
[0.0099999, 1.0, 0.0099999],
78+
[1.0, 1.0, 0.0099999],
79+
[1.0, 0.0099999, 1.0],
8080
],
8181
];
8282
}
@@ -113,7 +113,7 @@ public static function boundaryProvider() : Generator
113113
]),
114114
[
115115

116-
[0.0010000000474974513, -0.000010000000656873453, 0.00009999999747378752, -0.0000009999999974752427],
116+
[0.0010000, -0.0000100, 0.0000999, -0.0000009],
117117
],
118118
];
119119
}
@@ -161,7 +161,7 @@ public function testActivate(NDArray $input, array $expected) : void
161161
{
162162
$activations = $this->activationFn->activate($input)->toArray();
163163

164-
static::assertEqualsWithDelta($expected, $activations, 1e-16);
164+
static::assertEqualsWithDelta($expected, $activations, 1e-7);
165165
}
166166

167167
#[Test]
@@ -171,16 +171,16 @@ public function testBoundaryActivate(NDArray $input, array $expected) : void
171171
{
172172
$activations = $this->activationFn->activate($input)->toArray();
173173

174-
static::assertEqualsWithDelta($expected, $activations, 1e-16);
174+
static::assertEqualsWithDelta($expected, $activations, 1e-7);
175175
}
176176

177177
#[Test]
178-
#[TestDox('Correctly differentiates the output')]
178+
#[TestDox('Correctly differentiates the input')]
179179
#[DataProvider('differentiateProvider')]
180-
public function testDifferentiate(NDArray $output, array $expected) : void
180+
public function testDifferentiate(NDArray $input, array $expected) : void
181181
{
182-
$derivatives = $this->activationFn->differentiate($output)->toArray();
182+
$derivatives = $this->activationFn->differentiate($input)->toArray();
183183

184-
static::assertEqualsWithDelta($expected, $derivatives, 1e-16);
184+
static::assertEqualsWithDelta($expected, $derivatives, 1e-7);
185185
}
186186
}

0 commit comments

Comments
 (0)