Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/main/scala/vaquita/VaquitaTop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,9 @@ class VaquitaTop extends Module {
EX.ex_reg_write_in := DE.de_io.de_reg_write

// -----------------memory stage ---------------------------------
// val comparison_bit_f6 = MEM.mem_instr_out(31,26)==="b011000".U || MEM.mem_instr_out(31,26)==="b011001".U || MEM.mem_instr_out(31,26)==="b011010".U || MEM.mem_instr_out(31,26)==="b011011".U || MEM.mem_instr_out(31,26)==="b011100".U || MEM.mem_instr_out(31,26)==="b011101".U || MEM.mem_instr_out(31,26)==="b011110".U || MEM.mem_instr_out(31,26)==="b011111".U
// val comparison_bit_f3 = MEM.mem_instr_out(14,12)==="b000".U || MEM.mem_instr_out(14,12)==="b011".U || MEM.mem_instr_out(14,12)==="b100".U
io.dmemReq <> MemFetch.dccmReq
MemFetch.dccmRsp <> io.dmemRsp
MemFetch.mem_lmul_in := EX.ex_lmul_out
// MemFetch.vec_comparison_bit := comparison_bit_f6 && comparison_bit_f3

val wb_vs3_data_in_store = VecInit(Seq.fill(8)(VecInit(Seq.fill(vec_config.count_lanes)(0.S(vec_config.XLEN.W)))))

Expand Down
119 changes: 119 additions & 0 deletions src/main/scala/vaquita/components/ALUClasses/Arith.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package vaquita.components.ALUClasses
import chisel3._
import chisel3.util._
import vaquita.components.ALUObj._
import vaquita.configparameter.VaquitaConfig

class Arith(implicit val config: VaquitaConfig) extends Module {
def Arithmatic(vs1_in: SInt, vs2_in: SInt,vs3:SInt,sew:Int,v0_bit_mask:UInt,alu_opcode: UInt, rs1_in : UInt): SInt = {
val maxUInt32 = "hFFFFFFFF".U(32.W)
val sum = WireInit(0.S(33.W))
val sub = WireInit(0.S(33.W))
val maxValue = (1.S << (sew - 1)) - 1.S
val minValue = -(1.S << (sew - 1))
sub := (vs2_in -& vs1_in)
sum := (vs1_in.asUInt +& vs2_in.asUInt).asSInt
// Overflow for addition
val positiveOverflowAdd = (vs1_in(sew - 1) === 0.U && vs2_in(sew - 1) === 0.U && sum(sew - 1) === 1.U)
val negativeOverflowAdd = (vs1_in(sew - 1) === 1.U && vs2_in(sew - 1) === 1.U && sum(sew - 1) === 0.U)
// Overflow for subtraction
val positiveOverflowSub = (vs2_in(sew - 1) === 0.U && vs1_in(sew - 1) === 1.U && sub(sew - 1) === 1.U)
val negativeOverflowSub = (vs2_in(sew - 1) === 1.U && vs1_in(sew - 1) === 0.U && sub(sew - 1) === 0.U)
val wire_vs1 = WireInit(0.S(32.W))
wire_vs1 := vs1_in
val wire_vs2 = WireInit(0.S(32.W))
wire_vs2 := vs2_in
val vxrm = 0.U
val shift_vs1_amount = WireInit(0.U(6.W))
shift_vs1_amount := Mux(sew.U===16.U ,vs1_in(4,0), vs1_in(3,0)) //vnclip 16 and 8
// val vnclip_u = WireInit(0.U(32.W))
// vnclip_u := fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> shift_vs1_amount),shift_vs1_amount,vxrm)
// val vnclip_s = WireInit(0.S(32.W))
// vnclip_s := fixed_round_mode(vs2_in.asUInt,(vs2_in >> shift_vs1_amount).asUInt,shift_vs1_amount,vxrm).asSInt

// val multi = WireInit(0.S(33.W))
// multi := fixed_round_mode(vs2_in.asUInt,(vs2_in * vs1_in).asUInt,vs1_in.asUInt,vxrm).asSInt

val result = WireInit(0.S(32.W))

result := MuxLookup(alu_opcode, 0.S(32.W), Seq(
vadd -> (vs1_in + vs2_in),//add
vsub -> (vs2_in - vs1_in),//sub
vrsub -> (vs1_in - vs2_in),//rsub
vand -> (vs1_in & vs2_in),// and
vor -> (vs1_in | vs2_in),//or
vxor -> (vs1_in ^ vs2_in),//xor
vsll -> (vs2_in << (rs1_in%sew.U)), //vsll
vsrl -> (vs2_in.asUInt >> (rs1_in%sew.U)).asSInt, //vsrl
vsra -> ((vs2_in >> (rs1_in%sew.U)).asSInt), //vsra
vmv -> (vs1_in), //vmv
vminu -> Mux(vs1_in.asUInt < vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//minu
vmin -> Mux(vs1_in < vs2_in,vs1_in,vs2_in),//min
vmaxu -> Mux(vs1_in.asUInt > vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//maxu
vmax -> Mux(vs1_in > vs2_in,vs1_in,vs2_in),//max
vsaddu -> Mux(sum(32), "hFFFFFFFF".U, sum(31,0)).asSInt,//vsaddu
vsadd -> (Mux(positiveOverflowAdd, maxValue, Mux(negativeOverflowAdd, minValue, sum))),//vsadd
vssub -> Mux(vs2_in.asUInt < vs1_in.asUInt, 0.U,vs2_in.asUInt - vs1_in.asUInt ).asSInt,//vssubu
vssub -> Mux(positiveOverflowSub, maxValue, Mux(negativeOverflowSub, minValue, sub(31, 0).asSInt)),//vssub
vadc -> (vs1_in.asUInt + vs2_in.asUInt + v0_bit_mask).asSInt, //vadc
vsbc -> (vs2_in.asUInt - vs1_in.asUInt - v0_bit_mask).asSInt, //vsbc
// vnsrl -> (vs2_in.asUInt >> (rs1_in%sew.U)).asSInt,
// vnsra -> ((vs2_in >> (rs1_in%sew.U)).asSInt),
// vssrl -> fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> vs1_in(3,0)).asUInt,vs1_in(3,0),vxrm).asSInt,//fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt,
// vssra -> fixed_round_mode(vs2_in.asUInt,(vs2_in >> vs1_in(3,0)).asUInt,vs1_in(3,0),vxrm).asSInt, //fixed_round_mode(vs2_in.asUInt,(vs2_in >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt,
// vnsrl -> (vs2_in.asUInt >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt,
// vnsra -> (vs2_in >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, //("habcd45".U.asSInt >> 2.U).asSInt //
// vnclipu -> Mux(sew.U===16.U,(Mux(vnclip_u>=65535.U,65535.U,vnclip_u).asSInt),(Mux(vnclip_u>=255.U,255.U,vnclip_u).asSInt)) , //Mux(sew.U===8.U,Mux(vnclip_u>255.U,255.U,vnclip_u),Mux(sew.U===16.U,Mux(vnclip_u>=65535.U,65535.U,vnclip_u),0.U)).asSInt,
// vnclip -> Mux(vnclip_s > maxValue, maxValue,Mux(vnclip_s < minValue, minValue, vnclip_s)),
// vsmul -> Mux(multi > maxValue, maxValue,Mux(multi < minValue, minValue, multi))//vsaddu


// fixed_round_mode(vs2_in.asUInt,(vs2_in >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt
// "b101010".U -> vs1_in
))
// printf(p"vs1 = 0x${Hexadecimal(wire_vs1(4,0))} , vs2 = 0x${Hexadecimal(wire_vs2)} , result = 0x${Hexadecimal(result)} , sew = 0x${Hexadecimal(sew.U)} \n")
result
}
def arith_8(vs1:UInt , vs2:UInt,vs3:UInt,mask_vs0:Bool,alu_opcode: UInt, rs1 : UInt,mask_arith25 : Bool,sew:Int):UInt={
dontTouch(mask_vs0)
val vsetvli_mask = WireInit(0.B)
val mask_bit_active_element = Wire(Bool())
val mask_bit_undisturb = Wire(Bool())
val vec_sew8_result = WireInit(0.U(sew.W))
mask_bit_active_element := (mask_vs0===1.B && mask_arith25===0.B) || mask_arith25===1.B
mask_bit_undisturb := mask_vs0===0.B && mask_arith25===0.B && vsetvli_mask===0.B
dontTouch(vec_sew8_result)
when(alu_opcode==="b010000".U || alu_opcode==="b010010".U){
vec_sew8_result := (Arithmatic(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1)).asUInt
}.otherwise{
vec_sew8_result := Mux(mask_bit_active_element===1.B,Arithmatic(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1).asUInt,Mux(mask_bit_undisturb===1.B,vs3,Fill(16,1.U))).asUInt
}
vec_sew8_result.asUInt
}
def arith_8_result(
vs1: Vec[Vec[UInt]],
vs2: Vec[Vec[UInt]],
vs3: Vec[Vec[UInt]],
mask: UInt,
alu_opcode: UInt,
rs1 : UInt,
vl : UInt,
mask_arith25 : Bool,
sew_lanes : Int,
sew : Int
): Vec[Vec[UInt]] = {
val tail = Wire(Bool())
tail := 0.B
val result_8 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W))))))
val result_r = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W))))))
var vl_counter = 0
for (i <- 0 until 8) {
for (j <- 0 until sew_lanes) {
result_8(i)(j) := arith_8(vs1(i)(j), vs2(i)(j), vs3(i)(j), mask(vl_counter).asBool,alu_opcode,rs1,mask_arith25,sew)
result_r(i)(j) := Mux(vl > vl_counter.U, result_8(i)(j),Mux(tail === 0.B, vs3(i)(j), Fill(sew, 1.U))).asUInt
vl_counter = vl_counter +1
}
}
result_r
}
}
71 changes: 71 additions & 0 deletions src/main/scala/vaquita/components/ALUClasses/Comparison.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package vaquita.components.ALUClasses
import chisel3._
import chisel3.util._
import vaquita.components.ALUObj._
import vaquita.configparameter.VaquitaConfig

