Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 932f4e9

Browse files
committed
Make broyden batched
1 parent e600ee5 commit 932f4e9

File tree

5 files changed

+50
-19
lines changed

5 files changed

+50
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1213
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1314
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1415
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
@@ -19,6 +20,7 @@ ArrayInterfaceCore = "0.1.1"
1920
DiffEqBase = "6.114"
2021
FiniteDiff = "2"
2122
ForwardDiff = "0.10.3"
23+
NNlib = "0.8"
2224
Reexport = "0.2, 1"
2325
SciMLBase = "1.73"
2426
SnoopPrecompile = "1"

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using StaticArraysCore
77
using LinearAlgebra
88
import ArrayInterfaceCore
99
using DiffEqBase
10+
using NNlib # Batched Matrix Multiplication
1011

1112
@reexport using SciMLBase
1213

src/broyden.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@ and static array problems.
88
"""
99
struct Broyden <: AbstractSimpleNonlinearSolveAlgorithm end
1010

11-
function SciMLBase.__solve(prob::NonlinearProblem,
12-
alg::Broyden, args...; abstol = nothing,
13-
reltol = nothing,
14-
maxiters = 1000, kwargs...)
11+
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden, args...; abstol = nothing,
12+
reltol = nothing, maxiters = 1000, batch = false, kwargs...)
1513
f = Base.Fix2(prob.f, prob.p)
1614
x = float(prob.u0)
15+
16+
if batch && ndims(x) != 2
17+
error("`batch` mode works only if `ndims(prob.u0) == 2`")
18+
end
19+
1720
fₙ = f(x)
1821
T = eltype(x)
19-
J⁻¹ = init_J(x)
22+
J⁻¹ = init_J(x; batch)
2023

2124
if SciMLBase.isinplace(prob)
2225
error("Broyden currently only supports out-of-place nonlinear problems")
@@ -30,11 +33,14 @@ function SciMLBase.__solve(prob::NonlinearProblem,
3033
xₙ₋₁ = x
3134
fₙ₋₁ = fₙ
3235
for _ in 1:maxiters
33-
xₙ = xₙ₋₁ - J⁻¹ * fₙ₋₁
36+
xₙ = xₙ₋₁ .- _batched_mul(J⁻¹, fₙ₋₁, batch)
3437
fₙ = f(xₙ)
35-
Δxₙ = xₙ - xₙ₋₁
36-
Δfₙ = fₙ - fₙ₋₁
37-
J⁻¹ += ((Δxₙ - J⁻¹ * Δfₙ) ./ (Δxₙ' * J⁻¹ * Δfₙ)) * (Δxₙ' * J⁻¹)
38+
Δxₙ = xₙ .- xₙ₋₁
39+
Δfₙ = fₙ .- fₙ₋₁
40+
J⁻¹Δfₙ = _batched_mul(J⁻¹, Δfₙ, batch)
41+
J⁻¹ += _batched_mul(((Δxₙ .- J⁻¹Δfₙ, batch) ./
42+
(_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹Δfₙ, batch))),
43+
_batched_mul(_batch_transpose(Δxₙ, batch), J⁻¹, batch), batch)
3844

3945
iszero(fₙ) &&
4046
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
@@ -50,3 +56,8 @@ function SciMLBase.__solve(prob::NonlinearProblem,
5056

5157
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
5258
end
59+
60+
function _batch_transpose(x, batch)
61+
!batch && return x'
62+
return reshape(x, 1, size(x)...)
63+
end

src/lbroyden.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ Base.@kwdef struct LBroyden <: AbstractSimpleNonlinearSolveAlgorithm
88
threshold::Int = 27
99
end
1010

11-
@views function SciMLBase.__solve(prob::NonlinearProblem,
12-
alg::LBroyden, args...; abstol = nothing,
13-
reltol = nothing,
14-
maxiters = 1000, kwargs...)
11+
@views function SciMLBase.__solve(prob::NonlinearProblem, alg::LBroyden, args...;
12+
abstol = nothing, reltol = nothing, maxiters = 1000,
13+
batch = false, kwargs...)
1514
threshold = min(maxiters, alg.threshold)
1615
x = float(prob.u0)
1716

src/utils.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ value(x) = x
3434
value(x::Dual) = ForwardDiff.value(x)
3535
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
3636

37-
function init_J(x)
38-
J = ArrayInterfaceCore.zeromatrix(x)
39-
if ismutable(x)
40-
J[diagind(J)] .= one(eltype(x))
37+
function init_J(x; batch = false)
38+
x_ = batch ? x[:, 1] : x
39+
40+
J = ArrayInterfaceCore.zeromatrix(x_)
41+
if ismutable(x_)
42+
J[diagind(J)] .= one(eltype(x_))
4143
else
4244
J += I
4345
end
44-
return J
46+
47+
return batch ? repeat(J, 1, 1, size(x, 2)) : J
4548
end
4649

4750
function dogleg_method(H, g, Δ)
@@ -68,3 +71,18 @@ function dogleg_method(H, g, Δ)
6871
tau = (-dot_δsd_δN_δsd + sqrt(fact)) / dot_δN_δsd
6972
return δsd + tau * δN_δsd
7073
end
74+
75+
_batched_mul(x, y, batch) = x * y
76+
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix, batch) where {T}
77+
!batch && return x * y
78+
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
79+
end
80+
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}, batch) where {T}
81+
!batch && return x * y
82+
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
83+
end
84+
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3},
85+
batch) where {T1, T2}
86+
!batch && return x * y
87+
return batched_mul(x, y)
88+
end

0 commit comments

Comments
 (0)