Skip to content

Commit bd5f46a

Browse files
committed
only build Cholesky and LU cache if possible
1 parent bbfa99d commit bd5f46a

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ function LinearSolve.is_cusparse(A::Union{
1717
CUDA.CUSPARSE.CuSparseMatrixCSR, CUDA.CUSPARSE.CuSparseMatrixCSC})
1818
true
1919
end
20+
LinearSolve.is_cusparse_csr(::CUDA.CUSPARSE.CuSparseMatrixCSR) = true
21+
LinearSolve.is_cusparse_csc(::CUDA.CUSPARSE.CuSparseMatrixCSC) = true
2022

2123
function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
2224
assump::OperatorAssumptions{Bool}) where {Tv, Ti}

ext/LinearSolveSparseArraysExt.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function LinearSolve.init_cacheval(
129129
maxiters::Int, abstol, reltol,
130130
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
131131
if LinearSolve.is_cusparse(A)
132-
ArrayInterface.lu_instance(A)
132+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
133133
else
134134
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int64}(
135135
zero(Int64), zero(Int64), [Int64(1)], Int64[], T[]))
@@ -141,7 +141,7 @@ function LinearSolve.init_cacheval(
141141
maxiters::Int, abstol, reltol,
142142
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
143143
if LinearSolve.is_cusparse(A)
144-
ArrayInterface.lu_instance(A)
144+
LinearSolve.cudss_loaded(A) ? ArrayInterface.lu_instance(A) : nothing
145145
else
146146
SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC{T, Int32}(
147147
zero(Int32), zero(Int32), [Int32(1)], Int32[], T[]))
@@ -344,7 +344,13 @@ function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
344344
Symmetric{T, <:AbstractSparseArray{T}}}, b, u, Pl, Pr,
345345
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
346346
assumptions::OperatorAssumptions) where {T <: BLASELTYPES}
347-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
347+
if LinearSolve.is_cusparse_csc(A)
348+
nothing
349+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
350+
nothing
351+
else
352+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
353+
end
348354
end
349355

350356
# Specialize QR for the non-square case

src/LinearSolve.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ ALREADY_WARNED_CUDSS = Ref{Bool}(false)
478478
error_no_cudss_lu(A) = nothing
479479
cudss_loaded(A) = false
480480
is_cusparse(A) = false
481+
is_cusparse_csr(A) = false
482+
is_cusparse_csc(A) = false
481483

482484
export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
483485
GenericLUFactorization, SimpleLUFactorization, RFLUFactorization, ButterflyFactorization,

src/factorization.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,13 @@ end
395395
function init_cacheval(
396396
alg::CholeskyFactorization, A::AbstractArray{<:BLASELTYPES}, b, u, Pl, Pr,
397397
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
398-
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
398+
if LinearSolve.is_cusparse_csc(A)
399+
nothing
400+
elseif LinearSolve.is_cusparse_csr(A) && !LinearSolve.cudss_loaded(A)
401+
nothing
402+
else
403+
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
404+
end
399405
end
400406

401407
const PREALLOCATED_CHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

0 commit comments

Comments
 (0)