class Comparison(implicit val config: VaquitaConfig)extends Module {
def comparison_operators(vs1_in:SInt,vs2_in:SInt,alu_opcode:UInt):Bool={
val comparison_table = Seq(
vmseq -> (vs1_in===vs2_in),//vmseq
vmsne -> (vs1_in.asUInt =/= vs2_in.asUInt),//vmsne
vmsltu -> (vs1_in.asUInt > vs2_in.asUInt),//vmsltu
vmslt -> (vs1_in > vs2_in),//vmslt
vmsleu -> (vs1_in.asUInt >= vs2_in.asUInt),//vmsleu
vmsle -> (vs1_in >= vs2_in),//vmsle
vmsgtu -> (vs1_in.asUInt < vs2_in.asUInt),//vmsgtu
vmsgt -> (vs1_in < vs2_in),//vmsgt
vmadc -> Mux((vs1_in +& vs2_in).asUInt < "hffffffff".U,1.B,0.B)//vmadc
// "b011111".U -> (vs1_in < vs2_in)//vmsbc
)
MuxLookup(alu_opcode, 0.B, comparison_table)
}
def main_comp(
vs1: Vec[Vec[UInt]],
vs2: Vec[Vec[UInt]],
vs3: Vec[Vec[UInt]],
mask: UInt,
alu_opcode: UInt,
rs1 : UInt,
vl : UInt,
mask_arith25 : Bool,
sew_lanes : Int,
sew : Int
): Vec[Vec[UInt]] = {
val result_val = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W))))))
val vsetvli_mask = Wire(Bool())
vsetvli_mask := 0.B
val tail = Wire(Bool())
val comp_bool = WireInit(VecInit((0 until config.vlen).map(i => vs3(0).asUInt(i).asBool))) //WireInit(VecInit(Seq.fill(config.vlen)((vs3(0).asUInt)(elem_idx).asBool)))
val comp_wire = Wire(UInt(config.vlen.W))
comp_wire := comp_bool.asUInt
tail := 0.B
var elem_idx = 0
for (i <- 0 until 8) {
for (j <- 0 until (sew_lanes)) {

val mask_bit_active_element = (mask(elem_idx) === 1.B && mask_arith25 === 0.B) || mask_arith25 === 1.B
val mask_bit_undisturb = mask(elem_idx) === 0.B && mask_arith25 === 0.B && vsetvli_mask === 0.B
comp_bool(elem_idx) := Mux(
(vl > elem_idx.U),
Mux(mask_bit_active_element,comparison_operators(vs1(i)(j).asSInt,vs2(i)(j).asSInt,alu_opcode), Mux(mask_bit_undisturb, (vs3(0).asUInt)(elem_idx).asBool, 1.B)),
Mux(tail === 0.B, (vs3(0).asUInt)(elem_idx).asBool, 1.B)
)
elem_idx = elem_idx + 1
}
}
var high = sew-1
var low = 0
for (j <- 0 until sew_lanes) {
result_val(0)(j) := comp_wire(high, low)
high += sew
low += sew
}
for (i <- 1 until 8) {
for (j <- 0 until sew_lanes) {
result_val(i)(j) := "hdeadbeef".U //vs3(i)(j).asUInt
}
}
result_val
}
}
41 changes: 41 additions & 0 deletions src/main/scala/vaquita/components/ALUClasses/FixedRoundMode.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package vaquita.components.ALUClasses
import chisel3._
import chisel3.util._
import vaquita.components.ALUObj._
import vaquita.configparameter.VaquitaConfig

