129129# build something that works for us here and worry about it later.
130130nonzerosmap (a:: CLILVector ) = NonZeros (a)
131131
132+ using FindFirstFunctions: findfirstequal
133+
132134function bareiss_update_virtual_colswap_mtk! (zero!, M:: SparseMatrixCLIL , k, swapto, pivot,
133135 last_pivot; pivot_equal_optimization = true )
134136 # for ei in nzrows(>= k)
@@ -168,12 +170,11 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
168170 # conservative, we leave it at this, as this captures the most important
169171 # case for MTK (where most pivots are `1` or `-1`).
170172 pivot_equal = pivot_equal_optimization && abs (pivot) == abs (last_pivot)
171-
172- for ei in (k + 1 ): size (M, 1 )
173+ @inbounds for ei in (k + 1 ): size (M, 1 )
173174 # eliminate `v`
174175 coeff = 0
175176 ivars = eadj[ei]
176- vj = findfirst ( isequal ( vpivot) , ivars)
177+ vj = findfirstequal ( vpivot, ivars)
177178 if vj != = nothing
178179 coeff = old_cadj[ei][vj]
179180 deleteat! (old_cadj[ei], vj)
@@ -189,24 +190,118 @@ function bareiss_update_virtual_colswap_mtk!(zero!, M::SparseMatrixCLIL, k, swap
189190 ivars = eadj[ei]
190191 icoeffs = old_cadj[ei]
191192
192- tmp_incidence = similar (eadj[ei], 0 )
193- tmp_coeffs = similar (old_cadj[ei], 0 )
194- # TODO : We know both ivars and kvars are sorted, we could just write
195- # a quick iterator here that does this without allocation/faster.
196- vars = sort (union (ivars, kvars))
197-
198- for v in vars
199- v == vpivot && continue
200- ck = getcoeff (kvars, kcoeffs, v)
201- ci = getcoeff (ivars, icoeffs, v)
202- p1 = Base. Checked. checked_mul (pivot, ci)
203- p2 = Base. Checked. checked_mul (coeff, ck)
204- ci = exactdiv (Base. Checked. checked_sub (p1, p2), last_pivot)
205- if ! iszero (ci)
206- push! (tmp_incidence, v)
207- push! (tmp_coeffs, ci)
193+ numkvars = length (kvars)
194+ numivars = length (ivars)
195+ tmp_incidence = similar (eadj[ei], numkvars + numivars)
196+ tmp_coeffs = similar (old_cadj[ei], numkvars + numivars)
197+ tmp_len = 0
198+ kvind = ivind = 0
199+ if _debug_mode
200+ # in debug mode, we at least check to confirm we're iterating over
201+ # `v`s in the correct order
202+ vars = sort (union (ivars, kvars))
203+ vi = 0
204+ end
205+ if numivars > 0 && numkvars > 0
206+ kvv = kvars[kvind += 1 ]
207+ ivv = ivars[ivind += 1 ]
208+ dobreak = false
209+ while true
210+ if kvv == ivv
211+ v = kvv
212+ ck = kcoeffs[kvind]
213+ ci = icoeffs[ivind]
214+ kvind += 1
215+ ivind += 1
216+ if kvind > numkvars
217+ dobreak = true
218+ else
219+ kvv = kvars[kvind]
220+ end
221+ if ivind > numivars
222+ dobreak = true
223+ else
224+ ivv = ivars[ivind]
225+ end
226+ p1 = Base. Checked. checked_mul (pivot, ci)
227+ p2 = Base. Checked. checked_mul (coeff, ck)
228+ ci = exactdiv (Base. Checked. checked_sub (p1, p2), last_pivot)
229+ elseif kvv < ivv
230+ v = kvv
231+ ck = kcoeffs[kvind]
232+ kvind += 1
233+ if kvind > numkvars
234+ dobreak = true
235+ else
236+ kvv = kvars[kvind]
237+ end
238+ p2 = Base. Checked. checked_mul (coeff, ck)
239+ ci = exactdiv (Base. Checked. checked_neg (p2), last_pivot)
240+ else # kvv > ivv
241+ v = ivv
242+ ci = icoeffs[ivind]
243+ ivind += 1
244+ if ivind > numivars
245+ dobreak = true
246+ else
247+ ivv = ivars[ivind]
248+ end
249+ ci = exactdiv (Base. Checked. checked_mul (pivot, ci), last_pivot)
250+ end
251+ if _debug_mode
252+ @assert v == vars[vi += 1 ]
253+ end
254+ if v != vpivot && ! iszero (ci)
255+ tmp_incidence[tmp_len += 1 ] = v
256+ tmp_coeffs[tmp_len] = ci
257+ end
258+ dobreak && break
259+ end
260+ elseif numkvars > 0
261+ ivind = 1
262+ kvv = kvars[kvind += 1 ]
263+ elseif numivars > 0
264+ kvind = 1
265+ ivv = ivars[ivind += 1 ]
266+ end
267+ if kvind <= numkvars
268+ v = kvv
269+ while true
270+ if _debug_mode
271+ @assert v == vars[vi += 1 ]
272+ end
273+ if v != vpivot
274+ ck = kcoeffs[kvind]
275+ p2 = Base. Checked. checked_mul (coeff, ck)
276+ ci = exactdiv (Base. Checked. checked_neg (p2), last_pivot)
277+ if ! iszero (ci)
278+ tmp_incidence[tmp_len += 1 ] = v
279+ tmp_coeffs[tmp_len] = ci
280+ end
281+ end
282+ (kvind == numkvars) && break
283+ v = kvars[kvind += 1 ]
284+ end
285+ elseif ivind <= numivars
286+ v = ivv
287+ while true
288+ if _debug_mode
289+ @assert v == vars[vi += 1 ]
290+ end
291+ if v != vpivot
292+ p1 = Base. Checked. checked_mul (pivot, icoeffs[ivind])
293+ ci = exactdiv (p1, last_pivot)
294+ if ! iszero (ci)
295+ tmp_incidence[tmp_len += 1 ] = v
296+ tmp_coeffs[tmp_len] = ci
297+ end
298+ end
299+ (ivind == numivars) && break
300+ v = ivars[ivind += 1 ]
208301 end
209302 end
303+ resize! (tmp_incidence, tmp_len)
304+ resize! (tmp_coeffs, tmp_len)
210305 eadj[ei] = tmp_incidence
211306 old_cadj[ei] = tmp_coeffs
212307 end
0 commit comments