Skip to content

Commit a1a26b4

Browse files
committed
Extract GaussianElimination to library
1 parent d40f539 commit a1a26b4

File tree

2 files changed

+114
-90
lines changed

2 files changed

+114
-90
lines changed

src/main/scala/eu/sim642/adventofcode2025/Day10.scala

Lines changed: 6 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package eu.sim642.adventofcode2025
22

33
import eu.sim642.adventofcodelib.IteratorImplicits.*
44
import com.microsoft.z3.{ArithExpr, Context, IntExpr, IntSort, Status}
5-
import eu.sim642.adventofcodelib.NumberTheory
5+
import eu.sim642.adventofcodelib.{GaussianElimination, NumberTheory}
66
import eu.sim642.adventofcodelib.graph.{BFS, Dijkstra, GraphSearch, TargetNode, UnitNeighbors}
77

88
import scala.collection.mutable
@@ -124,80 +124,8 @@ object Day10 {
124124
button.foldLeft(zeroCol)((acc, i) => acc.updated(i, 1L))
125125
)
126126
.transpose
127-
.zip(machine.joltages.map(_.toLong))
128127

129-
val m: mutable.ArraySeq[mutable.ArraySeq[Long]] = rows.map((a, b) => (a :+ b).to(mutable.ArraySeq)).to(mutable.ArraySeq)
130-
131-
def swapRows(y1: Int, y2: Int): Unit = {
132-
val row1 = m(y1)
133-
m(y1) = m(y2)
134-
m(y2) = row1
135-
}
136-
137-
def multiplyRow(y: Int, factor: Long): Unit = {
138-
for (x2 <- 0 until (machine.buttons.size + 1))
139-
m(y)(x2) *= factor
140-
}
141-
142-
def simplifyRow(y: Int): Unit = {
143-
val factor = NumberTheory.gcd(m(y).toSeq) // TODO: avoid conversion
144-
if (factor.abs > 1) {
145-
for (x2 <- 0 until (machine.buttons.size + 1))
146-
m(y)(x2) /= factor
147-
}
148-
}
149-
150-
def reduceDown(x: Int, y1: Int, y2: Int): Unit = {
151-
val c2 = m(y2)(x)
152-
if (c2 != 0) {
153-
val c1 = m(y1)(x)
154-
val (_, _, (factor, factor2)) = NumberTheory.extendedGcd(c1, c2)
155-
for (x2 <- 0 until x) // must start from 0 because we're now multiplying entire row y2
156-
m(y2)(x2) = factor2 * m(y2)(x2)
157-
for (x2 <- x until (machine.buttons.size + 1))
158-
m(y2)(x2) = factor2 * m(y2)(x2) + factor * m(y1)(x2)
159-
//simplifyRow(y2)
160-
}
161-
}
162-
163-
var y = 0
164-
for (x <- machine.buttons.indices) {
165-
val y2opt = m.indices.find(y2 => y2 >= y && m(y2)(x) != 0)
166-
y2opt match {
167-
case None => // move to next x
168-
case Some(y2) =>
169-
swapRows(y, y2)
170-
for (y3 <- (y + 1) until m.size)
171-
reduceDown(x, y, y3)
172-
173-
y += 1
174-
}
175-
}
176-
177-
// check consistency
178-
for (y2 <- y until m.size)
179-
assert(m(y2).last == 0)
180-
181-
val mainVars = mutable.ArrayBuffer.empty[Int]
182-
val freeVars = mutable.ArrayBuffer.empty[Int]
183-
y = 0
184-
for (x <- machine.buttons.indices) {
185-
if (y < m.size) { // TODO: break if y too big
186-
if (m(y)(x) == 0) {
187-
freeVars += x
188-
()
189-
} // move to next x
190-
else {
191-
mainVars += x
192-
for (y3 <- 0 until y)
193-
reduceDown(x, y, y3)
194-
195-
y += 1
196-
}
197-
}
198-
else
199-
freeVars += x // can't break if this is here
200-
}
128+
val sol = GaussianElimination.solve(rows, machine.joltages.map(_.toLong))
201129

202130
//val mSum = m.transpose.map(_.sum) // TODO: use?
203131

@@ -216,25 +144,13 @@ object Day10 {
216144
}
217145
}
218146

219-
def eval(freeVals: List[Int]): List[Long] = {
220-
val mainVals = mainVars.view.zipWithIndex.map((mainVar, y) => {
221-
val row = m(y)
222-
val r = row.last - (freeVars lazyZip freeVals).map((freeVar, freeVal) => row(freeVar) * freeVal).sum
223-
if (r % row(mainVar) == 0)
224-
r / row(mainVar)
225-
else
226-
-1
227-
}).toList
228-
mainVals
229-
}
230-
231-
val bound = freeVars.map(maxVals).sum
232-
val choices = (0 to bound).iterator.flatMap(helper0(_, freeVars.map(maxVals).toList))
147+
val bound = sol.freeVars.map(maxVals).sum
148+
val choices = (0 to bound).iterator.flatMap(helper0(_, sol.freeVars.map(maxVals).toList))
233149

