@@ -27,36 +27,40 @@ function ascend_dg_all(xs, dg, level, maxlevel)
2727 return r
2828end
2929
30- function pss_graph_modia! (structure:: SystemStructure , var_eq_matching , varlevel,
30+ function pss_graph_modia! (structure:: SystemStructure , maximal_top_matching , varlevel,
3131 inv_varlevel, inv_eqlevel)
3232 @unpack eq_to_diff, var_to_diff, graph, solvable_graph = structure
3333
3434 # var_eq_matching is a maximal matching on the top-differentiated variables.
3535 # Find Strongly connected components. Note that after pantelides, we expect
3636 # a balanced system, so a maximal matching should be possible.
37- var_sccs:: Vector{Union{Vector{Int}, Int}} = find_var_sccs (graph, var_eq_matching )
38- var_eq_matching = Matching {Union{Unassigned, SelectedState}} (var_eq_matching )
37+ var_sccs:: Vector{Union{Vector{Int}, Int}} = find_var_sccs (graph, maximal_top_matching )
38+ var_eq_matching = Matching {Union{Unassigned, SelectedState}} (ndsts (graph) )
3939 for vars in var_sccs
4040 # TODO : We should have a way to not have the scc code look at unassigned vars.
41- if length (vars) == 1 && varlevel [vars[1 ]] != 0
41+ if length (vars) == 1 && maximal_top_matching [vars[1 ]] === unassigned
4242 continue
4343 end
4444
4545 # Now proceed level by level from lowest to highest and tear the graph.
46- eqs = [var_eq_matching[var] for var in vars if var_eq_matching[var] != = unassigned]
46+ eqs = [maximal_top_matching[var]
47+ for var in vars if maximal_top_matching[var] != = unassigned]
4748 isempty (eqs) && continue
48- maxlevel = level = maximum (map (x -> inv_eqlevel[x], eqs))
49+ maxeqlevel = maximum (map (x -> inv_eqlevel[x], eqs))
50+ maxvarlevel = level = maximum (map (x -> inv_varlevel[x], vars))
4951 old_level_vars = ()
5052 ict = IncrementalCycleTracker (DiCMOBiGraph {true} (graph,
5153 complete (Matching (ndsts (graph))));
5254 dir = :in )
55+
5356 while level >= 0
5457 to_tear_eqs_toplevel = filter (eq -> inv_eqlevel[eq] >= level, eqs)
5558 to_tear_eqs = ascend_dg (to_tear_eqs_toplevel, invview (eq_to_diff), level)
5659
5760 to_tear_vars_toplevel = filter (var -> inv_varlevel[var] >= level, vars)
58- to_tear_vars = ascend_dg_all (to_tear_vars_toplevel, invview (var_to_diff), level,
59- maxlevel)
61+ to_tear_vars = ascend_dg (to_tear_vars_toplevel, invview (var_to_diff), level)
62+
63+ assigned_eqs = Int[]
6064
6165 if old_level_vars != = ()
6266 # Inherit constraints from previous level.
@@ -66,45 +70,59 @@ function pss_graph_modia!(structure::SystemStructure, var_eq_matching, varlevel,
6670 removed_eqs = Int[]
6771 removed_vars = Int[]
6872 for var in old_level_vars
69- old_assign = ict. graph. matching[var]
70- if ! isa (old_assign, Int) ||
71- ict. graph. matching[var_to_diff[var]] != = unassigned
73+ old_assign = var_eq_matching[var]
74+ if isa (old_assign, SelectedState)
75+ push! (removed_vars, var)
76+ continue
77+ elseif ! isa (old_assign, Int) ||
78+ ict. graph. matching[var_to_diff[var]] != = unassigned
7279 continue
7380 end
7481 # Make sure the ict knows about this edge, so it doesn't accidentally introduce
7582 # a cycle.
76- ok = try_assign_eq! (ict, var_to_diff[var], eq_to_diff[old_assign])
83+ assgned_eq = eq_to_diff[old_assign]
84+ ok = try_assign_eq! (ict, var_to_diff[var], assgned_eq)
7785 @assert ok
78- var_eq_matching[var_to_diff[var]] = eq_to_diff[old_assign]
86+ var_eq_matching[var_to_diff[var]] = assgned_eq
7987 push! (removed_eqs, eq_to_diff[ict. graph. matching[var]])
8088 push! (removed_vars, var_to_diff[var])
89+ push! (removed_vars, var)
8190 end
8291 to_tear_eqs = setdiff (to_tear_eqs, removed_eqs)
8392 to_tear_vars = setdiff (to_tear_vars, removed_vars)
8493 end
85- filter! (var -> ict. graph. matching[var] === unassigned, to_tear_vars)
86- filter! (eq -> invview (ict. graph. matching)[eq] === unassigned, to_tear_eqs)
8794 tearEquations! (ict, solvable_graph. fadjlist, to_tear_eqs, BitSet (to_tear_vars),
8895 nothing )
96+
8997 for var in to_tear_vars
90- var_eq_matching[var] = unassigned
98+ @assert var_eq_matching[var] === unassigned
99+ assgned_eq = ict. graph. matching[var]
100+ var_eq_matching[var] = assgned_eq
101+ isa (assgned_eq, Int) && push! (assigned_eqs, assgned_eq)
91102 end
92- for var in to_tear_vars
93- var_eq_matching[var] = ict. graph. matching[var]
103+
104+ if level != 0
105+ remaining_vars = collect (v for v in to_tear_vars
106+ if var_eq_matching[v] === unassigned)
107+ if ! isempty (remaining_vars)
108+ remaining_eqs = setdiff (to_tear_eqs, assigned_eqs)
109+ nlsolve_matching = maximal_matching (graph,
110+ Base. Fix2 (in, remaining_eqs),
111+ Base. Fix2 (in, remaining_vars))
112+ for var in remaining_vars
113+ if nlsolve_matching[var] === unassigned &&
114+ var_eq_matching[var] === unassigned
115+ var_eq_matching[var] = SelectedState ()
116+ end
117+ end
118+ end
94119 end
120+
95121 old_level_vars = to_tear_vars
96122 level -= 1
97123 end
98124 end
99- for var in 1 : ndsts (graph)
100- dv = var_to_diff[var]
101- # If `var` is not algebraic (not differentiated nor a dummy derivative),
102- # then it's a SelectedState
103- if ! (dv === nothing || (varlevel[dv] != = 0 && var_eq_matching[dv] === unassigned))
104- var_eq_matching[var] = SelectedState ()
105- end
106- end
107- return var_eq_matching
125+ return complete (var_eq_matching)
108126end
109127
110128struct SelectedState end
0 commit comments