diff --git a/src/correlations.jl b/src/correlations.jl index feeb4eb..b5d9008 100644 --- a/src/correlations.jl +++ b/src/correlations.jl @@ -1,4 +1,4 @@ -export ccorr +export ccorr, plan_ccorr, ccorr_psf, plan_ccorr_psf, plan_ccorr_buffer, plan_ccorr_psf_buffer """ ccorr(u, v[, dims]; centered=false) @@ -65,3 +65,81 @@ function ccorr(u::AbstractArray{<:Real, N}, v::AbstractArray{<:Real, M}, return out end end + +function ccorr_psf(u::AbstractArray{T, N}, psf::AbstractArray{D, M}, dims=ntuple(+, min(N, M))) where {T, D, N, M} + return ccorr(u, ifftshift(psf, dims), dims) +end + +function p_ccorr_aux(P, P_inv, u, v_ft) + return (P_inv.p * ((P * u) .* conj(v_ft) .* P_inv.scale)) +end + +function plan_ccorr(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N); + kwargs...) where {T1, T2, N, M} + eltype_error(T1, T2) + plan = get_plan(T1) + # do the preplanning step + P = let + # FFTW.MEASURE flag might overwrite input! Hence copy! + if (:flags in keys(kwargs) && + (getindex(kwargs, :flags) == FFTW.MEASURE || getindex(kwargs, :flags) == FFTW.PATIENT)) + plan(copy(u), dims; kwargs...) + else + plan(u, dims; kwargs...) + end + end + + v_ft = fft_or_rfft(T1)(v, dims) + # construct the efficient conv function + # P and P_inv can be understood like matrices + # but their computation is fast + ccorr = let P = P, + P_inv = inv(P), + # put a different name here! See https://discourse.julialang.org/t/type-issue-with-captured-variables-let-workaround-failed/85661 + v_ft = v_ft + ccorr(u, v_ft=v_ft) = p_ccorr_aux(P, P_inv, u, v_ft) + end + + return v_ft, ccorr +end + +function plan_ccorr_psf(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N); + kwargs...) where {T, N, M} +return plan_ccorr(u, ifftshift(psf, dims), dims; kwargs...) +end + +function plan_ccorr_buffer(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N); + kwargs...) where {T1, T2, N, M} + eltype_error(T1, T2) + plan = get_plan(T1) + # do the preplanning step + P_u = plan(u, dims; kwargs...) + P_v = plan(v, dims) + + u_buff = P_u * u + v_ft = P_v * v + conj!(v_ft) + uv_buff = u_buff .* v_ft + + # for fourier space we need a new plan + P = plan(u .* v, dims; kwargs...) + P_inv = inv(P) + out_buff = P_inv * uv_buff + + # construct the efficient conv function + # P and P_inv can be understood like matrices + # but their computation is fast + function ccorr(u, v_ft=v_ft) + mul!(u_buff, P_u, u) + uv_buff .= u_buff .* v_ft + mul!(out_buff, P_inv, uv_buff) + return out_buff + end + + return v_ft, ccorr +end + +function plan_ccorr_psf_buffer(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N); + kwargs...) where {T, N, M} + return plan_ccorr_buffer(u, ifftshift(psf, dims), dims; kwargs...) +end