@@ -593,6 +593,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
593593 u0 = nothing ;
594594 version = nothing , tgrad = false , sparse = false ,
595595 jac = false , Wfact = false , eval_expression = false ,
596+ sparsity = false , analytic = nothing ,
596597 eval_module = @__MODULE__ ,
597598 checkbounds = false , initialization_data = nothing ,
598599 cse = true , kwargs... ) where {iip, specialize}
@@ -641,6 +642,17 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
641642 _Wfact, _Wfact_t = nothing , nothing
642643 end
643644
645+ jac_prototype = if sparse
646+ uElType = u0 === nothing ? Float64 : eltype (u0)
647+ if jac
648+ similar (calculate_jacobian (sys, sparse = sparse), uElType)
649+ else
650+ similar (jacobian_sparsity (sys), uElType)
651+ end
652+ else
653+ nothing
654+ end
655+
644656 M = calculate_massmatrix (sys)
645657 _M = (u0 === nothing || M == I) ? M : ArrayInterface. restructure (u0 .* u0' , M)
646658
@@ -651,10 +663,14 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
651663 sys = sys,
652664 jac = _jac === nothing ? nothing : _jac,
653665 tgrad = _tgrad === nothing ? nothing : _tgrad,
666+ mass_matrix = _M
667+ jac_prototype = jac_prototype,
668+ observed = observedfun,
669+ sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
670+ analytic = analytic,
654671 Wfact = _Wfact === nothing ? nothing : _Wfact,
655672 Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
656- mass_matrix = _M, initialization_data,
657- observed = observedfun)
673+ initialization_data)
658674end
659675
660676"""
@@ -724,6 +740,17 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
724740 _jac = :nothing
725741 end
726742
743+ jac_prototype = if sparse
744+ uElType = u0 === nothing ? Float64 : eltype (u0)
745+ if jac
746+ similar (calculate_jacobian (sys, sparse = sparse), uElType)
747+ else
748+ similar (jacobian_sparsity (sys), uElType)
749+ end
750+ else
751+ nothing
752+ end
753+
727754 if Wfact
728755 tmp_Wfact, tmp_Wfact_t = generate_factorized_W (
729756 sys, dvs, ps; expression = Val{true },
@@ -743,11 +770,13 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = unknowns(sys),
743770 g = $ g
744771 tgrad = $ _tgrad
745772 jac = $ _jac
773+ jac_prototype = $ jac_prototype
746774 Wfact = $ _Wfact
747775 Wfact_t = $ _Wfact_t
748776 M = $ _M
749777 SDEFunction {$iip} (f, g,
750778 jac = jac,
779+ jac_prototype = jac_prototype,
751780 tgrad = tgrad,
752781 Wfact = Wfact,
753782 Wfact_t = Wfact_t,
0 commit comments