Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit e8a3001

Browse files
committed
change singularity handeling, and init of J
1 parent 4bc4ee7 commit e8a3001

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

src/klement.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function SciMLBase.solve(prob::NonlinearProblem,
1616
x = float(prob.u0)
1717
fₙ = f(x)
1818
T = eltype(x)
19-
J = ArrayInterfaceCore.zeromatrix(x) + I
19+
J = init_J(x)
2020

2121
if SciMLBase.isinplace(prob)
2222
error("Klement currently only supports out-of-place nonlinear problems")
@@ -30,8 +30,19 @@ function SciMLBase.solve(prob::NonlinearProblem,
3030
xₙ₋₁ = x
3131
fₙ₋₁ = fₙ
3232
for _ in 1:maxiters
33-
xₙ = xₙ₋₁ - inv(J) * fₙ₋₁
33+
tmp = J \ fₙ₋₁
34+
xₙ = xₙ₋₁ - tmp
3435
fₙ = f(xₙ)
36+
37+
iszero(fₙ) &&
38+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
39+
retcode = ReturnCode.Success)
40+
41+
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
42+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
43+
retcode = ReturnCode.Success)
44+
end
45+
3546
Δxₙ = xₙ - xₙ₋₁
3647
Δfₙ = fₙ - fₙ₋₁
3748

@@ -41,19 +52,11 @@ function SciMLBase.solve(prob::NonlinearProblem,
4152
k = (Δfₙ - J * Δxₙ) ./ denominator
4253
J += (k * Δxₙ' .* J) * J
4354

44-
# Prevent inverting singular matrix
45-
if det(J) 0
46-
J = ArrayInterfaceCore.zeromatrix(x) + I
55+
# Singularity test
56+
if cond(J) > 1e9
57+
J = init_J(xₙ)
4758
end
4859

49-
iszero(fₙ) &&
50-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
51-
retcode = ReturnCode.Success)
52-
53-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
54-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
55-
retcode = ReturnCode.Success)
56-
end
5760
xₙ₋₁ = xₙ
5861
fₙ₋₁ = fₙ
5962
end

src/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(
3333
value(x) = x
3434
value(x::Dual) = ForwardDiff.value(x)
3535
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
36+
37+
function init_J(x)
38+
J = ArrayInterfaceCore.zeromatrix(x)
39+
if ismutable(x)
40+
J[diagind(J)] .= one(eltype(x))
41+
else
42+
J += I
43+
end
44+
return J
45+
end

0 commit comments

Comments
 (0)