@@ -10,19 +10,27 @@ private[quoted] object Matcher {
1010 class QuoteMatcher [QCtx <: QuoteContext & Singleton ](given val qctx : QCtx ) {
1111 // TODO improve performance
1212
13+ // TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
1314 private final val debug = false
1415
1516 import qctx .tasty .{_ , given }
1617 import Matching ._
1718
18- private type Env = Set [(Symbol , Symbol )]
19+ /** A map relating equivalent symbols from the scrutinee and the pattern
20+ * For example in
21+ * ```
22+ * '{val a = 4; a * a} match case '{ val x = 4; x * x }
23+ * ```
24+ * when matching `a * a` with `x * x` the enviroment will contain `Map(a -> x)`.
25+ */
26+ private type Env = Map [Symbol , Symbol ]
1927
2028 inline private def withEnv [T ](env : Env )(body : => (given Env ) => T ): T = body(given env )
2129
2230 class SymBinding (val sym : Symbol , val fromAbove : Boolean )
2331
2432 def termMatch (scrutineeTerm : Term , patternTerm : Term , hasTypeSplices : Boolean ): Option [Tuple ] = {
25- implicit val env : Env = Set .empty
33+ implicit val env : Env = Map .empty
2634 if (hasTypeSplices) {
2735 implicit val ctx : Context = internal.Context_GADT_setFreshGADTBounds (rootContext)
2836 val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +50,7 @@ private[quoted] object Matcher {
4250
4351 // TODO factor out common logic with `termMatch`
4452 def typeTreeMatch (scrutineeTypeTree : TypeTree , patternTypeTree : TypeTree , hasTypeSplices : Boolean ): Option [Tuple ] = {
45- implicit val env : Env = Set .empty
53+ implicit val env : Env = Map .empty
4654 if (hasTypeSplices) {
4755 implicit val ctx : Context = internal.Context_GADT_setFreshGADTBounds (rootContext)
4856 val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +146,29 @@ private[quoted] object Matcher {
138146 matched(scrutinee.seal)
139147
140148 // Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141- case (scrutinee : Term , TypeApply (patternHole, tpt :: Nil ))
149+ case (ClosedPatternTerm ( scrutinee) , TypeApply (patternHole, tpt :: Nil ))
142150 if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143151 scrutinee.tpe <:< tpt.tpe =>
144152 matched(scrutinee.seal)
145153
154+ // Matches an open term and wraps it into a lambda that provides the free variables
155+ case (scrutinee, pattern @ Apply (Select (TypeApply (patternHole, List (Inferred ())), " apply" ), args0 @ IdentArgs (args)))
156+ if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
157+ def bodyFn (lambdaArgs : List [Tree ]): Tree = {
158+ val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf [List [Term ]]).toMap
159+ new TreeMap {
160+ override def transformTerm (tree : Term )(given ctx : Context ): Term =
161+ tree match
162+ case tree : Ident => summon[Env ].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
163+ case tree => super .transformTerm(tree)
164+ }.transformTree(scrutinee)
165+ }
166+ val names = args.map(_.name)
167+ val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
168+ val resType = pattern.tpe
169+ val res = Lambda (MethodType (names)(_ => argTypes, _ => resType), bodyFn)
170+ matched(res.seal)
171+
146172 //
147173 // Match two equivalent trees
148174 //
@@ -156,7 +182,7 @@ private[quoted] object Matcher {
156182 case (scrutinee, Typed (expr2, _)) =>
157183 scrutinee =?= expr2
158184
159- case (Ident (_), Ident (_)) if scrutinee.symbol == pattern.symbol || summon[Env ].apply(( scrutinee.symbol, pattern.symbol) ) =>
185+ case (Ident (_), Ident (_)) if scrutinee.symbol == pattern.symbol || summon[Env ].get( scrutinee.symbol).contains( pattern.symbol) =>
160186 matched
161187
162188 case (Select (qual1, _), Select (qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -165,18 +191,24 @@ private[quoted] object Matcher {
165191 case (_ : Ref , _ : Ref ) if scrutinee.symbol == pattern.symbol =>
166192 matched
167193
168- case (Apply (fn1, args1), Apply (fn2, args2)) if fn1.symbol == fn2.symbol =>
194+ case (Apply (fn1, args1), Apply (fn2, args2)) if fn1.symbol == fn2.symbol || summon[ Env ].get(fn1.symbol).contains(fn2.symbol) =>
169195 fn1 =?= fn2 && args1 =?= args2
170196
171- case (TypeApply (fn1, args1), TypeApply (fn2, args2)) if fn1.symbol == fn2.symbol =>
197+ case (TypeApply (fn1, args1), TypeApply (fn2, args2)) if fn1.symbol == fn2.symbol || summon[ Env ].get(fn1.symbol).contains(fn2.symbol) =>
172198 fn1 =?= fn2 && args1 =?= args2
173199
174200 case (Block (stats1, expr1), Block (binding :: stats2, expr2)) if isTypeBinding(binding) =>
175201 qctx.tasty.internal.Context_GADT_addToConstraint (summon[Context ])(binding.symbol :: Nil )
176202 matched(new SymBinding (binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block (stats1, expr1) =?= Block (stats2, expr2)
177203
178204 case (Block (stat1 :: stats1, expr1), Block (stat2 :: stats2, expr2)) =>
179- withEnv(summon[Env ] + (stat1.symbol -> stat2.symbol)) {
205+ val newEnv = (stat1, stat2) match {
206+ case (stat1 : Definition , stat2 : Definition ) =>
207+ summon[Env ] + (stat1.symbol -> stat2.symbol)
208+ case _ =>
209+ summon[Env ]
210+ }
211+ withEnv(newEnv) {
180212 stat1 =?= stat2 && Block (stats1, expr1) =?= Block (stats2, expr2)
181213 }
182214
@@ -268,7 +300,7 @@ private[quoted] object Matcher {
268300 |
269301 | ${pattern.showExtractors}
270302 |
271- |
303+ |with environment: ${summon[ Env ]}
272304 |
273305 |
274306 | """ .stripMargin)
@@ -277,6 +309,33 @@ private[quoted] object Matcher {
277309 }
278310 end treeOps
279311
312+ private object ClosedPatternTerm {
313+ /** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
314+ def unapply (term : Term )(given Context , Env ): Option [term.type ] =
315+ if freePatternVars(term).isEmpty then Some (term) else None
316+
317+ /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
318+ def freePatternVars (term : Term )(given qctx : Context , env : Env ): Set [Symbol ] =
319+ val accumulator = new TreeAccumulator [Set [Symbol ]] {
320+ def foldTree (x : Set [Symbol ], tree : Tree )(given ctx : Context ): Set [Symbol ] =
321+ tree match
322+ case tree : Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
323+ case _ => foldOverTree(x, tree)
324+ }
325+ accumulator.foldTree(Set .empty, term)
326+ }
327+
328+ private object IdentArgs {
329+ def unapply (args : List [Term ])(given Context ): Option [List [Ident ]] =
330+ args.foldRight(Option (List .empty[Ident ])) {
331+ case (id : Ident , Some (acc)) => Some (id :: acc)
332+ case (Block (List (DefDef (" $anonfun" , Nil , List (params), Inferred (), Some (Apply (id : Ident , args)))), Closure (Ident (" $anonfun" ), None )), Some (acc))
333+ if params.zip(args).forall(_.symbol == _.symbol) =>
334+ Some (id :: acc)
335+ case _ => None
336+ }
337+ }
338+
280339 private def treeOptMatches (scrutinee : Option [Tree ], pattern : Option [Tree ])(given Context , Env ): Matching = {
281340 (scrutinee, pattern) match {
282341 case (Some (x), Some (y)) => x =?= y
@@ -344,7 +403,7 @@ private[quoted] object Matcher {
344403 |
345404 | ${pattern.showExtractors}
346405 |
347- |
406+ |with environment: ${summon[ Env ]}
348407 |
349408 |
350409 | """ .stripMargin)
0 commit comments