diff --git a/src/main/scala/KeyBuffer.scala b/src/main/scala/KeyBuffer.scala index 15e084b..9b39ded 100644 --- a/src/main/scala/KeyBuffer.scala +++ b/src/main/scala/KeyBuffer.scala @@ -5,14 +5,18 @@ import chisel3.util._ class KeyBufferIO(busWidth: Int, numberOfBuffers: Int) extends Bundle { + // Inputs from KV Transfer module val enq = Flipped(Decoupled(UInt(busWidth.W))) - val deq = Decoupled(UInt(busWidth.W)) - val bufferInputSelect = Input(UInt(log2Ceil(numberOfBuffers).W)) val incrWritePtr = Input(Bool()) - val clearBuffer = Input(Bool()) val lastInput = Input(Bool()) + // Control inputs from KV Transfer module + val clearBuffer = Input(Bool()) + val mask = Input(UInt(numberOfBuffers.W)) + + // Outputs to Merger module + val deq = Decoupled(UInt(busWidth.W)) val bufferOutputSelect = Output(UInt(log2Ceil(numberOfBuffers).W)) val empty = Output(Bool()) val lastOutput = Output(Bool()) @@ -52,6 +56,7 @@ class KeyBuffer(busWidth: Int, numberOfBuffers: Int, maximumKeySize: Int) extend val stateReg = RegInit(idle) val shadowReg = RegInit(0.U(busWidth.W)) + val mask = RegInit(0.U(numberOfBuffers.W)) val bufferOutputSelect = RegInit(0.U(log2Ceil(numberOfBuffers).W)) val shouldIncreaseReadPtr = bufferOutputSelect === (numberOfBuffers-1).U @@ -80,13 +85,17 @@ class KeyBuffer(busWidth: Int, numberOfBuffers: Int, maximumKeySize: Int) extend } } + val nextIndexSelector = Module(new NextIndexSelector(numberOfBuffers)) + nextIndexSelector.io.mask := io.mask + nextIndexSelector.io.currentIndex := bufferOutputSelect + when (io.clearBuffer) { stateReg := idle emptyReg := true.B fullReg := false.B writePtr := 0.U readPtr := 0.U - bufferOutputSelect := 0.U + bufferOutputSelect := PriorityEncoder(io.mask) lastChunks.foreach(_ := false.B) lastChunkCounters.foreach(_ := 0.U) } @@ -100,7 +109,7 @@ class KeyBuffer(busWidth: Int, numberOfBuffers: Int, maximumKeySize: Int) extend fullReg := false.B emptyReg := nextRead === writePtr && shouldIncreaseReadPtr && !io.incrWritePtr // prepare for the next read, as it requires one cycle delay - bufferOutputSelect := bufferOutputSelect + 1.U + bufferOutputSelect := nextIndexSelector.io.nextIndex } } is(valid) { @@ -111,9 +120,9 @@ class KeyBuffer(busWidth: Int, numberOfBuffers: Int, maximumKeySize: Int) extend fullReg := false.B emptyReg := nextRead === writePtr && shouldIncreaseReadPtr && !io.incrWritePtr incrRead := shouldIncreaseReadPtr - bufferOutputSelect := bufferOutputSelect + 1.U + bufferOutputSelect := nextIndexSelector.io.nextIndex } otherwise { - bufferOutputSelect := 0.U + bufferOutputSelect := nextIndexSelector.io.nextIndex stateReg := idle } } otherwise { @@ -130,9 +139,9 @@ class KeyBuffer(busWidth: Int, numberOfBuffers: Int, maximumKeySize: Int) extend fullReg := false.B emptyReg := nextRead === writePtr && shouldIncreaseReadPtr && !io.incrWritePtr incrRead := shouldIncreaseReadPtr - bufferOutputSelect := bufferOutputSelect + 1.U + bufferOutputSelect := nextIndexSelector.io.nextIndex } otherwise { - bufferOutputSelect := 0.U + bufferOutputSelect := nextIndexSelector.io.nextIndex stateReg := idle } } diff --git a/src/test/scala/KeyBufferSpec.scala b/src/test/scala/KeyBufferSpec.scala index 5aa9b80..368f92d 100644 --- a/src/test/scala/KeyBufferSpec.scala +++ b/src/test/scala/KeyBufferSpec.scala @@ -14,6 +14,12 @@ class KeyBufferSpec extends AnyFreeSpec with ChiselScalatestTester { dut.io.enq.valid.poke(true.B) dut.io.incrWritePtr.poke(false.B) dut.io.lastInput.poke(false.B) + dut.io.mask.poke("b1111".U) + + // clear key buffer before loading + dut.io.clearBuffer.poke(true.B) + dut.clock.step() + dut.io.clearBuffer.poke(false.B) // Write two rows of key chunks for (i <- 0 until 4) { @@ -111,6 +117,12 @@ class KeyBufferSpec extends AnyFreeSpec with ChiselScalatestTester { dut.io.enq.valid.poke(true.B) dut.io.incrWritePtr.poke(false.B) dut.io.lastInput.poke(false.B) + dut.io.mask.poke("b1111".U) + + // clear key buffer before loading + dut.io.clearBuffer.poke(true.B) + dut.clock.step() + dut.io.clearBuffer.poke(false.B) // Write first row of key chunks for (i <- 0 until 4) { @@ -213,6 +225,12 @@ class KeyBufferSpec extends AnyFreeSpec with ChiselScalatestTester { dut.io.incrWritePtr.poke(false.B) dut.io.empty.expect(true.B) dut.io.lastInput.poke(false.B) + dut.io.mask.poke("b1111".U) + + // clear key buffer before loading + dut.io.clearBuffer.poke(true.B) + dut.clock.step() + dut.io.clearBuffer.poke(false.B) // Write one row of key chunks for (i <- 0 until 4) { @@ -284,6 +302,12 @@ class KeyBufferSpec extends AnyFreeSpec with ChiselScalatestTester { dut.io.deq.ready.poke(false.B) dut.io.enq.valid.poke(true.B) dut.io.incrWritePtr.poke(false.B) + dut.io.mask.poke("b1111".U) + + // clear key buffer before loading + dut.io.clearBuffer.poke(true.B) + dut.clock.step() + dut.io.clearBuffer.poke(false.B) // Write one row of key chunks for (i <- 0 until 4) { @@ -316,6 +340,12 @@ class KeyBufferSpec extends AnyFreeSpec with ChiselScalatestTester { dut.io.deq.ready.poke(false.B) dut.io.enq.valid.poke(true.B) dut.io.lastInput.poke(false.B) + dut.io.mask.poke("b1111".U) + + // clear key buffer before loading + dut.io.clearBuffer.poke(true.B) + dut.clock.step() + dut.io.clearBuffer.poke(false.B) // Write first row of key chunks for (i <- 0 until 4) {