Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit ea2fe76

Browse files
committed
bug fix
1 parent 06ae37e commit ea2fe76

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
3131
solve(prob_no_brack, alg(), tol = T(1e-2))
3232
end
3333

34-
for alg in (TrustRegion(1.0),)
34+
for alg in (TrustRegion(10.0),)
3535
solve(prob_no_brack, alg, tol = T(1e-2))
3636
end
3737

src/trustRegion.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ solver
3030
"""
3131
struct TrustRegion{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
3232
max_trust_radius::Number
33-
function TrustRegion(max_turst_radius::Number; chunk_size = Val{0}(),
34-
autodiff = Val{true}(),
35-
diff_type = Val{:forward})
33+
function TrustRegion(max_trust_radius::Number; chunk_size = Val{0}(),
34+
autodiff = Val{true}(),
35+
diff_type = Val{:forward})
3636
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
3737
SciMLBase._unwrap_val(diff_type)}(max_trust_radius)
3838
end
@@ -46,8 +46,8 @@ function SciMLBase.solve(prob::NonlinearProblem,
4646
x = float(prob.u0)
4747
T = typeof(x)
4848
Δₘₐₓ = float(alg.max_trust_radius) # The maximum trust region radius.
49-
Δ = Δₘₐₓ / 5 # Initial trust region radius.
50-
η₁ = 0.1 # Threshold for taking a step.
49+
Δ = Δₘₐₓ / 11 # Initial trust region radius.
50+
η₁ = 0.0 # Threshold for taking a step.
5151
η₂ = 0.25 # Threshold for shrinking the trust region.
5252
η₃ = 0.75 # Threshold for expanding the trust region.
5353
t₁ = 0.25 # Factor to shrink the trust region with.
@@ -88,38 +88,44 @@ function SciMLBase.solve(prob::NonlinearProblem,
8888

8989
# Update the trust region radius.
9090
if r < η₂
91-
Δ *= t₁
92-
if r > η₁
93-
if isapprox(x̂, x, atol = atol, rtol = rtol)
91+
Δ = t₁ * Δ
92+
93+
if Δ < 1e-10
9494
return SciMLBase.build_solution(prob, alg, x, F;
9595
retcode = ReturnCode.Success)
9696
end
97-
97+
end
98+
if r > η₁
99+
if isapprox(xₖ₊₁, x, atol = atol, rtol = rtol)
100+
return SciMLBase.build_solution(prob, alg, xₖ₊₁, Fₖ₊₁;
101+
retcode = ReturnCode.Success)
102+
end
103+
# Take the step.
98104
x = xₖ₊₁
99105
F = Fₖ₊₁
100106
if alg_autodiff(alg)
101107
F, ∇f = value_derivative(f, x)
102108
elseif x isa AbstractArray
103109
∇f = FiniteDiff.finite_difference_jacobian(f, x, diff_type(alg), eltype(x),
104-
F)
110+
F)
105111
else
106112
∇f = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg),
107-
eltype(x),
108-
F)
113+
eltype(x),
114+
F)
109115
end
110116

111117
iszero(F) &&
112118
return SciMLBase.build_solution(prob, alg, x, F;
113119
retcode = ReturnCode.Success)
120+
114121
# Update the trust region radius.
115122
if r > η₃ && norm(δ) Δ
116123
Δ = min(t₂ * Δ, Δₘₐₓ)
117124
end
118-
fₖ =
125+
fₖ = fₖ₊₁
119126
H = ∇f * ∇f
120127
g = ∇f * F
121128
end
122129
end
123-
124130
return SciMLBase.build_solution(prob, alg, x, F; retcode = ReturnCode.MaxIters)
125131
end

test/basictests.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ sol = benchmark_scalar(sf, csu0)
4646
@test sol.u * sol.u - 2 < 1e-9
4747
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
4848

49-
# SimpleNewtonRaphsonTrustRegion
49+
# TrustRegion
5050
function benchmark_scalar(f, u0)
5151
probN = NonlinearProblem{false}(f, u0)
52-
sol = (solve(probN, SimpleNewtonRaphsonTrustRegion(1.0)))
52+
sol = (solve(probN, TrustRegion(10.0)))
5353
end
5454

5555
sol = benchmark_scalar(sf, csu0)
@@ -63,7 +63,7 @@ using ForwardDiff
6363
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
6464

6565
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
66-
SimpleNewtonRaphsonTrustRegion(10.0)]
66+
TrustRegion(10.0)]
6767
g = function (p)
6868
probN = NonlinearProblem{false}(f, csu0, p)
6969
sol = solve(probN, alg, tol = 1e-9)
@@ -79,7 +79,7 @@ end
7979
# Scalar
8080
f, u0 = (u, p) -> u * u - p, 1.0
8181
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
82-
SimpleNewtonRaphsonTrustRegion(10.0)]
82+
TrustRegion(10.0)]
8383
g = function (p)
8484
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
8585
sol = solve(probN, alg)
@@ -121,7 +121,7 @@ for alg in [Bisection(), Falsi()]
121121
end
122122

123123
for alg in [SimpleNewtonRaphson(), Broyden(), Klement(),
124-
SimpleNewtonRaphsonTrustRegion(1.0)]
124+
TrustRegion(10.0)]
125125
global g, p
126126
g = function (p)
127127
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -141,10 +141,10 @@ probN = NonlinearProblem(f, u0)
141141
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
142142
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
143143

144-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0)).u[end] sqrt(2.0)
145-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0); immutable = false).u[end] sqrt(2.0)
146-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0; autodiff = false)).u[end] sqrt(2.0)
147-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0; autodiff = false)).u[end] sqrt(2.0)
144+
@test solve(probN, TrustRegion(10.0)).u[end] sqrt(2.0)
145+
@test solve(probN, TrustRegion(10.0); immutable = false).u[end] sqrt(2.0)
146+
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
147+
@test solve(probN, TrustRegion(10.0; autodiff = false)).u[end] sqrt(2.0)
148148

149149
@test solve(probN, Broyden()).u[end] sqrt(2.0)
150150
@test solve(probN, Broyden(); immutable = false).u[end] sqrt(2.0)
@@ -162,9 +162,9 @@ for u0 in [1.0, [1, 1.0]]
162162
@test solve(probN, SimpleNewtonRaphson()).u sol
163163
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol
164164

165-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0)).u sol
166-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0)).u sol
167-
@test solve(probN, SimpleNewtonRaphsonTrustRegion(1.0; autodiff = false)).u sol
165+
@test solve(probN, TrustRegion(10.0)).u sol
166+
@test solve(probN, TrustRegion(10.0)).u sol
167+
@test solve(probN, TrustRegion(10.0; autodiff = false)).u sol
168168

169169
@test solve(probN, Broyden()).u sol
170170

0 commit comments

Comments
 (0)