Skip to content

Commit 47286dc

Browse files
committed
In make_chi, introspect J_T to determine default via
The `via=:tau` option for `make_chi` requires that `J_T` takes a `tau` keyword argument. For the default "automatic" keyword for `via`, this should be checked explicitly via introspection. Otherwise, the naive definition of a `J_T` function by a user that does not use `tau` even though the functional is defined in terms of target states leads to a hard-to-understand error.
1 parent 423c611 commit 47286dc

File tree

4 files changed

+90
-41
lines changed

4 files changed

+90
-41
lines changed

ext/QuantumControlFiniteDifferencesExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using LinearAlgebra
44

55
import FiniteDifferences
66
import QuantumControl.Functionals:
7-
_default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
7+
make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
88

99

10-
function make_automatic_chi(
11-
J_T,
12-
trajectories,
13-
::Val{:FiniteDifferences};
14-
via=_default_chi_via(trajectories)
15-
)
10+
function make_automatic_chi(J_T, trajectories, ::Val{:FiniteDifferences}; via=:states)
1611

1712
# TODO: Benchmark if χ should be closure, see QuantumControlZygoteExt.jl
1813

ext/QuantumControlZygoteExt.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using LinearAlgebra
44

55
import Zygote
66
import QuantumControl.Functionals:
7-
_default_chi_via, make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
7+
make_gate_chi, make_automatic_chi, make_automatic_grad_J_a
88

99

10-
function make_automatic_chi(
11-
J_T,
12-
trajectories,
13-
::Val{:Zygote};
14-
via=_default_chi_via(trajectories)
15-
)
10+
function make_automatic_chi(J_T, trajectories, ::Val{:Zygote}; via=:states)
1611

