diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 0abdaa365b..2925bfd970 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,24 +10,14 @@ import firrtl.Mappers._ import annotation.tailrec object CommonSubexpressionElimination extends Pass { - private def cseOnce(s: Statement): (Statement, Long) = { - var nEliminated = 0L + private def cse(s: Statement): Statement = { val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() - def recordNodes(s: Statement): Statement = s match { - case x: DefNode => - nodes(x.name) = x.value - expressions.getOrElseUpdate(x.value, x.name) - x - case _ => s map recordNodes - } - def eliminateNodeRef(e: Expression): Expression = e match { case WRef(name, tpe, kind, gender) => nodes get name match { case Some(expression) => expressions get expression match { case Some(cseName) if cseName != name => - nEliminated += 1 WRef(cseName, tpe, kind, gender) case _ => e } @@ -36,16 +26,17 @@ object CommonSubexpressionElimination extends Pass { case _ => e map eliminateNodeRef } - def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef - - recordNodes(s) - (eliminateNodeRefs(s), nEliminated) - } + def eliminateNodeRefs(s: Statement): Statement = { + s map eliminateNodeRef match { + case x: DefNode => + nodes(x.name) = x.value + expressions.getOrElseUpdate(x.value, x.name) + x + case other => other map eliminateNodeRefs + } + } - @tailrec - private def cse(s: Statement): Statement = { - val (res, n) = cseOnce(s) - if (n > 0) cse(res) else res + eliminateNodeRefs(s) } def run(c: Circuit): Circuit = {