@@ -16,7 +16,6 @@ import dotty.tools.dotc.ast.Trees._
1616import SymUtils ._
1717
1818import annotation .threadUnsafe
19- import collection .mutable
2019
2120object CompleteJavaEnums {
2221 val name : String = " completeJavaEnums"
@@ -116,13 +115,13 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
116115 && (((cls.owner.name eq nme.DOLLAR_NEW ) && cls.owner.isAllOf(Private | Synthetic )) || cls.owner.isAllOf(EnumCase ))
117116 && cls.owner.owner.linkedClass.derivesFromJavaEnum
118117
119- @ threadUnsafe
120- private lazy val enumCaseOrdinals : mutable.Map [Symbol , Int ] = mutable.AnyRefMap .empty
118+ private val enumCaseOrdinals : MutableSymbolMap [Int ] = newMutableSymbolMap
121119
122120 private def registerEnumClass (cls : Symbol )(using Context ): Unit =
123121 cls.children.zipWithIndex.foreach(enumCaseOrdinals.put)
124122
125- private def ordinalFor (enumCase : Symbol ): Int = enumCaseOrdinals.remove(enumCase).get
123+ private def ordinalFor (enumCase : Symbol ): Int =
124+ enumCaseOrdinals.remove(enumCase).get
126125
127126 /** 1. If this is an enum class, add $name and $ordinal parameters to its
128127 * parameter accessors and pass them on to the java.lang.Enum constructor.
@@ -145,17 +144,20 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
145144 * "same as before"
146145 * }
147146 */
148- override def transformTemplate (templ : Template )(using Context ): Template = {
147+ override def transformTemplate (templ : Template )(using Context ): Tree = {
149148 val cls = templ.symbol.owner
150149 if cls.derivesFromJavaEnum then
151- registerEnumClass(cls)
150+ registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala
152151 val (params, rest) = decomposeTemplateBody(templ.body)
153152 val addedDefs = addedParams(cls, isLocal= true , ParamAccessor )
154153 val addedSyms = addedDefs.map(_.symbol.entered)
155154 val addedForwarders = addedEnumForwarders(cls)
156155 cpy.Template (templ)(
157156 parents = addEnumConstrArgs(defn.JavaEnumClass , templ.parents, addedSyms.map(ref)),
158157 body = params ++ addedDefs ++ addedForwarders ++ rest)
158+ else if cls.linkedClass.derivesFromJavaEnum then
159+ enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases
160+ templ
159161 else if isJavaEnumValueImpl(cls) then
160162 def creatorParamRef (name : TermName ) =
161163 ref(cls.owner.paramSymss.head.find(_.name == name).get)
0 commit comments