Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0874978
first changes
RainerHeintzmann Mar 20, 2025
0014f53
added Adapts stuff
RainerHeintzmann Mar 20, 2025
da99701
replaced test files
RainerHeintzmann Mar 20, 2025
a559c4f
added CircShiftedArray.jl
RainerHeintzmann Mar 20, 2025
29334fd
additions to make it work in CUDA
RainerHeintzmann Mar 21, 2025
9a80dd7
before streamlining
RainerHeintzmann Mar 22, 2025
89de489
first part of streamlining
RainerHeintzmann Mar 22, 2025
ce66b60
almost done
RainerHeintzmann Mar 22, 2025
48edce3
cleanup in CUDASupportExt
RainerHeintzmann Mar 22, 2025
3e53700
tiny bug fix in tests
RainerHeintzmann Mar 22, 2025
68a7cb5
removed old tests and bug-fix with unavailable CUDA hardware support.
RainerHeintzmann Mar 22, 2025
faa5677
changed opt_cu to not call CuArray under non-cuda conditions
RainerHeintzmann Mar 22, 2025
2ed72a4
better generation of zeros [skip ci]
RainerHeintzmann Mar 23, 2025
08c0295
saved some memory on plan_conv_buffer [skip ci]
RainerHeintzmann Mar 23, 2025
9bece7e
updated NDTools dependency to 0.8
RainerHeintzmann Apr 5, 2025
1713bf4
updated Project.toml
RainerHeintzmann Apr 5, 2025
0ab6450
back to using ShiftedArrays.jl
RainerHeintzmann Apr 6, 2025
c1ca435
separate cuda support extension for ShiftedArray
RainerHeintzmann Apr 6, 2025
60c0cfe
removed both Extension stub
RainerHeintzmann Apr 6, 2025
5116de8
better code coverage in 1d version of getindex.
RainerHeintzmann Apr 7, 2025
5220a1f
get_base_arr to identity check
RainerHeintzmann Apr 8, 2025
bf0c713
bug fixes as indicated by Felix
RainerHeintzmann Jun 30, 2025
32acf41
added comment [skip ci]
RainerHeintzmann Jun 30, 2025
08efce8
minor modifications
RainerHeintzmann Aug 4, 2025
80b6f4b
switched entirely to MutableShiftedArrays.jl, removing ShiftedArrays.…
RainerHeintzmann Aug 10, 2025
49b8201
Merge branch 'main' into cuda_via_mutable_shifted
RainerHeintzmann Aug 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,34 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MutableShiftedArrays = "d3d30c82-a38e-471c-a45a-3d24d2f4d22d"
NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d"
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ShiftedArrays = "1277b4bf-5013-50f5-be3d-901d8477a67a"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
CUDASupportExt_FT = ["CUDA", "Adapt"]

