Skip to content

Commit 2ab1c47

Browse files
gdalleamontoison
andauthored
Propagate arbitrary integer types as indices (#222)
* Propagate arbitrary integer types as indices * Fixes * Single type * Coverage * Fix * Add tests * Update src/coloring.jl Co-authored-by: Alexis Montoison <35051714+amontoison@users.noreply.github.com> * Fix type * Update src/graph.jl Co-authored-by: Alexis Montoison <35051714+amontoison@users.noreply.github.com> * Fix * Coverage --------- Co-authored-by: Alexis Montoison <35051714+amontoison@users.noreply.github.com>
1 parent 87fd1b3 commit 2ab1c47

File tree

11 files changed

+250
-168
lines changed

11 files changed

+250
-168
lines changed

src/adtypes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function coloring(
55
kwargs...,
66
)
77
bg = BipartiteGraph(A)
8-
color = convert(Vector{Int}, ADTypes.column_coloring(A, algo))
8+
color = convert(Vector{eltype(bg)}, ADTypes.column_coloring(A, algo))
99
return ColumnColoringResult(A, bg, color)
1010
end
1111

@@ -16,6 +16,6 @@ function coloring(
1616
kwargs...,
1717
)
1818
bg = BipartiteGraph(A)
19-
color = convert(Vector{Int}, ADTypes.row_coloring(A, algo))
19+
color = convert(Vector{eltype(bg)}, ADTypes.row_coloring(A, algo))
2020
return RowColoringResult(A, bg, color)
2121
end

src/check.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ function directly_recoverable_columns(
262262
end
263263

264264
"""
265-
valid_dynamic_order(g::AdjacencyGraph, π::AbstractVector{Int}, order::DynamicDegreeBasedOrder)
266-
valid_dynamic_order(bg::AdjacencyGraph, ::Val{side}, π::AbstractVector{Int}, order::DynamicDegreeBasedOrder)
265+
valid_dynamic_order(g::AdjacencyGraph, π::AbstractVector{<:Integer}, order::DynamicDegreeBasedOrder)
266+
valid_dynamic_order(bg::AdjacencyGraph, ::Val{side}, π::AbstractVector{<:Integer}, order::DynamicDegreeBasedOrder)
267267
268268
Check that a permutation `π` corresponds to a valid application of a [`DynamicDegreeBasedOrder`](@ref).
269269
@@ -273,7 +273,9 @@ This is done by checking, for each ordered vertex, that its back- or forward-deg
273273
This function is not coded with efficiency in mind, it is designed for small-scale tests.
274274
"""
275275
function valid_dynamic_order(
276-
g::AdjacencyGraph, π::AbstractVector{Int}, ::DynamicDegreeBasedOrder{degtype,direction}
276+
g::AdjacencyGraph,
277+
π::AbstractVector{<:Integer},
278+
::DynamicDegreeBasedOrder{degtype,direction},
277279
) where {degtype,direction}
278280
length(π) != nb_vertices(g) && return false
279281
length(unique(π)) != nb_vertices(g) && return false
@@ -300,7 +302,7 @@ end
300302
function valid_dynamic_order(
301303
g::BipartiteGraph,
302304
::Val{side},
303-
π::AbstractVector{Int},
305+
π::AbstractVector{<:Integer},
304306
::DynamicDegreeBasedOrder{degtype,direction},
305307
) where {side,degtype,direction}
306308
length(π) != nb_vertices(g, Val(side)) && return false

src/coloring.jl

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ The vertices are colored in a greedy fashion, following the `order` supplied.
1717
> [_What Color Is Your Jacobian? Graph Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005), Algorithm 3.2
1818
"""
1919
function partial_distance2_coloring(
20-
bg::BipartiteGraph, ::Val{side}, order::AbstractOrder
21-
) where {side}
22-
color = Vector{Int}(undef, nb_vertices(bg, Val(side)))
23-
forbidden_colors = Vector{Int}(undef, nb_vertices(bg, Val(side)))
20+
bg::BipartiteGraph{T}, ::Val{side}, order::AbstractOrder
21+
) where {T,side}
22+
color = Vector{T}(undef, nb_vertices(bg, Val(side)))
23+
forbidden_colors = Vector{T}(undef, nb_vertices(bg, Val(side)))
2424
vertices_in_order = vertices(bg, Val(side), order)
2525
partial_distance2_coloring!(color, forbidden_colors, bg, Val(side), vertices_in_order)
2626
return color
2727
end
2828

2929
function partial_distance2_coloring!(
30-
color::Vector{Int},
31-
forbidden_colors::Vector{Int},
30+
color::AbstractVector{<:Integer},
31+
forbidden_colors::AbstractVector{<:Integer},
3232
bg::BipartiteGraph,
3333
::Val{side},
3434
vertices_in_order::AbstractVector{<:Integer},
@@ -76,16 +76,18 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
7676
7777
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 4.1
7878
"""
79-
function star_coloring(g::AdjacencyGraph, order::AbstractOrder, postprocessing::Bool)
79+
function star_coloring(
80+
g::AdjacencyGraph{T}, order::AbstractOrder, postprocessing::Bool
81+
) where {T<:Integer}
8082
# Initialize data structures
8183
nv = nb_vertices(g)
8284
ne = nb_edges(g)
83-
color = zeros(Int, nv)
84-
forbidden_colors = zeros(Int, nv)
85-
first_neighbor = fill((0, 0, 0), nv) # at first no neighbors have been encountered
86-
treated = zeros(Int, nv)
87-
star = Vector{Int}(undef, ne)
88-
hub = Int[] # one hub for each star, including the trivial ones
85+
color = zeros(T, nv)
86+
forbidden_colors = zeros(T, nv)
87+
first_neighbor = fill((zero(T), zero(T), zero(T)), nv) # at first no neighbors have been encountered
88+
treated = zeros(T, nv)
89+
star = Vector{T}(undef, ne)
90+
hub = T[] # one hub for each star, including the trivial ones
8991
vertices_in_order = vertices(g, order)
9092

9193
for v in vertices_in_order
@@ -196,11 +198,11 @@ Encode a set of 2-colored stars resulting from the [`star_coloring`](@ref) algor
196198
197199
$TYPEDFIELDS
198200
"""
199-
struct StarSet
201+
struct StarSet{T}
200202
"a mapping from edges (pair of vertices) to their star index"
201-
star::Vector{Int}
203+
star::Vector{T}
202204
"a mapping from star indices to their hub (undefined hubs for single-edge stars are the negative value of one of the vertices, picked arbitrarily)"
203-
hub::Vector{Int}
205+
hub::Vector{T}
204206
end
205207

206208
"""
@@ -226,15 +228,17 @@ If `postprocessing=true`, some colors might be replaced with `0` (the "neutral"
226228
227229
> [_New Acyclic and Star Coloring Algorithms with Application to Computing Hessians_](https://epubs.siam.org/doi/abs/10.1137/050639879), Gebremedhin et al. (2007), Algorithm 3.1
228230
"""
229-
function acyclic_coloring(g::AdjacencyGraph, order::AbstractOrder, postprocessing::Bool)
231+
function acyclic_coloring(
232+
g::AdjacencyGraph{T}, order::AbstractOrder, postprocessing::Bool
233+
) where {T<:Integer}
230234
# Initialize data structures
231235
nv = nb_vertices(g)
232236
ne = nb_edges(g)
233-
color = zeros(Int, nv)
234-
forbidden_colors = zeros(Int, nv)
235-
first_neighbor = fill((0, 0, 0), nv) # at first no neighbors have been encountered
236-
first_visit_to_tree = fill((0, 0), ne)
237-
forest = Forest{Int}(ne)
237+
color = zeros(T, nv)
238+
forbidden_colors = zeros(T, nv)
239+
first_neighbor = fill((zero(T), zero(T), zero(T)), nv) # at first no neighbors have been encountered
240+
first_visit_to_tree = fill((zero(T), zero(T)), ne)
241+
forest = Forest{T}(ne)
238242
vertices_in_order = vertices(g, order)
239243

240244
for v in vertices_in_order
@@ -367,23 +371,23 @@ Encode a set of 2-colored trees resulting from the [`acyclic_coloring`](@ref) al
367371
368372
$TYPEDFIELDS
369373
"""
370-
struct TreeSet
371-
reverse_bfs_orders::Vector{Vector{Tuple{Int,Int}}}
374+
struct TreeSet{T}
375+
reverse_bfs_orders::Vector{Vector{Tuple{T,T}}}
372376
is_star::Vector{Bool}
373377
end
374378

375-
function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
379+
function TreeSet(g::AdjacencyGraph{T}, forest::Forest{T}) where {T}
376380
S = pattern(g)
377381
edge_to_index = edge_indices(g)
378382
nv = nb_vertices(g)
379383
nt = forest.num_trees
380384

381385
# dictionary that maps a tree's root to the index of the tree
382-
roots = Dict{Int,Int}()
386+
roots = Dict{T,T}()
383387
sizehint!(roots, nt)
384388

385389
# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
386-
trees = [Dict{Int,Vector{Int}}() for i in 1:nt]
390+
trees = [Dict{T,Vector{T}}() for i in 1:nt]
387391

388392
# current number of roots found
389393
nr = 0
@@ -423,10 +427,10 @@ function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
423427
end
424428

425429
# degrees is a vector of integers that stores the degree of each vertex in a tree
426-
degrees = Vector{Int}(undef, nv)
430+
degrees = Vector{T}(undef, nv)
427431

428432
# reverse breadth first (BFS) traversal order for each tree in the forest
429-
reverse_bfs_orders = [Tuple{Int,Int}[] for i in 1:nt]
433+
reverse_bfs_orders = [Tuple{T,T}[] for i in 1:nt]
430434

431435
# nvmax is the number of vertices of the biggest tree in the forest
432436
nvmax = 0
@@ -436,7 +440,7 @@ function TreeSet(g::AdjacencyGraph, forest::Forest{Int})
436440
end
437441

438442
# Create a queue with a fixed size nvmax
439-
queue = Vector{Int}(undef, nvmax)
443+
queue = Vector{T}(undef, nvmax)
440444

441445
# Specify if each tree in the forest is a star,
442446
# meaning that one vertex is directly connected to all other vertices in the tree
@@ -519,7 +523,7 @@ function postprocess!(
519523
color::AbstractVector{<:Integer},
520524
star_or_tree_set::Union{StarSet,TreeSet},
521525
g::AdjacencyGraph,
522-
offsets::Vector{Int},
526+
offsets::AbstractVector{<:Integer},
523527
)
524528
S = pattern(g)
525529
edge_to_index = edge_indices(g)

src/constant.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Indeed, for symmetric coloring problems, we need more than just the vector of co
1414
1515
- `partition::Symbol`: either `:row` or `:column`.
1616
- `matrix_template::AbstractMatrix`: matrix for which the vector of colors was precomputed (the algorithm will only accept matrices of the exact same size).
17-
- `color::Vector{Int}`: vector of integer colors, one for each row or column (depending on `partition`).
17+
- `color::Vector{<:Integer}`: vector of integer colors, one for each row or column (depending on `partition`).
1818
1919
!!! warning
2020
The second constructor (based on keyword arguments) is type-unstable.
@@ -65,33 +65,36 @@ julia> column_colors(result)
6565
- [`ADTypes.row_coloring`](@extref ADTypes.row_coloring)
6666
"""
6767
struct ConstantColoringAlgorithm{
68-
partition,M<:AbstractMatrix,R<:AbstractColoringResult{:nonsymmetric,partition,:direct}
68+
partition,
69+
M<:AbstractMatrix,
70+
T<:Integer,
71+
R<:AbstractColoringResult{:nonsymmetric,partition,:direct},
6972
} <: ADTypes.AbstractColoringAlgorithm
7073
matrix_template::M
71-
color::Vector{Int}
74+
color::Vector{T}
7275
result::R
7376
end
7477

7578
function ConstantColoringAlgorithm{:column}(
76-
matrix_template::AbstractMatrix, color::Vector{Int}
79+
matrix_template::AbstractMatrix, color::Vector{<:Integer}
7780
)
7881
bg = BipartiteGraph(matrix_template)
7982
result = ColumnColoringResult(matrix_template, bg, color)
80-
M, R = typeof(matrix_template), typeof(result)
81-
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
83+
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
84+
return ConstantColoringAlgorithm{:column,M,T,R}(matrix_template, color, result)
8285
end
8386

8487
function ConstantColoringAlgorithm{:row}(
85-
matrix_template::AbstractMatrix, color::Vector{Int}
88+
matrix_template::AbstractMatrix, color::Vector{<:Integer}
8689
)
8790
bg = BipartiteGraph(matrix_template)
8891
result = RowColoringResult(matrix_template, bg, color)
89-
M, R = typeof(matrix_template), typeof(result)
90-
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
92+
T, M, R = eltype(bg), typeof(matrix_template), typeof(result)
93+
return ConstantColoringAlgorithm{:row,M,T,R}(matrix_template, color, result)
9194
end
9295

9396
function ConstantColoringAlgorithm(
94-
matrix_template::AbstractMatrix, color::Vector{Int}; partition=:column
97+
matrix_template::AbstractMatrix, color::Vector{<:Integer}; partition::Symbol=:column
9598
)
9699
return ConstantColoringAlgorithm{partition}(matrix_template, color)
97100
end

src/decompression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,12 +677,12 @@ function decompress!(
677677
result::LinearSystemColoringResult,
678678
uplo::Symbol=:F,
679679
)
680-
(; color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) = result
680+
(; color, strict_upper_nonzero_inds, M_factorization, strict_upper_nonzeros_A) = result
681681
S = result.ag.S
682682
uplo == :F && check_same_pattern(A, S)
683683

684684
# TODO: for some reason I cannot use ldiv! with a sparse QR
685-
strict_upper_nonzeros_A = T_factorization \ vec(B)
685+
strict_upper_nonzeros_A = M_factorization \ vec(B)
686686
fill!(A, zero(eltype(A)))
687687
for i in axes(A, 1)
688688
if !iszero(S[i, i])

src/graph.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ end
2323

2424
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)
2525

26+
Base.eltype(::SparsityPatternCSC{T}) where {T} = T
2627
Base.size(S::SparsityPatternCSC) = (S.m, S.n)
27-
Base.size(S::SparsityPatternCSC, d) = d::Integer <= 2 ? size(S)[d] : 1
28+
Base.size(S::SparsityPatternCSC, d::Integer) = d::Integer <= 2 ? size(S)[d] : 1
2829
Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))
2930

3031
SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval)
@@ -222,11 +223,13 @@ The adjacency graph of a symmetric matrix `A ∈ ℝ^{n × n}` is `G(A) = (V, E)
222223
223224
> [_What Color Is Your Jacobian? SparsityPatternCSC Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005)
224225
"""
225-
struct AdjacencyGraph{T,has_diagonal}
226+
struct AdjacencyGraph{T<:Integer,has_diagonal}
226227
S::SparsityPatternCSC{T}
227228
edge_to_index::Vector{T}
228229
end
229230

231+
Base.eltype(::AdjacencyGraph{T}) where {T} = T
232+
230233
function AdjacencyGraph(
231234
S::SparsityPatternCSC{T},
232235
edge_to_index::Vector{T}=build_edge_to_index(S);
@@ -298,7 +301,7 @@ function has_neighbor(g::AdjacencyGraph, v::Integer, u::Integer)
298301
return false
299302
end
300303

301-
function degree_in_subset(g::AdjacencyGraph, v::Integer, subset::AbstractVector{Int})
304+
function degree_in_subset(g::AdjacencyGraph, v::Integer, subset::AbstractVector{<:Integer})
302305
d = 0
303306
for u in subset
304307
if has_neighbor(g, v, u)
@@ -338,11 +341,13 @@ When `symmetric_pattern` is `true`, this construction is more efficient.
338341
339342
> [_What Color Is Your Jacobian? SparsityPatternCSC Coloring for Computing Derivatives_](https://epubs.siam.org/doi/10.1137/S0036144504444711), Gebremedhin et al. (2005)
340343
"""
341-
struct BipartiteGraph{T}
344+
struct BipartiteGraph{T<:Integer}
342345
S1::SparsityPatternCSC{T}
343346
S2::SparsityPatternCSC{T}
344347
end
345348

349+
Base.eltype(::BipartiteGraph{T}) where {T} = T
350+
346351
function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
347352
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
348353
end
@@ -425,7 +430,7 @@ function has_neighbor_dist2(
425430
end
426431

427432
function degree_dist2_in_subset(
428-
bg::BipartiteGraph, ::Val{side}, v::Integer, subset::AbstractVector{Int}
433+
bg::BipartiteGraph, ::Val{side}, v::Integer, subset::AbstractVector{<:Integer}
429434
) where {side}
430435
d = 0
431436
for u in subset

src/interface.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ function _coloring(
306306
symmetric_result = StarSetColoringResult(A_and_Aᵀ, ag, color, star_set)
307307
return BicoloringResult(A, ag, symmetric_result, R)
308308
else
309-
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
309+
row_color, column_color, _ = remap_colors(
310+
eltype(ag), color, maximum(color), size(A)...
311+
)
310312
return row_color, column_color
311313
end
312314
end
@@ -326,7 +328,9 @@ function _coloring(
326328
symmetric_result = TreeSetColoringResult(A_and_Aᵀ, ag, color, tree_set, R)
327329
return BicoloringResult(A, ag, symmetric_result, R)
328330
else
329-
row_color, column_color, _ = remap_colors(color, maximum(color), size(A)...)
331+
row_color, column_color, _ = remap_colors(
332+
eltype(ag), color, maximum(color), size(A)...
333+
)
330334
return row_color, column_color
331335
end
332336
end

0 commit comments

Comments
 (0)