Skip to content

Commit aec596e

Browse files
bors[bot]ToucheSir
andauthored
Merge #1432
1432: Generalize train/testmode! to all Functors r=CarloLucibello a=ToucheSir Addresses #1044 (comment). See also https://discourse.julialang.org/t/do-i-have-to-implement-flux-testmode-for-my-own-models/52038. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
2 parents 843b4d4 + 0fbb22d commit aec596e

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

src/functor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Possible values include:
1919
- `true` for testing
2020
- `:auto` or `nothing` for Flux to detect the mode automatically
2121
"""
22-
testmode!(m, mode = true) = m
22+
testmode!(m, mode = true) = (foreach(x -> testmode!(x, mode), trainable(m)); m)
2323

2424
"""
2525
trainmode!(m, mode = true)

src/layers/basic.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
3939

4040
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
4141

42-
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
43-
4442
function Base.show(io::IO, c::Chain)
4543
print(io, "Chain(")
4644
join(io, c.layers, ", ")

test/utils.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,17 @@ end
272272
testdense(re(p), bt)
273273
end
274274
end
275-
end
275+
end
276+
277+
@testset "Train and test mode" begin
278+
mutable struct DummyLayer
279+
testing::Bool
280+
end
281+
Flux.testmode!(m::DummyLayer, testing=true) = (m.testing = testing; m)
282+
283+
c = Chain(DummyLayer(true))
284+
testmode!(c)
285+
@test c[1].testing
286+
trainmode!(c)
287+
@test !c[1].testing
288+
end

0 commit comments

Comments
 (0)