@@ -264,7 +264,8 @@ function find_linear_variables(graph, linear_equations, var_to_diff, irreducible
264264 return linear_variables
265265end
266266
267- function aag_bareiss! (graph, var_to_diff, mm_orig:: SparseMatrixCLIL{T, Ti} ) where {T, Ti}
267+ function aag_bareiss! (structure, mm_orig:: SparseMatrixCLIL{T, Ti} ) where {T, Ti}
268+ @unpack graph, var_to_diff = structure
268269 mm = copy (mm_orig)
269270 linear_equations_set = BitSet (mm_orig. nzrows)
270271
@@ -279,6 +280,7 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
279280 v -> var_to_diff[v] === nothing === invview (var_to_diff)[v]
280281 end
281282 is_linear_variables = is_algebraic .(1 : length (var_to_diff))
283+ is_highest_diff = computed_highest_diff_variables (structure)
282284 for i in 𝑠vertices (graph)
283285 # only consider linear algebraic equations
284286 (i in linear_equations_set && all (is_algebraic, 𝑠neighbors (graph, i))) &&
@@ -291,25 +293,31 @@ function aag_bareiss!(graph, var_to_diff, mm_orig::SparseMatrixCLIL{T, Ti}) wher
291293
292294 local bar
293295 try
294- bar = do_bareiss! (mm, mm_orig, is_linear_variables)
296+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff )
295297 catch e
296298 e isa OverflowError || rethrow (e)
297299 mm = convert (SparseMatrixCLIL{BigInt, Ti}, mm_orig)
298- bar = do_bareiss! (mm, mm_orig, is_linear_variables)
300+ bar = do_bareiss! (mm, mm_orig, is_linear_variables, is_highest_diff )
299301 end
300302
301303 return mm, solvable_variables, bar
302304end
303305
304- function do_bareiss! (M, Mold, is_linear_variables)
306+ function do_bareiss! (M, Mold, is_linear_variables, is_highest_diff )
305307 rank1r = Ref {Union{Nothing, Int}} (nothing )
308+ rank2r = Ref {Union{Nothing, Int}} (nothing )
306309 find_pivot = let rank1r = rank1r
307310 (M, k) -> begin
308311 if rank1r[] === nothing
309312 r = find_masked_pivot (is_linear_variables, M, k)
310313 r != = nothing && return r
311314 rank1r[] = k - 1
312315 end
316+ if rank2r[] === nothing
317+ r = find_masked_pivot (is_highest_diff, M, k)
318+ r != = nothing && return r
319+ rank2r[] = k - 1
320+ end
313321 # TODO : It would be better to sort the variables by
314322 # derivative order here to enable more elimination
315323 # opportunities.
@@ -334,15 +342,19 @@ function do_bareiss!(M, Mold, is_linear_variables)
334342 bareiss_ops = ((M, i, j) -> nothing , myswaprows!,
335343 bareiss_update_virtual_colswap_mtk!, bareiss_zero!)
336344
337- rank2, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
345+ rank3, = bareiss! (M, bareiss_ops; find_pivot = find_and_record_pivot)
346+ rank2 = something (rank2r[], rank3)
338347 rank1 = something (rank1r[], rank2)
339- (rank1, rank2, pivots)
348+ (rank1, rank2, rank3, pivots)
340349end
341350
342- function simple_aliases! (ils, graph, solvable_graph, eq_to_diff, var_to_diff)
343- ils, solvable_variables, (rank1, rank2, pivots) = aag_bareiss! (graph,
344- var_to_diff,
345- ils)
351+ function alias_eliminate_graph! (state:: TransformationState , ils:: SparseMatrixCLIL )
352+ @unpack structure = state
353+ @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state. structure
354+ # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
355+ # subsystem of the system we're interested in.
356+ #
357+ ils, solvable_variables, (rank1, rank2, rank3, pivots) = aag_bareiss! (structure, ils)
346358
347359 # # Step 2: Simplify the system using the Bareiss factorization
348360 rk1vars = BitSet (@view pivots[1 : rank1])
@@ -362,14 +374,6 @@ function simple_aliases!(ils, graph, solvable_graph, eq_to_diff, var_to_diff)
362374 return ils
363375end
364376
365- function alias_eliminate_graph! (state:: TransformationState , ils:: SparseMatrixCLIL )
366- @unpack graph, solvable_graph, var_to_diff, eq_to_diff = state. structure
367- # Step 1: Perform Bareiss factorization on the adjacency matrix of the linear
368- # subsystem of the system we're interested in.
369- #
370- return simple_aliases! (ils, graph, solvable_graph, eq_to_diff, var_to_diff)
371- end
372-
373377function exactdiv (a:: Integer , b)
374378 d, r = divrem (a, b)
375379 @assert r == 0
0 commit comments