230230=#
231231
232232function tearing_reassemble (state:: TearingState , var_eq_matching,
233- full_var_eq_matching = nothing ; simplify = false , mm = nothing )
233+ full_var_eq_matching = nothing ; simplify = false , mm = nothing , cse_hack = true , array_hack = true )
234234 @unpack fullvars, sys, structure = state
235235 @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure
236236 extra_vars = Int[]
@@ -574,39 +574,6 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
574574 # TODO : compute the dependency correctly so that we don't have to do this
575575 obs = [fast_substitute (observed (sys), obs_sub); subeqs]
576576
577- # HACK: Substitute non-scalarized symbolic arrays of observed variables
578- # E.g. if `p[1] ~ (...)` and `p[2] ~ (...)` then substitute `p => [p[1], p[2]]` in all equations
579- # ideally, we want to support equations such as `p ~ [p[1], p[2]]` which will then be handled
580- # by the topological sorting and dependency identification pieces
581- obs_arr_subs = Dict ()
582-
583- for eq in obs
584- lhs = eq. lhs
585- iscall (lhs) || continue
586- operation (lhs) === getindex || continue
587- Symbolics. shape (lhs) != = Symbolics. Unknown () || continue
588- arg1 = arguments (lhs)[1 ]
589- haskey (obs_arr_subs, arg1) && continue
590- obs_arr_subs[arg1] = [arg1[i] for i in eachindex (arg1)] # e.g. p => [p[1], p[2]]
591- index_first = eachindex (arg1)[1 ]
592-
593- # respect non-1-indexed arrays
594- # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
595- obs_arr_subs[arg1] = Origin (index_first)(obs_arr_subs[arg1])
596- end
597- for i in eachindex (neweqs)
598- neweqs[i] = fast_substitute (neweqs[i], obs_arr_subs; operator = Symbolics. Operator)
599- end
600- for i in eachindex (obs)
601- obs[i] = fast_substitute (obs[i], obs_arr_subs; operator = Symbolics. Operator)
602- end
603- for i in eachindex (subeqs)
604- subeqs[i] = fast_substitute (subeqs[i], obs_arr_subs; operator = Symbolics. Operator)
605- end
606-
607- @set! sys. eqs = neweqs
608- @set! sys. observed = obs
609-
610577 unknowns = Any[v
611578 for (i, v) in enumerate (fullvars)
612579 if diff_to_var[i] === nothing && ispresent (i)]
@@ -616,6 +583,13 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
616583 end
617584 end
618585 @set! sys. unknowns = unknowns
586+
587+ obs, subeqs, deps = cse_and_array_hacks (
588+ obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
589+
590+ @set! sys. eqs = neweqs
591+ @set! sys. observed = obs
592+
619593 @set! sys. substitutions = Substitutions (subeqs, deps)
620594
621595 # Only makes sense for time-dependent
@@ -629,6 +603,168 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
629603 return invalidate_cache! (sys)
630604end
631605
606+ """
607+ # HACK 1
608+
609+ Since we don't support array equations, any equation of the sort `x[1:n] ~ f(...)[1:n]`
610+ gets turned into `x[1] ~ f(...)[1], x[2] ~ f(...)[2]`. Repeatedly calling `f` gets
611+ _very_ expensive. this hack performs a limited form of CSE specifically for this case to
612+ avoid the unnecessary cost. This and the below hack are implemented simultaneously
613+
614+ # HACK 2
615+
616+ Add equations for array observed variables. If `p[i] ~ (...)` are equations, add an
617+ equation `p ~ [p[1], p[2], ...]` allow topsort to reorder them only add the new equation
618+ if all `p[i]` are present and the unscalarized form is used in any equation (observed or
619+ not) we first count the number of times the scalarized form of each observed variable
620+ occurs in observed equations (and unknowns if it's split).
621+ """
622+ function cse_and_array_hacks (obs, subeqs, unknowns, neweqs; cse = true , array = true )
623+ # HACK 1
624+ # mapping of rhs to temporary CSE variable
625+ # `f(...) => tmpvar` in above example
626+ rhs_to_tempvar = Dict ()
627+
628+ # HACK 2
629+ # map of array observed variable (unscalarized) to number of its
630+ # scalarized terms that appear in observed equations
631+ arr_obs_occurrences = Dict ()
632+ # to check if array variables occur in unscalarized form anywhere
633+ all_vars = Set ()
634+ for (i, eq) in enumerate (obs)
635+ lhs = eq. lhs
636+ rhs = eq. rhs
637+ vars! (all_vars, rhs)
638+
639+ # HACK 1
640+ if cse && is_getindexed_array (rhs)
641+ rhs_arr = arguments (rhs)[1 ]
642+ if ! haskey (rhs_to_tempvar, rhs_arr)
643+ tempvar = gensym (Symbol (lhs))
644+ N = length (rhs_arr)
645+ tempvar = unwrap (Symbolics. variable (
646+ tempvar; T = Symbolics. symtype (rhs_arr)))
647+ tempvar = setmetadata (
648+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
649+ tempeq = tempvar ~ rhs_arr
650+ rhs_to_tempvar[rhs_arr] = tempvar
651+ push! (obs, tempeq)
652+ push! (subeqs, tempeq)
653+ end
654+
655+ # getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
656+ # so it doesn't find a dependency between this equation and `tempvar ~ rhs_arr`
657+ # which fails the topological sort
658+ neweq = lhs ~ getindex_wrapper (
659+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
660+ obs[i] = neweq
661+ subeqi = findfirst (isequal (eq), subeqs)
662+ if subeqi != = nothing
663+ subeqs[subeqi] = neweq
664+ end
665+ end
666+ # end HACK 1
667+
668+ array || continue
669+ iscall (lhs) || continue
670+ operation (lhs) === getindex || continue
671+ Symbolics. shape (lhs) != Symbolics. Unknown () || continue
672+ arg1 = arguments (lhs)[1 ]
673+ cnt = get (arr_obs_occurrences, arg1, 0 )
674+ arr_obs_occurrences[arg1] = cnt + 1
675+ continue
676+ end
677+
678+ # Also do CSE for `equations(sys)`
679+ if cse
680+ for (i, eq) in enumerate (neweqs)
681+ (; lhs, rhs) = eq
682+ is_getindexed_array (rhs) || continue
683+ rhs_arr = arguments (rhs)[1 ]
684+ if ! haskey (rhs_to_tempvar, rhs_arr)
685+ tempvar = gensym (Symbol (lhs))
686+ N = length (rhs_arr)
687+ tempvar = unwrap (Symbolics. variable (
688+ tempvar; T = Symbolics. symtype (rhs_arr)))
689+ tempvar = setmetadata (
690+ tempvar, Symbolics. ArrayShapeCtx, Symbolics. shape (rhs_arr))
691+ tempeq = tempvar ~ rhs_arr
692+ rhs_to_tempvar[rhs_arr] = tempvar
693+ push! (obs, tempeq)
694+ push! (subeqs, tempeq)
695+ end
696+ # don't need getindex_wrapper, but do it anyway to know that this
697+ # hack took place
698+ neweq = lhs ~ getindex_wrapper (
699+ rhs_to_tempvar[rhs_arr], Tuple (arguments (rhs)[2 : end ]))
700+ neweqs[i] = neweq
701+ end
702+ end
703+
704+ # count variables in unknowns if they are scalarized forms of variables
705+ # also present as observed. e.g. if `x[1]` is an unknown and `x[2] ~ (..)`
706+ # is an observed equation.
707+ for sym in unknowns
708+ iscall (sym) || continue
709+ operation (sym) === getindex || continue
710+ Symbolics. shape (sym) != Symbolics. Unknown () || continue
711+ arg1 = arguments (sym)[1 ]
712+ cnt = get (arr_obs_occurrences, arg1, 0 )
713+ cnt == 0 && continue
714+ arr_obs_occurrences[arg1] = cnt + 1
715+ end
716+ for eq in neweqs
717+ vars! (all_vars, eq. rhs)
718+ end
719+ obs_arr_eqs = Equation[]
720+ for (arrvar, cnt) in arr_obs_occurrences
721+ cnt == length (arrvar) || continue
722+ arrvar in all_vars || continue
723+ # firstindex returns 1 for multidimensional array symbolics
724+ firstind = first (eachindex (arrvar))
725+ scal = [arrvar[i] for i in eachindex (arrvar)]
726+ # respect non-1-indexed arrays
727+ # TODO : get rid of this hack together with the above hack, then remove OffsetArrays dependency
728+ # `change_origin` is required because `Origin(firstind)(scal)` makes codegen
729+ # try to `create_array(OffsetArray{...}, ...)` which errors.
730+ # `term(Origin(firstind), scal)` doesn't retain the `symtype` and `size`
731+ # of `scal`.
732+ push! (obs_arr_eqs, arrvar ~ change_origin (Origin (firstind), scal))
733+ end
734+ append! (obs, obs_arr_eqs)
735+ append! (subeqs, obs_arr_eqs)
736+
737+ # need to re-sort subeqs
738+ subeqs = ModelingToolkit. topsort_equations (subeqs, [eq. lhs for eq in subeqs])
739+
740+ deps = Vector{Int}[i == 1 ? Int[] : collect (1 : (i - 1 ))
741+ for i in 1 : length (subeqs)]
742+
743+ return obs, subeqs, deps
744+ end
745+
746+ function is_getindexed_array (rhs)
747+ (! ModelingToolkit. isvariable (rhs) || ModelingToolkit. iscalledparameter (rhs)) &&
748+ iscall (rhs) && operation (rhs) === getindex &&
749+ Symbolics. shape (rhs) != Symbolics. Unknown ()
750+ end
751+
752+ # PART OF HACK 1
753+ getindex_wrapper (x, i) = x[i... ]
754+
755+ @register_symbolic getindex_wrapper (x:: AbstractArray , i:: Tuple{Vararg{Int}} )
756+
757+ # PART OF HACK 2
758+ function change_origin (origin, arr)
759+ return origin (arr)
760+ end
761+
762+ @register_array_symbolic change_origin (origin:: Origin , arr:: AbstractArray ) begin
763+ size = size (arr)
764+ eltype = eltype (arr)
765+ ndims = ndims (arr)
766+ end
767+
632768function tearing (state:: TearingState ; kwargs... )
633769 state. structure. solvable_graph === nothing && find_solvables! (state; kwargs... )
634770 complete! (state. structure)
@@ -643,10 +779,10 @@ new residual equations after tearing. End users are encouraged to call [`structu
643779instead, which calls this function internally.
644780"""
645781function tearing (sys:: AbstractSystem , state = TearingState (sys); mm = nothing ,
646- simplify = false , kwargs... )
782+ simplify = false , cse_hack = true , array_hack = true , kwargs... )
647783 var_eq_matching, full_var_eq_matching = tearing (state)
648784 invalidate_cache! (tearing_reassemble (
649- state, var_eq_matching, full_var_eq_matching; mm, simplify))
785+ state, var_eq_matching, full_var_eq_matching; mm, simplify, cse_hack, array_hack ))
650786end
651787
652788"""
@@ -668,7 +804,7 @@ Perform index reduction and use the dummy derivative technique to ensure that
668804the system is balanced.
669805"""
670806function dummy_derivative (sys, state = TearingState (sys); simplify = false ,
671- mm = nothing , kwargs... )
807+ mm = nothing , cse_hack = true , array_hack = true , kwargs... )
672808 jac = let state = state
673809 (eqs, vars) -> begin
674810 symeqs = EquationsView (state)[eqs]
@@ -692,5 +828,5 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
692828 end
693829 var_eq_matching = dummy_derivative_graph! (state, jac; state_priority,
694830 kwargs... )
695- tearing_reassemble (state, var_eq_matching; simplify, mm)
831+ tearing_reassemble (state, var_eq_matching; simplify, mm, cse_hack, array_hack )
696832end
0 commit comments