class FixedRoundMode{

def fixed_round_mode(v:UInt,shifted: UInt, d: UInt, vxrm: UInt): UInt = {
// Bit positions
val bit_d = v(d) // v[d]
val bit_d1 = Mux(d === 0.U, 0.U, v(d - 1.U)) // v[d-1]

val high_2 = Mux(d > 1.U, d - 2.U, 0.U)
val mask_2 = "hffffffff".U >> high_2
val bits_d2_to_0 = v & mask_2 // v[d-2:0]

val high_1 = Mux(d > 1.U, d - 1.U, 0.U)
val mask_1 = "hffffffff".U >> high_1
val bits_d1_to_0 = v & mask_1 // v[d-1]

// Rounding increment flags
val rnu_inc = Mux(bit_d1 === 1.U,1.U,0.U) // vxrm = 00
val rne_inc = Mux((bit_d1 === 1.U) && ((bits_d2_to_0 =/= 0.U) || (bit_d === 1.U)),1.U,0.U) // vxrm = 01
val rdn_inc = 1.U // vxrm = 10
val rod_inc = Mux((bit_d === 0.U) && (bits_d1_to_0 =/= 0.U),1.U,0.U) // vxrm = 11

// // Select rounding increment

val rounding_inc = WireDefault(0.U(1.W))
switch(vxrm) {
is("b00".U) { rounding_inc := rnu_inc }
is("b01".U) { rounding_inc := rne_inc }
is("b10".U) { rounding_inc := rdn_inc }
is("b11".U) { rounding_inc := rod_inc }
}

// Final result
shifted + rounding_inc
}
}
82 changes: 82 additions & 0 deletions src/main/scala/vaquita/components/ALUClasses/Narrow.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package vaquita.components.ALUClasses
import chisel3._
import chisel3.util._
import vaquita.components.ALUObj._
import vaquita.configparameter.VaquitaConfig

