Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 93c8279

Browse files
committed
Add requires for backward compat
1 parent 5e468fe commit 93c8279

File tree

7 files changed

+51
-42
lines changed

7 files changed

+51
-42
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1415
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
1516
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
@@ -34,10 +35,11 @@ julia = "1.6"
3435

3536
[extras]
3637
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
38+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3739
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3840
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3941
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4042
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4143

4244
[targets]
43-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays"]
45+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]

ext/SimpleBatchedNonlinearSolveExt/broyden.jl renamed to ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
1+
module SimpleBatchedNonlinearSolveExt
2+
3+
using SimpleNonlinearSolve, SciMLBase
4+
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
5+
6+
_batch_transpose(x) = reshape(x, 1, size(x)...)
7+
8+
_batched_mul(x, y) = x * y
9+
10+
function _batched_mul(x::AbstractArray{T, 3}, y::AbstractMatrix) where {T}
11+
return dropdims(batched_mul(x, reshape(y, size(y, 1), 1, size(y, 2))); dims = 2)
12+
end
13+
14+
function _batched_mul(x::AbstractMatrix, y::AbstractArray{T, 3}) where {T}
15+
return batched_mul(reshape(x, size(x, 1), 1, size(x, 2)), y)
16+
end
17+
18+
function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T1, T2}
19+
return batched_mul(x, y)
20+
end
21+
22+
function _init_J_batched(x::AbstractMatrix{T}) where {T}
23+
J = ArrayInterfaceCore.zeromatrix(x[:, 1])
24+
if ismutable(x)
25+
J[diagind(J)] .= one(eltype(x))
26+
else
27+
J += I
28+
end
29+
return repeat(J, 1, 1, size(x, 2))
30+
end
31+
132
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
233
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
334
f = Base.Fix2(prob.f, prob.p)
@@ -46,3 +77,6 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
4677

4778
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.MaxIters)
4879
end
80+
81+
82+
end

ext/SimpleBatchedNonlinearSolveExt/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

ext/SimpleBatchedNonlinearSolveExt/lbroyden.jl

Whitespace-only changes.

ext/SimpleBatchedNonlinearSolveExt/utils.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/SimpleNonlinearSolve.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ using DiffEqBase
1010

1111
@reexport using SciMLBase
1212

13+
if !isdefined(Base, :get_extension)
14+
using Requires
15+
end
16+
17+
function __init__()
18+
@static if !isdefined(Base, :get_extension)
19+
@require NNlib="872c559c-99b0-510c-b3b7-b6c96a88d5cd" begin
20+
include("../ext/SimpleBatchedNonlinearSolveExt.jl")
21+
end
22+
end
23+
end
24+
1325
abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end
1426
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
1527
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end

src/broyden.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,20 @@
11
"""
2-
Broyden()
2+
Broyden(; batched = false)
33
44
A low-overhead implementation of Broyden. This method is non-allocating on scalar
55
and static array problems.
66
"""
77
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
8-
Broyden(batched = false) = new{batched}()
8+
Broyden(; batched = false) = new{batched}()
99
end
1010

1111
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
1212
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
1313
f = Base.Fix2(prob.f, prob.p)
1414
x = float(prob.u0)
1515

16-
# if batch && ndims(x) != 2
17-
# error("`batch` mode works only if `ndims(prob.u0) == 2`")
18-
# end
19-
2016
fₙ = f(x)
2117
T = eltype(x)
22-
# J⁻¹ = init_J(x; batch)
2318
J⁻¹ = init_J(x)
2419

2520
if SciMLBase.isinplace(prob)

0 commit comments

Comments
 (0)