diff --git a/src/main/scala/vaquita/VaquitaTop.scala b/src/main/scala/vaquita/VaquitaTop.scala index b0c74470..97b6b124 100644 --- a/src/main/scala/vaquita/VaquitaTop.scala +++ b/src/main/scala/vaquita/VaquitaTop.scala @@ -104,12 +104,9 @@ class VaquitaTop extends Module { EX.ex_reg_write_in := DE.de_io.de_reg_write // -----------------memory stage --------------------------------- - // val comparison_bit_f6 = MEM.mem_instr_out(31,26)==="b011000".U || MEM.mem_instr_out(31,26)==="b011001".U || MEM.mem_instr_out(31,26)==="b011010".U || MEM.mem_instr_out(31,26)==="b011011".U || MEM.mem_instr_out(31,26)==="b011100".U || MEM.mem_instr_out(31,26)==="b011101".U || MEM.mem_instr_out(31,26)==="b011110".U || MEM.mem_instr_out(31,26)==="b011111".U - // val comparison_bit_f3 = MEM.mem_instr_out(14,12)==="b000".U || MEM.mem_instr_out(14,12)==="b011".U || MEM.mem_instr_out(14,12)==="b100".U io.dmemReq <> MemFetch.dccmReq MemFetch.dccmRsp <> io.dmemRsp MemFetch.mem_lmul_in := EX.ex_lmul_out - // MemFetch.vec_comparison_bit := comparison_bit_f6 && comparison_bit_f3 val wb_vs3_data_in_store = VecInit(Seq.fill(8)(VecInit(Seq.fill(vec_config.count_lanes)(0.S(vec_config.XLEN.W))))) diff --git a/src/main/scala/vaquita/components/ALUClasses/Arith.scala b/src/main/scala/vaquita/components/ALUClasses/Arith.scala new file mode 100644 index 00000000..93d57fea --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/Arith.scala @@ -0,0 +1,120 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.components.ALUObj._ +import vaquita.configparameter.VaquitaConfig + +class Arith(implicit val config: VaquitaConfig) extends Module { + def Arithmatic(vs1_in: SInt, vs2_in: SInt,vs3:SInt,sew:Int,v0_bit_mask:UInt,alu_opcode: UInt, rs1_in : UInt): SInt = { + val maxUInt32 = "hFFFFFFFF".U(32.W) + val sum = WireInit(0.S(33.W)) + val sub = WireInit(0.S(33.W)) + val maxValue = (1.S << (sew - 1)) - 1.S + val minValue = -(1.S << (sew - 1)) + sub := (vs2_in -& vs1_in) + sum := (vs1_in.asUInt +& vs2_in.asUInt).asSInt + // Overflow for addition + val positiveOverflowAdd = (vs1_in(sew - 1) === 0.U && vs2_in(sew - 1) === 0.U && sum(sew - 1) === 1.U) + val negativeOverflowAdd = (vs1_in(sew - 1) === 1.U && vs2_in(sew - 1) === 1.U && sum(sew - 1) === 0.U) + // Overflow for subtraction + val positiveOverflowSub = (vs2_in(sew - 1) === 0.U && vs1_in(sew - 1) === 1.U && sub(sew - 1) === 1.U) + val negativeOverflowSub = (vs2_in(sew - 1) === 1.U && vs1_in(sew - 1) === 0.U && sub(sew - 1) === 0.U) + val wire_vs1 = WireInit(0.S(32.W)) + wire_vs1 := vs1_in + val wire_vs2 = WireInit(0.S(32.W)) + wire_vs2 := vs2_in + val vxrm = 0.U + val shift_vs1_amount = WireInit(0.U(6.W)) + shift_vs1_amount := Mux(sew.U===16.U ,vs1_in(4,0), vs1_in(3,0)) //vnclip 16 and 8 + // val vnclip_u = WireInit(0.U(32.W)) + // vnclip_u := fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> shift_vs1_amount),shift_vs1_amount,vxrm) + // val vnclip_s = WireInit(0.S(32.W)) + // vnclip_s := fixed_round_mode(vs2_in.asUInt,(vs2_in >> shift_vs1_amount).asUInt,shift_vs1_amount,vxrm).asSInt + + // val multi = WireInit(0.S(33.W)) + // multi := fixed_round_mode(vs2_in.asUInt,(vs2_in * vs1_in).asUInt,vs1_in.asUInt,vxrm).asSInt + + val result = WireInit(0.S(32.W)) + val shamt = rs1_in(log2Ceil(sew)-1, 0) + + result := MuxLookup(alu_opcode, 0.S(32.W), Seq( + vadd -> (vs1_in + vs2_in),//add + vsub -> (vs2_in - vs1_in),//sub + vrsub -> (vs1_in - vs2_in),//rsub + vand -> (vs1_in & vs2_in),// and + vor -> (vs1_in | vs2_in),//or + vxor -> (vs1_in ^ vs2_in),//xor + vsll -> (vs2_in << shamt), // vsll + vsrl -> (vs2_in.asUInt >> shamt).asSInt, // vsrl + vsra -> (vs2_in >> shamt).asSInt, // vsra + vmv -> (vs1_in), //vmv + vminu -> Mux(vs1_in.asUInt < vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//minu + vmin -> Mux(vs1_in < vs2_in,vs1_in,vs2_in),//min + vmaxu -> Mux(vs1_in.asUInt > vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//maxu + vmax -> Mux(vs1_in > vs2_in,vs1_in,vs2_in),//max + vsaddu -> Mux(sum(32), "hFFFFFFFF".U, sum(31,0)).asSInt,//vsaddu + vsadd -> (Mux(positiveOverflowAdd, maxValue, Mux(negativeOverflowAdd, minValue, sum))),//vsadd + vssubu -> Mux(vs2_in.asUInt < vs1_in.asUInt, 0.U,vs2_in.asUInt - vs1_in.asUInt ).asSInt,//vssubu + vssub -> Mux(positiveOverflowSub, maxValue, Mux(negativeOverflowSub, minValue, sub(31, 0).asSInt)),//vssub + vadc -> (vs1_in.asUInt + vs2_in.asUInt + v0_bit_mask).asSInt, //vadc + vsbc -> (vs2_in.asUInt - vs1_in.asUInt - v0_bit_mask).asSInt, //vsbc + // vnsrl -> (vs2_in.asUInt >> (rs1_in%sew.U)).asSInt, + // vnsra -> ((vs2_in >> (rs1_in%sew.U)).asSInt), + // vssrl -> fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> vs1_in(3,0)).asUInt,vs1_in(3,0),vxrm).asSInt,//fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt, + // vssra -> fixed_round_mode(vs2_in.asUInt,(vs2_in >> vs1_in(3,0)).asUInt,vs1_in(3,0),vxrm).asSInt, //fixed_round_mode(vs2_in.asUInt,(vs2_in >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt, + // vnsrl -> (vs2_in.asUInt >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, + // vnsra -> (vs2_in >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, //("habcd45".U.asSInt >> 2.U).asSInt // + // vnclipu -> Mux(sew.U===16.U,(Mux(vnclip_u>=65535.U,65535.U,vnclip_u).asSInt),(Mux(vnclip_u>=255.U,255.U,vnclip_u).asSInt)) , //Mux(sew.U===8.U,Mux(vnclip_u>255.U,255.U,vnclip_u),Mux(sew.U===16.U,Mux(vnclip_u>=65535.U,65535.U,vnclip_u),0.U)).asSInt, + // vnclip -> Mux(vnclip_s > maxValue, maxValue,Mux(vnclip_s < minValue, minValue, vnclip_s)), + // vsmul -> Mux(multi > maxValue, maxValue,Mux(multi < minValue, minValue, multi))//vsaddu + + + // fixed_round_mode(vs2_in.asUInt,(vs2_in >> ((Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))))).asUInt,(Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0))),vxrm).asSInt + // "b101010".U -> vs1_in + )) + // printf(p"vs1 = 0x${Hexadecimal(wire_vs1(4,0))} , vs2 = 0x${Hexadecimal(wire_vs2)} , result = 0x${Hexadecimal(result)} , sew = 0x${Hexadecimal(sew.U)} \n") + result + } + def arith_8(vs1:UInt , vs2:UInt,vs3:UInt,mask_vs0:Bool,alu_opcode: UInt, rs1 : UInt,mask_arith25 : Bool,sew:Int):UInt={ + dontTouch(mask_vs0) + val vsetvli_mask = WireInit(0.B) + val mask_bit_active_element = Wire(Bool()) + val mask_bit_undisturb = Wire(Bool()) + val vec_sew8_result = WireInit(0.U(sew.W)) + mask_bit_active_element := (mask_vs0===1.B && mask_arith25===0.B) || mask_arith25===1.B + mask_bit_undisturb := mask_vs0===0.B && mask_arith25===0.B && vsetvli_mask===0.B + dontTouch(vec_sew8_result) + when(alu_opcode==="b010000".U || alu_opcode==="b010010".U){ + vec_sew8_result := (Arithmatic(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1)).asUInt + }.otherwise{ + vec_sew8_result := Mux(mask_bit_active_element===1.B,Arithmatic(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1).asUInt,Mux(mask_bit_undisturb===1.B,vs3,Fill(16,1.U))).asUInt + } + vec_sew8_result.asUInt + } + def arith_8_result( + vs1: Vec[Vec[UInt]], + vs2: Vec[Vec[UInt]], + vs3: Vec[Vec[UInt]], + mask: UInt, + alu_opcode: UInt, + rs1 : UInt, + vl : UInt, + mask_arith25 : Bool, + sew_lanes : Int, + sew : Int + ): Vec[Vec[UInt]] = { + val tail = Wire(Bool()) + tail := 0.B + val result_8 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + val result_r = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + var vl_counter = 0 + for (i <- 0 until 8) { + for (j <- 0 until sew_lanes) { + result_8(i)(j) := arith_8(vs1(i)(j), vs2(i)(j), vs3(i)(j), mask(vl_counter).asBool,alu_opcode,rs1,mask_arith25,sew) + result_r(i)(j) := Mux(vl > vl_counter.U, result_8(i)(j),Mux(tail === 0.B, vs3(i)(j), Fill(sew, 1.U))).asUInt + vl_counter = vl_counter +1 + } + } + result_r + } +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ALUClasses/Comparison.scala b/src/main/scala/vaquita/components/ALUClasses/Comparison.scala new file mode 100644 index 00000000..e5dbd2d6 --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/Comparison.scala @@ -0,0 +1,71 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.components.ALUObj._ +import vaquita.configparameter.VaquitaConfig + +class Comparison(implicit val config: VaquitaConfig)extends Module { + def comparison_operators(vs1_in:SInt,vs2_in:SInt,alu_opcode:UInt):Bool={ + val comparison_table = Seq( + vmseq -> (vs1_in === vs2_in), // vmseq + vmsne -> (vs1_in.asUInt =/= vs2_in.asUInt), // vmsne + vmsltu -> (vs1_in.asUInt < vs2_in.asUInt), // vmsltu + vmslt -> (vs1_in < vs2_in), // vmslt + vmsleu -> (vs1_in.asUInt <= vs2_in.asUInt), // vmsleu + vmsle -> (vs1_in <= vs2_in), // vmsle + vmsgtu -> (vs1_in.asUInt > vs2_in.asUInt), // vmsgtu + vmsgt -> (vs1_in > vs2_in), // vmsgt + vmadc -> Mux((vs1_in +& vs2_in).asUInt < "hffffffff".U,1.B,0.B)//vmadc + // "b011111".U -> (vs1_in < vs2_in)//vmsbc + ) + MuxLookup(alu_opcode, 0.B, comparison_table) + } + def main_comp( + vs1: Vec[Vec[UInt]], + vs2: Vec[Vec[UInt]], + vs3: Vec[Vec[UInt]], + mask: UInt, + alu_opcode: UInt, + rs1 : UInt, + vl : UInt, + mask_arith25 : Bool, + sew_lanes : Int, + sew : Int + ): Vec[Vec[UInt]] = { + val result_val = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + val vsetvli_mask = Wire(Bool()) + vsetvli_mask := 0.B + val tail = Wire(Bool()) + val comp_bool = WireInit(VecInit((0 until config.vlen).map(i => vs3(0).asUInt(i).asBool))) //WireInit(VecInit(Seq.fill(config.vlen)((vs3(0).asUInt)(elem_idx).asBool))) + val comp_wire = Wire(UInt(config.vlen.W)) + comp_wire := comp_bool.asUInt + tail := 0.B + var elem_idx = 0 + for (i <- 0 until 8) { + for (j <- 0 until (sew_lanes)) { + + val mask_bit_active_element = (mask(elem_idx) === 1.B && mask_arith25 === 0.B) || mask_arith25 === 1.B + val mask_bit_undisturb = mask(elem_idx) === 0.B && mask_arith25 === 0.B && vsetvli_mask === 0.B + comp_bool(elem_idx) := Mux( + (vl > elem_idx.U), + Mux(mask_bit_active_element,comparison_operators(vs1(i)(j).asSInt,vs2(i)(j).asSInt,alu_opcode), Mux(mask_bit_undisturb, (vs3(0).asUInt)(elem_idx).asBool, 1.B)), + Mux(tail === 0.B, (vs3(0).asUInt)(elem_idx).asBool, 1.B) + ) + elem_idx = elem_idx + 1 + } + } + var high = sew-1 + var low = 0 + for (j <- 0 until sew_lanes) { + result_val(0)(j) := comp_wire(high, low) + high += sew + low += sew + } + for (i <- 1 until 8) { + for (j <- 0 until sew_lanes) { + result_val(i)(j) := "hdeadbeef".U //vs3(i)(j).asUInt + } + } + result_val + } +} diff --git a/src/main/scala/vaquita/components/ALUClasses/FixedRoundMode.scala b/src/main/scala/vaquita/components/ALUClasses/FixedRoundMode.scala new file mode 100644 index 00000000..113b34bf --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/FixedRoundMode.scala @@ -0,0 +1,55 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.components.ALUObj._ +import vaquita.configparameter.VaquitaConfig + +class FixedRoundMode{ + + def fixed_round_mode(v:UInt,shifted: UInt, d: UInt, vxrm: UInt): UInt = { + // Bit positions + // val bit_d = v(d) // v[d] + // val bit_d1 = Mux(d === 0.U, 0.U, v(d - 1.U)) // v[d-1] + + // val high_2 = Mux(d > 1.U, d - 2.U, 0.U) + // val mask_2 = "hffffffff".U >> high_2 + // val bits_d2_to_0 = v & mask_2 // v[d-2:0] + + // val high_1 = Mux(d > 1.U, d - 1.U, 0.U) + // val mask_1 = "hffffffff".U >> high_1 + // val bits_d1_to_0 = v & mask_1 // v[d-1] + + val width = v.getWidth + val bit_d = Mux(d < width.U, v(d), 0.U(1.W)) // LSB after shift + val bit_d1 = Mux(d === 0.U, 0.U(1.W), v(d - 1.U)) // guard bit + val mask_d1_to_0 = Mux(d === 0.U, 0.U(width.W), ((1.U(width.W) << d) - 1.U)) + val mask_d2_to_0 = Mux(d <= 1.U, 0.U(width.W), ((1.U(width.W) << (d - 1.U)) - 1.U)) + val any_d1_to_0 = (v & mask_d1_to_0).orR + val any_d2_to_0 = (v & mask_d2_to_0).orR + + val rnu_inc = bit_d1 // vxrm = 00 (round to nearest, up) + val rne_inc = Mux(bit_d1 === 1.U && (any_d2_to_0 || bit_d === 1.U), 1.U, 0.U) // vxrm = 01 (ties to even) + val rdn_inc = 0.U(1.W) // vxrm = 10 (round down: add 0) + val rod_inc = Mux((bit_d === 0.U) && any_d1_to_0, 1.U, 0.U) + + + // Rounding increment flags + // val rnu_inc = Mux(bit_d1 === 1.U,1.U,0.U) // vxrm = 00 + // val rne_inc = Mux((bit_d1 === 1.U) && ((bits_d2_to_0 =/= 0.U) || (bit_d === 1.U)),1.U,0.U) // vxrm = 01 + // val rdn_inc = 1.U // vxrm = 10 + // val rod_inc = Mux((bit_d === 0.U) && (bits_d1_to_0 =/= 0.U),1.U,0.U) // vxrm = 11 + + // // Select rounding increment + + val rounding_inc = WireDefault(0.U(1.W)) + switch(vxrm) { + is("b00".U) { rounding_inc := rnu_inc } + is("b01".U) { rounding_inc := rne_inc } + is("b10".U) { rounding_inc := rdn_inc } + is("b11".U) { rounding_inc := rod_inc } + } + + // Final result + shifted + rounding_inc + } +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ALUClasses/Narrow.scala b/src/main/scala/vaquita/components/ALUClasses/Narrow.scala new file mode 100644 index 00000000..d3fe904e --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/Narrow.scala @@ -0,0 +1,85 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.components.ALUObj._ +import vaquita.configparameter.VaquitaConfig + +class NarrowIns(implicit val config: VaquitaConfig) extends Module { + + def narrow_ins(vs1_in: SInt, vs2_in: SInt,vs3:SInt,sew:Int,v0_bit_mask:UInt,alu_opcode: UInt, rs1_in : UInt): SInt = { + val fixed_round_mode = new FixedRoundMode() + val maxValue = (1.S << (sew - 1)) - 1.S + val minValue = -(1.S << (sew - 1)) + val vxrm = 0.U + val shift_vs1_amount = WireInit(0.U(6.W)) + shift_vs1_amount := Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)) //vnclip 16 and 8 + val vnclip_u = WireInit(0.U(32.W)) + vnclip_u := fixed_round_mode.fixed_round_mode(vs2_in.asUInt,(vs2_in.asUInt >> shift_vs1_amount),shift_vs1_amount,vxrm) + val vnclip_s = WireInit(0.S(32.W)) + vnclip_s := fixed_round_mode.fixed_round_mode(vs2_in.asUInt,(vs2_in >> shift_vs1_amount).asUInt,shift_vs1_amount,vxrm).asSInt + val result = WireInit(0.S(32.W)) + val maxU = ((1.U << sew) - 1.U) + result := MuxLookup(alu_opcode, 0.S(32.W), Seq( + // vnsrl -> (vs2_in.asUInt >> (rs1_in%sew.U)).asSInt, + // vnsra -> ((vs2_in >> (rs1_in%sew.U)).asSInt), + vnsrl -> (vs2_in.asUInt >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, + vnsra -> (vs2_in >> (Mux(sew.U===8.U ,vs1_in(3,0), vs1_in(4,0)))).asSInt, //("habcd45".U.asSInt >> 2.U).asSInt // + vnclipu -> Mux(vnclip_u >= maxU, maxU, vnclip_u).asSInt, + vnclip -> Mux(vnclip_s > maxValue, maxValue,Mux(vnclip_s < minValue, minValue, vnclip_s)) + )) + // printf(p"vs1 = 0x${Hexadecimal(wire_vs1(4,0))} , vs2 = 0x${Hexadecimal(wire_vs2)} , result = 0x${Hexadecimal(result)} , sew = 0x${Hexadecimal(sew.U)} \n") + result + } + def narrow_mask(vs1:UInt , vs2:UInt,vs3:UInt,mask_vs0:Bool,alu_opcode: UInt, rs1 : UInt,mask_arith25 : Bool,sew : Int):UInt={ + dontTouch(mask_vs0) + val vsetvli_mask = WireInit(0.B) + val mask_bit_active_element = Wire(Bool()) + val mask_bit_undisturb = Wire(Bool()) + val vec_sew8_result = WireInit(0.U(sew.W)) + mask_bit_active_element := (mask_vs0===1.B && mask_arith25===0.B) || mask_arith25===1.B + mask_bit_undisturb := mask_vs0===0.B && mask_arith25===0.B && vsetvli_mask===0.B + dontTouch(vec_sew8_result) + vec_sew8_result := Mux(mask_bit_active_element===1.B, + narrow_ins(vs1.asSInt, vs2.asSInt,vs3.asSInt,sew,mask_vs0.asUInt,alu_opcode,rs1).asUInt, + Mux(mask_bit_undisturb===1.B, vs3, Fill(sew, 1.U))).asUInt + vec_sew8_result.asUInt + } + def narrow_result( + vs1: Vec[Vec[UInt]], + vs2: Vec[Vec[UInt]], + vs3: Vec[Vec[UInt]], + mask: UInt, + alu_opcode: UInt, + rs1 : UInt, + vl : UInt, + mask_arith25 : Bool, + sew_lanes : Int, + sew : Int + ): Vec[Vec[UInt]] = { + val tail = Wire(Bool()) + tail := 0.B + val result_8 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + val result_r = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + var vl_counter = 0 + var i_widening = 0 + for (i <- 0 until 4) { + var j_widening = 0 + for (j <- 0 until sew_lanes) { + result_8(i)(j) := narrow_mask(vs1(i)(j), vs2(i_widening)(j_widening), vs3(i)(j), mask(vl_counter).asBool,alu_opcode,rs1,mask_arith25,sew) + result_r(i)(j) := Mux(vl > vl_counter.U, result_8(i)(j),Mux(tail === 0.B, vs3(i)(j), Fill(sew, 1.U))) + vl_counter = vl_counter +1 + if ((j == (sew_lanes/2-1)) || (j == (sew_lanes-1))) { //15 31 when sew =8 + j_widening = 0 //7 15 when sew =16 + i_widening = i_widening + 1 + } else { + j_widening = j_widening + 1 + } + } + } + for (i <- 4 until 8) { + for (j <- 0 until sew_lanes) { + result_r(i)(j) := 0.U + }} + result_r + } +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ALUClasses/Permutation.scala b/src/main/scala/vaquita/components/ALUClasses/Permutation.scala new file mode 100644 index 00000000..8bc70b99 --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/Permutation.scala @@ -0,0 +1,109 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.components.ALUObj._ +import vaquita.configparameter.VaquitaConfig +class Permutation(implicit val config: VaquitaConfig) { + def slide_target_idx_md(alu_opcode:UInt ,elem_idx:UInt , vs1_idx : UInt): UInt = { + val slide_target_idx = Wire(UInt(32.W)) + when(vslideup === alu_opcode) { + slide_target_idx := elem_idx - vs1_idx + } + .elsewhen(vslidedown === alu_opcode) { + slide_target_idx := elem_idx + vs1_idx + } + .elsewhen(vrgather===alu_opcode) { + slide_target_idx := vs1_idx + }.otherwise{ + slide_target_idx := 0.U + } + slide_target_idx + } + def slide_valid_md(alu_opcode:UInt ,elem_idx:UInt , vs1_idx : UInt,sew_lanes:UInt,slide_target_idx_down:UInt,lmul:UInt): Bool = { + val valid_wire = Wire(Bool()) + val lmul_valid_cat = WireInit(1.U(32.W)) + switch(lmul) { + is(0.U) { lmul_valid_cat := sew_lanes } + is(1.U) { lmul_valid_cat := Cat(sew_lanes, 0.U(1.W)) } + is(2.U) { lmul_valid_cat := Cat(sew_lanes, 0.U(2.W)) } + is(3.U) { lmul_valid_cat := Cat(sew_lanes, 0.U(3.W)) } + } + when(vslideup === alu_opcode) { + valid_wire := (vs1_idx <= elem_idx) + } + .elsewhen(vslidedown === alu_opcode) { + valid_wire := slide_target_idx_down < lmul_valid_cat + } + .elsewhen(vrgather===alu_opcode) { + valid_wire := (vs1_idx < lmul_valid_cat) + }.otherwise{ + valid_wire := 0.B + } + valid_wire + } + def permutation( + vs1: Vec[Vec[UInt]], + vs2: Vec[Vec[UInt]], + vs3: Vec[Vec[UInt]], + mask: UInt, + alu_opcode: UInt, + rs1_per : UInt, + vl : UInt, + mask_arith25 : Bool, + sew_lanes : Int, + sew : Int, + func3 : UInt, + lmul:UInt + ): Vec[Vec[UInt]] = { + // for slide down idx + // Some instructions such as vslidedown and vrgather may read indices past vl or even VLMAX in source vector register groups. The + // general policy is to return the value 0 when the index is greater than VLMAX in the source vector register group. + + val vsetvli_mask = Wire(Bool()) + vsetvli_mask := 0.B + val tail = Wire(Bool()) + tail := 0.B + var elem_idx = 0 + val log2ByteWidth = Wire(UInt(32.W)) + val vs1_idx = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(config.XLEN.W)))))) + val result_r = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + val slide_target_idx = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(32.W)))))) + val valid_idx = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.B))))) + val mask_bit_active_element1 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.B))))) + val mask_bit_undisturb1 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.B))))) + + val slide_vec_idx = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(32.W)))))) + val slide_byte_idx = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(32.W)))))) + val vs2_val = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(sew_lanes)(0.U(sew.W)))))) + + + + + log2ByteWidth := (log2Ceil(sew_lanes).U) + + for (i <- 0 until 8) { + for (j <- 0 until (sew_lanes)) { + + mask_bit_active_element1(i)(j) := (mask(elem_idx) === 1.B && mask_arith25 === 0.B) || mask_arith25 === 1.B + mask_bit_undisturb1(i)(j) := mask(elem_idx) === 0.B && mask_arith25 === 0.B && vsetvli_mask === 0.B + val vs3_val = vs3(i)(j) // old value + val elem_idx_wire = Wire(UInt(33.W)) + val result_val = WireInit(0.U(32.W)) + vs1_idx(i)(j) := Mux(func3=== "b000".U, vs1(i)(j),rs1_per) + elem_idx_wire := elem_idx.U + slide_target_idx(i)(j) := slide_target_idx_md(alu_opcode ,elem_idx.U , vs1_idx(i)(j)) + valid_idx(i)(j) := slide_valid_md(alu_opcode ,elem_idx.U , vs1_idx(i)(j),sew_lanes.U,slide_target_idx(i)(j),lmul) + slide_vec_idx(i)(j) := slide_target_idx(i)(j) >> log2ByteWidth //for row + slide_byte_idx(i)(j) := slide_target_idx(i)(j) & (sew_lanes.U - 1.U) //for column + vs2_val(i)(j) := Mux(valid_idx(i)(j), vs2(slide_vec_idx(i)(j))(slide_byte_idx(i)(j)), 0.U) + result_r(i)(j) := Mux( + (vl > elem_idx.U && Mux(vslideup===alu_opcode,valid_idx(i)(j),1.B)), + Mux(mask_bit_active_element1(i)(j), vs2_val(i)(j), Mux(mask_bit_undisturb1(i)(j), vs3_val, Fill(sew, 1.U))), + Mux(tail === 0.B, vs3_val, Fill(sew, 1.U)) + ) + elem_idx += 1 + } + } + result_r + } +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ALUClasses/VecIO.scala b/src/main/scala/vaquita/components/ALUClasses/VecIO.scala new file mode 100644 index 00000000..2f121fec --- /dev/null +++ b/src/main/scala/vaquita/components/ALUClasses/VecIO.scala @@ -0,0 +1,12 @@ +package vaquita.components.ALUClasses +import chisel3._ +import chisel3.util._ +import vaquita.configparameter.VaquitaConfig + +class DecodeStageVecIO(implicit val config: VaquitaConfig) extends Bundle { + val vs1_data = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + val vs2_data = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + val vs3_data = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + val vs0_data = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + val vsd_data = Output(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ALUObject.scala b/src/main/scala/vaquita/components/ALUObject.scala index 746312d8..57d930f6 100644 --- a/src/main/scala/vaquita/components/ALUObject.scala +++ b/src/main/scala/vaquita/components/ALUObject.scala @@ -46,5 +46,16 @@ object ALUObj{ val vnclipu = 46.U(6.W) val vnclip = 47.U(6.W) val vwredsumu = 48.U(6.W) + + // val vwaddu = + // val vwsubu = + // val vwadd = + // val vwsub = + // val vwaddu_wv = + // val vwsubu_wv = + // val vwadd_wv = + // val vwsub_wv = + + val vwredsum = 59.U(6.W) } \ No newline at end of file diff --git a/src/main/scala/vaquita/components/ForwardingUnit.scala b/src/main/scala/vaquita/components/ForwardingUnit.scala index aff2b74c..b1d2eaed 100644 --- a/src/main/scala/vaquita/components/ForwardingUnit.scala +++ b/src/main/scala/vaquita/components/ForwardingUnit.scala @@ -31,9 +31,4 @@ class ForwardingUnit extends Module { io.forward_c := forwardLogic(vs3_addr, io.mem_vd, io.wb_vd, io.mem_regWrite, io.wb_regWrite) io.forward_d := forwardLogic(0.U, io.mem_vd, io.wb_vd, io.mem_regWrite, io.wb_regWrite) - io.forward_a := DontCare - io.forward_b := DontCare - io.forward_c := DontCare - io.forward_d := DontCare - } diff --git a/src/main/scala/vaquita/components/VecALU.scala b/src/main/scala/vaquita/components/VecALU.scala index 1b79491e..cea5b5e8 100644 --- a/src/main/scala/vaquita/components/VecALU.scala +++ b/src/main/scala/vaquita/components/VecALU.scala @@ -3,6 +3,7 @@ import chisel3._ import chisel3.util._ import ALUObj._ import vaquita.configparameter.VaquitaConfig +import vaquita.components.ALUClasses.{Arith,Permutation,Comparison,NarrowIns} class VecALU(implicit val config: VaquitaConfig) extends Module{ val io = IO(new Bundle{ val vs1_in = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) @@ -10,451 +11,227 @@ class VecALU(implicit val config: VaquitaConfig) extends Module{ val vs3_in = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) val vs0_in = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) val sew = Input(UInt(3.W)) - // val vl = Input(UInt(32.W)) //remove this vl val vl_in = Input(UInt(32.W)) + val rs1_in = Input(UInt(32.W)) + val lmul = Input(UInt(32.W)) + val func3 =Input(UInt(3.W)) val alu_opcode = Input(UInt(6.W)) val mask_arith = Input(Bool()) val vsd_out = Output(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + // val vxrm_bits = Input(UInt(32.W)) }) - def slide_sew_selector( - vstart: Int, - maxStartOffset: SInt, - vs2_in: SInt, - vs3_in: SInt, - slide_vec_wire: UInt, - vs1_value: UInt, - slide_i_value: UInt, - sew: Int, - count_sew: Int - ): SInt = { - val valid_vl = io.vl_in > vstart.U - val vstart_s = vstart.S - Mux1H(Seq( - (valid_vl && (vstart_s < maxStartOffset && vstart_s >= 0.S)) -> vs3_in, - (valid_vl && (vstart_s >= maxStartOffset && vstart_s < io.vl_in.asSInt)) -> vslideup( - vs2_in, - slide_vec_wire % 8.U, - vs0_mask(vstart), - vs3_in, - Mux(vs1_value === 0.U, slide_i_value, slide_vec_wire / 8.U), - sew, - count_sew - ), - valid_vl -> 0.S, - (!valid_vl && (tail === 0.B)) -> vs3_in, - !valid_vl -> Fill(32, 1.U).asSInt - )) - } + val rs1_imm_value = WireInit(0.U(32.W)) + when(io.func3==="b100".U){ + rs1_imm_value := io.rs1_in + }.elsewhen(io.func3==="b011".U){ //for imm + rs1_imm_value := io.vs1_in(0)(0)(4,0).asUInt + }.otherwise{ + rs1_imm_value := io.vs1_in(0)(0).asUInt //////////////////////////////////////////////////////----------------------/////// + } + //convert into one array + val vs0_mask = io.vs0_in.asUInt()(config.vlen-1,0) + // val vs0_mask_bool = Wire(Vec(256, Bool())) - //convert into one array - val vs0_mask = io.vs0_in.asUInt()(config.vlen,0) - //To check active elements and send them to the comparison_func; otherwise, return the element according to the vector tailing concept. + // for (i <- 0 until 256) { + // vs0_mask_bool(i) := vs0_mask(i) // picks bit i as Bool + // } - def comp_element_fn(sew:Int,counter:UInt):SInt={ - val cat_element = WireInit(0.S(32.W)) - val comp_fn_value = comparison_func(sew).asSInt - val comp_shift = 0 - val output_comp_Data = VecInit(Seq.tabulate(config.count_lanes)(i => comp_fn_value(32 * (i + 1) - 1, 32 * i))) - val comp_1bt_cn = WireInit(VecInit(Seq.fill(32)(0.U(32.W)))) - for (i <- 1 to 31) { - comp_1bt_cn(i) := ((io.vl_in) - (32.U * counter)) - when(comp_1bt_cn(i) === i.U) { - cat_element := Cat(io.vs3_in(0)(counter)(31,i), output_comp_Data(counter)(i-1, 0)).asSInt - }.elsewhen(comp_1bt_cn(i)===32.U || (comp_1bt_cn(i)/32.U)>0.U){ - cat_element := output_comp_Data(counter).asSInt - } - } - cat_element - } - def comparison_operators(vs1_in:SInt,vs2_in:SInt):Bool={ - val comparison_table = Seq( - vmseq -> (vs1_in===vs2_in),//vmseq - vmsne -> (vs1_in.asUInt =/= vs2_in.asUInt),//vmsne - vmsltu -> (vs1_in.asUInt > vs2_in.asUInt),//vmsltu - vmslt -> (vs1_in > vs2_in),//vmslt - vmsleu -> (vs1_in.asUInt >= vs2_in.asUInt),//vmsleu - vmsle -> (vs1_in >= vs2_in),//vmsle - vmsgtu -> (vs1_in.asUInt < vs2_in.asUInt),//vmsgtu - vmsgt -> (vs1_in < vs2_in),//vmsgt - vmadc -> Mux((vs1_in +& vs2_in).asUInt < "hffffffff".U,1.B,0.B)//vmadc - // "b011111".U -> (vs1_in < vs2_in)//vmsbc - ) - MuxLookup(io.alu_opcode, 0.B, comparison_table) +// *****************convert into 1 D array to vector register ************************* + val vs1_in = Wire(Vec(8, UInt((config.vlen).W))) + val vs2_in = Wire(Vec(8, UInt((config.vlen).W))) + val vs3_in = Wire(Vec(8, UInt((config.vlen).W))) + // val vsd_out = Wire(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) + for (i <- 0 until 8) { + vs1_in(i) := io.vs1_in(i).asUInt + vs2_in(i) := io.vs2_in(i).asUInt + vs3_in(i) := io.vs3_in(i).asUInt } + // ********************* convert into 2d vectors (8 bit elements) ********************* - def comparison_func(sew: Int): UInt = { - val elementsPerLane = config.vlen / sew - val comparison_vec_bit_wires = WireInit(VecInit(Seq.fill(config.vlen)(0.B))) - val comp_1b = WireInit(VecInit(Seq.fill(config.vlen)(0.B))) - val comp_0b = WireInit(VecInit(Seq.fill(config.vlen)(0.B))) - val comp_vs3 = WireInit(VecInit(Seq.fill(config.vlen)(0.B))) - val vs3_bit = io.vs3_in.asUInt - var counter = 0 - for (i <- 0 until config.count_lanes) { - for (elem_idx <- 0 until elementsPerLane) { - val startBit = elem_idx * sew - val endBit = (elem_idx + 1) * sew - 1 - if (endBit < io.vs1_in(i).getWidth && endBit < io.vs2_in(i).getWidth) { - val vs1_elem = io.vs1_in(i).asUInt()(endBit, startBit) - val vs2_elem = io.vs2_in(i).asUInt()(endBit, startBit) - val comparison = comparison_operators(vs1_elem.asSInt, vs2_elem.asSInt) - comp_1b(counter) := (io.mask_arith && comparison) || (!io.mask_arith && comparison && vs0_mask(counter)) - comp_vs3(counter) := (!vs0_mask(counter) && !io.mask_arith) - comp_0b(counter) := (io.mask_arith && !comparison) - comparison_vec_bit_wires(counter) := MuxCase(0.B, Array( - (comp_0b(counter) === 1.B) -> 0.B, - (comp_vs3(counter) === 1.B) -> vs3_bit(counter), - (comp_1b(counter) === 1.B) -> 1.B - )) - counter += 1 - } - } + val vs1_8 = Wire(Vec(8, Vec(config.lane8, UInt(8.W)))) + val vs2_8 = Wire(Vec(8, Vec(config.lane8, UInt(8.W)))) + val vs3_8 = Wire(Vec(8, Vec(config.lane8, UInt(8.W)))) + // val vsd_8 = Wire(Vec(8, Vec(config.lane8, UInt(8.W)))) + dontTouch(vs1_8) + dontTouch(vs2_8) + dontTouch(vs3_8) + // dontTouch(vsd_8) + + for (i <- 0 until 8) { + var high = 7 + var low = 0 + for (j <- 0 until config.lane8) { + vs1_8(i)(j) := vs1_in(i)(high, low) + vs2_8(i)(j) := vs2_in(i)(high, low) + vs3_8(i)(j) := vs3_in(i)(high, low) + high += 8 + low += 8 } - comparison_vec_bit_wires.asUInt } - def Arithmatic(vs1_in: SInt, vs2_in: SInt,vs3:SInt,sew:Int,v0_bit_mask:UInt): SInt = { - val maxUInt32 = "hFFFFFFFF".U(32.W) - val sum = WireInit(0.S(33.W)) - val sub = WireInit(0.S(33.W)) - val maxValue = (1.S << (sew - 1)) - 1.S - val minValue = -(1.S << (sew - 1)) - sub := (vs2_in -& vs1_in) - sum := (vs1_in.asUInt +& vs2_in.asUInt).asSInt - // Overflow for addition - val positiveOverflowAdd = (vs1_in(sew - 1) === 0.U && vs2_in(sew - 1) === 0.U && sum(sew - 1) === 1.U) - val negativeOverflowAdd = (vs1_in(sew - 1) === 1.U && vs2_in(sew - 1) === 1.U && sum(sew - 1) === 0.U) - // Overflow for subtraction - val positiveOverflowSub = (vs2_in(sew - 1) === 0.U && vs1_in(sew - 1) === 1.U && sub(sew - 1) === 1.U) - val negativeOverflowSub = (vs2_in(sew - 1) === 1.U && vs1_in(sew - 1) === 0.U && sub(sew - 1) === 0.U) - val lookuptable = Seq( - vadd -> (vs1_in + vs2_in),//add - vsub -> (vs2_in - vs1_in),//sub - vrsub -> (vs1_in - vs2_in),//rsub - vand -> (vs1_in & vs2_in),// and - vor -> (vs1_in | vs2_in),//or - vxor -> (vs1_in ^ vs2_in),//xor - vsll -> (vs2_in << (vs1_in.asUInt%sew.U)), //vsll - vsrl -> (vs2_in.asUInt >> (vs1_in.asUInt%sew.U)).asSInt, //vsrl - vsra -> ((vs2_in >> (vs1_in.asUInt%sew.U)).asSInt), //vsra - vmv -> (vs1_in), //vmv - vminu -> Mux(vs1_in.asUInt < vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//minu - vmin -> Mux(vs1_in < vs2_in,vs1_in,vs2_in),//min - vmaxu -> Mux(vs1_in.asUInt > vs2_in.asUInt,vs1_in.asUInt,vs2_in.asUInt).asSInt,//maxu - vmax -> Mux(vs1_in > vs2_in,vs1_in,vs2_in),//max - vsaddu -> Mux(sum(32), "hFFFFFFFF".U, sum(31,0)).asSInt,//vsaddu - vsadd -> (Mux(positiveOverflowAdd, maxValue, Mux(negativeOverflowAdd, minValue, sum))),//vsadd - vssub -> Mux(vs2_in.asUInt < vs1_in.asUInt, 0.U,vs2_in.asUInt - vs1_in.asUInt ).asSInt,//vssubu - vssub -> Mux(positiveOverflowSub, maxValue, Mux(negativeOverflowSub, minValue, sub(31, 0).asSInt)),//vssub - vadc -> (vs1_in.asUInt + vs2_in.asUInt + v0_bit_mask).asSInt, //vadc - vsbc -> (vs2_in.asUInt - vs1_in.asUInt - v0_bit_mask).asSInt //vsbc - // "b101010".U -> vs1_in,//vnsra - // "b101010".U -> vs1_in,//vadc - // "b101010".U -> vs1_in,//vsbc - // "b101010".U -> vs1_in - ) - MuxLookup(io.alu_opcode, 0.S, lookuptable) - } - def arith_32(vs1:SInt , vs2:SInt,vs3:SInt,mask_vs0:Bool):SInt={ - val vsetvli_mask = 0.B - val mask_bit_active_element = (mask_vs0===1.B && io.mask_arith===0.B) || io.mask_arith===1.B - val mask_bit_undisturb = mask_vs0===0.B && io.mask_arith===0.B && vsetvli_mask===0.B - val vec_sew32_b = WireInit(0.S(32.W)) - val vec_sew32_result = WireInit(0.S(config.XLEN.W)) - when(io.alu_opcode==="b010000".U || io.alu_opcode==="b010010".U){ - vec_sew32_b := (Arithmatic(vs1, vs2,vs3,32,mask_vs0.asUInt)).asSInt - }.otherwise{ - vec_sew32_b := Mux(mask_bit_active_element===1.B,Arithmatic(vs1, vs2,vs3,32,mask_vs0.asUInt),Mux(mask_bit_undisturb===1.B,vs3,Fill(32,1.U).asSInt)).asSInt - } - vec_sew32_result := vec_sew32_b - vec_sew32_result - } + // ******************* convert into 2d vectors (16 bit elements) ************* - def arith_16(vs1:SInt , vs2:SInt,vs3:SInt,mask_vs0:Bool):SInt={ - val vsetvli_mask = 0.B - val mask_bit_active_element = (mask_vs0===1.B && io.mask_arith===0.B) || io.mask_arith===1.B - val mask_bit_undisturb = mask_vs0===0.B && io.mask_arith===0.B && vsetvli_mask===0.B - val vec_sew16_result = WireInit(0.S(16.W)) - when(io.alu_opcode==="b010000".U || io.alu_opcode==="b010010".U){ - vec_sew16_result := (Arithmatic(vs1.asSInt, vs2.asSInt,vs3,16,mask_vs0.asUInt)).asSInt - }.otherwise{ - vec_sew16_result := Mux(mask_bit_active_element===1.B,Arithmatic(vs1.asSInt, vs2.asSInt,vs3,16,mask_vs0.asUInt),Mux(mask_bit_undisturb===1.B,vs3,Fill(16,1.U).asSInt)).asSInt - } - vec_sew16_result - } + val vs1_16 = Wire(Vec(8, Vec(config.lane16, UInt(16.W)))) + val vs2_16 = Wire(Vec(8, Vec(config.lane16, UInt(16.W)))) + val vs3_16 = Wire(Vec(8, Vec(config.lane16, UInt(16.W)))) + // val vsd_16 = Wire(Vec(8, Vec(config.lane16, UInt(16.W)))) - def arith_8(vs1:SInt , vs2:SInt,vs3:SInt,mask_vs0:Bool):SInt={ - dontTouch(mask_vs0) - val vsetvli_mask = 0.B - val mask_bit_active_element = (mask_vs0===1.B && io.mask_arith===0.B) || io.mask_arith===1.B - val mask_bit_undisturb = mask_vs0===0.B && io.mask_arith===0.B && vsetvli_mask===0.B - val vec_sew8_result = WireInit(0.S(8.W)) - dontTouch(vec_sew8_result) - when(io.alu_opcode==="b010000".U || io.alu_opcode==="b010010".U){ - vec_sew8_result := (Arithmatic(vs1.asSInt, vs2.asSInt,vs3,8,mask_vs0.asUInt)).asSInt - }.otherwise{ - vec_sew8_result := Mux(mask_bit_active_element===1.B,Arithmatic(vs1.asSInt, vs2.asSInt,vs3,8,mask_vs0.asUInt),Mux(mask_bit_undisturb===1.B,vs3,Fill(16,1.U).asSInt)).asSInt - } - vec_sew8_result + for (i <- 0 until 8) { + var high = 15 + var low = 0 + for (j <- 0 until config.lane16) { + vs1_16(i)(j) := vs1_in(i)(high, low) + vs2_16(i)(j) := vs2_in(i)(high, low) + vs3_16(i)(j) := vs3_in(i)(high, low) + high += 16 + low += 16 + } } + dontTouch(vs2_16) - // Function for vslideup - def vslideup(input: SInt, slide_amount:UInt,mask_vs0:Bool,vs3:SInt,i_incre:UInt,sew:Int,count_sew:Int): SInt = { - val shifted = Wire(SInt(32.W)) - val in_i = WireInit(0.U(33.W)) - val in_j = WireInit(0.U(10.W)) - val sew1 = WireInit(0.U(10.W)) - val count_sew1 = WireInit(0.U(10.W)) - val vs2_element = WireInit(0.S(sew.W)) - val sewup = WireInit(0.U(10.W)) - val sewdown = WireInit(0.U(10.W)) - val vsetvli_mask = 0.B - val mask_bit_active_element = (mask_vs0===1.B && io.mask_arith===0.B) || io.mask_arith===1.B - val mask_bit_undisturb = mask_vs0===0.B && io.mask_arith===0.B && vsetvli_mask===0.B - sewup := ((count_sew*sew)-1).U - sewdown := ((count_sew-1)*sew).U - sew1 := sew.U - count_sew1 := count_sew.U - // Mux(mask_bit_active_element===1.B,Arithmatic(vs1, vs2,vs3,32,mask_vs0.asUInt),Mux(mask_bit_undisturb===1.B,vs3,Fill(32,1.U).asSInt)).asSInt - when(count_sew1===1.U && io.sew==="b000".U){ - in_i := ((slide_amount)/(config.vlen.U/32.U)+i_incre) - in_j := (slide_amount)%(config.vlen.U/32.U) - }.elsewhen(count_sew1===2.U && io.sew==="b000".U){ - in_i := (slide_amount)/(config.vlen.U/32.U)+i_incre - in_j := (slide_amount)%(config.vlen.U/32.U) - } - .otherwise{ - in_i := (slide_amount)/(config.vlen.U/32.U)+i_incre - in_j := (slide_amount)%(config.vlen.U/32.U) - } - vs2_element := (io.vs2_in(in_i)(in_j)((count_sew*sew)-1,(count_sew-1)*sew).asSInt) - shifted := Mux(mask_bit_active_element===1.B,vs2_element,Mux(mask_bit_undisturb===1.B,vs3,Fill(sew,1.U).asSInt)).asSInt - shifted + // ******************* convert into Unsign (32 bit elements Unsign) ************* + + val vs1_32 = Wire(Vec(8, Vec(config.count_lanes, UInt(32.W)))) + val vs2_32 = Wire(Vec(8, Vec(config.count_lanes, UInt(32.W)))) + val vs3_32 = Wire(Vec(8, Vec(config.count_lanes, UInt(32.W)))) + // val vsd_16 = Wire(Vec(8, Vec(config.lane16, UInt(16.W)))) + + for (i <- 0 until 8) { + for (j <- 0 until config.count_lanes) { + vs1_32(i)(j) := io.vs1_in(i)(j).asUInt + vs2_32(i)(j) := io.vs2_in(i)(j).asUInt + vs3_32(i)(j) := io.vs3_in(i)(j).asUInt + } } - val slide_instr = "b001110".U === io.alu_opcode || "b001111".U === io.alu_opcode + + val narrow_ins = Wire(Bool()) + val slide_ins = Wire(Bool()) + val comp_ins = Wire(Bool()) + val red_sum_ins = Wire(Bool()) + + val narrow_valid = Wire(Bool()) + val slide_valid = Wire(Bool()) + val comp_valid = Wire(Bool()) + val red_sum_valid = Wire(Bool()) + val arith_valid = Wire(Bool()) + + slide_ins := "b001110".U === io.alu_opcode || "b001111".U === io.alu_opcode || vrgather === io.alu_opcode + comp_ins := "b011000".U === io.alu_opcode || "b011001".U === io.alu_opcode || "b011010".U === io.alu_opcode || "b011011".U === io.alu_opcode || "b011100".U === io.alu_opcode || "b011101".U === io.alu_opcode || "b011110".U === io.alu_opcode || "b011111".U === io.alu_opcode + narrow_ins := vnsrl===io.alu_opcode || vnsra===io.alu_opcode || vnclipu===io.alu_opcode || vnclip===io.alu_opcode + red_sum_ins := vwredsumu===io.alu_opcode || vwredsum===io.alu_opcode + + narrow_valid := narrow_ins && !slide_ins && !comp_ins && !red_sum_ins + slide_valid := !narrow_ins && slide_ins && !comp_ins && !red_sum_ins + comp_valid := !narrow_ins && !slide_ins && comp_ins && !red_sum_ins + red_sum_valid := !narrow_ins && !slide_ins && !comp_ins && red_sum_ins + arith_valid := !narrow_ins && !slide_ins && !comp_ins && !red_sum_ins + + + val vl= 4 val tail = 0.B + + // call main function var count_mask = 0.U - val comp_bit = "b011000".U === io.alu_opcode || "b011001".U === io.alu_opcode || "b011010".U === io.alu_opcode || "b011011".U === io.alu_opcode || "b011100".U === io.alu_opcode || "b011101".U === io.alu_opcode || "b011110".U === io.alu_opcode || "b011111".U === io.alu_opcode //code changes - when(io.sew==="b000".U && slide_instr===0.B){ - when(comp_bit===0.B){ - var vl_counter = 0 - for (i <- 0 until 8) { - for (j <- 0 until config.count_lanes) { - val idx = (i * config.count_lanes) + j - io.vsd_out(i)(j) := Cat(Mux(io.vl_in > vl_counter.U+3.U, - arith_8(io.vs1_in(i)(j)(31,24).asSInt, io.vs2_in(i)(j)(31,24).asSInt, io.vs3_in(i)(j)(31,24).asSInt, vs0_mask(vl_counter+3)), - Mux(tail === 0.B, io.vs3_in(i)(j)(31,24).asSInt, Fill(8, 1.U).asSInt)), + val comp_8 = Module(new Comparison()(config)) + val arith_8 = Module(new Arith()(config)) + val narrow_8 = Module(new NarrowIns()(config)) + val Permutation_8 = new Permutation()(config) - Mux(io.vl_in > vl_counter.U +2.U, - arith_8(io.vs1_in(i)(j)(23,16).asSInt, io.vs2_in(i)(j)(23,16).asSInt, io.vs3_in(i)(j)(23,16).asSInt, vs0_mask(vl_counter+2)), - Mux(tail === 0.B, io.vs3_in(i)(j)(23,16).asSInt, Fill(8, 1.U).asSInt)), - - Mux(io.vl_in > vl_counter.U+1.U, - arith_8(io.vs1_in(i)(j)(15,8).asSInt, io.vs2_in(i)(j)(15,8).asSInt, io.vs3_in(i)(j)(15,8).asSInt, vs0_mask(vl_counter+1)), - Mux(tail === 0.B, io.vs3_in(i)(j)(15,8).asSInt, Fill(8, 1.U).asSInt)), + val comp_16 = Module(new Comparison()(config)) + val arith_16 = Module(new Arith()(config)) + val narrow_16 = Module(new NarrowIns()(config)) + val Permutation_16 = new Permutation()(config) - Mux(io.vl_in > vl_counter.U, - arith_8(io.vs1_in(i)(j)(7,0).asSInt, io.vs2_in(i)(j)(7,0).asSInt, io.vs3_in(i)(j)(7,0).asSInt, vs0_mask(vl_counter)), - Mux(tail === 0.B, io.vs3_in(i)(j)(7,0).asSInt, Fill(8, 1.U).asSInt)) - ).asSInt - vl_counter = vl_counter + 4 - } - } - }.otherwise{ - var vl_counter1 = 1 - var counter2 = 0 - for (j <- 0 until config.count_lanes) { - io.vsd_out(0)(j) := Mux(io.vl_in > vl_counter1.U,comp_element_fn(8,counter2.U), Mux(tail === 0.B, io.vs3_in(0)(j), Fill(32, 1.U).asSInt)) - vl_counter1 = vl_counter1 + 32 - counter2 = counter2 + 1 - } - for (i <- 1 until 8) { - for (j <- 0 until config.count_lanes) { - io.vsd_out(i)(j) := Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt) - } - } - } - }.elsewhen(io.sew==="b001".U && slide_instr===0.B){//sew=16 - when(comp_bit===0.B){ - val vec_sew16_b = WireInit(0.S(16.W)) - dontTouch(vec_sew16_b) - val sew16_result = WireInit(0.S(config.XLEN.W)) - var vl_counter = 0 - for (i <- 0 until 8) { - for (j <- 0 until config.count_lanes) { - val idx = (i * config.count_lanes) + j - io.vsd_out(i)(j) := Cat(Mux(io.vl_in > vl_counter.U+1.U, - arith_16(io.vs1_in(i)(j)(31,16).asSInt, io.vs2_in(i)(j)(31,16).asSInt, io.vs3_in(i)(j)(31,16).asSInt, vs0_mask(vl_counter+1)), - Mux(tail === 0.B, io.vs3_in(i)(j)(31,16).asSInt, Fill(16, 1.U).asSInt)), - Mux(io.vl_in > vl_counter.U, - arith_16(io.vs1_in(i)(j)(15,0).asSInt, io.vs2_in(i)(j)(15,0).asSInt, io.vs3_in(i)(j)(15,0).asSInt, vs0_mask(vl_counter)), - Mux(tail === 0.B, io.vs3_in(i)(j)(15,0).asSInt, Fill(16, 1.U).asSInt))).asSInt - vl_counter = vl_counter + 2 //counter_of_2 * 2 - } - } - }.otherwise{ - var vl_counter1 = 1 - var counter2 = 0 - for (j <- 0 until config.count_lanes) { - io.vsd_out(0)(j) := Mux(io.vl_in > vl_counter1.U,comp_element_fn(16,counter2.U), Mux(tail === 0.B, io.vs3_in(0)(j), Fill(32, 1.U).asSInt)) - vl_counter1 = vl_counter1 + 32 - counter2 = counter2 + 1 - } - for (i <- 1 until 8) { - for (j <- 0 until config.count_lanes) { - io.vsd_out(i)(j) := Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt) - } - } - } - }.elsewhen(io.sew==="b010".U && slide_instr===0.B){//sew = 32 - when(comp_bit === 0.B) { - var vl_counter = 1 - for (i <- 0 until 8) { - for (j <- 0 until config.count_lanes) { - val idx = (i * config.count_lanes) + j - val mask = vs0_mask(idx) - - io.vsd_out(i)(j) := Mux(io.vl_in >= vl_counter.U, - MuxLookup(io.alu_opcode, arith_32(io.vs1_in(i)(j),io.vs2_in(i)(j), io.vs3_in(i)(j), mask), Seq( - "b001110".U -> 0.S//Mux(io.vs1_in(i)(j) < 8.S,vslideup(io.vs2_in(i)(j), slide_value,mask,io.vs3_in(i)(j),"b001110".U),io.vs3_in(i)(j)), - // "b001111".U -> vslidedown(io.vs1_in(i)(j), vl_counter.U) - // "b001100".U -> vrgather(io.vs1_in(i)(j), io.vs2_in(i)(j)) - )), - Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt) - ) - vl_counter = vl_counter + 1 - } - } - }.otherwise{ - var vl_counter1 = 1 - var counter2 = 0 - for (j <- 0 until config.count_lanes) { - io.vsd_out(0)(j) := Mux(io.vl_in > vl_counter1.U,comp_element_fn(32,counter2.U), Mux(tail === 0.B, io.vs3_in(0)(j), Fill(32, 1.U).asSInt)) - vl_counter1 = vl_counter1 + 32 - counter2 = counter2 + 1 - } - for (i <- 1 until 8) { - for (j <- 0 until config.count_lanes) { - io.vsd_out(i)(j) := Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt) - } - } - } - } - // 368 // } - .elsewhen(slide_instr===1.B && io.sew === "b000".U){ //for slide instructions and sew 8 - var vl_counter = 1 - var vstart = 0 - val slidedown_value = WireInit(0.U(33.W)) - var j_slide_count = 0 - val vs1_value = WireInit(0.U(33.W)) - val slide_vec_wire11 = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.U(32.W)})})) - // val maxStartOffset = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.S(32.W)})})) - vs1_value := io.vs1_in(0)(0).asUInt + val comp_32 = Module(new Comparison()(config)) + val arith_32 = Module(new Arith()(config)) + val Permutation_32 = new Permutation()(config) + + val result_8 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(config.lane8)(0.U(8.W)))))) + val result_16 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(config.lane16)(0.U(16.W)))))) + val result_32 = WireInit(VecInit(Seq.fill(8)(VecInit(Seq.fill(config.count_lanes)(0.U(32.W)))))) + when(io.sew==="b000".U){ + when (arith_valid){ + result_8 <> arith_8.arith_8_result(vs1_8,vs2_8,vs3_8,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane8,8) + } + .elsewhen(slide_valid){ + result_8 <> Permutation_8.permutation(vs1_8,vs2_8,vs3_8,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane8,8,io.func3,io.lmul) + } + .elsewhen (narrow_valid){ + result_8 <> narrow_8.narrow_result(vs1_8,vs2_16,vs3_8,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane8,8) + } + .elsewhen(comp_valid){ + result_8 <> comp_8.main_comp(vs1_8,vs2_8,vs3_8,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane8,8) + } + // connect wires with io.vsd out*********************** for (i <- 0 until 8) { + var base = 0 for (j <- 0 until config.count_lanes) { - slidedown_value := (vs1_value+vl_counter.U-1.U) - // maxStartOffset(i)(j) := Mux((vstart.S) > vs1_value.asSInt, (vstart.S), vs1_value.asSInt) - slide_vec_wire11(i)(j) := (j_slide_count.U) - vs1_value(7,0) // (vstart.U) - vs1_value - when(io.alu_opcode==="b001110".U){ // for slide up // Mux(i.U>0.U && j.U===0.U && vs1_value=/=0.U,,(j.U-vs1_value))("b1111".U).asSInt - io.vsd_out(i)(j) := Cat( - slide_sew_selector(vstart+3,Mux((vstart.S + 3.S) > vs1_value.asSInt, (vstart.S + 3.S), vs1_value.asSInt) ,io.vs2_in(i)(j)(31,24).asSInt,io.vs3_in(i)(j)(31,24).asSInt, (vstart.U) - vs1_value ,vs1_value,i.U,8,4)(7,0).asSInt, - slide_sew_selector(vstart+2,Mux((vstart.S + 2.S) > vs1_value.asSInt, (vstart.S + 2.S), vs1_value.asSInt) ,io.vs2_in(i)(j)(23,16).asSInt,io.vs3_in(i)(j)(23,16).asSInt, (vstart.U) - vs1_value ,vs1_value,i.U,8,3)(7,0).asSInt, - slide_sew_selector(vstart+1,Mux((vstart.S + 1.S) > vs1_value.asSInt, (vstart.S + 1.S), vs1_value.asSInt) ,io.vs2_in(i)(j)(15,8).asSInt,io.vs3_in(i)(j)(15,8).asSInt, (vstart.U) - vs1_value ,vs1_value,i.U,8,2)(7,0).asSInt, - slide_sew_selector(vstart, Mux((vstart.S ) > vs1_value.asSInt, (vstart.S ), vs1_value.asSInt) ,io.vs2_in(i)(j)(7,0).asSInt ,io.vs3_in(i)(j)(7,0).asSInt, (vstart.U) - vs1_value ,vs1_value,i.U,8,1)(7,0).asSInt).asSInt - Mux(vs1_value =/=0.U && vstart.S + 3.S >= (Mux((vstart.S + 3.S) > vs1_value.asSInt, (vstart.S + 3.S), vs1_value.asSInt)) && vstart.S + 3.S < io.vl_in.asSInt,slide_vec_wire11(i)(j)+1.U,slide_vec_wire11(i)(j) - ) // Mux((vstart.U + 3.U) - vs1_value(7,0) %4.U===3.U,(vstart.U + 3.U) - vs1_value(7,0)+1.U,(vstart.U + 3.U) - vs1_value(7,0)) - }.otherwise{ /// for slide down - io.vsd_out(i)(j) := 0.S//Mux(io.vl_in >= vl_counter.U, vslideup(io.vs2_in(i)(j), slidedown_value,(vs0_mask((i * config.count_lanes) + j)),io.vs3_in(i)(j),i.U), - // Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt)) - } - vl_counter = vl_counter + 4 - vstart = vstart + 4 - j_slide_count = j_slide_count + 1 + io.vsd_out(i)(j) := Cat( + result_8(i)(base + 3), + result_8(i)(base + 2), + result_8(i)(base + 1), + result_8(i)(base) + ).asSInt + base = base + 4 } } - }.elsewhen(slide_instr===1.B && io.sew === "b01".U){ //for slide instructions and sew 16 - var vl_counter = 1 - var vstart = 0 - val slidedown_value = WireInit(0.U(33.W)) - val vs1_value = WireInit(0.U(33.W)) - val slide_vec_wire = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.U(32.W)})})) - val maxStartOffset = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.S(32.W)})})) - vs1_value := io.vs1_in(0)(0).asUInt + } + .elsewhen(io.sew==="b001".U){ + when (arith_valid){ + result_16 <> arith_16.arith_8_result(vs1_16,vs2_16,vs3_16,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane16,16) + } + .elsewhen(slide_valid){ + result_16 <> Permutation_8.permutation(vs1_16,vs2_16,vs3_16,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane16,16,io.func3,io.lmul) + } + .elsewhen (narrow_valid){ + result_16 <> narrow_16.narrow_result(vs1_16,vs2_32,vs3_16,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane16,16) + } + .elsewhen(comp_valid){ + result_16 <> comp_16.main_comp(vs1_16,vs2_16,vs3_16,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.lane16,16) + } + // // connect wires with io.vsd out*********************** for (i <- 0 until 8) { + var base = 0 for (j <- 0 until config.count_lanes) { - slidedown_value := (vs1_value+vl_counter.U-1.U) - maxStartOffset(i)(j) := Mux((vstart.S) > vs1_value.asSInt, (vstart.S), vs1_value.asSInt) - slide_vec_wire(i)(j) := (vstart.U) - vs1_value - when(io.alu_opcode==="b001110".U){ // for slide up // Mux(i.U>0.U && j.U===0.U && vs1_value=/=0.U,,(j.U-vs1_value))("b1111".U).asSInt - io.vsd_out(i)(j) := Cat( - slide_sew_selector(vstart+1,Mux((vstart.S + 1.S) > vs1_value.asSInt, (vstart.S + 1.S), vs1_value.asSInt) ,io.vs2_in(i)(j)(15,8).asSInt,io.vs3_in(i)(j)(31,0).asSInt, (vstart.U+1.U) - vs1_value,vs1_value,i.U,16,2)(31,16), - slide_sew_selector(vstart, Mux((vstart.S ) > vs1_value.asSInt, (vstart.S ), vs1_value.asSInt) ,io.vs2_in(i)(j)(7,0).asSInt ,io.vs3_in(i)(j)(15,0).asSInt, (vstart.U ) - vs1_value,vs1_value,i.U,16,1)(15,0)).asSInt - }.otherwise{ /// for slide down - io.vsd_out(i)(j) := 0.S//Mux(io.vl_in >= vl_counter.U, vslideup(io.vs2_in(i)(j), slidedown_value,(vs0_mask((i * config.count_lanes) + j)),io.vs3_in(i)(j),i.U), - // Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt)) - } - vl_counter = vl_counter + 2 - vstart = vstart + 2 + io.vsd_out(i)(j) := Cat( + result_16(i)(base + 1), + result_16(i)(base) + ).asSInt + base = base + 2 } } - }.elsewhen(slide_instr===1.B && io.sew === "b10".U){ //for slide instructions and sew 32 - var vl_counter = 1 - var vstart = 0 - val slidedown_value = WireInit(0.U(33.W)) - val vs1_value = WireInit(0.U(33.W)) - vs1_value := io.vs1_in(0)(0).asUInt - val slide_vec_wire = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.U(32.W)})})) - val maxStartOffset = WireInit(VecInit(Seq.fill(8){VecInit(Seq.fill(8) {0.S(32.W)})})) + } + .elsewhen(io.sew==="b010".U){ + when (arith_valid){ + result_32 <> arith_32.arith_8_result(vs1_32,vs2_32,vs3_32,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.count_lanes,32) + } + .elsewhen(slide_valid){ + val Permutation_8 = new Permutation()(config) + result_32 <> Permutation_8.permutation(vs1_32,vs2_32,vs3_32,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.count_lanes,32,io.func3,io.lmul) + } + .elsewhen(comp_valid){ + result_32 <> comp_32.main_comp(vs1_32,vs2_32,vs3_32,vs0_mask,io.alu_opcode,rs1_imm_value,io.vl_in,io.mask_arith,config.count_lanes,32) + } + // // connect wires with io.vsd out*********************** + for (i <- 0 until 8) { for (j <- 0 until config.count_lanes) { - slidedown_value := (vs1_value+vl_counter.U-1.U) - maxStartOffset(i)(j) := Mux((vstart.S) > vs1_value.asSInt, (vstart.S), vs1_value.asSInt) - slide_vec_wire(i)(j) := (vstart.U) - vs1_value - when(io.alu_opcode==="b001110".U){ // for slide up // Mux(i.U>0.U && j.U===0.U && vs1_value=/=0.U,,(j.U-vs1_value))("b1111".U).asSInt - io.vsd_out(i)(j) := slide_sew_selector(vstart,maxStartOffset(i)(j) ,io.vs2_in(i)(j),io.vs3_in(i)(j),slide_vec_wire(i)(j),vs1_value,i.U,32,1) - }.otherwise{ /// for slide down - io.vsd_out(i)(j) := 0.S//Mux(io.vl_in >= vl_counter.U, vslideup(io.vs2_in(i)(j), slidedown_value,(vs0_mask((i * config.count_lanes) + j)),io.vs3_in(i)(j),i.U), - // Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt)) - } - vl_counter = vl_counter + 1 - vstart = vstart + 1 + io.vsd_out(i)(j) := result_32(i)(j).asSInt } } } .otherwise{ - var vl_counter = 1 - for (i <- 0 until 8) { - for (j <- 0 until config.count_lanes) { - val idx = (i * config.count_lanes) + j - val mask = vs0_mask(idx) - io.vsd_out(i)(j) :=0.S //Mux(io.vl_in >= vl_counter.U, - // arith_32(io.vs1_in(i)(j), io.vs2_in(i)(j), io.vs3_in(i)(j), mask), - // Mux(tail === 0.B, io.vs3_in(i)(j), Fill(32, 1.U).asSInt) - // ) - vl_counter = vl_counter + 1 - } + for (i <- 0 until 8) { + for (j <- 0 until config.count_lanes) { + io.vsd_out(i)(j) :=("h0000ffff".U).asSInt } } -} - -// Function for vslidedown - // def vslidedown(input: SInt, slide_amount: UInt): SInt = { - // val shifted = WireInit(0.S(32.W)) - // // shifted := Mux(mask===1.B && io.mask_arith===0.B || io.mask_arith===1.B,(io.vs2_in(in_i)(in_j)).asSInt, vs3) - // shifted - // } - - // // Function for vrgather - // def vrgather(input: Vec[Vec[SInt]], indices: Vec[Vec[UInt]], i: Int, j: Int): SInt = { - // // Gathers element from input based on indices provided in indices vector - // // Ensures index is within bounds, else returns zero - // val index = indices(i)(j) - // Mux(index < input.length.U, input(i)(index), 0.S) - // } \ No newline at end of file + } +} \ No newline at end of file diff --git a/src/main/scala/vaquita/components/VecRegFile.scala b/src/main/scala/vaquita/components/VecRegFile.scala index 825ad47e..632f0492 100644 --- a/src/main/scala/vaquita/components/VecRegFile.scala +++ b/src/main/scala/vaquita/components/VecRegFile.scala @@ -23,6 +23,7 @@ class VecRegFile(implicit val config: VaquitaConfig) extends Module { val func3 = Input(UInt(3.W)) val store_vs3_to_mem = Input(Bool()) val reg_write_decode = Input(Bool()) + val de_instr = Input(UInt(32.W)) }) val vrf = RegInit(VecInit(Seq.fill(config.reg_count){VecInit(Seq.fill(config.count_lanes) {0.S(config.XLEN.W)})})) dontTouch(vrf) @@ -46,7 +47,7 @@ class VecRegFile(implicit val config: VaquitaConfig) extends Module { io.vs1_data(i)(j) := vrf(io.vs1_addr + offset)(j) io.vs3_data(i)(j) := vrf(vs3_addr + offset)(j) io.vs0_data(i)(j) := vrf(vs0_addr + offset)(j) - }}}.elsewhen((io.reg_write === 1.B) && (io.vd_addr === io.wb_vd_addr && io.store_vs3_to_mem===1.B) ){//use next vs3 addr for store instruction + }}}.elsewhen((io.reg_write === 1.B) && (io.vd_addr === io.wb_vd_addr && io.store_vs3_to_mem===1.B) ){ for (i <- 0 until a) { // for grouping = 8 val offset = i.U for (j <- 0 until (config.count_lanes)) { @@ -99,14 +100,17 @@ class VecRegFile(implicit val config: VaquitaConfig) extends Module { } } - - when(io.lmul===0.U){ + val lmul_wire = WireInit(0.U(5.W)) + val narrow_f6 = io.de_instr(31,26) + val isNarrowOp = narrow_f6 === "d44".U || narrow_f6 === "d45".U || narrow_f6 === "d46".U || narrow_f6 === "d47".U + val lmul_read = Mux(isNarrowOp, Mux(io.lmul === 3.U, 3.U, io.lmul + 1.U), io.lmul) + when(lmul_read===0.U){ read_vrf(1) - }.elsewhen(io.lmul===1.U){ + }.elsewhen(lmul_read===1.U){ read_vrf(2) - }.elsewhen(io.lmul===2.U){ + }.elsewhen(lmul_read===2.U){ read_vrf(4) - }.elsewhen(io.lmul===3.U){ + }.elsewhen(lmul_read===3.U){ read_vrf(8) } .otherwise{ read_vrf(1) diff --git a/src/main/scala/vaquita/configparameter/VaquitaConfig.scala b/src/main/scala/vaquita/configparameter/VaquitaConfig.scala index ea4a53d0..67f365da 100644 --- a/src/main/scala/vaquita/configparameter/VaquitaConfig.scala +++ b/src/main/scala/vaquita/configparameter/VaquitaConfig.scala @@ -3,8 +3,10 @@ package vaquita.configparameter import chisel3._ case class VaquitaConfig( - vlen: Int , - reg_count: Int , - XLEN: Int , - count_lanes: Int,//vlen >> 5, + vlen: Int = 256, + reg_count: Int = 32, + XLEN: Int = 32, + count_lanes: Int = 8 ,// vlen/32=8 + lane8: Int = 32,//vlen/8=32 + lane16: Int = 16,//vlen/16=16 ) \ No newline at end of file diff --git a/src/main/scala/vaquita/pipeline/DecodeStage.scala b/src/main/scala/vaquita/pipeline/DecodeStage.scala index 7a2281f5..dc8dafc4 100644 --- a/src/main/scala/vaquita/pipeline/DecodeStage.scala +++ b/src/main/scala/vaquita/pipeline/DecodeStage.scala @@ -56,6 +56,8 @@ class DecodeStage(implicit val config: VaquitaConfig) extends Module { vec_reg_module.io.vtype := vcsr_module.io.vtype_out vec_reg_module.io.wb_vd_addr := io.de_io.wb_de_instr_in(11, 7) vec_reg_module.io.store_vs3_to_mem := vec_cu_module.io.store_vs3_to_mem + vec_reg_module.io.de_instr := io.de_io.instr + /** Vector CSR Module Wiring */ vcsr_module.io.vec_config := vec_cu_module.io.vec_config diff --git a/src/main/scala/vaquita/pipeline/ExcuteStage.scala b/src/main/scala/vaquita/pipeline/ExcuteStage.scala index fbdec2fa..7e3e79d5 100644 --- a/src/main/scala/vaquita/pipeline/ExcuteStage.scala +++ b/src/main/scala/vaquita/pipeline/ExcuteStage.scala @@ -37,10 +37,14 @@ class ExcuteStage(implicit val config: VaquitaConfig) extends Module { val ex_alu_op_out = RegNext(io.ex_alu_op_in) io.ex_instr_out := RegNext(io.ex_instr_in) vec_alu_module.io.vl_in := vsetvli_module.io.vl + vec_alu_module.io.rs1_in := io.hazard_rs1 + vec_alu_module.io.lmul := io.ex_lmul_in + vec_alu_module.io.func3 := RegNext(io.ex_instr_in(14,12)) val sew_selector = new SewSelector() for (i <- 0 to 7) { // for grouping = 8 for (j <- 0 until (config.count_lanes)) { + // vec_alu_module.io.vs1_in(i)(j) := Mux(io.ex_instr_out(6,0)==="b1010111".U && io.ex_instr_out(14,12)==="b100".U,io.hazard_rs1.asSInt,io.ex_vs1_data_in(i)(j)) vec_alu_module.io.vs1_in(i)(j) := Mux(io.ex_instr_out(6,0)==="b1010111".U && io.ex_instr_out(14,12)==="b100".U,sew_selector.sew_selector_with_element(next_sew,io.hazard_rs1.asSInt),io.ex_vs1_data_in(i)(j)) }} vec_alu_module.io.vs2_in <> io.ex_vs2_data_in diff --git a/src/main/scala/vaquita/pipeline/MemStage.scala b/src/main/scala/vaquita/pipeline/MemStage.scala index d1c4f361..bf667842 100644 --- a/src/main/scala/vaquita/pipeline/MemStage.scala +++ b/src/main/scala/vaquita/pipeline/MemStage.scala @@ -29,7 +29,7 @@ class MemStage(implicit val config: VaquitaConfig) extends Module { vsd_data <> io.mem_vsd_data_in io.mem_vsd_data_out <> vsd_data vs3_data <> io.mem_vs1_data_vs3_in - io.vs3_data_out <> vs3_data//io.mem_vs1_data_vs3_in(i)(j)// vs3_data(i)(j) + io.vs3_data_out <> vs3_data io.mem_instr_out := RegNext(io.mem_instr_in) io.mem_stage_write_en := RegNext(io.write_en) io.mem_stage_read_en := RegNext(io.read_en) diff --git a/src/main/scala/vaquita/pipeline/WBStage.scala b/src/main/scala/vaquita/pipeline/WBStage.scala index 74fbf68c..3dd593bc 100644 --- a/src/main/scala/vaquita/pipeline/WBStage.scala +++ b/src/main/scala/vaquita/pipeline/WBStage.scala @@ -4,7 +4,7 @@ import chisel3.util._ import vaquita.configparameter.VaquitaConfig -class WBStage(implicit val config: VaquitaConfig,val on : Bool =1.B, val off : Bool =0.B) extends Module { +class WBStage(implicit val config: VaquitaConfig) extends Module { val io = IO (new Bundle { val wb_vsd_data_in = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W)))) val wb_vs3_data_in_store = Input(Vec(8, Vec(config.count_lanes, SInt(config.XLEN.W))))