@@ -9,12 +9,21 @@ export make_grad_J_a, make_chi
99using LinearAlgebra: axpy!, dot
1010
1111
12- # default for `via` argument of `make_chi`
13- function _default_chi_via (trajectories)
14- if any (isnothing (traj. target_state) for traj in trajectories)
15- return :states
16- else
17- return :tau
12+ function _check_chi (chi; states, trajectories, tau, via)
13+ try
14+ if via == :tau
15+ chi_states = chi (states, trajectories; tau)
16+ else
17+ chi_states = chi (states, trajectories)
18+ end
19+ if typeof (chi_states) ≠ typeof (states)
20+ msg = " `chi` must return a vector of states"
21+ error (msg)
22+ end
23+ catch exception
24+ msg = " The chi generated by `make_chi` does not have the required interface"
25+ @error msg exception
26+ error (" Cannot make chi" )
1827 end
1928end
2029
@@ -86,15 +95,25 @@ chi = make_chi(
8695 trajectories;
8796 mode=:any,
8897 automatic=:default,
89- via=(any(isnothing(t.target_state) for t in trajectories) ? :states : : tau),
98+ via=:automatic, # one of :automatic, : tau, :states
9099)
91100```
92101
93- creates a function `chi(Ψ, trajectories; τ)` that returns
94- a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩`` is the
95- k'th element of `Ψ`. These are the states used as the boundary condition for
96- the backward propagation propagation in Krotov's method and GRAPE. Each
97- ``|χₖ⟩`` is defined as a matrix calculus
102+ creates a function `chi(Ψ, trajectories)` or `chi(Ψ, trajectories; tau)` that
103+ returns a vector of states `χ` with ``|χ_k⟩ = -∂J_T/∂⟨Ψ_k|``, where ``|Ψ_k⟩``
104+ is the k'th element of `Ψ`. These are the states used as the boundary condition
105+ for the backward propagation propagation in Krotov's method and GRAPE.
106+
107+ The resulting `chi` function takes the keyword argument `tau`
108+ if and only if `via=:tau` or `via=:automatic` (default) if the following
109+ conditions are met:
110+
111+ * All `trajectories` have a defined `target_state` component (not `nothing`)
112+ * `J_T` takes `tau` as a keyword argument (determined via introspection)
113+
114+ Both of these conditions are _requirements_ for `via=:tau`.
115+
116+ Each ``|χₖ⟩`` is defined as a matrix calculus
98117[Wirtinger derivative](https://www.ekinakyurek.me/complex-derivatives-wirtinger/),
99118
100119```math
@@ -193,25 +212,53 @@ and the definition of the Zygote gradient with respect to a complex scalar,
193212 gradients). Always test automatic derivatives against finite differences
194213 and/or other automatic differentiation frameworks.
195214"""
196- function make_chi (
197- J_T,
198- trajectories;
199- mode= :any ,
200- automatic= :default ,
201- via= _default_chi_via (trajectories),
202- )
215+ function make_chi (J_T, trajectories; mode= :any , automatic= :default , via= :automatic ,)
216+ states = [traj. initial_state for traj in trajectories]
217+ tau = [zero (ComplexF64) for _ in states]
218+ J_T_takes_tau = hasmethod (J_T, Tuple{typeof (states),typeof (trajectories)}, (:tau ,))
219+ has_target_states = all ((traj. target_state ≢ nothing ) for traj in trajectories)
220+ if (via == :tau ) && ! J_T_takes_tau
221+ msg = " Called `make_chi` with `via=:tau`, but given J_T does not take `tau` keyword argument"
222+ error (msg)
223+ end
224+ if (via == :tau ) && ! has_target_states
225+ msg = " Called `make_chi` with `via=:tau`, but not all `trajectories` define a `target_state`"
226+ error (msg)
227+ end
228+ if via == :automatic
229+ via = :states
230+ if J_T_takes_tau && has_target_states
231+ via = :tau
232+ end
233+ end
234+ chi = nothing
235+ try
236+ if via == :tau
237+ J_T_val = J_T (states, trajectories; tau)
238+ else
239+ J_T_val = J_T (states, trajectories)
240+ end
241+ if ! (J_T_val isa Float64)
242+ msg = " J_T passed to `make_chi` must return a Float64, not $(typeof (J_T_val)) "
243+ error (msg)
244+ end
245+ catch exception
246+ msg = " The J_T passed to `make_chi` does not have the required interface"
247+ @error msg exception
248+ error (" Cannot make chi" )
249+ end
203250 if mode == :any
204251 try
205252 chi = make_analytic_chi (J_T, trajectories)
206253 @debug " make_chi for J_T=$(J_T) -> analytic"
207- # TODO : call chi to compile it and ensure required properties
254+ _check_chi ( chi; states, trajectories, tau, via)
208255 return chi
209256 catch exception
210257 if exception isa MethodError
211258 @info " make_chi for J_T=$(J_T) : fallback to mode=:automatic"
212259 try
213260 chi = make_automatic_chi (J_T, trajectories, automatic; via)
214- # TODO : call chi to compile it and ensure required properties
261+ _check_chi ( chi; states, trajectories, tau, via)
215262 return chi
216263 catch exception
217264 if exception isa MethodError
@@ -228,7 +275,7 @@ function make_chi(
228275 elseif mode == :analytic
229276 try
230277 chi = make_analytic_chi (J_T, trajectories)
231- # TODO : call chi to compile it and ensure required properties
278+ _check_chi ( chi; states, trajectories, tau, via)
232279 return chi
233280 catch exception
234281 if exception isa MethodError
@@ -241,7 +288,7 @@ function make_chi(
241288 elseif mode == :automatic
242289 try
243290 chi = make_automatic_chi (J_T, trajectories, automatic; via)
244- # TODO : call chi to compile it and ensure required properties
291+ _check_chi ( chi; states, trajectories, tau, via)
245292 return chi
246293 catch exception
247294 if exception isa MethodError
0 commit comments