234150
val answer =
235151
choices
236-
.map(freeVals => (eval(freeVals), freeVals))
237-
.filter(p => p._1.forall(_ >= 0) && (p._1 lazyZip mainVars).forall((a, b) => a <= maxVals(b))) // all main vals must be non-negative, but at most their max
152+
.map(freeVals => (sol.evaluate(freeVals.map(_.toLong)), freeVals))
153+
.filter(p => p._1.forall(_ >= 0) && (p._1 lazyZip sol.dependentVars).forall((a, b) => a <= maxVals(b))) // all main vals must be non-negative, but at most their max
238154
.map((s1, s2) => s1.sum + s2.sum)
239155
.min
240156

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package eu.sim642.adventofcodelib
2+
3+
import scala.collection.mutable
4+
import scala.math.Integral.Implicits.infixIntegralOps
5+
import scala.math.Ordering.Implicits.infixOrderingOps
6+
import scala.reflect.ClassTag
7+
8+
object GaussianElimination {
9+
10+
case class Solution[A: Integral](dependentVars: Seq[Int], dependentGenerator: Seq[A],
11+
freeVars: Seq[Int], freeGenerators: Seq[Seq[A]],
12+
const: Seq[A]) {
13+
def evaluate(freeVals: Seq[A]): Seq[A] = {
14+
(dependentGenerator lazyZip const).zipWithIndex.map({case ((mainVar, v), i) =>
15+
val r = v - (freeGenerators lazyZip freeVals).map((freeVar, freeVal) => freeVar(i) * freeVal).sum
16+
if (r % mainVar == 0)
17+
r / mainVar
18+
else
19+
-summon[Integral[A]].one // TODO: Option
20+
}).toList
21+
}
22+
}
23+
24+
def solve[A: ClassTag](initialA: Seq[Seq[A]], initialb: Seq[A])(using aIntegral: Integral[A]): Solution[A] = {
25+
val rows = initialA zip initialb // TODO: lazyZip
26+
val m: mutable.ArraySeq[mutable.ArraySeq[A]] = rows.map((a, b) => (a :+ b).to(mutable.ArraySeq)).to(mutable.ArraySeq)
27+
val n = initialA.head.size
28+
29+
def swapRows(y1: Int, y2: Int): Unit = {
30+
val row1 = m(y1)
31+
m(y1) = m(y2)
32+
m(y2) = row1
33+
}
34+
35+
def multiplyRow(y: Int, factor: A): Unit = {
36+
for (x2 <- 0 until (n + 1))
37+
m(y)(x2) *= factor
38+
}
39+
40+
def simplifyRow(y: Int): Unit = {
41+
val factor = NumberTheory.gcd(m(y).toSeq) // TODO: avoid conversion
42+
if (factor.abs > summon[Integral[A]].one) {
43+
for (x2 <- 0 until (n + 1))
44+
m(y)(x2) /= factor
45+
}
46+
}
47+
48+
def reduceDown(x: Int, y1: Int, y2: Int): Unit = {
49+
val c2 = m(y2)(x)
50+
if (c2 != 0) {
51+
val c1 = m(y1)(x)
52+
val (_, _, (factor, factor2)) = NumberTheory.extendedGcd(c1, c2)
53+
for (x2 <- 0 until x) // must start from 0 because we're now multiplying entire row y2
54+
m(y2)(x2) = factor2 * m(y2)(x2)
55+
for (x2 <- x until (n + 1))
56+
m(y2)(x2) = factor2 * m(y2)(x2) + factor * m(y1)(x2)
57+
//simplifyRow(y2) // TODO: helps?
58+
}
59+
}
60+
61+
var y = 0
62+
for (x <- 0 until n) {
63+
val y2opt = m.indices.find(y2 => y2 >= y && m(y2)(x) != 0)
64+
y2opt match {
65+
case None => // move to next x
66+
case Some(y2) =>
67+
swapRows(y, y2)
68+
for (y3 <- (y + 1) until m.size)
69+
reduceDown(x, y, y3)
70+
71+
y += 1
72+
}
73+
}
74+
75+
// check consistency
76+
for (y2 <- y until m.size)
77+
assert(m(y2).last == 0) // TODO: return Option
78+
79+
val mainVars = mutable.ArrayBuffer.empty[Int]
80+
val freeVars = mutable.ArrayBuffer.empty[Int]
81+
y = 0
82+
for (x <- 0 until n) {
83+
if (y < m.size) { // TODO: break if y too big
84+
if (m(y)(x) == 0) {
85+
freeVars += x
86+
()
87+
} // move to next x
88+
else {
89+
mainVars += x
90+
for (y3 <- 0 until y)
91+
reduceDown(x, y, y3)
92+
93+
y += 1
94+
}
95+
}
96+
else
97+
freeVars += x // can't break if this is here
98+
}
99+
100+
Solution(
101+
dependentVars = mainVars.toSeq,
102+
dependentGenerator = (mainVars lazyZip m).view.map((v, row) => row(v)).toSeq,
103+
freeVars = freeVars.toSeq,
104+
freeGenerators = freeVars.view.map(x => m.view.map(_(x)).toSeq).toSeq,
105+
const = m.view.map(_.last).toSeq
106+
)
107+
}
108+
}

0 commit comments

Comments
 (0)