diff --git a/src/graph.jl b/src/graph.jl index b93a864b..39df611c 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -32,6 +32,11 @@ SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval) SparseArrays.rowvals(S::SparsityPatternCSC) = S.rowval SparseArrays.nzrange(S::SparsityPatternCSC, j::Integer) = S.colptr[j]:(S.colptr[j + 1] - 1) +# Needed if using `coloring(::SparsityPatternCSC, ...)` +function Base.similar(A::SparsityPatternCSC, ::Type{T}) where {T} + return SparseArrays.SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, similar(A.rowval, T)) +end + """ transpose(S::SparsityPatternCSC) diff --git a/test/structured.jl b/test/structured.jl index 23ba967f..5c7b5874 100644 --- a/test/structured.jl +++ b/test/structured.jl @@ -2,6 +2,7 @@ using ArrayInterface: ArrayInterface using BandedMatrices: BandedMatrix, brand using BlockBandedMatrices: BandedBlockBandedMatrix, BlockBandedMatrix using LinearAlgebra +using SparseArrays using SparseMatrixColorings using Test @@ -56,3 +57,19 @@ end; test_structured_coloring_decompression(A) end end; + +# See https://github.com/gdalle/SparseMatrixColorings.jl/pull/299 +@testset "SparsityPatternCSC $T" for T in [Int, Float32] + S = sparse(T[ + 0 0 1 1 0 1 + 1 0 0 0 1 0 + 0 1 0 0 1 0 + 0 1 1 0 0 0 + ]) + P = SparseMatrixColorings.SparsityPatternCSC(S) + problem = ColoringProblem() + algo = GreedyColoringAlgorithm() + result = coloring(P, problem, algo) + B = compress(S, result) + @test decompress(B, result) isa SparseMatrixCSC{T,Int} +end;