1712
# TODO: At some point, for a large system, we could benchmark if there is
1813
# any benefit to making χ a closure and using LinearAlgebra.axpby! to
@@ -26,7 +21,14 @@ function make_automatic_chi(
2621
χ = Vector{eltype(Ψ)}(undef, length(Ψ))
2722
∇J = Zygote.gradient(_J_T, Ψ...)
2823
for (k, ∇Jₖ) enumerate(∇J)
29-
χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv
24+
if isnothing(∇Jₖ)
25+
# Functional does not depend on Ψₖ. That probably means a buggy
26+
# J_T, but who knows: maybe there are situations where that
27+
# makes sense. It would be extremely noisy to warn here.
28+
χ[k] = zero(χ[k])
29+
else
30+
χ[k] = 0.5 * ∇Jₖ # ½ corrects for gradient vs Wirtinger deriv
31+
end
3032
# axpby!(0.5, ∇Jₖ, false, χ[k])
3133
end
3234
return χ
@@ -43,7 +45,12 @@ function make_automatic_chi(
4345
χ = Vector{eltype(Ψ)}(undef, length(Ψ))
4446
∇J = Zygote.gradient(_J_T, τ...)
4547
for (k, traj) enumerate(trajectories)
46-
∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv
48+
if isnothing(∇J[k])
49+
# Functional does not depend on τₖ
50+
∂J╱∂τ̄ₖ = zero(ComplexF64)
51+
else
52+
∂J╱∂τ̄ₖ = 0.5 * ∇J[k] # ½ corrects for gradient vs Wirtinger deriv
53+
end
4754
χ[k] = ∂J╱∂τ̄ₖ * traj.target_state
4855
# axpby!(∂J╱∂τ̄ₖ, traj.target_state, false, χ[k])
4956
end

src/functionals.jl

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@ export make_grad_J_a, make_chi
99
using LinearAlgebra: axpy!, dot
1010

1111

12-
# default for `via` argument of `make_chi`
13-
function _default_chi_via(trajectories)
14-
if any(isnothing(traj.target_state) for traj in trajectories)
15-
return :states
16-
else
17-
return :tau
12+
function _check_chi(chi; states, trajectories, tau, via)
13+
try
14+
if via == :tau
15+
chi_states = chi(states, trajectories; tau)
16+
else
17+
chi_states = chi(states, trajectories)
18+
end
19+
if typeof(chi_states) typeof(states)
20+
msg = "`chi` must return a vector of states"
21+
error(msg)
22+
end
23+
catch exception
24+
msg = "The chi generated by `make_chi` does not have the required interface"
25+
@error msg exception
26+
error("Cannot make chi")
1827
end
1928
end
2029

@@ -86,15 +95,25 @@ chi = make_chi(
8695
trajectories;
8796
mode=:any,
8897
automatic=:default,
89-
via=(any(isnothing(t.target_state) for t in trajectories) ? :states : :tau),
98+
via=:automatic, # one of :automatic, :tau, :states
9099
)
91100
```
92101
93-
creates a function `chi(Ψ, trajectories; τ)` that returns
94-
a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩`` is the
95-
k'th element of `Ψ`. These are the states used as the boundary condition for
96-
the backward propagation propagation in Krotov's method and GRAPE. Each
97-
``|χₖ⟩`` is defined as a matrix calculus
102+
creates a function `chi(Ψ, trajectories)` or `chi(Ψ, trajectories; tau)` that
103+
returns a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩``
104+
is the k'th element of `Ψ`. These are the states used as the boundary condition
105+
for the backward propagation propagation in Krotov's method and GRAPE.
106+
107+
The resulting `chi` function takes the keyword argument `tau`
108+
if and only if `via=:tau` or `via=:automatic` (default) if the following
109+
conditions are met:
110+
111+
* All `trajectories` have a defined `target_state` component (not `nothing`)
112+
* `J_T` takes `tau` as a keyword argument (determined via introspection)
113+
114+
Both of these conditions are _requirements_ for `via=:tau`.
115+
116+
Each ``|χₖ⟩`` is defined as a matrix calculus
98117
[Wirtinger derivative](https://www.ekinakyurek.me/complex-derivatives-wirtinger/),
99118
100119
```math
@@ -193,25 +212,53 @@ and the definition of the Zygote gradient with respect to a complex scalar,
193212
gradients). Always test automatic derivatives against finite differences
194213
and/or other automatic differentiation frameworks.
195214
"""
196-
function make_chi(
197-
J_T,
198-
trajectories;
199-
mode=:any,
200-
automatic=:default,
201-
via=_default_chi_via(trajectories),
202-
)
215+
function make_chi(J_T, trajectories; mode=:any, automatic=:default, via=:automatic,)
216+
states = [traj.initial_state for traj in trajectories]
217+
tau = [zero(ComplexF64) for _ in states]
218+
J_T_takes_tau = hasmethod(J_T, Tuple{typeof(states),typeof(trajectories)}, (:tau,))
219+
has_target_states = all((traj.target_state nothing) for traj in trajectories)
220+
if (via == :tau) && !J_T_takes_tau
221+
msg = "Called `make_chi` with `via=:tau`, but given J_T does not take `tau` keyword argument"
222+
error(msg)
223+
end
224+
if (via == :tau) && !has_target_states
225+
msg = "Called `make_chi` with `via=:tau`, but not all `trajectories` define a `target_state`"
226+
error(msg)
227+
end
228+
if via == :automatic
229+
via = :states
230+
if J_T_takes_tau && has_target_states
231+
via = :tau
232+
end
233+
end
234+
chi = nothing
235+
try
236+
if via == :tau
237+
J_T_val = J_T(states, trajectories; tau)
238+
else
239+
J_T_val = J_T(states, trajectories)
240+
end
241+
if !(J_T_val isa Float64)
242+
msg = "J_T passed to `make_chi` must return a Float64, not $(typeof(J_T_val))"
243+
error(msg)
244+
end
245+
catch exception
246+
msg = "The J_T passed to `make_chi` does not have the required interface"
247+
@error msg exception
248+
error("Cannot make chi")
249+
end
203250
if mode == :any
204251
try
205252
chi = make_analytic_chi(J_T, trajectories)
206253
@debug "make_chi for J_T=$(J_T) -> analytic"
207-
# TODO: call chi to compile it and ensure required properties
254+
_check_chi(chi; states, trajectories, tau, via)
208255
return chi
209256
catch exception
210257
if exception isa MethodError
211258
@info "make_chi for J_T=$(J_T): fallback to mode=:automatic"
212259
try
213260
chi = make_automatic_chi(J_T, trajectories, automatic; via)
214-
# TODO: call chi to compile it and ensure required properties
261+
_check_chi(chi; states, trajectories, tau, via)
215262
return chi
216263
catch exception
217264
if exception isa MethodError
@@ -228,7 +275,7 @@ function make_chi(
228275
elseif mode == :analytic
229276
try
230277
chi = make_analytic_chi(J_T, trajectories)
231-
# TODO: call chi to compile it and ensure required properties
278+
_check_chi(chi; states, trajectories, tau, via)
232279
return chi
233280
catch exception
234281
if exception isa MethodError
@@ -241,7 +288,7 @@ function make_chi(
241288
elseif mode == :automatic
242289
try
243290
chi = make_automatic_chi(J_T, trajectories, automatic; via)
244-
# TODO: call chi to compile it and ensure required properties
291+
_check_chi(chi; states, trajectories, tau, via)
245292
return chi
246293
catch exception
247294
if exception isa MethodError

test/test_functionals.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,13 @@ end
299299
throw(DomainError("XXX"))
300300
end
301301

302-
@test_throws DomainError begin
302+
@test_throws Exception begin
303303
IOCapture.capture() do
304304
make_chi(J_T_xxx, trajectories)
305305
end
306306
end
307307

308-
@test_throws DomainError begin
308+
@test_throws Exception begin
309309
IOCapture.capture() do
310310
make_chi(J_T_xxx, trajectories; mode=:automatic)
311311
end

0 commit comments

Comments
 (0)