diff --git a/.gitignore b/.gitignore index 898f111c9..ba9fc93f9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,8 @@ out/ .*.sw[a-p] -.bsp -.idea \ No newline at end of file +.bsp/ +.idea/ +.bloop/ +.metals/ +.vscode/ +.scala-build/ \ No newline at end of file diff --git a/rocket/src/util.scala b/rocket/src/util.scala index 316fd572a..f7a667704 100644 --- a/rocket/src/util.scala +++ b/rocket/src/util.scala @@ -2,6 +2,8 @@ package org.chipsalliance.rocket import chisel3._ import chisel3.util._ +import scala.collection.immutable +import scala.collection.mutable //todo: remove util package object util { @@ -189,4 +191,26 @@ package object util { implicit def uintToBitPat(x: UInt): BitPat = BitPat(x) + def bitIndexes(x: BigInt, tail: Seq[Int] = Nil): Seq[Int] = { + require (x >= 0) + if (x == 0) { + tail.reverse + } else { + val lowest = x.lowestSetBit + bitIndexes(x.clearBit(lowest), lowest +: tail) + } + } + + /** Similar to Seq.groupBy except this returns a Seq instead of a Map + * Useful for deterministic code generation + */ + def groupByIntoSeq[A, K](xs: Seq[A])(f: A => K): immutable.Seq[(K, immutable.Seq[A])] = { + val map = mutable.LinkedHashMap.empty[K, mutable.ListBuffer[A]] + for (x <- xs) { + val key = f(x) + val l = map.getOrElseUpdate(key, mutable.ListBuffer.empty[A]) + l += x + } + map.view.map({ case (k, vs) => k -> vs.toList }).toList + } } diff --git a/rocket/src/util/AddressDecoder.scala b/rocket/src/util/AddressDecoder.scala new file mode 100644 index 000000000..8a84c8873 --- /dev/null +++ b/rocket/src/util/AddressDecoder.scala @@ -0,0 +1,134 @@ +// See LICENSE.SiFive for license details. + +package org.chipsalliance.rocket.util + +import Chisel.log2Ceil + +object AddressDecoder +{ + type Port = Seq[AddressSet] + type Ports = Seq[Port] + type Partition = Ports + type Partitions = Seq[Partition] + + val addressOrder = Ordering.ordered[AddressSet] + val portOrder = Ordering.Iterable(addressOrder) + val partitionOrder = Ordering.Iterable(portOrder) + + // Find the minimum subset of bits needed to disambiguate port addresses. + // ie: inspecting only the bits in the output, you can look at an address + // and decide to which port (outer Seq) the address belongs. + def apply(ports: Ports, givenBits: BigInt = BigInt(0)): BigInt = { + val nonEmptyPorts = ports.filter(_.nonEmpty) + if (nonEmptyPorts.size <= 1) { + givenBits + } else { + // Verify the user did not give us an impossible problem + nonEmptyPorts.combinations(2).foreach { case Seq(x, y) => + x.foreach { a => y.foreach { b => + require (!a.overlaps(b), s"Ports cannot overlap: $a $b") + } } + } + + val maxBits = log2Ceil(1 + nonEmptyPorts.map(_.map(_.base).max).max) + val (bitsToTry, bitsToTake) = (0 until maxBits).map(BigInt(1) << _).partition(b => (givenBits & b) == 0) + val partitions = Seq(nonEmptyPorts.map(_.sorted).sorted(portOrder)) + val givenPartitions = bitsToTake.foldLeft(partitions) { (p, b) => partitionPartitions(p, b) } + val selected = recurse(givenPartitions, bitsToTry.reverse.toSeq) + val output = selected.reduceLeft(_ | _) | givenBits + + // Modify the AddressSets to allow the new wider match functions + val widePorts = nonEmptyPorts.map { _.map { _.widen(~output) } } + // Verify that it remains possible to disambiguate all ports + widePorts.combinations(2).foreach { case Seq(x, y) => + x.foreach { a => y.foreach { b => + require (!a.overlaps(b), s"Ports cannot overlap: $a $b") + } } + } + + output + } + } + + // A simpler version that works for a Seq[Int] + def apply(keys: Seq[Int]): Int = { + val ports = keys.map(b => Seq(AddressSet(b, 0))) + apply(ports).toInt + } + + // The algorithm has a set of partitions, discriminated by the selected bits. + // Each partion has a set of ports, listing all addresses that lead to that port. + // Seq[Seq[Seq[AddressSet]]] + // ^^^^^^^^^^^^^^^ set of addresses that are routed out this port + // ^^^ the list of ports + // ^^^ cases already distinguished by the selected bits thus far + // + // Solving this problem is NP-hard, so we use a simple greedy heuristic: + // pick the bit which minimizes the number of ports in each partition + // as a secondary goal, reduce the number of AddressSets within a partition + + def bitScore(partitions: Partitions): Seq[Int] = { + val maxPortsPerPartition = partitions.map(_.size).max + val maxSetsPerPartition = partitions.map(_.map(_.size).sum).max + val sumSquarePortsPerPartition = partitions.map(p => p.size * p.size).sum + val sumSquareSetsPerPartition = partitions.map(_.map(p => p.size * p.size).sum).max + Seq(maxPortsPerPartition, maxSetsPerPartition, sumSquarePortsPerPartition, sumSquareSetsPerPartition) + } + + def partitionPort(port: Port, bit: BigInt): (Port, Port) = { + val addr_a = AddressSet(0, ~bit) + val addr_b = AddressSet(bit, ~bit) + // The addresses were sorted, so the filtered addresses are still sorted + val subset_a = port.filter(_.overlaps(addr_a)) + val subset_b = port.filter(_.overlaps(addr_b)) + (subset_a, subset_b) + } + + def partitionPorts(ports: Ports, bit: BigInt): (Ports, Ports) = { + val partitioned_ports = ports.map(p => partitionPort(p, bit)) + // because partitionPort dropped AddresSets, the ports might no longer be sorted + val case_a_ports = partitioned_ports.map(_._1).filter(!_.isEmpty).sorted(portOrder) + val case_b_ports = partitioned_ports.map(_._2).filter(!_.isEmpty).sorted(portOrder) + (case_a_ports, case_b_ports) + } + + def partitionPartitions(partitions: Partitions, bit: BigInt): Partitions = { + val partitioned_partitions = partitions.map(p => partitionPorts(p, bit)) + val case_a_partitions = partitioned_partitions.map(_._1).filter(!_.isEmpty) + val case_b_partitions = partitioned_partitions.map(_._2).filter(!_.isEmpty) + val new_partitions = (case_a_partitions ++ case_b_partitions).sorted(partitionOrder) + // Prevent combinational memory explosion; if two partitions are equal, keep only one + // Note: AddressSets in a port are sorted, and ports in a partition are sorted. + // This makes it easy to structurally compare two partitions for equality + val keep = (new_partitions.init zip new_partitions.tail) filter { case (a,b) => partitionOrder.compare(a,b) != 0 } map { _._2 } + new_partitions.head +: keep + } + + // requirement: ports have sorted addresses and are sorted lexicographically + val debug = false + def recurse(partitions: Partitions, bits: Seq[BigInt]): Seq[BigInt] = { + if (partitions.map(_.size <= 1).reduce(_ && _)) Seq() else { + if (debug) { + println("Partitioning:") + partitions.foreach { partition => + println(" Partition:") + partition.foreach { port => + print(" ") + port.foreach { a => print(s" ${a}") } + println("") + } + } + } + val candidates = bits.map { bit => + val result = partitionPartitions(partitions, bit) + val score = bitScore(result) + if (debug) + println(" For bit %x, %s".format(bit, score.toString)) + (score, bit, result) + } + val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable)) + if (debug) println("=> Selected bit 0x%x".format(bestBit)) + bestBit +: recurse(bestPartitions, bits.filter(_ != bestBit)) + } + } +} diff --git a/rocket/src/util/ClockGate.scala b/rocket/src/util/ClockGate.scala new file mode 100644 index 000000000..143c6e327 --- /dev/null +++ b/rocket/src/util/ClockGate.scala @@ -0,0 +1,51 @@ +// See LICENSE.SiFive for license details. + +package org.chipsalliance.rocket.util + +import chisel3._ +import chisel3.util.{HasBlackBoxResource, HasBlackBoxPath} + +import java.nio.file.{Files, Paths} + +abstract class ClockGate extends BlackBox + with HasBlackBoxResource with HasBlackBoxPath { + val io = IO(new Bundle{ + val in = Input(Clock()) + val test_en = Input(Bool()) + val en = Input(Bool()) + val out = Output(Clock()) + }) + + def addVerilogResource(vsrc: String): Unit = { + if (Files.exists(Paths.get(vsrc))) + addPath(vsrc) + else + addResource(vsrc) + } +} + +object ClockGate { + def apply[T <: ClockGate]( + in: Clock, + en: Bool, + modelFile: Option[String], + name: Option[String] = None): Clock = { + val cg = Module(new EICG_wrapper) + name.foreach(cg.suggestName(_)) + modelFile.map(cg.addVerilogResource(_)) + + cg.io.in := in + cg.io.test_en := false.B + cg.io.en := en + cg.io.out + } + + def apply[T <: ClockGate]( + in: Clock, + en: Bool, + name: String): Clock = + apply(in, en, Some(name)) +} + +// behavioral model of Integrated Clock Gating cell +class EICG_wrapper extends ClockGate \ No newline at end of file diff --git a/rocket/src/util/CoreMonitorBundle.scala b/rocket/src/util/CoreMonitorBundle.scala new file mode 100644 index 000000000..37ebdf51a --- /dev/null +++ b/rocket/src/util/CoreMonitorBundle.scala @@ -0,0 +1,28 @@ +// See LICENSE.Berkeley for license details. +// See LICENSE.SiFive for license details. + +package org.chipsalliance.rocket.util + +import chisel3._ + +// this bundle is used to expose some internal core signals +// to verification monitors which sample instruction commits +class CoreMonitorBundle(val xLen: Int, val fLen: Int) extends Bundle { + val excpt = Bool() + val priv_mode = UInt(width = 3.W) + val hartid = UInt(width = xLen.W) + val timer = UInt(width = 32.W) + val valid = Bool() + val pc = UInt(width = xLen.W) + val wrdst = UInt(width = 5.W) + val wrdata = UInt(width = (xLen max fLen).W) + val wrenx = Bool() + val wrenf = Bool() + @deprecated("replace wren with wrenx or wrenf to specify integer or floating point","Rocket Chip 2020.05") + def wren: Bool = wrenx || wrenf + val rd0src = UInt(width = 5.W) + val rd0val = UInt(width = xLen.W) + val rd1src = UInt(width = 5.W) + val rd1val = UInt(width = xLen.W) + val inst = UInt(width = 32.W) +} diff --git a/rocket/src/util/DescribedSRAM.scala b/rocket/src/util/DescribedSRAM.scala new file mode 100644 index 000000000..535781a7f --- /dev/null +++ b/rocket/src/util/DescribedSRAM.scala @@ -0,0 +1,30 @@ +// See LICENSE.Berkeley for license details. +// See LICENSE.SiFive for license details. + +package org.chipsalliance.rocket.util + +import chisel3.{Data, SyncReadMem, Vec} +import chisel3.util.log2Ceil + +object DescribedSRAM { + def apply[T <: Data]( + name: String, + desc: String, + size: BigInt, // depth + data: T + ): SyncReadMem[T] = { + + val mem = SyncReadMem(size, data) + + mem.suggestName(name) + + val granWidth = data match { + case v: Vec[_] => v.head.getWidth + case d => d.getWidth + } + + val uid = 0 + + mem + } +} \ No newline at end of file diff --git a/rocket/src/util/ECC.scala b/rocket/src/util/ECC.scala new file mode 100644 index 000000000..1b462cdf4 --- /dev/null +++ b/rocket/src/util/ECC.scala @@ -0,0 +1,233 @@ +// See LICENSE.Berkeley for license details. + +package org.chipsalliance.rocket.util + +import chisel3._ +import chisel3.util._ +import chisel3.util.random.LFSR + +abstract class Decoding +{ + def uncorrected: UInt + def corrected: UInt + def correctable: Bool + def uncorrectable: Bool // If true, correctable should be ignored + def error = correctable || uncorrectable +} + +abstract class Code +{ + def canDetect: Boolean + def canCorrect: Boolean + + def width(w0: Int): Int + + /** Takes the unencoded width and returns a list of indices indicating which + * bits of the encoded value will be used for ecc + */ + def eccIndices(width: Int): Seq[Int] + + /** Encode x to a codeword suitable for decode. + * If poison is true, the decoded value will report uncorrectable + * error despite uncorrected == corrected == x. + */ + def encode(x: UInt, poison: Bool = false.B): UInt + def decode(x: UInt): Decoding + + /** Copy the bits in x to the right bit positions in an encoded word, + * so that x === decode(swizzle(x)).uncorrected; but don't generate + * the other code bits, so decode(swizzle(x)).error might be true. + * For codes for which this operation is not trivial, throw an + * UnsupportedOperationException. */ + def swizzle(x: UInt): UInt +} + +class IdentityCode extends Code +{ + def canDetect = false + def canCorrect = false + + def width(w0: Int) = w0 + def eccIndices(width: Int) = Seq.empty[Int] + def encode(x: UInt, poison: Bool = false.B) = { + require (poison.isLit && poison.litValue == 0, "IdentityCode can not be poisoned") + x + } + def swizzle(x: UInt) = x + def decode(y: UInt) = new Decoding { + def uncorrected = y + def corrected = y + def correctable = false.B + def uncorrectable = false.B + } +} + +class ParityCode extends Code +{ + def canDetect = true + def canCorrect = false + + def width(w0: Int) = w0+1 + def eccIndices(w0: Int) = Seq(w0) + def encode(x: UInt, poison: Bool = false.B) = Cat(x.xorR ^ poison, x) + def swizzle(x: UInt) = Cat(false.B, x) + def decode(y: UInt) = new Decoding { + val uncorrected = y(y.getWidth-2,0) + val corrected = uncorrected + val correctable = false.B + val uncorrectable = y.xorR + } +} + +class SECCode extends Code +{ + def canDetect = true + def canCorrect = true + + // SEC codes may or may not be poisonous depending on the length + // If the code is perfect, every non-codeword is correctable + def poisonous(n: Int) = !isPow2(n+1) + + def width(k: Int) = { + val m = log2Floor(k) + 1 + k + m + (if((1 << m) < m+k+1) 1 else 0) + } + + def eccIndices(w0: Int) = { + (0 until width(w0)).collect { + case i if i >= w0 => i + } + } + + def swizzle(x: UInt) = { + val k = x.getWidth + val n = width(k) + Cat(0.U((n-k).W), x) + } + + // An (n=16, k=11) Hamming code is naturally encoded as: + // PPxPxxxPxxxxxxxP where P are parity bits and x are data + // Indexes typically start at 1, because then the P are on powers of two + // In systematic coding, you put all the data in the front: + // xxxxxxxxxxxPPPPP + // Indexes typically start at 0, because Computer Science + // For sanity when reading SRAMs, you want systematic form. + + private def impl(n: Int, k: Int) = { + require (n >= 3 && k >= 1 && !isPow2(n)) + val hamm2sys = IndexedSeq.tabulate(n+1) { i => + if (i == 0) { + n /* undefined */ + } else if (isPow2(i)) { + k + log2Ceil(i) + } else { + i - 1 - log2Ceil(i) + } + } + val sys2hamm = hamm2sys.zipWithIndex.sortBy(_._1).map(_._2).toIndexedSeq + def syndrome(j: Int) = { + val bit = 1 << j + ("b" + Seq.tabulate(n) { i => + if ((sys2hamm(i) & bit) != 0) "1" else "0" + }.reverse.mkString).U + } + (hamm2sys, sys2hamm, syndrome _) + } + + def encode(x: UInt, poison: Bool = false.B) = { + val k = x.getWidth + val n = width(k) + val (_, _, syndrome) = impl(n, k) + + require ((poison.isLit && poison.litValue == 0) || poisonous(n), s"SEC code of length ${n} cannot be poisoned") + + /* By setting the entire syndrome on poison, the corrected bit falls off the end of the code */ + val syndromeUInt = VecInit.tabulate(n-k) { j => (syndrome(j)(k-1, 0) & x).xorR ^ poison }.asUInt + Cat(syndromeUInt, x) + } + + def decode(y: UInt) = new Decoding { + val n = y.getWidth + val k = n - log2Ceil(n) + val (_, sys2hamm, syndrome) = impl(n, k) + + val syndromeUInt = VecInit.tabulate(n-k) { j => (syndrome(j) & y).xorR }.asUInt + + val hammBadBitOH = UIntToOH(syndromeUInt, n+1) + val sysBadBitOH = VecInit.tabulate(k) { i => hammBadBitOH(sys2hamm(i)) }.asUInt + + val uncorrected = y(k-1, 0) + val corrected = uncorrected ^ sysBadBitOH + val correctable = syndromeUInt.orR + val uncorrectable = if (poisonous(n)) { syndromeUInt > n.U } else { false.B } + } +} + +class SECDEDCode extends Code +{ + def canDetect = true + def canCorrect = true + + private val sec = new SECCode + private val par = new ParityCode + + def width(k: Int) = sec.width(k)+1 + def eccIndices(w0: Int) = { + (0 until width(w0)).collect { + case i if i >= w0 => i + } + } + def encode(x: UInt, poison: Bool = false.B) = { + // toggling two bits ensures the error is uncorrectable + // to ensure corrected == uncorrected, we pick one redundant + // bit from SEC (the highest); correcting it does not affect + // corrected == uncorrected. the second toggled bit is the + // parity bit, which also does not appear in the decoding + val toggle_lo = Cat(poison.asUInt, poison.asUInt) + val toggle_hi = toggle_lo << (sec.width(x.getWidth)-1) + par.encode(sec.encode(x)) ^ toggle_hi + } + def swizzle(x: UInt) = par.swizzle(sec.swizzle(x)) + def decode(x: UInt) = new Decoding { + val secdec = sec.decode(x(x.getWidth-2,0)) + val pardec = par.decode(x) + + val uncorrected = secdec.uncorrected + val corrected = secdec.corrected + val correctable = pardec.uncorrectable + val uncorrectable = !pardec.uncorrectable && secdec.correctable + } +} + +object ErrGen +{ + // generate a 1-bit error with approximate probability 2^-f + def apply(width: Int, f: Int): UInt = { + require(width > 0 && f >= 0 && log2Up(width) + f <= 16) + UIntToOH(LFSR(16)(log2Up(width)+f-1,0))(width-1,0) + } + def apply(x: UInt, f: Int): UInt = x ^ apply(x.getWidth, f) +} + +trait CanHaveErrors extends Bundle { + val correctable: Option[ValidIO[UInt]] + val uncorrectable: Option[ValidIO[UInt]] +} + +case class ECCParams( + bytes: Int = 1, + code: Code = new IdentityCode, + notifyErrors: Boolean = false, +) + +object Code { + def fromString(s: Option[String]): Code = fromString(s.getOrElse("none")) + def fromString(s: String): Code = s.toLowerCase match { + case "none" => new IdentityCode + case "identity" => new IdentityCode + case "parity" => new ParityCode + case "sec" => new SECCode + case "secded" => new SECDEDCode + case _ => throw new IllegalArgumentException("Unknown ECC type") + } +} \ No newline at end of file diff --git a/rocket/src/util/Memory.scala b/rocket/src/util/Memory.scala new file mode 100644 index 000000000..a5bcf7822 --- /dev/null +++ b/rocket/src/util/Memory.scala @@ -0,0 +1,244 @@ +// See LICENSE.SiFive for license details. + +package org.chipsalliance.rocket.util + +import chisel3._ +import chisel3.util._ +import chisel3.util.experimental._ + +import org.chipsalliance.rocket._ + +object Memory { + // The safe version will check the entire address + def findSafe(address: UInt, slaves: Seq[MemSlaveParameters]) = VecInit(slaves.map(_.address.map(_.contains(address)).reduce(_ || _))) + + // Compute the simplest AddressSets that decide a key + def fastPropertyGroup[K](p: MemSlaveParameters => K, slaves: Seq[MemSlaveParameters]): Seq[(K, Seq[AddressSet])] = { + val groups = groupByIntoSeq(slaves.map(m => (p(m), m.address)))( _._1).map { case (k, vs) => + k -> vs.flatMap(_._2) + } + val reductionMask = AddressDecoder(groups.map(_._2)) + groups.map { case (k, seq) => k -> AddressSet.unify(seq.map(_.widen(~reductionMask)).distinct) } + } + // Select a property + def fastProperty[K, D <: Data](address: UInt, p: MemSlaveParameters => K, d: K => D, slaves: Seq[MemSlaveParameters]): D = + Mux1H(fastPropertyGroup(p, slaves).map { case (v, a) => (a.map(_.contains(address)).reduce(_||_), d(v)) }) +} + +/** Options for describing the attributes of memory regions */ +object RegionType { + // Define the 'more relaxed than' ordering + val cases = Seq(CACHED, TRACKED, UNCACHED, IDEMPOTENT, VOLATILE, PUT_EFFECTS, GET_EFFECTS) + sealed trait T extends Ordered[T] { + def compare(that: T): Int = cases.indexOf(that) compare cases.indexOf(this) + } + + case object CACHED extends T // an intermediate agent may have cached a copy of the region for you + case object TRACKED extends T // the region may have been cached by another master, but coherence is being provided + case object UNCACHED extends T // the region has not been cached yet, but should be cached when possible + case object IDEMPOTENT extends T // gets return most recently put content, but content should not be cached + case object VOLATILE extends T // content may change without a put, but puts and gets have no side effects + case object PUT_EFFECTS extends T // puts produce side effects and so must not be combined/delayed + case object GET_EFFECTS extends T // gets produce side effects and so must not be issued speculatively +} + +// An potentially empty inclusive range of 2-powers [min, max] (in bytes) +case class TransferSizes(min: Int, max: Int) +{ + def this(x: Int) = this(x, x) + + require (min <= max, s"Min transfer $min > max transfer $max") + require (min >= 0 && max >= 0, s"TransferSizes must be positive, got: ($min, $max)") + require (max == 0 || isPow2(max), s"TransferSizes must be a power of 2, got: $max") + require (min == 0 || isPow2(min), s"TransferSizes must be a power of 2, got: $min") + require (max == 0 || min != 0, s"TransferSize 0 is forbidden unless (0,0), got: ($min, $max)") + + def none = min == 0 + def contains(x: Int) = isPow2(x) && min <= x && x <= max + def containsLg(x: Int) = contains(1 << x) + def containsLg(x: UInt) = + if (none) false.B + else if (min == max) { log2Ceil(min).U === x } + else { log2Ceil(min).U <= x && x <= log2Ceil(max).U } + + def contains(x: TransferSizes) = x.none || (min <= x.min && x.max <= max) + + def intersect(x: TransferSizes) = + if (x.max < min || max < x.min) TransferSizes.none + else TransferSizes(scala.math.max(min, x.min), scala.math.min(max, x.max)) + + // Not a union, because the result may contain sizes contained by neither term + // NOT TO BE CONFUSED WITH COVERPOINTS + def mincover(x: TransferSizes) = { + if (none) { + x + } else if (x.none) { + this + } else { + TransferSizes(scala.math.min(min, x.min), scala.math.max(max, x.max)) + } + } + + override def toString() = "TransferSizes[%d, %d]".format(min, max) +} + +object TransferSizes { + def apply(x: Int) = new TransferSizes(x) + val none = new TransferSizes(0) + + def mincover(seq: Seq[TransferSizes]) = seq.foldLeft(none)(_ mincover _) + def intersect(seq: Seq[TransferSizes]) = seq.reduce(_ intersect _) + + implicit def asBool(x: TransferSizes) = !x.none +} + +// AddressSets specify the address space managed by the manager +// Base is the base address, and mask are the bits consumed by the manager +// e.g: base=0x200, mask=0xff describes a device managing 0x200-0x2ff +// e.g: base=0x1000, mask=0xf0f decribes a device managing 0x1000-0x100f, 0x1100-0x110f, ... +case class AddressSet(val bitSet: BitSet) extends Ordered[AddressSet] +{ + // TODO: This assumption might not hold true after BitSet intersection or subtraction. It is highly depended on the concrete implementation of BitSet. + require(bitSet.terms.size == 1, "The wrapped BitSet should only have one BitPat") + + val base = bitSet.terms.head.value + val mask = bitSet.terms.head.mask + + def contains(x: BigInt) = bitSet matches x.U + def contains(x: UInt) = bitSet matches x + + // turn x into an address contained in this set + def legalize(x: UInt): UInt = base.U | (mask.U & x) + + // overlap iff bitwise: both care (~mask0 & ~mask1) => both equal (base0=base1) + def overlaps(x: AddressSet) = bitSet overlap x.bitSet + // contains iff bitwise: x.mask => mask && contains(x.base) + def contains(x: AddressSet) = bitSet cover x.bitSet + + // The number of bytes to which the manager must be aligned + def alignment = ((mask + 1) & ~mask) + // Is this a contiguous memory range + def contiguous = alignment == mask+1 + + def finite = mask >= 0 + def max = { require (finite, "Max cannot be calculated on infinite mask"); base | mask } + + // Widen the match function to ignore all bits in imask + def widen(imask: BigInt) = AddressSet(base & ~imask, mask | imask) + + // Return an AddressSet that only contains the addresses both sets contain + def intersect(x: AddressSet): Option[AddressSet] = { + if (!overlaps(x)) { + None + } else { + Some(AddressSet(bitSet intersect x.bitSet)) + } + } + + def subtract(x: AddressSet): Seq[AddressSet] = { + (bitSet intersect x.bitSet).terms.toSeq.map(p => AddressSet(BitSet(p))) + } + + // AddressSets have one natural Ordering (the containment order, if contiguous) + def compare(x: AddressSet) = { + val primary = (this.base - x.base).signum // smallest address first + val secondary = (x.mask - this.mask).signum // largest mask first + if (primary != 0) primary else secondary + } + + // We always want to see things in hex + override def toString() = { + if (mask >= 0) { + "AddressSet(0x%x, 0x%x)".format(base, mask) + } else { + "AddressSet(0x%x, ~0x%x)".format(base, ~mask) + } + } + + def toRanges = { + require (finite, "Ranges cannot be calculated on infinite mask") + val size = alignment + val fragments = mask & ~(size-1) + val bits = bitIndexes(fragments) + (BigInt(0) until (BigInt(1) << bits.size)).map { i => + val off = bitIndexes(i).foldLeft(base) { case (a, b) => a.setBit(bits(b)) } + AddressSet(off, size) + } + } +} + +object AddressSet +{ + def apply(base: BigInt, mask: BigInt): AddressSet = { + // Forbid misaligned base address (and empty sets) + require ((base & mask) == 0, s"Mis-aligned AddressSets are forbidden, got: ${this.toString}") + require (base >= 0, s"AddressSet negative base is ambiguous: $base") // TL2 address widths are not fixed => negative is ambiguous + // We do allow negative mask (=> ignore all high bits) + + AddressSet(BitSet(new BitPat(base, mask, base.U.getWidth max mask.U.getWidth))) + } + + val everything = AddressSet(0, -1) + def misaligned(base: BigInt, size: BigInt, tail: Seq[AddressSet] = Seq()): Seq[AddressSet] = { + if (size == 0) tail.reverse else { + val maxBaseAlignment = base & (-base) // 0 for infinite (LSB) + val maxSizeAlignment = BigInt(1) << log2Floor(size) // MSB of size + val step = + if (maxBaseAlignment == 0 || maxBaseAlignment > maxSizeAlignment) + maxSizeAlignment else maxBaseAlignment + misaligned(base+step, size-step, AddressSet(base, step-1) +: tail) + } + } + + def unify(seq: Seq[AddressSet], bit: BigInt): Seq[AddressSet] = { + // Pair terms up by ignoring 'bit' + seq.distinct.groupBy(x => AddressSet(x.base & ~bit, x.mask)).map { case (key, seq) => + if (seq.size == 1) { + seq.head // singleton -> unaffected + } else { + AddressSet(key.base, key.mask | bit) // pair - widen mask by bit + } + }.toList + } + + def unify(seq: Seq[AddressSet]): Seq[AddressSet] = { + val bits = seq.map(_.base).foldLeft(BigInt(0))(_ | _) + AddressSet.enumerateBits(bits).foldLeft(seq) { case (acc, bit) => unify(acc, bit) }.sorted + } + + def enumerateMask(mask: BigInt): Seq[BigInt] = { + def helper(id: BigInt, tail: Seq[BigInt]): Seq[BigInt] = + if (id == mask) (id +: tail).reverse else helper(((~mask | id) + 1) & mask, id +: tail) + helper(0, Nil) + } + + def enumerateBits(mask: BigInt): Seq[BigInt] = { + def helper(x: BigInt): Seq[BigInt] = { + if (x == 0) { + Nil + } else { + val bit = x & (-x) + bit +: helper(x & ~bit) + } + } + helper(mask) + } +} + +case class MemSlaveParameters( + val address: Seq[AddressSet], + val regionType: RegionType.T = RegionType.GET_EFFECTS, + + val executable: Boolean = false, + + val supportsAcquireT: TransferSizes = TransferSizes.none, + val supportsAcquireB: TransferSizes = TransferSizes.none, + val supportsArithmetic: TransferSizes = TransferSizes.none, + val supportsLogical: TransferSizes = TransferSizes.none, + val supportsGet: TransferSizes = TransferSizes.none, + val supportsPutFull: TransferSizes = TransferSizes.none, + val supportsPutPartial: TransferSizes = TransferSizes.none, + val supportsHint: TransferSizes = TransferSizes.none, + + val name: String, +) \ No newline at end of file diff --git a/rocket/src/util/Misc.scala b/rocket/src/util/Misc.scala new file mode 100644 index 000000000..c44773b54 --- /dev/null +++ b/rocket/src/util/Misc.scala @@ -0,0 +1,59 @@ +// See LICENSE.SiFive for license details. +// See LICENSE.Berkeley for license details. + +package org.chipsalliance.rocket.util + +import chisel3._ +import chisel3.util._ +import chisel3.util.random._ + +object PopCountAtLeast { + private def two(x: UInt): (Bool, Bool) = x.getWidth match { + case 1 => (x.asBool, false.B) + case n => + val half = x.getWidth / 2 + val (leftOne, leftTwo) = two(x(half - 1, 0)) + val (rightOne, rightTwo) = two(x(x.getWidth - 1, half)) + (leftOne || rightOne, leftTwo || rightTwo || (leftOne && rightOne)) + } + def apply(x: UInt, n: Int): Bool = n match { + case 0 => true.B + case 1 => x.orR + case 2 => two(x)._2 + case 3 => PopCount(x) >= n.U + } +} + +object Random +{ + def apply(mod: Int, random: UInt): UInt = { + if (isPow2(mod)) random.extract(log2Ceil(mod)-1,0) + else PriorityEncoder(partition(apply(1 << log2Up(mod*8), random), mod)) + } + def apply(mod: Int): UInt = apply(mod, randomizer) + def oneHot(mod: Int, random: UInt): UInt = { + if (isPow2(mod)) UIntToOH(random(log2Up(mod)-1,0)) + else PriorityEncoderOH(partition(apply(1 << log2Up(mod*8), random), mod)).asUInt + } + def oneHot(mod: Int): UInt = oneHot(mod, randomizer) + + private def randomizer = LFSR(16) + private def partition(value: UInt, slices: Int) = + Seq.tabulate(slices)(i => value < UInt((((i + 1) << value.getWidth) / slices).W)) +} + +object Split +{ + def apply(x: UInt, n0: Int) = { + val w = x.getWidth + (x.extract(w-1,n0), x.extract(n0-1,0)) + } + def apply(x: UInt, n1: Int, n0: Int) = { + val w = x.getWidth + (x.extract(w-1,n1), x.extract(n1-1,n0), x.extract(n0-1,0)) + } + def apply(x: UInt, n2: Int, n1: Int, n0: Int) = { + val w = x.getWidth + (x.extract(w-1,n2), x.extract(n2-1,n1), x.extract(n1-1,n0), x.extract(n0-1,0)) + } +} \ No newline at end of file diff --git a/rocket/src/util/Replacement.scala b/rocket/src/util/Replacement.scala new file mode 100644 index 000000000..2d4dbb266 --- /dev/null +++ b/rocket/src/util/Replacement.scala @@ -0,0 +1,325 @@ +// See LICENSE.SiFive for license details. +// See LICENSE.Berkeley for license details. + +// TODO: Should be upstreamed to Chisel + +package org.chipsalliance.rocket.util + +import chisel3._ +import chisel3.util._ +import chisel3.util.random._ + +abstract class ReplacementPolicy { + def nBits: Int + def perSet: Boolean + def way: UInt + def miss: Unit + def hit: Unit + def access(touch_way: UInt): Unit + def access(touch_ways: Seq[Valid[UInt]]): Unit + def state_read: UInt + def get_next_state(state: UInt, touch_way: UInt): UInt + def get_next_state(state: UInt, touch_ways: Seq[Valid[UInt]]): UInt = { + touch_ways.foldLeft(state)((prev, touch_way) => Mux(touch_way.valid, get_next_state(prev, touch_way.bits), prev)) + } + def get_replace_way(state: UInt): UInt +} + +object ReplacementPolicy { + def fromString(s: String, n_ways: Int): ReplacementPolicy = s.toLowerCase match { + case "random" => new RandomReplacement(n_ways) + case "lru" => new TrueLRU(n_ways) + case "plru" => new PseudoLRU(n_ways) + case t => throw new IllegalArgumentException(s"unknown Replacement Policy type $t") + } +} + +class RandomReplacement(n_ways: Int) extends ReplacementPolicy { + private val replace = Wire(Bool()) + replace := false.B + def nBits = 16 + def perSet = false + private val lfsr = LFSR(nBits, replace) + def state_read = WireDefault(lfsr) + + def way = Random(n_ways, lfsr) + def miss = replace := true.B + def hit = {} + def access(touch_way: UInt) = {} + def access(touch_ways: Seq[Valid[UInt]]) = {} + def get_next_state(state: UInt, touch_way: UInt) = 0.U //DontCare + def get_replace_way(state: UInt) = way +} + +abstract class SeqReplacementPolicy { + def access(set: UInt): Unit + def update(valid: Bool, hit: Bool, set: UInt, way: UInt): Unit + def way: UInt +} + +abstract class SetAssocReplacementPolicy { + def access(set: UInt, touch_way: UInt): Unit + def access(sets: Seq[UInt], touch_ways: Seq[Valid[UInt]]): Unit + def way(set: UInt): UInt +} + +class SeqRandom(n_ways: Int) extends SeqReplacementPolicy { + val logic = new RandomReplacement(n_ways) + def access(set: UInt) = { } + def update(valid: Bool, hit: Bool, set: UInt, way: UInt) = { + when (valid && !hit) { logic.miss } + } + def way = logic.way +} + +class TrueLRU(n_ways: Int) extends ReplacementPolicy { + // True LRU replacement policy, using a triangular matrix to track which sets are more recently used than others. + // The matrix is packed into a single UInt (or Bits). Example 4-way (6-bits): + // [5] - 3 more recent than 2 + // [4] - 3 more recent than 1 + // [3] - 2 more recent than 1 + // [2] - 3 more recent than 0 + // [1] - 2 more recent than 0 + // [0] - 1 more recent than 0 + def nBits = (n_ways * (n_ways-1)) / 2 + def perSet = true + private val state_reg = RegInit(0.U(nBits.W)) + def state_read = WireDefault(state_reg) + + private def extractMRUVec(state: UInt): Seq[UInt] = { + // Extract per-way information about which higher-indexed ways are more recently used + val moreRecentVec = Wire(Vec(n_ways-1, UInt(n_ways.W))) + var lsb = 0 + for (i <- 0 until n_ways-1) { + moreRecentVec(i) := Cat(state(lsb+n_ways-i-2,lsb), 0.U((i+1).W)) + lsb = lsb + (n_ways - i - 1) + } + moreRecentVec + } + + def get_next_state(state: UInt, touch_way: UInt): UInt = { + val nextState = Wire(Vec(n_ways-1, UInt(n_ways.W))) + val moreRecentVec = extractMRUVec(state) // reconstruct lower triangular matrix + val wayDec = UIntToOH(touch_way, n_ways) + + // Compute next value of triangular matrix + // set the touched way as more recent than every other way + nextState.zipWithIndex.map { case (e, i) => + e := Mux(i.U === touch_way, 0.U(n_ways.W), moreRecentVec(i) | wayDec) + } + + nextState.zipWithIndex.tail.foldLeft((nextState.head.apply(n_ways-1,1),0)) { case ((pe,pi),(ce,ci)) => (Cat(ce.apply(n_ways-1,ci+1), pe), ci) }._1 + } + + def access(touch_way: UInt): Unit = { + state_reg := get_next_state(state_reg, touch_way) + } + def access(touch_ways: Seq[Valid[UInt]]): Unit = { + when (touch_ways.map(_.valid).orR) { + state_reg := get_next_state(state_reg, touch_ways) + } + for (i <- 1 until touch_ways.size) { + cover(PopCount(touch_ways.map(_.valid)) === i.U, s"LRU_UpdateCount$i; LRU Update $i simultaneous") + } + } + + def get_replace_way(state: UInt): UInt = { + val moreRecentVec = extractMRUVec(state) // reconstruct lower triangular matrix + // For each way, determine if all other ways are more recent + val mruWayDec = (0 until n_ways).map { i => + val upperMoreRecent = (if (i == n_ways-1) true.B else moreRecentVec(i).apply(n_ways-1,i+1).andR) + val lowerMoreRecent = (if (i == 0) true.B else moreRecentVec.map(e => !e(i)).reduce(_ && _)) + upperMoreRecent && lowerMoreRecent + } + OHToUInt(mruWayDec) + } + + def way = get_replace_way(state_reg) + def miss = access(way) + def hit = {} + @deprecated("replace 'replace' with 'way' from abstract class ReplacementPolicy","Rocket Chip 2020.05") + def replace: UInt = way +} + +class PseudoLRU(n_ways: Int) extends ReplacementPolicy { + // Pseudo-LRU tree algorithm: https://en.wikipedia.org/wiki/Pseudo-LRU#Tree-PLRU + // + // + // - bits storage example for 4-way PLRU binary tree: + // bit[2]: ways 3+2 older than ways 1+0 + // / \ + // bit[1]: way 3 older than way 2 bit[0]: way 1 older than way 0 + // + // + // - bits storage example for 3-way PLRU binary tree: + // bit[1]: way 2 older than ways 1+0 + // \ + // bit[0]: way 1 older than way 0 + // + // + // - bits storage example for 8-way PLRU binary tree: + // bit[6]: ways 7-4 older than ways 3-0 + // / \ + // bit[5]: ways 7+6 > 5+4 bit[2]: ways 3+2 > 1+0 + // / \ / \ + // bit[4]: way 7>6 bit[3]: way 5>4 bit[1]: way 3>2 bit[0]: way 1>0 + + def nBits = n_ways - 1 + def perSet = true + private val state_reg = if (nBits == 0) Reg(UInt(0.W)) else RegInit(0.U(nBits.W)) + def state_read = WireDefault(state_reg) + + def access(touch_way: UInt): Unit = { + state_reg := get_next_state(state_reg, touch_way) + } + def access(touch_ways: Seq[Valid[UInt]]): Unit = { + when (touch_ways.map(_.valid).orR) { + state_reg := get_next_state(state_reg, touch_ways) + } + for (i <- 1 until touch_ways.size) { + cover(PopCount(touch_ways.map(_.valid)) === i.U, s"PLRU_UpdateCount$i; PLRU Update $i simultaneous") + } + } + + + /** @param state state_reg bits for this sub-tree + * @param touch_way touched way encoded value bits for this sub-tree + * @param tree_nways number of ways in this sub-tree + */ + def get_next_state(state: UInt, touch_way: UInt, tree_nways: Int): UInt = { + require(state.getWidth == (tree_nways-1), s"wrong state bits width ${state.getWidth} for $tree_nways ways") + require(touch_way.getWidth == (log2Ceil(tree_nways) max 1), s"wrong encoded way width ${touch_way.getWidth} for $tree_nways ways") + + if (tree_nways > 2) { + // we are at a branching node in the tree, so recurse + val right_nways: Int = 1 << (log2Ceil(tree_nways) - 1) // number of ways in the right sub-tree + val left_nways: Int = tree_nways - right_nways // number of ways in the left sub-tree + val set_left_older = !touch_way(log2Ceil(tree_nways)-1) + val left_subtree_state = state.extract(tree_nways-3, right_nways-1) + val right_subtree_state = state(right_nways-2, 0) + + if (left_nways > 1) { + // we are at a branching node in the tree with both left and right sub-trees, so recurse both sub-trees + Cat(set_left_older, + Mux(set_left_older, + left_subtree_state, // if setting left sub-tree as older, do NOT recurse into left sub-tree + get_next_state(left_subtree_state, touch_way.extract(log2Ceil(left_nways)-1,0), left_nways)), // recurse left if newer + Mux(set_left_older, + get_next_state(right_subtree_state, touch_way(log2Ceil(right_nways)-1,0), right_nways), // recurse right if newer + right_subtree_state)) // if setting right sub-tree as older, do NOT recurse into right sub-tree + } else { + // we are at a branching node in the tree with only a right sub-tree, so recurse only right sub-tree + Cat(set_left_older, + Mux(set_left_older, + get_next_state(right_subtree_state, touch_way(log2Ceil(right_nways)-1,0), right_nways), // recurse right if newer + right_subtree_state)) // if setting right sub-tree as older, do NOT recurse into right sub-tree + } + } else if (tree_nways == 2) { + // we are at a leaf node at the end of the tree, so set the single state bit opposite of the lsb of the touched way encoded value + !touch_way(0) + } else { // tree_nways <= 1 + // we are at an empty node in an empty tree for 1 way, so return single zero bit for Chisel (no zero-width wires) + 0.U(1.W) + } + } + + def get_next_state(state: UInt, touch_way: UInt): UInt = { + val touch_way_sized = if (touch_way.getWidth < log2Ceil(n_ways)) touch_way.padTo (log2Ceil(n_ways)) + else touch_way.extract(log2Ceil(n_ways)-1,0) + get_next_state(state, touch_way_sized, n_ways) + } + + + /** @param state state_reg bits for this sub-tree + * @param tree_nways number of ways in this sub-tree + */ + def get_replace_way(state: UInt, tree_nways: Int): UInt = { + require(state.getWidth == (tree_nways-1), s"wrong state bits width ${state.getWidth} for $tree_nways ways") + + // this algorithm recursively descends the binary tree, filling in the way-to-replace encoded value from msb to lsb + if (tree_nways > 2) { + // we are at a branching node in the tree, so recurse + val right_nways: Int = 1 << (log2Ceil(tree_nways) - 1) // number of ways in the right sub-tree + val left_nways: Int = tree_nways - right_nways // number of ways in the left sub-tree + val left_subtree_older = state(tree_nways-2) + val left_subtree_state = state.extract(tree_nways-3, right_nways-1) + val right_subtree_state = state(right_nways-2, 0) + + if (left_nways > 1) { + // we are at a branching node in the tree with both left and right sub-trees, so recurse both sub-trees + Cat(left_subtree_older, // return the top state bit (current tree node) as msb of the way-to-replace encoded value + Mux(left_subtree_older, // if left sub-tree is older, recurse left, else recurse right + get_replace_way(left_subtree_state, left_nways), // recurse left + get_replace_way(right_subtree_state, right_nways))) // recurse right + } else { + // we are at a branching node in the tree with only a right sub-tree, so recurse only right sub-tree + Cat(left_subtree_older, // return the top state bit (current tree node) as msb of the way-to-replace encoded value + Mux(left_subtree_older, // if left sub-tree is older, return and do not recurse right + 0.U(1.W), + get_replace_way(right_subtree_state, right_nways))) // recurse right + } + } else if (tree_nways == 2) { + // we are at a leaf node at the end of the tree, so just return the single state bit as lsb of the way-to-replace encoded value + state(0) + } else { // tree_nways <= 1 + // we are at an empty node in an unbalanced tree for non-power-of-2 ways, so return single zero bit as lsb of the way-to-replace encoded value + 0.U(1.W) + } + } + + def get_replace_way(state: UInt): UInt = get_replace_way(state, n_ways) + + def way = get_replace_way(state_reg) + def miss = access(way) + def hit = {} +} + +class SeqPLRU(n_sets: Int, n_ways: Int) extends SeqReplacementPolicy { + val logic = new PseudoLRU(n_ways) + val state = SyncReadMem(n_sets, UInt(logic.nBits.W)) + val current_state = Wire(UInt(logic.nBits.W)) + val next_state = Wire(UInt(logic.nBits.W)) + val plru_way = logic.get_replace_way(current_state) + + def access(set: UInt) = { + current_state := state.read(set) + } + + def update(valid: Bool, hit: Bool, set: UInt, way: UInt) = { + val update_way = Mux(hit, way, plru_way) + next_state := logic.get_next_state(current_state, update_way) + when (valid) { state.write(set, next_state) } + } + + def way = plru_way +} + + +class SetAssocLRU(n_sets: Int, n_ways: Int, policy: String) extends SetAssocReplacementPolicy { + val logic = policy.toLowerCase match { + case "plru" => new PseudoLRU(n_ways) + case "lru" => new TrueLRU(n_ways) + case t => throw new IllegalArgumentException(s"unknown Replacement Policy type $t") + } + val state_vec = + if (logic.nBits == 0) Reg(Vec(n_sets, UInt(logic.nBits.W))) // Work around elaboration error on following line + else RegInit(VecInit(Seq.fill(n_sets)(0.U(logic.nBits.W)))) + + def access(set: UInt, touch_way: UInt) = { + state_vec(set) := logic.get_next_state(state_vec(set), touch_way) + } + + def access(sets: Seq[UInt], touch_ways: Seq[Valid[UInt]]) = { + require(sets.size == touch_ways.size, "internal consistency check: should be same number of simultaneous updates for sets and touch_ways") + for (set <- 0 until n_sets) { + val set_touch_ways = (sets zip touch_ways).map { case (touch_set, touch_way) => + Pipe(touch_way.valid && (touch_set === set.U), touch_way.bits, 0)} + when (set_touch_ways.map(_.valid).orR) { + state_vec(set) := logic.get_next_state(state_vec(set), set_touch_ways) + } + } + } + + def way(set: UInt) = logic.get_replace_way(state_vec(set)) + +} \ No newline at end of file