Skip to content

Commit d690f38

Browse files
committed
init: add jac_prototype
1 parent 61a64f9 commit d690f38

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
593593
u0 = nothing;
594594
version = nothing, tgrad = false, sparse = false,
595595
jac = false, Wfact = false, eval_expression = false,
596+
sparsity = false, analytic = nothing,
596597
eval_module = @__MODULE__,
597598
checkbounds = false, initialization_data = nothing,
598599
cse = true, kwargs...) where {iip, specialize}
@@ -641,6 +642,17 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
641642
_Wfact, _Wfact_t = nothing, nothing
642643
end
643644

645+
jac_prototype = if sparse
646+
uElType = u0 === nothing ? Float64 : eltype(u0)
647+
if jac
648+
similar(calculate_jacobian(sys, sparse = sparse), uElType)
649+
else
650+
similar(jacobian_sparsity(sys), uElType)
651+
end
652+
else
653+
nothing
654+
end
655+
644656
M = calculate_massmatrix(sys)
645657
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
646658

@@ -651,10 +663,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
651663
sys = sys,
652664
jac = _jac === nothing ? nothing : _jac,
653665
tgrad = _tgrad === nothing ? nothing : _tgrad,
666+
mass_matrix = _M
667+
jac_prototype = jac_prototype,
668+
observed = observedfun,
669+
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
670+
analytic = analytic,
654671
Wfact = _Wfact === nothing ? nothing : _Wfact,
655672
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
656-
mass_matrix = _M, initialization_data,
657-
observed = observedfun)
673+
initialization_data)
658674
end
659675

660676
"""
@@ -724,6 +740,17 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
724740
_jac = :nothing
725741
end
726742

743+
jac_prototype = if sparse
744+
uElType = u0 === nothing ? Float64 : eltype(u0)
745+
if jac
746+
similar(calculate_jacobian(sys, sparse = sparse), uElType)
747+
else
748+
similar(jacobian_sparsity(sys), uElType)
749+
end
750+
else
751+
nothing
752+
end
753+
727754
if Wfact
728755
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(
729756
sys, dvs, ps; expression = Val{true},
@@ -743,11 +770,13 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
743770
g = $g
744771
tgrad = $_tgrad
745772
jac = $_jac
773+
jac_prototype = $jac_prototype
746774
Wfact = $_Wfact
747775
Wfact_t = $_Wfact_t
748776
M = $_M
749777
SDEFunction{$iip}(f, g,
750778
jac = jac,
779+
jac_prototype = jac_prototype,
751780
tgrad = tgrad,
752781
Wfact = Wfact,
753782
Wfact_t = Wfact_t,

0 commit comments

Comments
 (0)