Skip to content

Commit ee36810

Browse files
committed
More fixups
1 parent 61310ce commit ee36810

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

src/enzyme.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
122122
end
123123

124124
ddsts = dst.dval
125-
dsrcs = src.dval
125+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
126126

127127
if EnzymeCore.EnzymeRules.width(config) == 1
128128
ddsts = (ddsts,)
@@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
182182
end
183183

184184
ddsts = dst.dval
185-
dsrcs = src.dval
185+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
186186

187187
if EnzymeCore.EnzymeRules.width(config) == 1
188188
ddsts = (ddsts,)
@@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
322322
keep = nothing
323323
end
324324

325-
# Cache idx if its overwritten
326-
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
327-
&& !(typeof(src) <: EnzymeCore.Const)
328-
&& !(typeof(dst) <: EnzymeCore.Const)
329-
) ? copy(idx.val) : nothing
330-
331325
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
332326
end
333327

@@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
336330
val = convert(T, 1/(1-p.val))
337331

338332
ddsts = dst.dval
339-
dsrcs = src.dval
333+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
340334

341335
if EnzymeCore.EnzymeRules.width(config) == 1
342336
ddsts = (ddsts,)

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ end
895895

896896
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tw, Tw) || continue
897897

898-
EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const))
898+
EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (cdims, EnzymeCore.Const))
899899
end
900900
end
901901
end

0 commit comments

Comments
 (0)