[compat]
ChainRulesCore = "1, 1.0, 1.1"
FFTW = "1.5"
ImageTransformations = "0.9"
Adapt = "3.7, 4.0, 4.1"
CUDA = "5.2, 5.3, 5.4, 5.5, 5.6, 5.7"
ChainRulesCore = "1"
FFTW = "1.5, 1.6, 1.7, 1.8, 1.9"
ImageTransformations = "0.9, 0.10"
IndexFunArrays = "0.2"
NDTools = "0.5.1, 0.6, 0.7, 0.8"
NDTools = "0.8"
NFFT = "0.11, 0.12, 0.13"
PaddedViews = "0.5"
Reexport = "1"
ShiftedArrays = "2"
Zygote = "0.6"
julia = "1, 1.6, 1.7, 1.8, 1.9, 1.10"
Zygote = "0.6, 0.7"
julia = "1, 1.6, 1.7, 1.8, 1.9, 1.10, 1.11"
MutableShiftedArrays = "0.3"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FractionalTransforms = "e50ca838-b4f0-4a10-ad18-4b920bf1ae5c"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -36,4 +44,4 @@ TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "TestImages", "FractionalTransforms", "Random", "ImageTransformations", "Zygote"]
test = ["Test", "TestImages", "FractionalTransforms", "Random", "ImageTransformations", "Zygote", "CUDA"]
1 change: 1 addition & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ FourierTools.get_indices_around_center
FourierTools.center_extract
FourierTools.odd_view
FourierTools.fourier_reverse!
FourierTools.get_indexrange_around_center
```
102 changes: 102 additions & 0 deletions ext/CUDASupportExt_FT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
module CUDASupportExt_FT
using CUDA
using Adapt
using FourierTools
using Base
# using NFFT
# using CuNFFT

# define a number of Union types to not repeat all definitions for each type
const AllShiftedType = Union{FourierTools.FourierSplit{<:Any,<:Any,<:Any},
FourierTools.FourierJoin{<:Any,<:Any,<:Any}}

# these are special only if a CuArray is wrapped

const AllSubArrayType = Union{SubArray{<:Any, <:Any, <:AllShiftedType, <:Any, <:Any},
Base.ReshapedArray{<:Any, <:Any, <:AllShiftedType, <:Any},
SubArray{<:Any, <:Any, <:Base.ReshapedArray{<:Any, <:Any, <:AllShiftedType, <:Any}, <:Any, <:Any}}
const AllShiftedAndViews = Union{AllShiftedType, AllSubArrayType}

const AllShiftedTypeCu{N, CD} = Union{FourierTools.FourierSplit{<:Any,<:Any,<:CuArray{<:Any,N,CD}},
FourierTools.FourierJoin{<:Any,<:Any,<:CuArray{<:Any,N,CD}}}
const AllSubArrayTypeCu{N, CD} = Union{SubArray{<:Any, <:Any, <:AllShiftedTypeCu{N,CD}, <:Any, <:Any},
Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCu{N,CD}, <:Any},
SubArray{<:Any, <:Any, <:Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCu{N,CD}, <:Any}, <:Any, <:Any}}
const AllShiftedAndViewsCu{N, CD} = Union{AllShiftedTypeCu{N, CD}, AllSubArrayTypeCu{N, CD}}

Adapt.adapt_structure(to, x::FourierTools.FourierSplit{T, M, AA, D}) where {T, M, AA, D} = FourierTools.FourierSplit(adapt(to, parent(x)), Val(D), x.L1, x.L2, x.do_split);
Adapt.adapt_structure(to, x::FourierTools.FourierJoin{T, M, AA, D}) where {T, M, AA, D} = FourierTools.FourierJoin(adapt(to, parent(x)), Val(D), x.L1, x.L2, x.do_join);

function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllShiftedTypeCu{N, CD}}
CUDA.CuArrayStyle{N,CD}()
end

# Define the BroadcastStyle for SubArray of MutableShiftedArray with CuArray

function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllSubArrayTypeCu{N, CD}}
CUDA.CuArrayStyle{N,CD}()
end

function Base.copy(s::AllShiftedAndViews)
res = similar(get_base_arr(s), eltype(s), size(s));
res .= s
return res
end

function Base.collect(x::AllShiftedAndViews)
return copy(x) # stay on the GPU
end

function Base.Array(x::AllShiftedAndViews)
return Array(copy(x)) # remove from GPU
end

function Base.:(==)(x::AllShiftedAndViewsCu, y::AbstractArray)
return all(x .== y)
end

function Base.:(==)(y::AbstractArray, x::AllShiftedAndViewsCu)
return all(x .== y)
end

function Base.:(==)(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu)
return all(x .== y)
end

function Base.isapprox(x::AllShiftedAndViewsCu, y::AbstractArray; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)
end

function Base.isapprox(y::AbstractArray, x::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)
end

function Base.isapprox(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)
end

function Base.show(io::IO, mm::MIME"text/plain", cs::AllShiftedAndViews)
CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)
end

### addition functions specific to CUDA

function FourierTools.optional_collect(a::CuArray)
a
end

get_base_arr(arr::CuArray) = arr
get_base_arr(arr::Array) = arr
function get_base_arr(arr::AbstractArray)
p = parent(arr)
return (p === arr) ? arr : get_base_arr(parent(arr))
end

function similar_zeros(arr::CuArray, sz::NTuple=size(arr))
CUDA.zeros(sz)
end

end
7 changes: 4 additions & 3 deletions src/FourierTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ module FourierTools


using Reexport
using PaddedViews, ShiftedArrays
# using PaddedViews
using MutableShiftedArrays # for circshift
@reexport using FFTW
using LinearAlgebra
using IndexFunArrays
using ChainRulesCore
using NDTools
import Base: checkbounds, getindex, setindex!, parent, size, axes, copy, collect

@reexport using NFFT
FFTW.set_num_threads(4)



include("utils.jl")
include("nfft_nd.jl")
include("resampling.jl")
Expand Down
5 changes: 4 additions & 1 deletion src/convolutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ end
plan_conv_buffer(u, v [, dims]; kwargs...)

Similar to [`plan_conv`](@ref) but instead uses buffers to prevent memory allocations.
The three buffers are internal to the function and are not exposed to the user.
Not AD friendly!

"""
Expand All @@ -161,7 +162,9 @@ function plan_conv_buffer(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims

u_buff = P_u * u
v_ft = P_v * v
uv_buff = u_buff .* v_ft
uv_sz = bc_size(u_buff, v_ft)
# this saves memory allocations:
uv_buff = (uv_sz == size(u_buff)) ? u_buff : u_buff .* v_ft;

# for fourier space we need a new plan
P = plan(u .* v, dims; kwargs...)
Expand Down
95 changes: 73 additions & 22 deletions src/custom_fourier_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ and then replaces the value by half of the parent at L1
`do_split` is a Bool that indicates whether this mechanism is active.
It is needed for type stability reasons of functions returnting this type.
"""
struct FourierSplit{T,N, AA<:AbstractArray{T, N}} <: AbstractArray{T,N}
struct FourierSplit{T, N, AA<:AbstractArray{T, N}, D} <: AbstractArray{T,N}
parent::AA # holds the data (or is another view)
D::Int # dimension along which to apply to copy
# D::Int # dimension along which to apply to copy
L1::Int # low index position to copy from (and half)
L2::Int # high index positon to copy to (and half)
do_split::Bool
Expand All @@ -19,32 +19,59 @@ struct FourierSplit{T,N, AA<:AbstractArray{T, N}} <: AbstractArray{T,N}
This version below is needed to avoid a split for the first rft dimension,
but still return half the value FFTs and other RFT dimension should use the version without L2.
"""
function FourierSplit(parent::AA, D::Int,L1::Int,L2::Int, do_split::Bool) where {T,N, AA<:AbstractArray{T, N}}
return new{T,N, AA}(parent, D, L1, L2, do_split)
function FourierSplit(parent::AA, ::Val{D}, L1::Int,L2::Int, do_split::Bool) where {T,N, D, AA<:AbstractArray{T, N}}
return new{T,N, AA, D}(parent, L1, L2, do_split)
end
function FourierSplit(parent::AA, D::Int, L1::Int, do_split::Bool) where {T,N, AA<:AbstractArray{T, N}}
function FourierSplit(parent::AA, ::Val{D}, L1::Int, do_split::Bool) where {T,N, D, AA<:AbstractArray{T, N}}
mid = fft_center(size(parent)[D])
L2 = mid + (mid-L1)
return FourierSplit(parent, D,L1,L2, do_split)
return FourierSplit(parent, Val(D), L1, L2, do_split)
end
# function FourierSplit(parent::AA, D::Int, L1::Int, do_split::Bool) where {T,N, AA<:AbstractArray{T, N}}
# FourierSplit(parent, Val(D), L1, do_split)
# end
end

# get_D(A::FourierSplit{D}) where {D} = D

Base.IndexStyle(::Type{FD}) where {FD<:FourierSplit} = IndexStyle(parenttype(FD))
parenttype(::Type{FourierSplit{T,N,AA}}) where {T,N,AA} = AA
parenttype(::Type{FourierSplit{T,N,AA,D}}) where {T,N,AA,D} = AA
parenttype(A::FourierSplit) = parenttype(typeof(A))

Base.similar(s::FourierSplit, el::Type, v::NTuple{N, Int64}) where {N} = similar(s.parent, el, v)
Base.parent(A::FourierSplit) = A.parent
Base.size(A::FourierSplit) = size(parent(A))

@inline function Base.getindex(A::FourierSplit{T,N, <:AbstractArray{T, N}}, i::Vararg{Int,N}) where {T,N}
if (i[A.D]==A.L2 || i[A.D]==A.L1) && A.do_split # index along this dimension A.D corrsponds to slice L2
# not that "setindex" in the line below modifies only the index, not the array
@inbounds return parent(A)[Base.setindex(i,A.L1, A.D)...] / 2
else i[A.D]==A.L2
@inline function Base.getindex(A::FourierSplit{T,N, <:AbstractArray{T, N}, D}, i::Vararg{Int,N}) where {T,N, D}
# D = get_D(A) # causes huge troubles in CUDA!
# return eltype(A)(D) #

if (i[D]==A.L2 || i[D]==A.L1) && A.do_split # index along this dimension A.D corrsponds to slice L2
# note that "setindex" in the line below modifies only the index, not the array
@inbounds return parent(A)[Base.setindex(i, A.L1, D)...] / 2
else i[D]==A.L2
@inbounds return parent(A)[i...]
# @inbounds return parent(A)[i...]
end
end

# One-D version
@inline function Base.getindex(A::FourierSplit{T,N, <:AbstractArray{T, N}, D}, i::Int) where {T,N,D}
if A.do_split
# compute the ND index from the one-D index i
ind = Tuple(CartesianIndices(parent(A))[i])
# D = get_D(A) # causes huge troubles in CUDA!
# return eltype(A)(D) #
if (ind[D]==A.L2 || ind[D]==A.L1)
return parent(A)[Base.setindex(ind, A.L1, D)...] / 2
else
return parent(A)[i]
end
else
return parent(A)[i]
end
end

"""
FourierJoin{T,N, AA<:AbstractArray{T, N}} <: AbstractArray{T, N}

Expand All @@ -53,9 +80,9 @@ and then replaces the value by add the value at the mirrored position L2
`do_join` is a Bool that indicates whether this mechanism is active.
It is needed for type stability reasons of functions returnting this type
"""
struct FourierJoin{T,N, AA<:AbstractArray{T, N}} <: AbstractArray{T, N}
struct FourierJoin{T,N, AA<:AbstractArray{T, N}, D} <: AbstractArray{T, N}
parent::AA
D::Int # dimension along which to apply to copy
# D::Int # dimension along which to apply to copy
L1::Int # low index position to copy from (and half)
L2::Int # high index positon to copy to (and half)
do_join::Bool
Expand All @@ -65,28 +92,52 @@ This version below is needed to avoid a split for the
first rft dimension but still return half the value
FFTs and other RFT dimension should use the version without L2
"""
function FourierJoin(parent::AA, D::Int, L1::Int, L2::Int, do_join::Bool) where {T, N, AA<:AbstractArray{T, N}}
return new{T, N, AA}(parent, D, L1, L2, do_join)
function FourierJoin(parent::AA, ::Val{D}, L1::Int, L2::Int, do_join::Bool) where {T, N, AA<:AbstractArray{T, N}, D}
return new{T, N, AA, D}(parent, L1, L2, do_join)
end

function FourierJoin(parent::AA, D::Int,L1::Int, do_join::Bool) where {T, N, AA<:AbstractArray{T, N}}
function FourierJoin(parent::AA, ::Val{D}, L1::Int, do_join::Bool) where {T, N, AA<:AbstractArray{T, N}, D}
mid = fft_center(size(parent)[D])
L2 = mid + (mid-L1)
return FourierJoin(parent, D, L1, L2, do_join)
return FourierJoin(parent, Val(D), L1, L2, do_join)
end

# function FourierJoin(parent::AA, D::Int, L1::Int, do_split::Bool) where {T,N, AA<:AbstractArray{T, N}}
# FourierJoin(parent, Val(D), L1, do_split)
# end
end

# get_D(A::FourierJoin) = A.D

Base.IndexStyle(::Type{FS}) where {FS<:FourierJoin} = IndexStyle(parenttype(FS))
parenttype(::Type{FourierJoin{T,N,AA}}) where {T,N,AA} = AA
parenttype(::Type{FourierJoin{T,N,AA,D}}) where {T,N,AA,D} = AA
parenttype(A::FourierJoin) = parenttype(typeof(A))

Base.similar(s::FourierJoin, el::Type, v::NTuple{N, Int64}) where {N} = similar(s.parent, el, v)

Base.parent(A::FourierJoin) = A.parent
Base.size(A::FourierJoin) = size(parent(A))

@inline function Base.getindex(A::FourierJoin{T,N, <:AbstractArray{T, N}}, i::Vararg{Int,N}) where {T,N}
if i[A.D]==A.L1 && A.do_join
@inbounds return (parent(A)[i...] + parent(A)[Base.setindex(i, A.L2, A.D)...])
@inline function Base.getindex(A::FourierJoin{T,N, <:AbstractArray{T, N}, D}, i::Vararg{Int,N}) where {T,N,D}
if i[D]==A.L1 && A.do_join
@inbounds return (parent(A)[i...] + parent(A)[Base.setindex(i, A.L2, D)...])
else
@inbounds return (parent(A)[i...])
end
end

# One-D version
@inline function Base.getindex(A::FourierJoin{T,N, <:AbstractArray{T, N},D}, i::Int) where {T,N,D}
if A.do_join
# compute the ND index from the one-D index i
ind = Tuple(CartesianIndices(parent(A))[i])
if (ind[D]==A.L1)
return parent(A)[i] + parent(A)[Base.setindex(ind, A.L2, D)...]
else
return parent(A)[i]
end
else
return parent(A)[i]
end
end

2 changes: 1 addition & 1 deletion src/czt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function get_kernel_1d(arr::AT, N::Integer, M::Integer; a= 1.0, w = cispi(-2/N),
CT = (RT <: Real) ? Complex{RT} : RT
RT = real(CT)

# converts ShiftedArrays.CircShiftedArray into a plain array type:
# converts MutableShiftedArrays.CircShiftedArray into a plain array type:
tmp = similar(arr, RT, (1,))
RAT = real_arr_type(typeof(tmp), Val(1))

Expand Down
Loading
Loading