class NarrowIns(implicit val config: VaquitaConfig) extends Module {

def narrow_ins(vs1_in: SInt, vs2_in: SInt,vs3:SInt,sew:Int,v0_bit_mask:UInt,alu_opcode: UInt, rs1_in : UInt): SInt = {
val fixed_round_mode = new FixedRoundMode()
val maxValue = (1.S << (sew - 1)) - 1.S
val minValue = -(1.S << (sew - 1))
val vxrm = 0.U
val shift_vs1_amount = WireInit(0.U(6.W))
shift_vs1_amount := Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)) //vnclip 16 and 8
val vnclip_u = WireInit(0.U(32.W))
vnclip_u := fixed_round_mode.fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> shift_vs1_amount),shift_vs1_amount,vxrm)
val vnclip_s = WireInit(0.S(32.W))
vnclip_s := fixed_round_mode.fixed_round_mode(vs2_in.asUInt,(vs2_in >> shift_vs1_amount).asUInt,shift_vs1_amount,vxrm).asSInt
val result = WireInit(0.S(32.W))
result := MuxLookup(alu_opcode, 0.S(32.W), Seq(
// vnsrl -> (vs2_in.asUInt >> (rs1_in%sew.U)).asSInt,
// vnsra -> ((vs2_in >> (rs1_in%sew.U)).asSInt),
vnsrl -> (vs2_in.asUInt >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt,
vnsra -> (vs2_in >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, //("habcd45".U.asSInt >> 2.U).asSInt //
vnclipu -> Mux(sew.U===16.U,(Mux(vnclip_u>=65535.U,65535.U,vnclip_u).asSInt),(Mux(vnclip_u>=255.U,255.U,vnclip_u).asSInt)) , //Mux(sew.U===8.U,Mux(vnclip_u>255.U,255.U,vnclip_u),Mux(sew.U===16.U,Mux(vnclip_u>=65535.U,65535.U,vnclip_u),0.U)).asSInt,
vnclip -> Mux(vnclip_s > maxValue, maxValue,Mux(vnclip_s < minValue, minValue, vnclip_s))
))
// printf(p"vs1 = 0x${Hexadecimal(wire_vs1(4,0))} , vs2 = 0x${Hexadecimal(wire_vs2)} , result = 0x${Hexadecimal(result)} , sew = 0x${Hexadecimal(sew.U)} \n")
result
}
def narrow_mask(vs1:UInt , vs2:UInt,vs3:UInt,mask_vs0:Bool,alu_opcode: UInt, rs1 : UInt,mask_arith25 : Bool,sew : Int):UInt={
dontTouch(mask_vs0)
val vsetvli_mask = WireInit(0.B)
val mask_bit_active_element = Wire(Bool())
val mask_bit_undisturb = Wire(Bool())
val vec_sew8_result = WireInit(0.U(sew.W))
mask_bit_active_element := (mask_vs0===1.B && mask_arith25===0.B) || mask_arith25===1.B
mask_bit_undisturb := mask_vs0===0.B && mask_arith25===0.B && vsetvli_mask===0.B
dontTouch(vec_sew8_result)
vec_sew8_result := Mux(mask_bit_active_element===1.B,narrow_ins(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1).asUInt,Mux(mask_bit_undisturb===1.B,vs3,Fill(16,1.U))).asUInt
vec_sew8_result.asUInt
}
def narrow_result(
vs1: Vec[Vec[UInt]],
vs2: Vec[Vec[UInt]],
vs3: Vec[Vec[UInt]],
mask: UInt,
alu_opcode: UInt,
rs1 : UInt,
vl : UInt,
mask_arith25 : Bool,
sew_lanes : Int,
sew : Int
): Vec[Vec[UInt]] = {
val tail = Wire(Bool())
tail := 0.B
val result_8 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W))))))
val result_r = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W))))))
var vl_counter = 0
var i_widening = 0
for (i <- 0 until 4) {
var j_widening = 0
for (j <- 0 until sew_lanes) {
result_8(i)(j) := narrow_mask(vs1(i)(j), vs2(i_widening)(j_widening), vs3(i)(j), mask(vl_counter).asBool,alu_opcode,rs1,mask_arith25,sew)
result_r(i)(j) := Mux(vl > vl_counter.U, result_8(i)(j),Mux(tail === 0.B, vs3(i)(j), Fill(sew, 1.U)))
vl_counter = vl_counter +1
if ((j == (sew_lanes/2-1)) || (j == (sew_lanes-1))) { //15 31 when sew =8
j_widening = 0 //7 15 when sew =16
i_widening = i_widening + 1
} else {
j_widening = j_widening + 1
}
}
}
for (i <- 4 until 8) {
for (j <- 0 until sew_lanes) {
result_r(i)(j) := 0.U
}}
result_r
}
}
Loading