@@ -87,56 +87,42 @@ class LabelDefs extends MiniPhase {
8787 override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context ): tpd.Tree = {
8888 if (tree.symbol is Label ) tree
8989 else {
90- collectLabelDefs.clear()
91- val newRhs = collectLabelDefs.transform(tree.rhs)
92- var labelDefs = collectLabelDefs.labelDefs
90+ val labelDefs = collectLabelDefs(tree.rhs)
9391
9492 def putLabelDefsNearCallees = new TreeMap () {
95-
9693 override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
9794 tree match {
95+ case t : Template => t
9896 case t : Apply if labelDefs.contains(t.symbol) =>
9997 val labelDef = labelDefs(t.symbol)
10098 labelDefs -= t.symbol
101-
102- val labelDef2 = transform(labelDef)
99+ val labelDef2 = cpy.DefDef (labelDef)(rhs = transform(labelDef.rhs))
103100 Block (labelDef2:: Nil , t)
104-
101+ case t : DefDef =>
102+ assert(t.symbol is Label )
103+ EmptyTree
105104 case _ => if (labelDefs.nonEmpty) super .transform(tree) else tree
106105 }
107106 }
108107 }
109108
110- val res = cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(newRhs))
111-
112- res
109+ cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(tree.rhs))
113110 }
114111 }
115112
116- private object collectLabelDefs extends TreeMap () {
117-
113+ private def collectLabelDefs (tree : Tree )(implicit ctx : Context ): mutable.HashMap [Symbol , DefDef ] = {
118114 // labelSymbol -> Defining tree
119- val labelDefs = new mutable.HashMap [Symbol , Tree ]()
120-
121- def clear (): Unit = {
122- labelDefs.clear()
123- }
124-
125- override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = tree match {
126- case t : Template => t
127- case t : Block =>
128- val r = super .transform(t)
129- r match {
130- case t : Block if t.stats.isEmpty => t.expr
131- case _ => r
132- }
133- case t : DefDef =>
134- assert(t.symbol is Label )
135- val r = super .transform(tree)
136- labelDefs(r.symbol) = r
137- EmptyTree
138- case _ =>
139- super .transform(tree)
140- }
115+ val labelDefs = new mutable.HashMap [Symbol , DefDef ]()
116+ new TreeTraverser {
117+ override def traverse (tree : tpd.Tree )(implicit ctx : Context ): Unit = tree match {
118+ case _ : Template =>
119+ case t : DefDef =>
120+ assert(t.symbol is Label )
121+ labelDefs(t.symbol) = t
122+ traverseChildren(t)
123+ case _ => traverseChildren(tree)
124+ }
125+ }.traverse(tree)
126+ labelDefs
141127 }
142128}
0 commit comments