Skip to content

Commit e2c0c94

Browse files
committed
Mark trailing map receives body to inspect
1 parent c7af9ab commit e2c0c94

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,7 +2236,7 @@ object desugar {
22362236
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
22372237
case _ => false
22382238

2239-
def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit =
2239+
def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName, body: Tree): Unit =
22402240
if sourceVersion.enablesBetterFors
22412241
&& selectName == mapName
22422242
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
@@ -2247,9 +2247,8 @@ object desugar {
22472247
enums match {
22482248
case Nil if sourceVersion.enablesBetterFors => body
22492249
case (gen: GenFrom) :: Nil =>
2250-
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2251-
markTrailingMap(aply, gen, mapName)
2252-
aply
2250+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2251+
.tap(markTrailingMap(_, gen, mapName, body))
22532252
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
22542253
val cont = makeFor(mapName, flatMapName, rest, body)
22552254
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
@@ -2266,7 +2265,7 @@ object desugar {
22662265
if suffix.exists(_.isInstanceOf[GenFrom]) then flatMapName
22672266
else mapName
22682267
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2269-
.tap(markTrailingMap(_, gen, selectName))
2268+
.tap(markTrailingMap(_, gen, selectName, cont))
22702269
else
22712270
val (pats, rhss) = valeqs.map { case GenAlias(pat, rhs) => (pat, rhs) }.unzip
22722271
val (defpat0, id0) = makeIdPat(gen.pat)

tests/run/i24673.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@main def Test = {
2+
def result = for {
3+
a <- Option(2)
4+
_ = if (true) {
5+
sys.error("err")
6+
}
7+
} yield a
8+
9+
try
10+
result
11+
???
12+
catch case e: RuntimeException => assert(e.getMessage == "err")
13+
}

0 commit comments

Comments
 (0)