From aecef26b60ddb378a0e748252de016150453714d Mon Sep 17 00:00:00 2001 From: ducky Date: Fri, 20 Jan 2017 18:18:13 -0800 Subject: [PATCH] Refactor out fromBits => asTypeOf, deprecate flatten and other bad APIs --- chiselFrontend/src/main/scala/Mux.scala | 33 ++++++ .../main/scala/chisel3/core/Aggregate.scala | 71 +++++++----- .../src/main/scala/chisel3/core/Bits.scala | 89 ++++----------- .../src/main/scala/chisel3/core/Data.scala | 107 ++++++++++++++---- .../src/main/scala/chisel3/core/Reg.scala | 11 +- .../main/scala/chisel3/core/SeqUtils.scala | 4 +- .../scala/chiselTests/ReinterpretCast.scala | 68 +++++++++++ 7 files changed, 260 insertions(+), 123 deletions(-) create mode 100644 chiselFrontend/src/main/scala/Mux.scala create mode 100644 src/test/scala/chiselTests/ReinterpretCast.scala diff --git a/chiselFrontend/src/main/scala/Mux.scala b/chiselFrontend/src/main/scala/Mux.scala new file mode 100644 index 00000000000..0483736c85a --- /dev/null +++ b/chiselFrontend/src/main/scala/Mux.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package chisel3.core + +import scala.language.experimental.macros + +import chisel3.internal.Builder.{pushOp} +import chisel3.internal.sourceinfo.{SourceInfo, MuxTransform} +import chisel3.internal.firrtl._ +import chisel3.internal.firrtl.PrimOp._ + +object Mux { + /** Creates a mux, whose output is one of the inputs depending on the + * value of the condition. + * + * @param cond condition determining the input to choose + * @param con the value chosen when `cond` is true + * @param alt the value chosen when `cond` is false + * @example + * {{{ + * val muxOut = Mux(data_in === 3.U, 3.U(4.W), 0.U(4.W)) + * }}} + */ + def apply[T <: Data](cond: Bool, con: T, alt: T): T = macro MuxTransform.apply[T] + + def do_apply[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = { + Binding.checkSynthesizable(cond, s"'cond' ($cond)") + Binding.checkSynthesizable(con, s"'con' ($con)") + Binding.checkSynthesizable(alt, s"'alt' ($alt)") + val d = cloneSupertype(Seq(con, alt), "Mux") + pushOp(DefPrim(sourceInfo, d, MultiplexOp, cond.ref, con.ref, alt.ref)) + } +} \ No newline at end of file diff --git a/chiselFrontend/src/main/scala/chisel3/core/Aggregate.scala b/chiselFrontend/src/main/scala/chisel3/core/Aggregate.scala index 5f64b079626..f334fdf8bdd 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Aggregate.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Aggregate.scala @@ -15,26 +15,25 @@ import chisel3.internal.sourceinfo._ * of) other Data objects. */ sealed abstract class Aggregate extends Data { - private[core] def cloneTypeWidth(width: Width): this.type = cloneType - private[core] def width: Width = flatten.map(_.width).reduce(_ + _) + /** Returns a Seq of the immediate contents of this Aggregate, in order. + */ + def getElements: Seq[Data] + + private[core] def width: Width = getElements.map(_.width).reduce(_ + _) private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = pushCommand(BulkConnect(sourceInfo, this.lref, that.lref)) - override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = SeqUtils.do_asUInt(this.flatten) - def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = { + SeqUtils.do_asUInt(getElements.map(_.asUInt())) + } + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = { var i = 0 - val wire = Wire(this.chiselCloneType) - val bits = - if (that.width.known && that.width.get >= wire.width.get) { - that - } else { - Wire(that.cloneTypeWidth(wire.width), init = that) - } - for (x <- wire.flatten) { - x := x.fromBits(bits(i + x.getWidth-1, i)) + val bits = Wire(UInt(this.width), init=that) // handles width padding + for (x <- getElements) { + x.connectFromBits(bits(i + x.getWidth - 1, i)) i += x.getWidth } - wire.asInstanceOf[this.type] } } @@ -78,19 +77,16 @@ object Vec { // Check that types are homogeneous. Width mismatch for Elements is safe. require(!elts.isEmpty) - def eltsCompatible(a: Data, b: Data) = a match { - case _: Element => a.getClass == b.getClass - case _: Aggregate => Mux.typesCompatible(a, b) - } - val t = elts.head - for (e <- elts.tail) - require(eltsCompatible(t, e), s"can't create Vec of heterogeneous types ${t.getClass} and ${e.getClass}") + val vec = Wire(new Vec(cloneSupertype(elts, "Vec"), elts.length)) - val maxWidth = elts.map(_.width).reduce(_ max _) - val vec = Wire(new Vec(t.cloneTypeWidth(maxWidth).chiselCloneType, elts.length)) def doConnect(sink: T, source: T) = { - if (elts.head.flatten.exists(_.dir != Direction.Unspecified)) { + // TODO: this looks bad, and should feel bad. Replace with a better abstraction. + val hasDirectioned = vec.sample_element match { + case t: Aggregate => t.flatten.exists(_.dir != Direction.Unspecified) + case t: Element => t.dir != Direction.Unspecified + } + if (hasDirectioned) { sink bulkConnect source } else { sink connect source @@ -163,13 +159,20 @@ object Vec { */ sealed class Vec[T <: Data] private (gen: => T, val length: Int) extends Aggregate with VecLike[T] { + private[core] override def typeEquivalent(that: Data): Boolean = that match { + case that: Vec[T] => + this.length == that.length && + (this.sample_element typeEquivalent that.sample_element) + case _ => false + } + // Note: the constructor takes a gen() function instead of a Seq to enforce // that all elements must be the same and because it makes FIRRTL generation // simpler. private val self: Seq[T] = Vector.fill(length)(gen) /** - * sample_element 'tracks' all changes to the elements of self. + * sample_element 'tracks' all changes to the elements. * For consistency, sample_element is always used for creating dynamically * indexed ports and outputing the FIRRTL type. * @@ -181,7 +184,7 @@ sealed class Vec[T <: Data] private (gen: => T, val length: Int) // This is somewhat weird although I think the best course of action here is // to deprecate allElements in favor of dispatched functions to Data or // a pattern matched recursive descent - private[chisel3] final def allElements: Seq[Element] = + private[chisel3] final override def allElements: Seq[Element] = (sample_element +: self).flatMap(_.allElements) /** Strong bulk connect, assigning elements in this Vec from elements in a Seq. @@ -244,8 +247,8 @@ sealed class Vec[T <: Data] private (gen: => T, val length: Int) } private[chisel3] def toType: String = s"${sample_element.toType}[$length]" - private[chisel3] lazy val flatten: IndexedSeq[Bits] = - (0 until length).flatMap(i => this.apply(i).flatten) + override def getElements: Seq[Data] = + (0 until length).map(apply(_)) for ((elt, i) <- self.zipWithIndex) elt.setRef(this, i) @@ -377,7 +380,15 @@ abstract class Record extends Aggregate { elements.toIndexedSeq.reverse.map(e => eltPort(e._2)).mkString("{", ", ", "}") } - private[chisel3] lazy val flatten = elements.toIndexedSeq.flatMap(_._2.flatten) + private[core] override def typeEquivalent(that: Data): Boolean = that match { + case that: Record => + this.getClass == that.getClass && + this.elements.size == that.elements.size && + this.elements.forall{case (name, model) => + that.elements.contains(name) && + (that.elements(name) typeEquivalent model)} + case _ => false + } // NOTE: This sets up dependent references, it can be done before closing the Module private[chisel3] override def _onModuleClose: Unit = { // scalastyle:ignore method.name @@ -390,6 +401,8 @@ abstract class Record extends Aggregate { private[chisel3] final def allElements: Seq[Element] = elements.toIndexedSeq.flatMap(_._2.allElements) + override def getElements: Seq[Data] = elements.toIndexedSeq.map(_._2) + // Helper because Bundle elements are reversed before printing private[chisel3] def toPrintableHelper(elts: Seq[(String, Data)]): Printable = { val xs = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala index e885f1eeb65..65217069a4e 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala @@ -8,7 +8,7 @@ import chisel3.internal._ import chisel3.internal.Builder.{pushCommand, pushOp} import chisel3.internal.firrtl._ import chisel3.internal.sourceinfo.{SourceInfo, DeprecatedSourceInfo, SourceInfoTransform, SourceInfoWhiteboxTransform, - UIntTransform, MuxTransform} + UIntTransform} import chisel3.internal.firrtl.PrimOp._ // TODO: remove this once we have CompileOptions threaded through the macro system. import chisel3.core.ExplicitCompileOptions.NotStrict @@ -54,7 +54,9 @@ sealed abstract class Bits(width: Width, override val litArg: Option[LitArg]) // Arguments for: self-checking code (can't do arithmetic on bits) // Arguments against: generates down to a FIRRTL UInt anyways - private[chisel3] def flatten: IndexedSeq[Bits] = IndexedSeq(this) + // Only used for Mux, which needs to return a value of the same type with width the larger of the + // inputs. + private[core] def cloneTypeWidth(width: Width): this.type def cloneType: this.type = cloneTypeWidth(width) @@ -390,6 +392,9 @@ abstract trait Num[T <: Data] { sealed class UInt private[core] (width: Width, lit: Option[ULit] = None) extends Bits(width, lit) with Num[UInt] { + private[core] override def typeEquivalent(that: Data): Boolean = + that.isInstanceOf[UInt] && this.width == that.width + private[core] override def cloneTypeWidth(w: Width): this.type = new UInt(w).asInstanceOf[this.type] private[chisel3] def toType = s"UInt$width" @@ -513,13 +518,9 @@ sealed class UInt private[core] (width: Width, lit: Option[ULit] = None) throwException(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint") } } - def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { - val res = Wire(this, null).asInstanceOf[this.type] - res := (that match { - case u: UInt => u - case _ => that.asUInt - }) - res + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = { + this := that.asUInt } } @@ -555,6 +556,9 @@ object Bits extends UIntFactory sealed class SInt private[core] (width: Width, lit: Option[SLit] = None) extends Bits(width, lit) with Num[SInt] { + private[core] override def typeEquivalent(that: Data): Boolean = + this.getClass == that.getClass && this.width == that.width // TODO: should this be true for unspecified widths? + private[core] override def cloneTypeWidth(w: Width): this.type = new SInt(w).asInstanceOf[this.type] private[chisel3] def toType = s"SInt$width" @@ -657,13 +661,8 @@ sealed class SInt private[core] (width: Width, lit: Option[SLit] = None) throwException(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint") } } - def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { - val res = Wire(this, null).asInstanceOf[this.type] - res := (that match { - case s: SInt => s - case _ => that.asSInt - }) - res + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { + this := that.asSInt } } @@ -756,52 +755,6 @@ trait BoolFactory { object Bool extends BoolFactory -object Mux { - /** Creates a mux, whose output is one of the inputs depending on the - * value of the condition. - * - * @param cond condition determining the input to choose - * @param con the value chosen when `cond` is true - * @param alt the value chosen when `cond` is false - * @example - * {{{ - * val muxOut = Mux(data_in === 3.U, 3.U(4.W), 0.U(4.W)) - * }}} - */ - def apply[T <: Data](cond: Bool, con: T, alt: T): T = macro MuxTransform.apply[T] - - def do_apply[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = - (con, alt) match { - // Handle Mux(cond, UInt, Bool) carefully so that the concrete type is UInt - case (c: Bool, a: Bool) => doMux(cond, c, a).asInstanceOf[T] - case (c: UInt, a: Bool) => doMux(cond, c, a << 0).asInstanceOf[T] - case (c: Bool, a: UInt) => doMux(cond, c << 0, a).asInstanceOf[T] - case (c: Bits, a: Bits) => doMux(cond, c, a).asInstanceOf[T] - case _ => doAggregateMux(cond, con, alt) - } - - private def doMux[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = { - require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}") - Binding.checkSynthesizable(cond, s"'cond' ($cond)") - Binding.checkSynthesizable(con, s"'con' ($con)") - Binding.checkSynthesizable(alt, s"'alt' ($alt)") - val d = alt.cloneTypeWidth(con.width max alt.width) - pushOp(DefPrim(sourceInfo, d, MultiplexOp, cond.ref, con.ref, alt.ref)) - } - - private[core] def typesCompatible[T <: Data](x: T, y: T): Boolean = { - val sameTypes = x.getClass == y.getClass - val sameElements = x.flatten zip y.flatten forall { case (a, b) => a.getClass == b.getClass && a.width == b.width } - val sameNumElements = x.flatten.size == y.flatten.size - sameTypes && sameElements && sameNumElements - } - - private def doAggregateMux[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = { - require(typesCompatible(con, alt), s"can't Mux between heterogeneous types ${con.getClass} and ${alt.getClass}") - doMux(cond, con, alt) - } -} - //scalastyle:off number.of.methods /** * A sealed class representing a fixed point number that has a bit width and a binary point @@ -817,6 +770,11 @@ object Mux { */ sealed class FixedPoint private (width: Width, val binaryPoint: BinaryPoint, lit: Option[FPLit] = None) extends Bits(width, lit) with Num[FixedPoint] { + private[core] override def typeEquivalent(that: Data): Boolean = that match { + case that: FixedPoint => this.width == that.width && this.binaryPoint == that.binaryPoint // TODO: should this be true for unspecified widths? + case _ => false + } + private[core] override def cloneTypeWidth(w: Width): this.type = new FixedPoint(w, binaryPoint).asInstanceOf[this.type] private[chisel3] def toType = s"Fixed$width$binaryPoint" @@ -926,13 +884,12 @@ sealed class FixedPoint private (width: Width, val binaryPoint: BinaryPoint, lit override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = pushOp(DefPrim(sourceInfo, UInt(this.width), AsUIntOp, ref)) override def do_asSInt(implicit sourceInfo: SourceInfo): SInt = pushOp(DefPrim(sourceInfo, SInt(this.width), AsSIntOp, ref)) - def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { - val res = Wire(this, null).asInstanceOf[this.type] - res := (that match { + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { + // TODO: redefine as just asFixedPoint on that, where FixedPoint.asFixedPoint just works. + this := (that match { case fp: FixedPoint => fp.asSInt.asFixedPoint(this.binaryPoint) case _ => that.asFixedPoint(this.binaryPoint) }) - res } //TODO(chick): Consider "convert" as an arithmetic conversion to UInt/SInt } diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index 4f3e2268357..9310271ddee 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -29,6 +29,37 @@ object DataMirror { def widthOf(target: Data): Width = target.width } +/** Creates a clone of the super-type of the input elements. Super-type is defined as: + * - for Bits type of the same class: the cloned type of the largest width + * - Bools are treated as UInts + * - For other types of the same class are are the same: clone of any of the elements + * - Otherwise: fail + */ +private[core] object cloneSupertype { + def apply[T <: Data](elts: Seq[T], createdType: String)(implicit sourceInfo: SourceInfo): T = { + require(!elts.isEmpty, s"can't create $createdType with no inputs") + + if (elts forall {_.isInstanceOf[Bits]}) { + val model: T = elts reduce { (elt1: T, elt2: T) => ((elt1, elt2) match { + case (elt1: Bool, elt2: Bool) => elt1 + case (elt1: Bool, elt2: UInt) => elt2 // TODO: what happens with zero width UInts? + case (elt1: UInt, elt2: Bool) => elt1 // TODO: what happens with zero width UInts? + case (elt1: UInt, elt2: UInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 // TODO: perhaps redefine Widths to allow >= op? + case (elt1: SInt, elt2: SInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 + case (elt1, elt2) => + throw new AssertionError(s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}") + }).asInstanceOf[T] } + model.chiselCloneType + } else { + for (elt <- elts.tail) { + require(elt.getClass == elts.head.getClass, s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}") + require(elt typeEquivalent elts.head, s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}") + } + elts.head.chiselCloneType + } + } +} + /** * Input, Output, and Flipped are used to define the directions of Module IOs. * @@ -128,8 +159,18 @@ object Data { * from bits. */ abstract class Data extends HasId { + // This is a bad API that punches through object boundaries. + @deprecated("pending removal once all instances replaced", "chisel3") + private[chisel3] def flatten: IndexedSeq[Element] = { + this match { + case elt: Aggregate => elt.getElements.toIndexedSeq flatMap {_.flatten} + case elt: Element => IndexedSeq(elt) + } + } + // Return ALL elements at root of this type. // Contasts with flatten, which returns just Bits + // TODO: refactor away this, this is outside the scope of Data private[chisel3] def allElements: Seq[Element] private[core] def badConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = @@ -166,9 +207,14 @@ abstract class Data extends HasId { this legacyConnect that } } + + /** Whether this Data has the same model ("data type") as that Data. + * Data subtypes should overload this with checks against their own type. + */ + private[core] def typeEquivalent(that: Data): Boolean + private[chisel3] def lref: Node = Node(this) private[chisel3] def ref: Arg = if (isLit) litArg.get else lref - private[core] def cloneTypeWidth(width: Width): this.type private[chisel3] def toType: String private[core] def width: Width private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit @@ -187,6 +233,8 @@ abstract class Data extends HasId { * @return a copy of the object with appropriate core state. */ def chiselCloneType(implicit compileOptions: CompileOptions): this.type = { + // TODO: refactor away allElements, handle this with Aggregate/Element match inside Bindings + // Call the user-supplied cloneType method val clone = this.cloneType // In compatibility mode, simply return cloneType; otherwise, propagate @@ -215,17 +263,6 @@ abstract class Data extends HasId { /** Returns Some(width) if the width is known, else None. */ final def widthOption: Option[Int] = if (isWidthKnown) Some(getWidth) else None - // While this being in the Data API doesn't really make sense (should be in - // Aggregate, right?) this is because of an implementation limitation: - // cloneWithDirection, which is private and defined here, needs flatten to - // set element directionality. - // Related: directionality is mutable state. A possible solution for both is - // to define directionality relative to the container, but these parent links - // currently don't exist (while this information may be available during - // FIRRTL emission, it would break directionality querying from Chisel, which - // does get used). - private[chisel3] def flatten: IndexedSeq[Bits] - /** Creates an new instance of this type, unpacking the input Bits into * structured data. * @@ -236,16 +273,39 @@ abstract class Data extends HasId { * @note does NOT check bit widths, may drop bits during assignment * @note what fromBits assigs to must have known widths */ - def fromBits(that: Bits): this.type = macro CompileOptionsTransform.thatArg - - def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type + def fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + val output = Wire(chiselCloneType).asInstanceOf[this.type] + output.connectFromBits(that) + output + } /** Packs the value of this object as plain Bits. * * This performs the inverse operation of fromBits(Bits). */ @deprecated("Best alternative, .asUInt()", "chisel3") - def toBits(): UInt = SeqUtils.do_asUInt(this.flatten)(DeprecatedSourceInfo) + def toBits(): UInt = do_asUInt(DeprecatedSourceInfo) + + /** Does a reinterpret cast of the bits in this node into the format that provides. + * Returns a new Wire of that type. Does not modify existing nodes. + * + * x.asTypeOf(that) performs the inverse operation of x = that.toBits. + * + * @note bit widths are NOT checked, may pad or drop bits from input + * @note that should have known widths + */ + def asTypeOf[T <: Data](that: T): T = macro CompileOptionsTransform.thatArg + + def do_asTypeOf[T <: Data](that: T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { + val thatCloned = Wire(that.chiselCloneType) + thatCloned.connectFromBits(this.asUInt()) + thatCloned + } + + /** Assigns this node from Bits type. Internal implementation for asTypeOf. + */ + private[core] def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit /** Reinterpret cast to UInt. * @@ -256,8 +316,7 @@ abstract class Data extends HasId { */ final def asUInt(): UInt = macro SourceInfoTransform.noArg - def do_asUInt(implicit sourceInfo: SourceInfo): UInt = - SeqUtils.do_asUInt(this.flatten)(sourceInfo) + def do_asUInt(implicit sourceInfo: SourceInfo): UInt // firrtlDirection is the direction we report to firrtl. // It maintains the user-specified value (as opposed to the "actual" or applied/propagated value). @@ -310,10 +369,11 @@ object Clock { // TODO: Document this. sealed class Clock extends Element(Width(1)) { def cloneType: this.type = Clock().asInstanceOf[this.type] - private[chisel3] override def flatten: IndexedSeq[Bits] = IndexedSeq() - private[core] def cloneTypeWidth(width: Width): this.type = cloneType private[chisel3] def toType = "Clock" + private[core] def typeEquivalent(that: Data): Boolean = + this.getClass == that.getClass + override def connect (that: Data)(implicit sourceInfo: SourceInfo, connectCompileOptions: CompileOptions): Unit = that match { case _: Clock => super.connect(that)(sourceInfo, connectCompileOptions) case _ => super.badConnect(that)(sourceInfo) @@ -323,9 +383,8 @@ sealed class Clock extends Element(Width(1)) { def toPrintable: Printable = PString("CLOCK") override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = pushOp(DefPrim(sourceInfo, UInt(this.width), AsUIntOp, ref)) - override def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { - val ret = Wire(this.cloneType) - ret := that - ret.asInstanceOf[this.type] + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = { + this := that } } diff --git a/chiselFrontend/src/main/scala/chisel3/core/Reg.scala b/chiselFrontend/src/main/scala/chisel3/core/Reg.scala index 57979be0496..30abe5e5f23 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Reg.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Reg.scala @@ -16,12 +16,19 @@ object Reg { } t.chiselCloneType } else if (next ne null) { - next.cloneTypeWidth(Width()) + (next match { + case next: Bits => next.cloneTypeWidth(Width()) + case _ => next.chiselCloneType + }).asInstanceOf[T] + } else if (init ne null) { init.litArg match { // For e.g. Reg(init=UInt(0, k)), fix the Reg's width to k case Some(lit) if lit.forcedWidth => init.chiselCloneType - case _ => init.cloneTypeWidth(Width()) + case _ => (init match { + case init: Bits => init.cloneTypeWidth(Width()) + case _ => init.chiselCloneType + }).asInstanceOf[T] } } else { throwException("cannot infer type") diff --git a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala index e435860e25e..3254a1fb58a 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala @@ -58,8 +58,8 @@ private[chisel3] object SeqUtils { in.head._2 } else { val masked = for ((s, i) <- in) yield Mux(s, i.asUInt, 0.U) - val width = in.map(_._2.width).reduce(_ max _) - in.head._2.cloneTypeWidth(width).fromBits(masked.reduceLeft(_|_)) + val output = cloneSupertype(in.toSeq map {_._2}, "oneHotMux") + output.fromBits(masked.reduceLeft(_|_)) } } } diff --git a/src/test/scala/chiselTests/ReinterpretCast.scala b/src/test/scala/chiselTests/ReinterpretCast.scala new file mode 100644 index 00000000000..d4aecbe1f86 --- /dev/null +++ b/src/test/scala/chiselTests/ReinterpretCast.scala @@ -0,0 +1,68 @@ +// See LICENSE for license details. + +package chiselTests + +import org.scalatest._ + +import chisel3._ +import chisel3.experimental.FixedPoint +import chisel3.testers.BasicTester +import chisel3.util._ +import chisel3.core.DataMirror + +class AsBundleTester extends BasicTester { + class MultiTypeBundle extends Bundle { + val u = UInt(4.W) + val s = SInt(4.W) + val fp = FixedPoint(4.W, 3.BP) + } + + val rawData = ((4 << 8) + (15 << 4) + (12 << 0)).U + val bunFromBits = rawData.asTypeOf(new MultiTypeBundle) + + assert(bunFromBits.u === 4.U) + assert(bunFromBits.s === -1.S) + assert(bunFromBits.fp === FixedPoint.fromDouble(-0.5, width=4, binaryPoint=3)) + + stop() +} + +class AsVecTester extends BasicTester { + val rawData = ((15 << 12) + (0 << 8) + (1 << 4) + (2 << 0)).U + val vec = rawData.asTypeOf(Vec(4, SInt(4.W))) + + assert(vec(0) === 2.S) + assert(vec(1) === 1.S) + assert(vec(2) === 0.S) + assert(vec(3) === -1.S) + + stop() +} + +class AsBitsTruncationTester extends BasicTester { + val truncate = (64 + 3).U.asTypeOf(UInt(3.W)) + val expand = 1.U.asTypeOf(UInt(3.W)) + + assert( DataMirror.widthOf(truncate).get == 3 ) + assert( truncate === 3.U ) + assert( DataMirror.widthOf(expand).get == 3 ) + assert( expand === 1.U ) + + stop() +} + +class AsBitsSpec extends ChiselFlatSpec { + behavior of "fromBits" + + it should "work with Bundles containing Bits Types" in { + assertTesterPasses{ new AsBundleTester } + } + + it should "work with Vecs containing Bits Types" in { + assertTesterPasses{ new AsVecTester } + } + + it should "expand and truncate UInts of different width" in { + assertTesterPasses{ new AsBitsTruncationTester } + } +}