Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 71132ce

Browse files
Merge pull request #22 from CCsimon123/main
Trying to improve the Klement-solver
2 parents e681591 + 8915554 commit 71132ce

File tree

3 files changed

+77
-23
lines changed

3 files changed

+77
-23
lines changed

src/broyden.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function SciMLBase.solve(prob::NonlinearProblem,
1616
x = float(prob.u0)
1717
fₙ = f(x)
1818
T = eltype(x)
19-
J⁻¹ = ArrayInterfaceCore.zeromatrix(x) + I
19+
J⁻¹ = init_J(x)
2020

2121
if SciMLBase.isinplace(prob)
2222
error("Broyden currently only supports out-of-place nonlinear problems")

src/klement.jl

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Klement()
44
```
55
66
A low-overhead implementation of [Klement](https://jatm.com.br/jatm/article/view/373).
7-
This method is non-allocating on scalar and static array problems.
7+
This method is non-allocating on scalar problems.
88
"""
99
struct Klement <: AbstractSimpleNonlinearSolveAlgorithm end
1010

@@ -16,7 +16,7 @@ function SciMLBase.solve(prob::NonlinearProblem,
1616
x = float(prob.u0)
1717
fₙ = f(x)
1818
T = eltype(x)
19-
J = ArrayInterfaceCore.zeromatrix(x) + I
19+
singular_tol = 1e-9
2020

2121
if SciMLBase.isinplace(prob)
2222
error("Klement currently only supports out-of-place nonlinear problems")
@@ -29,33 +29,77 @@ function SciMLBase.solve(prob::NonlinearProblem,
2929
xₙ = x
3030
xₙ₋₁ = x
3131
fₙ₋₁ = fₙ
32-
for _ in 1:maxiters
33-
xₙ = xₙ₋₁ - inv(J) * fₙ₋₁
34-
fₙ = f(xₙ)
35-
Δxₙ = xₙ - xₙ₋₁
36-
Δfₙ = fₙ - fₙ₋₁
3732

38-
# Prevent division by 0
39-
denominator = max.(J' .^ 2 * Δxₙ .^ 2, 1e-9)
33+
# x is scalar
34+
if x isa Number
35+
J = 1.0
36+
for _ in 1:maxiters
37+
xₙ = xₙ₋₁ - fₙ₋₁ / J
38+
fₙ = f(xₙ)
4039

41-
k = (Δfₙ - J * Δxₙ) ./ denominator
42-
J += (k * Δxₙ' .* J) * J
40+
iszero(fₙ) &&
41+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
42+
retcode = ReturnCode.Success)
4343

44-
# Prevent inverting singular matrix
45-
if det(J) 0
46-
J = ArrayInterfaceCore.zeromatrix(x) + I
44+
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
45+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
46+
retcode = ReturnCode.Success)
47+
end
48+
49+
Δxₙ = xₙ - xₙ₋₁
50+
Δfₙ = fₙ - fₙ₋₁
51+
52+
# Prevent division by 0
53+
denominator = max(J^2 * Δxₙ^2, 1e-9)
54+
55+
k = (Δfₙ - J * Δxₙ) / denominator
56+
J += (k * Δxₙ * J) * J
57+
58+
# Singularity test
59+
if J < singular_tol
60+
J = 1.0
61+
end
62+
63+
xₙ₋₁ = xₙ
64+
fₙ₋₁ = fₙ
4765
end
66+
# x is a vector
67+
else
68+
J = init_J(x)
69+
for _ in 1:maxiters
70+
F = lu(J, check = false)
71+
72+
# Singularity test
73+
if any(abs.(F.U[diagind(F.U)]) .< singular_tol)
74+
J = init_J(xₙ)
75+
F = lu(J, check = false)
76+
end
77+
78+
tmp = F \ fₙ₋₁
79+
xₙ = xₙ₋₁ - tmp
80+
fₙ = f(xₙ)
81+
82+
iszero(fₙ) &&
83+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
84+
retcode = ReturnCode.Success)
85+
86+
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
87+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
88+
retcode = ReturnCode.Success)
89+
end
90+
91+
Δxₙ = xₙ - xₙ₋₁
92+
Δfₙ = fₙ - fₙ₋₁
93+
94+
# Prevent division by 0
95+
denominator = max.(J' .^ 2 * Δxₙ .^ 2, 1e-9)
4896

49-
iszero(fₙ) &&
50-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
51-
retcode = ReturnCode.Success)
97+
k = (Δfₙ - J * Δxₙ) ./ denominator
98+
J += (k * Δxₙ' .* J) * J
5299

53-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
54-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
55-
retcode = ReturnCode.Success)
100+
xₙ₋₁ = xₙ
101+
fₙ₋₁ = fₙ
56102
end
57-
xₙ₋₁ = xₙ
58-
fₙ₋₁ = fₙ
59103
end
60104

61105
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)

src/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ value_derivative(f::F, x::AbstractArray) where {F} = f(x), ForwardDiff.jacobian(
3333
value(x) = x
3434
value(x::Dual) = ForwardDiff.value(x)
3535
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
36+
37+
function init_J(x)
38+
J = ArrayInterfaceCore.zeromatrix(x)
39+
if ismutable(x)
40+
J[diagind(J)] .= one(eltype(x))
41+
else
42+
J += I
43+
end
44+
return J
45+
end

0 commit comments

Comments
 (0)