@@ -24,7 +24,7 @@ function setup(rule, x; seen = Base.IdSet())
24
24
end
25
25
end
26
26
27
- subtract! (x, x̄) = iswriteable (x) ? (x .= x .- x̄) : eltype (x).(x .- x̄)
27
+ subtract! (x, x̄) = maywrite (x) ? (x .= x .- x̄) : eltype (x).(x .- x̄)
28
28
29
29
update! (:: Nothing , x, :: Zero , :: Zero... ) = nothing , x
30
30
update! (:: Nothing , x, x̄s... ) = nothing , x
@@ -44,8 +44,8 @@ function update!(tree, x, x̄s...)
44
44
end
45
45
46
46
function update (tree, x, x̄s... )
47
- t′ = fmap (copy, tree; exclude = iswriteable )
48
- x′ = fmap (copy, x; exclude = iswriteable )
47
+ t′ = fmap (copy, tree; exclude = maywrite )
48
+ x′ = fmap (copy, x; exclude = maywrite )
49
49
update! (t′, x′, x̄s... )
50
50
end
51
51
@@ -56,8 +56,17 @@ isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. tr
56
56
isnumeric (x:: AbstractArray{<:Integer} ) = false
57
57
isnumeric (x) = false
58
58
59
- iswriteable (:: DenseArray ) = true # more elaborate versions are possible, wait until needed?
60
- iswriteable (_) = false
59
+ """
60
+ maywrite(x) -> Bool
61
+
62
+ Should return `true` if we are completely sure that `update!` can write new
63
+ values into `x`. Otherwise `false`, indicating a non-mutating path.
64
+ For now, simply `x isa DenseArray` allowing `Array`, `CuArray`, etc.
65
+ """
66
+ maywrite (:: DenseArray ) = true # see https://github.com/FluxML/Optimisers.jl/issues/99 for discussion
67
+ maywrite (_) = false
68
+
69
+ @deprecate iswriteable maywrite false # remove when releasing Optimisers@0.3
61
70
62
71
"""
63
72
trainable(x::Layer) -> NamedTuple
84
93
@.. x = x + y
85
94
86
95
Sometimes in-place broadcasting macro, for use in `apply!` rules.
87
- If `iswriteable (x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
96
+ If `maywrite (x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
88
97
"""
89
98
macro var".." (ex)
90
99
Meta. isexpr (ex, :(= )) || throw (" the macro @.. only accepts assignment, like @.. x = y + z" )
91
100
dst = esc (ex. args[1 ])
92
101
src = esc (Broadcast. __dot__ (ex. args[2 ]))
93
- :($ dst = if $ iswriteable ($ dst)
102
+ :($ dst = if $ maywrite ($ dst)
94
103
$ dst .= $ src
95
104
else
96
105
$ src
0 commit comments