diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py index e67ebd1a3..6b317007e 100644 --- a/src/mase_cocotb/interfaces/streaming.py +++ b/src/mase_cocotb/interfaces/streaming.py @@ -238,50 +238,62 @@ def _check(self, got, exp): self.log.debug("Passed | Got: %20s Exp: %20s Err: %10s" % (g, e, err)) -class MultiSignalStreamDriver(Driver): - def __init__(self, clk, data, valid, ready) -> None: - super().__init__() - self.clk = clk - self.data = data - self.valid = valid - self.ready = ready - self.valid_prob = 1.0 - - def set_valid_prob(self, prob): - assert prob >= 0.0 and prob <= 1.0 - self.valid_prob = prob - - async def _driver_send(self, data) -> None: +class MultiSignalStreamDriver(StreamDriver): + async def _driver_send(self, transaction) -> None: while True: await RisingEdge(self.clk) - for hardware_target, item in zip(self.data, data): - hardware_target.value = item - + if type(self.data) == tuple: + # Drive multiple data bus + for wire, val in zip(self.data, transaction): + wire.value = val + else: + # Drive single data + self.data.value = transaction if random.random() > self.valid_prob: self.valid.value = 0 continue # Try roll random valid again at next clock self.valid.value = 1 await ReadOnly() if self.ready.value == 1: - self.log.debug(f"Sent {data}") + if type(self.data) == tuple: + # Drive multiple data bus + for t in transaction: + self.log.debug("Sent %s" % t) + else: + self.log.debug("Sent %s" % transaction) + if self.record_num_beats: + self.num_beats += 1 break + + # Load extra + # self.load_driver + if self.send_queue.empty(): await RisingEdge(self.clk) self.valid.value = 0 - -class MultiSignalStreamMonitor(Monitor): - def __init__(self, clk, data, valid, ready, check=True): - super().__init__(clk) - self.clk = clk - self.data = data - self.valid = valid - self.ready = ready - self.check = check - - def _trigger(self): - return self.valid.value == 1 and self.ready.value == 1 - + # async def _driver_send(self, data) -> None: + # while True: + # await RisingEdge(self.clk) + # print(self.data, data) + # for hardware_target, item in zip(self.data, data): + # print(hardware_target, item) + # hardware_target.value = item + + # if random.random() > self.valid_prob: + # self.valid.value = 0 + # continue # Try roll random valid again at next clock + # self.valid.value = 1 + # await ReadOnly() + # if self.ready.value == 1: + # self.log.debug(f"Sent {data}") + # break + # if self.send_queue.empty(): + # await RisingEdge(self.clk) + # self.valid.value = 0 + + +class MultiSignalStreamMonitor(StreamMonitor): def _recv(self): def cast_data(value): if type(value) == list: diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py index be535dba5..e7f7293dd 100644 --- a/src/mase_cocotb/testbench.py +++ b/src/mase_cocotb/testbench.py @@ -38,10 +38,6 @@ def get_parameter(self, parameter_name): parameter = getattr(self.dut, parameter_name) return int(parameter) - def get_parameter(self, parameter_name): - parameter = getattr(self.dut, parameter_name) - return int(parameter) - async def reset(self, active_high=True): if self.rst is None: raise Exception( @@ -53,6 +49,10 @@ async def reset(self, active_high=True): self.rst.value = 1 if active_high else 0 await RisingEdge(self.clk) self.rst.value = 0 if active_high else 1 + for monitor in self.output_monitors.values(): + monitor.ready.value = 1 + for driver in self.input_drivers.values(): + driver.valid.value = 0 await RisingEdge(self.clk) async def initialize(self): diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv b/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv index f3ed5beac..52a3f2656 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/log2_max_abs.sv @@ -13,12 +13,12 @@ module log2_max_abs #( input logic clk, input logic rst, /* verilator lint_on UNUSEDSIGNAL */ - input logic [ IN_WIDTH-1:0] data_in [IN_SIZE-1:0], - input logic data_in_valid, - output logic data_in_ready, - output logic [OUT_WIDTH-1:0] data_out, - output logic data_out_valid, - input logic data_out_ready + input logic [ IN_WIDTH-1:0] data_in_0 [IN_SIZE-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + output logic [OUT_WIDTH-1:0] data_out_0, + output logic data_out_0_valid, + input logic data_out_0_ready ); logic [IN_WIDTH - 1:0] or_result; logic [IN_WIDTH - 1:0] abs_data_in[IN_SIZE - 1:0]; @@ -26,28 +26,28 @@ module log2_max_abs #( abs #( .IN_WIDTH(IN_WIDTH) ) abs_i ( - .data_in (data_in[i]), + .data_in (data_in_0[i]), .data_out(abs_data_in[i]) ); end or_tree #( .IN_SIZE (IN_SIZE), - .IN_WIDTH(IN_WIDTH), - ) max_bas_i ( + .IN_WIDTH(IN_WIDTH) + ) or_tree_i ( .clk, .rst, .data_in(abs_data_in), - .data_in_valid(data_in_valid), - .data_in_ready(data_in_ready), + .data_in_valid(data_in_0_valid), + .data_in_ready(data_in_0_ready), .data_out(or_result), - .data_out_valid(data_out_valid), - .data_out_ready(data_out_ready) + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) ); log2_value #( - .IN_WIDTH(IN_WIDTH), + .IN_WIDTH(IN_WIDTH) ) log2_i ( .data_in (or_result), - .data_out(data_out) + .data_out(data_out_0) ); endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv index ccbf0ddf3..394e65ded 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_accumulator.sv @@ -5,13 +5,13 @@ Description : The accumulator for mxint. When inputing different exponent, the mantissa will cast to the same bitwidth then accumulate. */ module mxint_accumulator #( - parameter DATA_IN_0_PRECISION_0 = 8, - parameter DATA_IN_0_PRECISION_1 = 4, + // precision_0 = mantissa_width + // precision_1 = exponent_width + parameter DATA_IN_0_PRECISION_0 = 4, + parameter DATA_IN_0_PRECISION_1 = 8, parameter BLOCK_SIZE = 4, parameter IN_DEPTH = 2, - parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + 2 ** DATA_IN_0_PRECISION_1 + $clog2( - IN_DEPTH - ), + parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + $clog2(IN_DEPTH), parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 ) ( input logic clk, @@ -37,15 +37,23 @@ module mxint_accumulator #( assign data_out_0_valid = (counter == IN_DEPTH); /* verilator lint_on WIDTH */ - // mantissa shift - logic [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_in_0[BLOCK_SIZE - 1:0]; + // lossless shift + logic [DATA_IN_0_PRECISION_0 - 1:0] shifted_mdata_in_0[BLOCK_SIZE - 1:0]; logic [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_out_0[BLOCK_SIZE - 1:0]; logic no_value_in_register; - logic [DATA_IN_0_PRECISION_1 - 1:0] exp_min; + logic [DATA_IN_0_PRECISION_1 - 1:0] exp_max; + + logic [DATA_IN_0_PRECISION_1 - 1:0] mdata_in_shift_value; + logic [DATA_IN_0_PRECISION_1 - 1:0] mdata_in_real_shift_value; + logic [DATA_IN_0_PRECISION_1 - 1:0] mdata_out_shift_value; + logic [DATA_IN_0_PRECISION_1 - 1:0] mdata_out_real_shift_value; + + logic [DATA_IN_0_PRECISION_0 - 1:0] shifted_mdata_in_list [BLOCK_SIZE - 1:0][DATA_IN_0_PRECISION_0 - 1:0]; + logic [DATA_OUT_0_PRECISION_0 - 1:0] shifted_mdata_out_list [BLOCK_SIZE - 1:0][DATA_OUT_0_PRECISION_0 - 1:0]; assign no_value_in_register =(counter == 0 || (data_out_0_valid && data_out_0_ready && data_in_0_valid)); - assign exp_min = ($signed(edata_out_0) > $signed(edata_in_0)) ? edata_in_0 : edata_out_0; + assign exp_max = ($signed(edata_out_0) < $signed(edata_in_0)) ? edata_in_0 : edata_out_0; // counter always_ff @(posedge clk) if (rst) counter <= 0; @@ -58,43 +66,51 @@ module mxint_accumulator #( end else if (data_in_0_valid && data_in_0_ready) counter <= counter + 1; end // mantissa + always_comb begin + mdata_in_shift_value = $signed(exp_max) - $signed(edata_in_0); + mdata_in_real_shift_value = (mdata_in_shift_value < DATA_IN_0_PRECISION_0)? mdata_in_shift_value: DATA_IN_0_PRECISION_0 - 1; + mdata_out_shift_value = $signed(exp_max) - $signed(edata_out_0); + mdata_out_real_shift_value = (mdata_out_shift_value < DATA_OUT_0_PRECISION_0)? mdata_out_shift_value: DATA_OUT_0_PRECISION_0 - 1; + end - for (genvar i = 0; i < BLOCK_SIZE; i++) begin : mantissa_block - // mantissa shift - for (genvar j = 0; j < 2 ** DATA_IN_0_PRECISION_1; j++) begin : static_shift + for (genvar i = 0; i < BLOCK_SIZE; i++) begin : optimize_variable_shift + for (genvar j = 0; j < DATA_IN_0_PRECISION_0; j++) begin : data_in_shift + always_comb begin + shifted_mdata_in_list[i][j] = no_value_in_register ? $signed(mdata_in_0[i]) : + $signed(mdata_in_0[i]) >>> j; + end + end + for (genvar k = 0; k < DATA_OUT_0_PRECISION_0; k++) begin : data_out_shift always_comb begin - if (($signed(edata_in_0) - $signed(exp_min)) == j) - shifted_mdata_in_0[i] = no_value_in_register ? $signed( - mdata_in_0[i] - ) : $signed( - mdata_in_0[i] - ) <<< j; - if (($signed(edata_out_0) - $signed(exp_min)) == j) - shifted_mdata_out_0[i] = $signed(mdata_out_0[i]) <<< j; + shifted_mdata_out_list[i][k] = $signed(mdata_out_0[i]) >>> k; end end - // mantissa out + assign shifted_mdata_in_0[i] = shifted_mdata_in_list[i][mdata_in_real_shift_value]; + assign shifted_mdata_out_0[i] = shifted_mdata_out_list[i][mdata_out_real_shift_value]; + end + + for (genvar i = 0; i < BLOCK_SIZE; i++) begin : mantissa_block always_ff @(posedge clk) if (rst) mdata_out_0[i] <= '0; else begin if (data_out_0_valid) begin if (data_out_0_ready) begin - if (data_in_0_valid) mdata_out_0[i] <= shifted_mdata_in_0[i]; + if (data_in_0_valid) mdata_out_0[i] <= $signed(shifted_mdata_in_0[i]); else mdata_out_0[i] <= '0; end end else if (data_in_0_valid && data_in_0_ready) mdata_out_0[i] <= $signed(shifted_mdata_out_0[i]) + $signed(shifted_mdata_in_0[i]); end end - localparam signed [DATA_IN_0_PRECISION_1 - 1:0] MAXIMUM_EXPONENTIAL = 2**(DATA_IN_0_PRECISION_1 - 1) - 1; + localparam signed [DATA_IN_0_PRECISION_1 - 1:0] MINIMUM_EXPONENTIAL = - 2**(DATA_IN_0_PRECISION_1 - 1); // exponent always_ff @(posedge clk) - if (rst) edata_out_0 <= MAXIMUM_EXPONENTIAL; + if (rst) edata_out_0 <= MINIMUM_EXPONENTIAL; else if (data_out_0_valid) begin if (data_out_0_ready) begin if (data_in_0_valid) edata_out_0 <= edata_in_0; - else edata_out_0 <= MAXIMUM_EXPONENTIAL; + else edata_out_0 <= MINIMUM_EXPONENTIAL; end - end else if (data_in_0_valid && data_in_0_ready) edata_out_0 <= exp_min; + end else if (data_in_0_valid && data_in_0_ready) edata_out_0 <= exp_max; endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv index c5d02830f..e6d875ed2 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_cast.sv @@ -5,6 +5,7 @@ Description : MxInt Cast between Layers. */ module mxint_cast #( parameter IN_MAN_WIDTH = 1, + parameter IN_MAN_FRAC_WIDTH = IN_MAN_WIDTH - 1, parameter IN_EXP_WIDTH = 1, parameter OUT_MAN_WIDTH = 1, parameter OUT_EXP_WIDTH = 1, @@ -39,32 +40,55 @@ module mxint_cast #( logic [LOG2_WIDTH - 1:0] log2_max_value; logic log2_max_value_valid, log2_max_value_ready; - localparam EBIAS = 2 ** (OUT_EXP_WIDTH - 1); - localparam LOSSLESSS_EDATA_WIDTH = max(LOG2_WIDTH, IN_EXP_WIDTH, OUT_EXP_WIDTH) + 2; + localparam LOSSLESSS_EDATA_WIDTH = + (LOG2_WIDTH > IN_EXP_WIDTH && LOG2_WIDTH > OUT_EXP_WIDTH) ? LOG2_WIDTH + 2 : + (IN_EXP_WIDTH > OUT_EXP_WIDTH) ? IN_EXP_WIDTH + 2: + OUT_EXP_WIDTH + 2; + localparam FIFO_DEPTH = $clog2(BLOCK_SIZE); logic [LOSSLESSS_EDATA_WIDTH - 1:0] edata_out_full; + localparam SHIFT_WIDTH = (OUT_EXP_WIDTH > IN_EXP_WIDTH) ? OUT_EXP_WIDTH + 1 : IN_EXP_WIDTH + 1; + logic [SHIFT_WIDTH - 1:0] shift_value; + logic [SHIFT_WIDTH - 1:0] abs_shift_value; + // we dont need to implement full shift here, because we'll clamp in the final. + // in order to avoid shift loss, we set the shift_data_width = OUT_MAN_WIDTH + 1. + localparam SHIFT_DATA_WIDTH = OUT_MAN_WIDTH + 1; + + logic [SHIFT_DATA_WIDTH - 1:0] shift_buffer_data_for_out[BLOCK_SIZE - 1:0]; + logic [SHIFT_DATA_WIDTH - 1:0] shift_data[BLOCK_SIZE - 1:0][SHIFT_DATA_WIDTH - 1:0]; + logic [$clog2(SHIFT_DATA_WIDTH) - 1:0] real_shift_value; log2_max_abs #( .IN_SIZE (BLOCK_SIZE), - .IN_WIDTH(IN_MAN_WIDTH), + .IN_WIDTH(IN_MAN_WIDTH) ) max_bas_i ( .clk, .rst, - .data_in(mdata_in), - .data_in_valid(data_for_max_valid), - .data_in_ready(data_for_max_ready), - .data_out(log2_max_value), - .data_out_valid(log2_max_value_valid), - .data_out_ready(log2_max_value_ready) + .data_in_0(mdata_in), + .data_in_0_valid(data_for_max_valid), + .data_in_0_ready(data_for_max_ready), + .data_out_0(log2_max_value), + .data_out_0_valid(log2_max_value_valid), + .data_out_0_ready(log2_max_value_ready) ); - if (FIFO_DEPTH == 0) begin - always_comb begin - mbuffer_data_for_out = mdata_in; - ebuffer_data_for_out = edata_in; - buffer_data_for_out_valid = data_for_out_valid; - data_for_out_ready = buffer_data_for_out_ready; - end - end else begin + if (FIFO_DEPTH == 0) begin : register + mxint_register_slice #( + .DATA_PRECISION_0($bits(mbuffer_data_for_out[0])), + .DATA_PRECISION_1($bits(ebuffer_data_for_out)), + .IN_NUM(BLOCK_SIZE) + ) register_slice ( + .clk (clk), + .rst (rst), + .mdata_in (mdata_in), + .edata_in (edata_in), + .data_in_valid (data_for_out_valid), + .data_in_ready (data_for_out_ready), + .mdata_out (mbuffer_data_for_out), + .edata_out (ebuffer_data_for_out), + .data_out_valid(buffer_data_for_out_valid), + .data_out_ready(buffer_data_for_out_ready) + ); + end else begin : data_buffer unpacked_mx_fifo #( .DEPTH(FIFO_DEPTH), .MAN_WIDTH(IN_MAN_WIDTH), @@ -89,7 +113,11 @@ module mxint_cast #( .data_out_valid(data_out_valid), .data_out_ready(data_out_ready) ); - assign edata_out_full = $signed(log2_max_value) + $signed(ebuffer_data_for_out) - EBIAS; + assign edata_out_full = $signed( + log2_max_value + ) + $signed( + ebuffer_data_for_out + ) - IN_MAN_FRAC_WIDTH; // clamp signed_clamp #( .IN_WIDTH (LOSSLESSS_EDATA_WIDTH), @@ -98,36 +126,50 @@ module mxint_cast #( .in_data (edata_out_full), .out_data(edata_out) ); - localparam SHIFT_WIDTH = max(OUT_EXP_WIDTH, IN_EXP_WIDTH, 0) + 1; - logic [SHIFT_WIDTH - 1:0] shift_value; - assign shift_value = $signed(edata_out) - $signed(ebuffer_data_for_out); - logic [SHIFT_WIDTH - 1:0] abs_shift_value; - assign abs_shift_value = (shift_value[SHIFT_WIDTH-1]) ? (~shift_value + 1) : shift_value; + optimized_variable_shift #( + .IN_WIDTH(IN_MAN_WIDTH), + .SHIFT_WIDTH(SHIFT_WIDTH), + .OUT_WIDTH(OUT_MAN_WIDTH), + .BLOCK_SIZE(BLOCK_SIZE) + ) ovshift_inst ( + .data_in(mbuffer_data_for_out), + .shift_value(shift_value), + .data_out(mdata_out) + ); + assign shift_value = $signed( + edata_out + ) - $signed( + ebuffer_data_for_out + ) + IN_MAN_FRAC_WIDTH - (OUT_MAN_WIDTH - 1); + // assign abs_shift_value = (shift_value[SHIFT_WIDTH-1]) ? (~shift_value + 1) : shift_value; + // assign real_shift_value = (abs_shift_value < SHIFT_DATA_WIDTH)? abs_shift_value: SHIFT_DATA_WIDTH - 1; - logic [IN_MAN_WIDTH + EBIAS - 1:0] shift_buffer_data_for_out[BLOCK_SIZE - 1:0]; - for (genvar i = 0; i < BLOCK_SIZE; i++) begin - for (genvar j = 0; j < 2 ** SHIFT_WIDTH; j++) - always_comb - if (abs_shift_value == j) - shift_buffer_data_for_out[i] = (shift_value[SHIFT_WIDTH-1]) ? $signed( - mbuffer_data_for_out[i] - ) <<< j : $signed( - mbuffer_data_for_out[i] - ) >>> j; - signed_clamp #( - .IN_WIDTH (IN_MAN_WIDTH + EBIAS), - .OUT_WIDTH(OUT_MAN_WIDTH) - ) exp_clamp ( - .in_data (shift_buffer_data_for_out[i]), - .out_data(mdata_out[i]) - ); - end + // for (genvar i = 0; i < BLOCK_SIZE; i++) begin + // for (genvar j = 0; j < SHIFT_DATA_WIDTH; j++) begin + // always_comb begin + // shift_data[i][j] = (shift_value[SHIFT_WIDTH-1]) ? $signed( + // mbuffer_data_for_out[i] + // ) <<< j : $signed( + // mbuffer_data_for_out[i] + // ) >>> j; + // end + // end + // assign shift_buffer_data_for_out[i] = shift_data[i][real_shift_value]; + // end + // for (genvar i = 0; i < BLOCK_SIZE; i++) begin + // signed_clamp #( + // .IN_WIDTH (OUT_MAN_WIDTH + 1), + // .OUT_WIDTH(OUT_MAN_WIDTH) + // ) exp_clamp ( + // .in_data (shift_buffer_data_for_out[i]), + // .out_data(mdata_out[i]) + // ); + // end endmodule -function [31:0] max; - input [31:0] x, y, z; - begin - if (x > y && x > z) max = x; - else if (y > z) max = y; - else max = z; - end -endfunction +// function int max(input int x, y, z); +// begin +// if (x > y && x > z) max = x; +// else if (y > z) max = y; +// else max = z; +// end +// endfunction diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_circular.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_circular.sv index 563b2be5b..0396255ff 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_circular.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_circular.sv @@ -27,33 +27,36 @@ module mxint_circular #( output data_out_valid, input data_out_ready ); - initial begin - assert (DATA_PRECISION_0 >= DATA_PRECISION_1) - else $fatal("DATA_PRECISION_0 must larger than PRECISION_1"); - end - logic [DATA_PRECISION_0 - 1:0] packed_data_in [IN_NUM:0]; - logic [DATA_PRECISION_0 - 1:0] packed_data_out[IN_NUM:0]; - always_comb begin : data_pack - packed_data_in[IN_NUM-1:0] = mdata_in; - packed_data_in[IN_NUM] = $signed(edata_in); - mdata_out = packed_data_out[IN_NUM-1:0]; - edata_out = packed_data_out[IN_NUM]; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] data_in_flatten[0:0]; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] data_out_flatten[0:0]; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] packed_data_out_flatten; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] packed_data_in_flatten; + assign data_in_flatten[0] = packed_data_in_flatten; + assign packed_data_out_flatten = data_out_flatten[0]; + for (genvar i = 0; i < IN_NUM; i++) begin : reshape + assign packed_data_in_flatten[(i+1)*DATA_PRECISION_0-1:i*DATA_PRECISION_0] = mdata_in[i]; end + assign packed_data_in_flatten[DATA_PRECISION_0*IN_NUM+DATA_PRECISION_1-1:DATA_PRECISION_0*IN_NUM] = edata_in; input_buffer #( - .DATA_WIDTH (DATA_PRECISION_0), - .IN_NUM (IN_NUM + 1), + .DATA_WIDTH (DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1), + .IN_NUM (1), .REPEAT (REPEAT), .BUFFER_SIZE(BUFFER_SIZE) ) mdata_in_0_buffer ( .clk, .rst, // Input streaming port - .data_in(packed_data_in), + .data_in(data_in_flatten), .data_in_valid(data_in_valid), .data_in_ready(data_in_ready), // Output streaming port - .data_out(packed_data_out), + .data_out(data_out_flatten), .data_out_valid(data_out_valid), .data_out_ready(data_out_ready) ); + for (genvar i = 0; i < IN_NUM; i++) begin : unreshape + assign mdata_out[i] = packed_data_out_flatten[(i+1)*DATA_PRECISION_0-1:i*DATA_PRECISION_0]; + end + assign edata_out = packed_data_out_flatten[DATA_PRECISION_0*IN_NUM+DATA_PRECISION_1-1:DATA_PRECISION_0*IN_NUM]; + endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv index b345241db..f05e59bdd 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_linear.sv @@ -151,7 +151,7 @@ module mxint_linear #( DATA_IN_0_PARALLELISM_DIM_0 ); localparam FDP_EXP_WIDTH = (WEIGHT_PRECISION_1 > DATA_IN_0_PRECISION_1)? WEIGHT_PRECISION_1 + 1: DATA_IN_0_PRECISION_1 + 1; - localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH_DIM_0) + 2 ** FDP_EXP_WIDTH; + localparam ACC_WIDTH = FDP_WIDTH + $clog2(IN_0_DEPTH_DIM_0); localparam ACC_EXP_WIDTH = FDP_EXP_WIDTH; localparam LOSSLESS_OUT_WIDTH = ACC_WIDTH + HAS_BIAS; localparam LOSSLESS_OUT_EXP_WIDTH = ACC_EXP_WIDTH; @@ -196,7 +196,9 @@ module mxint_linear #( .DATA_IN_0_PRECISION_1(DATA_IN_0_PRECISION_1), .WEIGHT_PRECISION_0(WEIGHT_PRECISION_0), .WEIGHT_PRECISION_1(WEIGHT_PRECISION_1), - .BLOCK_SIZE(DATA_IN_0_PARALLELISM_DIM_0) + .BLOCK_SIZE(DATA_IN_0_PARALLELISM_DIM_0), + .DATA_OUT_0_PRECISION_0(FDP_WIDTH), + .DATA_OUT_0_PRECISION_1(FDP_EXP_WIDTH) ) mxdp_inst ( .clk(clk), .rst(rst), @@ -222,7 +224,9 @@ module mxint_linear #( .DATA_IN_0_PRECISION_0(FDP_WIDTH), .DATA_IN_0_PRECISION_1(FDP_EXP_WIDTH), .IN_DEPTH(IN_0_DEPTH_DIM_0), - .BLOCK_SIZE(DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0) + .BLOCK_SIZE(DATA_OUT_0_PARALLELISM_DIM_1 * DATA_OUT_0_PARALLELISM_DIM_0), + .DATA_OUT_0_PRECISION_0(ACC_WIDTH), + .DATA_OUT_0_PRECISION_1(FDP_EXP_WIDTH) ) accumulator_inst ( .clk(clk), .rst(rst), @@ -250,14 +254,23 @@ module mxint_linear #( .data_out_valid(cast_data_out_0_valid), .data_out_ready(cast_data_out_0_ready) ); - assign exp_difference = $signed(circular_ebias) - $signed(acc_edata_out); - assign abs_shift_value = exp_difference[FDP_EXP_WIDTH - 1]? (~exp_difference + 1): exp_difference; + assign exp_difference = -($signed( + circular_ebias + ) - $signed( + acc_edata_out + ) + DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 - 2 - (BIAS_PRECISION_0 - 1)); + + optimized_variable_shift #( + .IN_WIDTH(BIAS_PRECISION_0), + .SHIFT_WIDTH(FDP_EXP_WIDTH), + .OUT_WIDTH(LOSSLESS_OUT_WIDTH), + .BLOCK_SIZE(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1) + ) ovshift_inst ( + .data_in(mbias_sext), + .shift_value(exp_difference), + .data_out(shifted_mbias) + ); for (genvar m = 0; m < DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1; m++) begin - assign shifted_mbias[m] = exp_difference[FDP_EXP_WIDTH-1] ? $signed( - mbias_sext[m] - ) >>> abs_shift_value : $signed( - mbias_sext[m] - ) <<< abs_shift_value; assign cast_mdata_out_0[m] = $signed(shifted_mbias[m]) + $signed(acc_mdata_out[m]); end assign cast_edata_out_0 = acc_edata_out; @@ -266,10 +279,11 @@ module mxint_linear #( assign cast_data_out_0_valid = acc_data_out_valid; assign cast_mdata_out_0 = acc_mdata_out; assign cast_edata_out_0 = acc_edata_out; - assign bias_ready = 1; + assign circular_bias_ready = 1; end mxint_cast #( .IN_MAN_WIDTH(LOSSLESS_OUT_WIDTH), + .IN_MAN_FRAC_WIDTH(DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 - 2), .IN_EXP_WIDTH(LOSSLESS_OUT_EXP_WIDTH), .OUT_MAN_WIDTH(DATA_OUT_0_PRECISION_0), .OUT_EXP_WIDTH(DATA_OUT_0_PRECISION_1), diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matmul.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matmul.sv index 325060743..9461f6b5f 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matmul.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_matmul.sv @@ -141,7 +141,7 @@ module mxint_matmul #( logic [C_DEPTH_DIM0-1:0] acc_out_valid; logic [C_DEPTH_DIM0-1:0] acc_out_ready; localparam MAT_ACC_EXP_WIDTH = SM_EXP_WIDTH; - localparam MAT_ACC_OUT_WIDTH = SM_OUT_WIDTH + 2 ** SM_EXP_WIDTH + $clog2(B_DEPTH_DIM1); + localparam MAT_ACC_OUT_WIDTH = SM_OUT_WIDTH + $clog2(B_DEPTH_DIM1); logic [MAT_ACC_OUT_WIDTH-1:0] macc_out_data[C_DEPTH_DIM0-1:0][C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; logic [MAT_ACC_EXP_WIDTH-1:0] eacc_out_data[C_DEPTH_DIM0-1:0]; @@ -358,14 +358,14 @@ module mxint_matmul #( simple_matmul #( .N (A_COMPUTE_DIM1), - .M (A_COMPUTE_DIM0), // == B_COMPUTE_DIM1 + .M (A_COMPUTE_DIM0), // == B_COMPUTE_DIM1 .K (B_COMPUTE_DIM0), .X_WIDTH (A_MAN_WIDTH), - .X_FRAC_WIDTH (0), + .X_FRAC_WIDTH (A_MAN_WIDTH - 1), .Y_WIDTH (B_MAN_WIDTH), - .Y_FRAC_WIDTH (0), + .Y_FRAC_WIDTH (B_MAN_WIDTH - 1), .OUT_WIDTH (SM_OUT_WIDTH), - .OUT_FRAC_WIDTH(0) + .OUT_FRAC_WIDTH(A_MAN_WIDTH + B_MAN_WIDTH - 2) ) simple_matmul_inst ( .clk (clk), .rst (rst), @@ -414,6 +414,7 @@ module mxint_matmul #( mxint_cast #( .IN_MAN_WIDTH(MAT_ACC_OUT_WIDTH), + .IN_MAN_FRAC_WIDTH(A_MAN_WIDTH + B_MAN_WIDTH - 2), .IN_EXP_WIDTH(MAT_ACC_EXP_WIDTH), .OUT_MAN_WIDTH(OUT_MAN_WIDTH), .OUT_EXP_WIDTH(OUT_EXP_WIDTH), diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_register_slice.sv b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_register_slice.sv index 3cbbbfa66..070bdd4a1 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/mxint_register_slice.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/mxint_register_slice.sv @@ -27,31 +27,28 @@ module mxint_register_slice #( output data_out_valid, input data_out_ready ); - initial begin - assert (DATA_PRECISION_0 >= DATA_PRECISION_1) - else $fatal("DATA_PRECISION_0 must larger than PRECISION_1"); - end - logic [DATA_PRECISION_0 - 1:0] packed_data_in [IN_NUM:0]; - logic [DATA_PRECISION_0 - 1:0] packed_data_out[IN_NUM:0]; - always_comb begin : data_pack - packed_data_in[IN_NUM-1:0] = mdata_in; - packed_data_in[IN_NUM] = $signed(edata_in); - mdata_out = packed_data_out[IN_NUM-1:0]; - edata_out = packed_data_out[IN_NUM]; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] data_in_flatten; + logic [DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1 - 1:0] data_out_flatten; + for (genvar i = 0; i < IN_NUM; i++) begin : reshape + assign data_in_flatten[(i+1)*DATA_PRECISION_0-1:i*DATA_PRECISION_0] = mdata_in[i]; end + assign data_in_flatten[DATA_PRECISION_0*IN_NUM+DATA_PRECISION_1-1:DATA_PRECISION_0*IN_NUM] = edata_in; - unpacked_register_slice #( - .DATA_WIDTH(DATA_PRECISION_0), - .IN_SIZE(IN_NUM + 1) - ) register_slice ( + register_slice #( + .DATA_WIDTH(DATA_PRECISION_0 * IN_NUM + DATA_PRECISION_1) + ) register_slice_i ( .clk (clk), .rst (rst), - .data_in (packed_data_in), + .data_in (data_in_flatten), .data_in_valid (data_in_valid), .data_in_ready (data_in_ready), - .data_out (packed_data_out), + .data_out (data_out_flatten), .data_out_valid(data_out_valid), .data_out_ready(data_out_ready) ); + for (genvar i = 0; i < IN_NUM; i++) begin : unreshape + assign mdata_out[i] = data_out_flatten[(i+1)*DATA_PRECISION_0-1:i*DATA_PRECISION_0]; + end + assign edata_out = data_out_flatten[DATA_PRECISION_0*IN_NUM+DATA_PRECISION_1-1:DATA_PRECISION_0*IN_NUM]; endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/optimized_variable_shift.sv b/src/mase_components/linear_layers/mxint_operators/rtl/optimized_variable_shift.sv new file mode 100644 index 000000000..1cdcca5ff --- /dev/null +++ b/src/mase_components/linear_layers/mxint_operators/rtl/optimized_variable_shift.sv @@ -0,0 +1,40 @@ +`timescale 1ns / 1ps +/* +Module : optimized_variable_shift +Description : optimized version of variable shift. +*/ +module optimized_variable_shift #( + parameter IN_WIDTH = -1, + parameter BLOCK_SIZE = -1, + parameter SHIFT_WIDTH = -1, + parameter OUT_WIDTH = -1 +) ( + input logic [IN_WIDTH - 1:0] data_in[BLOCK_SIZE - 1:0], + input logic [SHIFT_WIDTH - 1:0] shift_value, + output logic [OUT_WIDTH - 1:0] data_out[BLOCK_SIZE - 1:0] +); + localparam SHIFT_DATA_WIDTH = OUT_WIDTH + 1; + logic [SHIFT_WIDTH - 1:0] abs_shift_value, real_shift_value; + assign abs_shift_value = (shift_value[SHIFT_WIDTH-1]) ? (~shift_value + 1) : shift_value; + assign real_shift_value = (abs_shift_value < SHIFT_DATA_WIDTH)? abs_shift_value: SHIFT_DATA_WIDTH - 1; + logic [SHIFT_DATA_WIDTH - 1:0] shift_data[BLOCK_SIZE - 1:0]; + logic [SHIFT_DATA_WIDTH - 1:0] shift_data_list[BLOCK_SIZE - 1:0][SHIFT_DATA_WIDTH -1 : 0]; + for (genvar i = 0; i < BLOCK_SIZE; i++) begin + for (genvar j = 0; j < SHIFT_DATA_WIDTH; j++) begin + always_comb begin + shift_data_list[i][j] = (shift_value[SHIFT_WIDTH-1]) ? $signed(data_in[i]) <<< j : + $signed(data_in[i]) >>> j; + end + end + assign shift_data[i] = shift_data_list[i][real_shift_value]; + end + for (genvar i = 0; i < BLOCK_SIZE; i++) begin + signed_clamp #( + .IN_WIDTH (OUT_WIDTH + 1), + .OUT_WIDTH(OUT_WIDTH) + ) data_clamp ( + .in_data (shift_data[i]), + .out_data(data_out[i]) + ); + end +endmodule diff --git a/src/mase_components/linear_layers/mxint_operators/rtl/or_tree.sv b/src/mase_components/linear_layers/mxint_operators/rtl/or_tree.sv index 570928f5d..94b60bdec 100644 --- a/src/mase_components/linear_layers/mxint_operators/rtl/or_tree.sv +++ b/src/mase_components/linear_layers/mxint_operators/rtl/or_tree.sv @@ -5,20 +5,19 @@ Description : This module actually implement the tree structure of or logic. */ module or_tree #( - parameter IN_SIZE = 2, - parameter IN_WIDTH = 32, - parameter OUT_WIDTH = IN_WIDTH + parameter IN_SIZE = 2, + parameter IN_WIDTH = 32 ) ( /* verilator lint_off UNUSEDSIGNAL */ - input logic clk, - input logic rst, + input logic clk, + input logic rst, /* verilator lint_on UNUSEDSIGNAL */ - input logic [ IN_WIDTH-1:0] data_in [IN_SIZE-1:0], - input logic data_in_valid, - output logic data_in_ready, - output logic [OUT_WIDTH-1:0] data_out, - output logic data_out_valid, - input logic data_out_ready + input logic [IN_WIDTH-1:0] data_in [IN_SIZE-1:0], + input logic data_in_valid, + output logic data_in_ready, + output logic [IN_WIDTH-1:0] data_out, + output logic data_out_valid, + input logic data_out_ready ); localparam LEVELS = $clog2(IN_SIZE); @@ -30,17 +29,29 @@ module or_tree #( generate if (LEVELS == 0) begin : gen_skip_adder_tree - assign data_out = data_in[0][IN_WIDTH-1] ? ~data_in[0] + 1 : data_in[0]; - assign data_out_valid = data_in_valid; - assign data_in_ready = data_out_ready; + register_slice #( + .DATA_WIDTH(IN_WIDTH) + ) register_slice_i ( + .clk (clk), + .rst (rst), + .data_in_valid (data_in_valid), + .data_in_ready (data_in_ready), + .data_in (data_in[0][IN_WIDTH-1] ? ~data_in[0] + 1 : data_in[0]), + .data_out_valid(data_out_valid), + .data_out_ready(data_out_ready), + .data_out (data_out) + ); + // assign data_out = data_in[0][IN_WIDTH-1] ? ~data_in[0] + 1 : data_in[0]; + // assign data_out_valid = data_in_valid; + // assign data_in_ready = data_out_ready; end else begin : gen_adder_tree // data & sum wires are oversized on purpose for vivado. - logic [OUT_WIDTH*IN_SIZE-1:0] data[LEVELS:0]; - logic [OUT_WIDTH*IN_SIZE-1:0] or_result[LEVELS-1:0]; - logic valid[IN_SIZE-1:0]; - logic ready[IN_SIZE-1:0]; + logic [IN_WIDTH*IN_SIZE-1:0] data[LEVELS:0]; + logic [IN_WIDTH*IN_SIZE-1:0] or_result[LEVELS-1:0]; + logic valid[LEVELS:0]; + logic ready[LEVELS:0]; // Generate adder for each layer for (genvar i = 0; i < LEVELS; i++) begin : level @@ -60,7 +71,7 @@ module or_tree #( register_slice #( .DATA_WIDTH(LEVEL_OUT_SIZE * LEVEL_OUT_WIDTH) - ) register_slice ( + ) register_slice_i ( .clk (clk), .rst (rst), .data_in (or_result[i]), @@ -80,7 +91,7 @@ module or_tree #( assign valid[0] = data_in_valid; assign data_in_ready = ready[0]; - assign data_out = data[LEVELS][OUT_WIDTH-1:0]; + assign data_out = data[LEVELS][IN_WIDTH-1:0]; assign data_out_valid = valid[LEVELS]; assign ready[LEVELS] = data_out_ready; diff --git a/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py b/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py index e66d2e0f2..c2e09a318 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/log2_max_abs_tb.py @@ -1,164 +1,147 @@ #!/usr/bin/env python3 -# This script tests the fixed point adder tree -import os, math, logging, pytest +# This script tests the fixed point linear +import os, logging + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import * -from mase_cocotb.random_test import RandomSource, RandomSink, check_results from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + StreamMonitor, +) + from mase_cocotb.runner import mase_runner +from utils import mxint_quantize -import cocotb -from cocotb.triggers import Timer -from cocotb.triggers import FallingEdge -from cocotb.clock import Clock +import torch +from math import ceil, log2 +import random +from mase_cocotb.utils import bit_driver -debug = False +logger = logging.getLogger("testbench") +logger.setLevel(logging.DEBUG) -logger = logging.getLogger("tb_signals") -if debug: - logger.setLevel(logging.DEBUG) +torch.manual_seed(10) -# DUT test specifications -class VerificationCase(Testbench): - def __init__(self, dut, samples=10): +class Log2_max_abs_tb(Testbench): + def __init__(self, dut, num=1) -> None: super().__init__(dut, dut.clk, dut.rst) - self.assign_self_params(["IN_SIZE", "IN_WIDTH"]) - self.data_in_width = self.IN_WIDTH - self.num = self.IN_SIZE - self.inputs = RandomSource( - samples=samples, num=self.num, max_stalls=2 * samples, debug=debug + self.num = num + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + + cocotb.start_soon(check_signal(dut, self.log)) + self.data_in_0_driver = StreamDriver( + dut.clk, + dut.data_in_0, + dut.data_in_0_valid, + dut.data_in_0_ready, ) - self.outputs = RandomSink( - samples=samples, num=self.num, max_stalls=2 * samples, debug=debug + self.data_out_0_monitor = StreamMonitor( + dut.clk, + dut.data_out_0, + dut.data_out_0_valid, + dut.data_out_0_ready, + check=True, ) - self.samples = samples - self.ref = self.sw_compute() - - def sw_compute(self): - ref = [] - for i in range(self.samples): - breakpoint() - ref.append( - math.ceil(math.log2(max([abs(data) for data in self.inputs.data[i]]))) - ) - print(self.inputs.data[i]) - ref.reverse() - return ref - - -# Check if an impossible state is reached -def is_impossible_state(data_in_ready, data_in_valid, data_out_ready, data_out_valid): - # (0, X, 0, 0) - # (0, X, 1, 0) - # (0, X, 1, 1) - if (not data_in_ready) and not ((not data_out_ready) and data_out_valid): - return True - return False + self.input_drivers = {"in": self.data_in_0_driver} + self.output_monitors = {"out": self.data_out_0_monitor} + self.data_in_0_driver.log.setLevel(logging.DEBUG) + self.data_out_0_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self): + from math import ceil, log2 + + data_in = torch.randint(-20, 20, size=(self.get_parameter("IN_SIZE"),)) + log2_max = ceil(log2((int(data_in.abs().max()) + 1e-6))) + inputs = [data_in.tolist()] + outputs = [log2_max] + return inputs, outputs + + async def run_test(self, samples, us): + await self.reset() + logger.info(f"Reset finished") + self.data_out_0_monitor.ready.value = 1 + self.data_in_0_driver.valid.value = 0 + for _ in range(samples): + logger.info(f"generating inputs") + inputs, exp_outputs = self.generate_inputs() + + # Load the inputs driver + print(inputs) + self.data_in_0_driver.load_driver(inputs) + # Load the output monitor + self.data_out_0_monitor.load_monitor(exp_outputs) + + await Timer(us, units="us") + assert self.data_out_0_monitor.exp_queue.empty() + + +async def check_signal(dut, log): + # await Timer(20, units="ns") + while True: + await RisingEdge(dut.clk) + await ReadOnly() + if str(dut.data_out_0_valid) == "1" and str(dut.data_out_0_ready) == "1": + print(dut.or_result.value) + # print("end") + + +# @cocotb.test() +# async def test(dut): +# tb = Log2_max_abs_tb(dut, 1) +# await tb.run_test(samples=10, us=5) + +# @cocotb.test() +# async def single_mult(dut): +# tb = MXIntMatmulTB(dut) +# tb.output_monitor.ready.value = 1 +# await tb.run_test(batches=1, us=100) + + +# @cocotb.test() +# async def repeated_mult(dut): +# tb = MXIntMatmulTB(dut) +# tb.output_monitor.ready.value = 1 +# await tb.run_test(batches=1000, us=2000) @cocotb.test() -async def cocotb_test_fixed_adder_tree(dut): - """Test integer based adder tree""" - samples = 1 - test_case = VerificationCase(dut, samples=samples) - - # Reset cycle - await Timer(20, units="ns") - dut.rst.value = 1 - await Timer(100, units="ns") - dut.rst.value = 0 - - # Create a 10ns-period clock on port clk - clock = Clock(dut.clk, 10, units="ns") - # Start the clock - cocotb.start_soon(clock.start()) - await Timer(500, units="ns") - - # Synchronize with the clock - dut.data_in_valid.value = 0 - dut.data_out_ready.value = 1 - logger.debug( - "Pre-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - await FallingEdge(dut.clk) - logger.debug( - "Post-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - logger.debug( - "Pre-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - await FallingEdge(dut.clk) - logger.debug( - "Post-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - - done = False - while not done: - await FallingEdge(dut.clk) - logger.debug( - "Post-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - dut.data_in_valid.value = test_case.inputs.pre_compute() - await Timer(1, units="ns") - dut.data_out_ready.value = test_case.outputs.pre_compute( - dut.data_out_valid.value - ) - await Timer(1, units="ns") - dut.data_in_valid.value, dut.data_in.value = test_case.inputs.compute( - dut.data_in_ready.value - ) - await Timer(1, units="ns") - dut.data_out_ready.value = test_case.outputs.compute( - dut.data_out_valid.value, dut.data_out.value - ) - logger.debug( - "Pre-clk State: (data_in_ready,data_in_valid,data_out_ready,data_out_valid) = ({},{},{},{})".format( - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - done = test_case.inputs.is_empty() and test_case.outputs.is_full() - check_results([i.signed_integer for i in test_case.outputs.data], test_case.ref) +async def repeated_mult_valid_backpressure(dut): + tb = Log2_max_abs_tb(dut, 1) + tb.data_in_0_driver.set_valid_prob(0.7) + cocotb.start_soon(bit_driver(dut.data_out_0_ready, dut.clk, 0.6)) + await tb.run_test(samples=20, us=200) -@pytest.mark.dev -def test_abs_max_tree(): +if __name__ == "__main__": mase_runner( + trace=True, module_param_list=[ - # Power of 2's - {"IN_SIZE": 2, "IN_WIDTH": 8}, + # { + # "DATA_IN_0_PRECISION_0": 8, + # "DATA_IN_0_PRECISION_1": 4, + # "BLOCK_SIZE": 1, + # "IN_DEPTH": 1, + # }, + # { + # "DATA_IN_0_PRECISION_0": 8, + # "DATA_IN_0_PRECISION_1": 4, + # "BLOCK_SIZE": 4, + # }, + { + "IN_WIDTH": 8, + "IN_SIZE": 16, + }, + { + "IN_WIDTH": 8, + "IN_SIZE": 4, + }, ], - trace=True, + # sim="questa", + # gui=True ) - - -if __name__ == "__main__": - test_abs_max_tree() diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py index aaa16a70f..dd9dfc16e 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_accumulator_tb.py @@ -14,7 +14,7 @@ ) from mase_cocotb.runner import mase_runner -from utils import mxint_quantize +from utils import mxint_quantize, MxIntAccumulator import torch from math import ceil, log2 @@ -33,7 +33,6 @@ def __init__(self, dut, num=1) -> None: self.num = num if not hasattr(self, "log"): self.log = SimLog("%s" % (type(self).__qualname__)) - self.data_in_0_driver = MultiSignalStreamDriver( dut.clk, (dut.mdata_in_0, dut.edata_in_0), @@ -47,14 +46,20 @@ def __init__(self, dut, num=1) -> None: dut.data_out_0_ready, check=True, ) + self.input_drivers = {"in0": self.data_in_0_driver} + self.output_monitors = {"out": self.data_out_0_monitor} def generate_inputs(self): from utils import block_mxint_quant, pack_tensor_to_mx_listed_chunk from utils import mxint_quantize from math import ceil, log2 - data_in = 20 * torch.rand( - self.get_parameter("IN_DEPTH"), self.get_parameter("BLOCK_SIZE") + data_in = ( + 20 + * torch.rand( + self.get_parameter("IN_DEPTH"), self.get_parameter("BLOCK_SIZE") + ) + - 20 ) config = { "width": self.get_parameter("DATA_IN_0_PRECISION_0"), @@ -62,15 +67,11 @@ def generate_inputs(self): } parallelism = [1, self.get_parameter("BLOCK_SIZE")] (qtensor, mtensor, etensor) = block_mxint_quant(data_in, config, parallelism) - - qout, mout, eout = mxint_quantize( - qtensor.sum(dim=0), - width=config["width"] - + 2 ** config["exponent_width"] - + ceil(log2(self.get_parameter("IN_DEPTH"))), - exponent_width=config["exponent_width"], - exponent=int(etensor.min()), + mtensor = mtensor.reshape( + self.get_parameter("IN_DEPTH"), self.get_parameter("BLOCK_SIZE") ) + etensor = etensor.reshape(self.get_parameter("IN_DEPTH")) + mout, eout = MxIntAccumulator(mtensor, etensor) tensor_inputs = pack_tensor_to_mx_listed_chunk(mtensor, etensor, parallelism) exp_outs = [(mout.int().tolist(), int(eout))] @@ -94,10 +95,27 @@ async def run_test(self, samples, us): assert self.data_out_0_monitor.exp_queue.empty() +async def check_signal(dut): + await Timer(40, units="ns") + while True: + await RisingEdge(dut.clk) + await ReadOnly() + if dut.data_in_0_valid.value == 1 and dut.data_in_0_valid.value == 1: + print( + "data_in_0 = ", [x.signed_integer for x in dut.shifted_mdata_in_0.value] + ) + print( + "data_out_0 = ", + [x.signed_integer for x in dut.shifted_mdata_out_0.value], + ) + print("end") + + # @cocotb.test() # async def test(dut): # tb = MXIntAccumulatorTB(dut, 1) -# await tb.run_test(samples=20, us=5) +# cocotb.start_soon(check_signal(dut)) +# await tb.run_test(samples=10, us=5) # @cocotb.test() # async def single_mult(dut): @@ -131,5 +149,18 @@ async def repeated_mult_valid_backpressure(dut): "BLOCK_SIZE": 1, "IN_DEPTH": 1, }, + # { + # "DATA_IN_0_PRECISION_0": 8, + # "DATA_IN_0_PRECISION_1": 4, + # "BLOCK_SIZE": 4, + # "IN_DEPTH": 1, + # }, + # { + # "DATA_IN_0_PRECISION_0": 8, + # "DATA_IN_0_PRECISION_1": 4, + # "BLOCK_SIZE": 4, + # "IN_DEPTH": 4, + # }, ], + # sim="questa", ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py index 963ae40df..90687e918 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_cast_tb.py @@ -14,6 +14,7 @@ ) from mase_cocotb.runner import mase_runner from utils import mxint_quantize +from utils import MxIntCast import torch @@ -21,7 +22,7 @@ logger.setLevel(logging.DEBUG) -class MXINTVectorMultTB(Testbench): +class MxIntCastTB(Testbench): def __init__(self, dut, num) -> None: super().__init__(dut, dut.clk, dut.rst) self.num = num @@ -42,6 +43,8 @@ def __init__(self, dut, num) -> None: dut.data_out_ready, check=True, ) + self.input_drivers = {"in0": self.data_in_0_driver} + self.output_monitors = {"out": self.data_out_0_monitor} self.data_in_0_driver.log.setLevel(logging.DEBUG) self.data_out_0_monitor.log.setLevel(logging.DEBUG) @@ -55,12 +58,17 @@ def generate_inputs(self): int(self.dut.IN_MAN_WIDTH), int(self.dut.IN_EXP_WIDTH), ) - exp_out, mexp_out, eexp_out = mxint_quantize( - data_in, - int(self.dut.OUT_MAN_WIDTH), - int(self.dut.OUT_EXP_WIDTH), + mexp_out, eexp_out = MxIntCast( + mdata_in, + edata_in, + { + "in_width": int(self.dut.IN_MAN_WIDTH), + "in_frac_width": int(self.dut.IN_MAN_FRAC_WIDTH), + "in_exponent_width": int(self.dut.IN_EXP_WIDTH), + "out_width": int(self.dut.OUT_MAN_WIDTH), + "out_exponent_width": int(self.dut.OUT_EXP_WIDTH), + }, ) - breakpoint() inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) exp_outputs.append((mexp_out.int().tolist(), eexp_out.int().tolist())) return inputs, exp_outputs @@ -68,7 +76,6 @@ def generate_inputs(self): async def run_test(self): await self.reset() logger.info(f"Reset finished") - self.data_out_0_monitor.ready.value = 1 logger.info(f"generating inputs") inputs, exp_outputs = self.generate_inputs() @@ -78,23 +85,56 @@ async def run_test(self): # Load the output monitor self.data_out_0_monitor.load_monitor(exp_outputs) - - await Timer(5, units="us") + await Timer(1, units="us") assert self.data_out_0_monitor.exp_queue.empty() @cocotb.test() async def test(dut): - tb = MXINTVectorMultTB(dut, num=1) + # cocotb.start_soon(check_signal(dut)) + tb = MxIntCastTB(dut, num=1) await tb.run_test() +async def check_signal(dut): + num = {"data_out_0": 0, "data_in_0": 0} + await Timer(40, units="ns") + while True: + await RisingEdge(dut.clk) + await ReadOnly() + if dut.data_out_valid.value == 1 and dut.data_out_ready.value == 1: + print(dut.edata_out_full) + print(dut.log2_max_value) + print(dut.ebuffer_data_for_out) + shift = dut.ovshift_inst + print(shift.SHIFT_WIDTH.value) + print(shift.OUT_WIDTH.value) + print(shift.shift_value.value.signed_integer) + print(shift.abs_shift_value.value.signed_integer) + # print(shift.shift_data.value.signed_integer) + print([x for x in shift.shift_data.value]) + # print(dut.data_for_max_ready.value) + # print(dut.data_for_out_valid.value) + # print(dut.data_for_out_ready.value) + print("end") + # print(dut.max_bas_i.or_tree_i.gen_adder_tree.level[0].register_slice.data_out_ready) + # print(dut.max_bas_i.or_tree_i.gen_adder_tree.level[0].register_slice.data_in_valid) + # print(dut.max_bas_i.or_tree_i.gen_adder_tree.level[0].register_slice.shift_reg) + # print(dut.max_bas_i.or_tree_i.data_in_ready) + # print(dut.max_bas_i.data_out_ready) + # print(dut.store_the_data.ff_inst.data_in_ready) + # print(dut.store_the_data.ff_inst.data_out_ready) + # print(dut.max_bas_i.or_tree_i.data_out_valid) + # print("end") + + if __name__ == "__main__": mase_runner( trace=True, module_param_list=[ # { # "IN_MAN_WIDTH": 6, + # "IN_MAN_FRAC_WIDTH": 5, # "IN_EXP_WIDTH": 3, # "OUT_MAN_WIDTH": 12, # "OUT_EXP_WIDTH": 4, @@ -109,10 +149,11 @@ async def test(dut): # }, { "IN_MAN_WIDTH": 8, + "IN_MAN_FRAC_WIDTH": 7, "IN_EXP_WIDTH": 4, - "OUT_MAN_WIDTH": 49, + "OUT_MAN_WIDTH": 16, "OUT_EXP_WIDTH": 5, - "BLOCK_SIZE": 4, + "BLOCK_SIZE": 1, }, # { # "IN_MAN_WIDTH": 12, @@ -122,4 +163,6 @@ async def test(dut): # "BLOCK_SIZE": 4, # }, ], + # sim="questa", + # gui=True ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py index 43e70adcf..062596565 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_dot_product_tb.py @@ -76,9 +76,6 @@ def generate_inputs(self): inputs.append((mdata_in.int().tolist(), edata_in.int().tolist())) weights.append((mweight.int().tolist(), eweight.int().tolist())) exp_outputs.append((mdp_out.int().tolist(), edp_out.int().tolist())) - print(inputs) - print(weights) - print(exp_outputs) return inputs, weights, exp_outputs async def run_test(self): @@ -118,4 +115,5 @@ async def test(dut): "BLOCK_SIZE": 4, }, ], + # sim="questa", ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py index 036c5a2e3..c07d0363f 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_linear_tb.py @@ -8,7 +8,7 @@ import cocotb from cocotb.log import SimLog -from cocotb.triggers import Timer, RisingEdge +from cocotb.triggers import Timer, RisingEdge, ReadOnly from mase_cocotb.testbench import Testbench from mase_cocotb.interfaces.streaming import ( @@ -19,7 +19,7 @@ torch.manual_seed(0) # from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner -from utils import MXIntLinear +from utils import MXIntLinear, MXIntLinearHardware class LinearTB(Testbench): @@ -40,11 +40,16 @@ def __init__(self, dut) -> None: dut.clk, (dut.mweight, dut.eweight), dut.weight_valid, dut.weight_ready ) + self.input_drivers = { + "a": self.data_in_0_driver, + "b": self.weight_driver, + } if self.get_parameter("HAS_BIAS") == 1: self.bias_driver = MultiSignalStreamDriver( dut.clk, (dut.mbias, dut.ebias), dut.bias_valid, dut.bias_ready ) self.bias_driver.log.setLevel(logging.DEBUG) + self.input_drivers["bias"] = self.bias_driver self.data_out_0_monitor = MultiSignalStreamMonitor( dut.clk, @@ -54,40 +59,37 @@ def __init__(self, dut) -> None: check=True, ) + self.output_monitors = {"out": self.data_out_0_monitor} # Model - self.model = MXIntLinear( + self.model = MXIntLinearHardware( in_features=self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"), out_features=self.get_parameter("DATA_OUT_0_TENSOR_SIZE_DIM_0"), bias=True if self.get_parameter("HAS_BIAS") == 1 else False, config={ "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"), "data_in_exponent_width": self.get_parameter("DATA_IN_0_PRECISION_1"), - "data_in_parallelism_dim_1": self.get_parameter( - "DATA_IN_0_PARALLELISM_DIM_1" - ), - "data_in_parallelism_dim_0": self.get_parameter( - "DATA_IN_0_PARALLELISM_DIM_0" - ), + "data_in_parallelism": [ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], "weight_width": self.get_parameter("WEIGHT_PRECISION_0"), "weight_exponent_width": self.get_parameter("WEIGHT_PRECISION_1"), - "weight_parallelism_dim_1": self.get_parameter( - "WEIGHT_PARALLELISM_DIM_1" - ), - "weight_parallelism_dim_0": self.get_parameter( - "WEIGHT_PARALLELISM_DIM_0" - ), + "weight_parallelism": [ + self.get_parameter("WEIGHT_PARALLELISM_DIM_1"), + self.get_parameter("WEIGHT_PARALLELISM_DIM_0"), + ], "bias_width": self.get_parameter("BIAS_PRECISION_0"), "bias_exponent_width": self.get_parameter("BIAS_PRECISION_1"), - "bias_parallelism_dim_1": self.get_parameter("BIAS_PARALLELISM_DIM_1"), - "bias_parallelism_dim_0": self.get_parameter("BIAS_PARALLELISM_DIM_0"), + "bias_parallelism": [ + self.get_parameter("BIAS_PARALLELISM_DIM_1"), + self.get_parameter("BIAS_PARALLELISM_DIM_0"), + ], "data_out_width": self.get_parameter("DATA_OUT_0_PRECISION_0"), "data_out_exponent_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), - "data_out_parallelism_dim_1": self.get_parameter( - "DATA_OUT_0_PARALLELISM_DIM_1" - ), - "data_out_parallelism_dim_0": self.get_parameter( - "DATA_OUT_0_PARALLELISM_DIM_0" - ), + "data_out_parallelism": [ + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), + ], }, ) @@ -182,7 +184,6 @@ async def run_test(self, us): self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), ], ) - breakpoint() self.data_out_0_monitor.load_monitor(outs) await Timer(us, units="us") @@ -192,9 +193,30 @@ async def run_test(self, us): @cocotb.test() async def cocotb_test(dut): tb = LinearTB(dut) + cocotb.start_soon(check_signal(dut)) await tb.run_test(us=100) +async def check_signal(dut): + await Timer(40, units="ns") + while True: + await RisingEdge(dut.clk) + await ReadOnly() + if ( + dut.cast_data_out_0_valid.value == 1 + and dut.cast_data_out_0_ready.value == 1 + ): + shift = dut.bias_cast.ovshift_inst + print(shift.SHIFT_WIDTH.value) + print(shift.OUT_WIDTH.value) + print(shift.shift_value.value.signed_integer) + print(shift.abs_shift_value.value.signed_integer) + print("data_in = ", [x.signed_integer for x in shift.data_in.value]) + print("data_out = ", [x.signed_integer for x in shift.data_out.value]) + # print("edata_out = ",dut.acc_edata_out.value.signed_integer) + # print("end") + + def get_fixed_linear_config(kwargs={}): # if pretranspose # weight1 = in0 @@ -230,7 +252,7 @@ def get_fixed_linear_config(kwargs={}): "WEIGHT_TENSOR_SIZE_DIM_1": 16, "WEIGHT_PARALLELISM_DIM_0": 4, "WEIGHT_PARALLELISM_DIM_1": 4, - "DATA_IN_0_PRECISION_0": 9, + "DATA_IN_0_PRECISION_0": 10, "DATA_IN_0_PRECISION_1": 4, "WEIGHT_PRECISION_0": 8, "WEIGHT_PRECISION_1": 3, @@ -250,7 +272,6 @@ def test_fixed_linear_smoke(): """ mase_runner( trace=True, - extra_build_args=["--trace-depth", "8"], module_param_list=[ get_fixed_linear_config(), # noticed here if change WEIGHT_PRE_TRANSPOSED also need to change the DIM_SIZE to match ACTIVATION @@ -263,6 +284,8 @@ def test_fixed_linear_smoke(): # }, # ), ], + # sim="questa", + # gui=True, ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py index f31d1ab61..322ce3d41 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py +++ b/src/mase_components/linear_layers/mxint_operators/test/mxint_matmul_tb.py @@ -14,7 +14,7 @@ ) from mase_cocotb.runner import mase_runner -from utils import block_mxint_quant +from utils import block_mxint_quant, MXIntMatmulHardware from mase_cocotb.matrix_tools import gen_random_matrix_input, matrix_mult_model from mase_cocotb.utils import bit_driver @@ -67,6 +67,12 @@ def __init__(self, dut) -> None: dut.out_ready, check=True, ) + self.input_drivers = { + "a": self.a_driver, + "b": self.b_driver, + } + self.output_monitors = {"out": self.output_monitor} + self.output_monitor.log.setLevel(logging.DEBUG) def generate_inputs(self): for _ in range(self.num): @@ -98,18 +104,43 @@ def generate_inputs(self): self.get_parameter("B_COMPUTE_DIM0"), ], ) - matmul_out = qa @ qb + self.log.debug(f"hardware_out = {ma @ mb}") - (qout, mout, eout) = block_mxint_quant( - matmul_out, - q_config={ + (mout, eout) = MXIntMatmulHardware( + ma, + ea, + mb, + eb, + { + "width": self.get_parameter("A_MAN_WIDTH"), + "exponent_width": self.get_parameter("A_EXP_WIDTH"), + "parallism_dim_0": self.get_parameter("A_COMPUTE_DIM0"), + "parallism_dim_1": self.get_parameter("A_COMPUTE_DIM1"), + "depth_dim_0": self.get_parameter("A_DEPTH_DIM0"), + "depth_dim_1": self.get_parameter("A_DEPTH_DIM1"), + "dim_0": self.get_parameter("A_TOTAL_DIM0"), + "dim_1": self.get_parameter("A_TOTAL_DIM1"), + }, + { + "width": self.get_parameter("B_MAN_WIDTH"), + "exponent_width": self.get_parameter("B_EXP_WIDTH"), + "parallism_dim_0": self.get_parameter("B_COMPUTE_DIM0"), + "parallism_dim_1": self.get_parameter("B_COMPUTE_DIM1"), + "depth_dim_0": self.get_parameter("B_DEPTH_DIM0"), + "depth_dim_1": self.get_parameter("B_DEPTH_DIM1"), + "dim_0": self.get_parameter("B_TOTAL_DIM0"), + "dim_1": self.get_parameter("B_TOTAL_DIM1"), + }, + { "width": self.get_parameter("OUT_MAN_WIDTH"), "exponent_width": self.get_parameter("OUT_EXP_WIDTH"), + "parallism_dim_0": self.get_parameter("C_COMPUTE_DIM0"), + "parallism_dim_1": self.get_parameter("C_COMPUTE_DIM1"), + "depth_dim_0": self.get_parameter("C_DEPTH_DIM0"), + "depth_dim_1": self.get_parameter("C_DEPTH_DIM1"), + "dim_0": self.get_parameter("C_TOTAL_DIM0"), + "dim_1": self.get_parameter("C_TOTAL_DIM1"), }, - parallelism=[ - self.get_parameter("C_COMPUTE_DIM1"), - self.get_parameter("C_COMPUTE_DIM0"), - ], ) from utils import pack_tensor_to_mx_listed_chunk @@ -166,11 +197,11 @@ async def run_test(self, batches, us): # await tb.run_test(batches=1, us=100) -# @cocotb.test() -# async def repeated_mult(dut): -# tb = MXIntMatmulTB(dut) -# tb.output_monitor.ready.value = 1 -# await tb.run_test(batches=1000, us=2000) +@cocotb.test() +async def repeated_mult(dut): + tb = MXIntMatmulTB(dut) + tb.output_monitor.ready.value = 1 + await tb.run_test(batches=20, us=20) # @cocotb.test() @@ -180,13 +211,13 @@ async def run_test(self, batches, us): # await tb.run_test(batches=500, us=2000) -@cocotb.test() -async def repeated_mult_valid_backpressure(dut): - tb = MXIntMatmulTB(dut) - tb.a_driver.set_valid_prob(0.7) - tb.b_driver.set_valid_prob(0.7) - cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.6)) - await tb.run_test(batches=20, us=200) +# @cocotb.test() +# async def repeated_mult_valid_backpressure(dut): +# tb = MXIntMatmulTB(dut) +# tb.a_driver.set_valid_prob(0.7) +# tb.b_driver.set_valid_prob(0.7) +# cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.6)) +# await tb.run_test(batches=20, us=200) def gen_random_dimensions(): @@ -236,19 +267,18 @@ def test_matmul(): "A_COMPUTE_DIM1": 2, "B_COMPUTE_DIM0": 2, "B_COMPUTE_DIM1": 2, # Must equal A_COMPUTE_DIM0 - "A_MAN_WIDTH": 8, - "A_EXP_WIDTH": 3, - "B_MAN_WIDTH": 8, - "B_EXP_WIDTH": 3, - "OUT_MAN_WIDTH": 8, - "OUT_EXP_WIDTH": 3, + "A_MAN_WIDTH": 4, + "A_EXP_WIDTH": 8, + "B_MAN_WIDTH": 4, + "B_EXP_WIDTH": 8, + "OUT_MAN_WIDTH": 4, + "OUT_EXP_WIDTH": 8, } mase_runner( module_param_list=[ # Default Square - DEFAULT_CONFIG, - # + # DEFAULT_CONFIG, # { # **DEFAULT_CONFIG, # "A_MAN_WIDTH": 9, @@ -258,7 +288,7 @@ def test_matmul(): # "OUT_MAN_WIDTH": 12, # "OUT_EXP_WIDTH": 4, # }, - # # Long Rectangle, should saturate many values + # Long Rectangle, should saturate many values { **DEFAULT_CONFIG, "A_TOTAL_DIM0": 16, @@ -270,20 +300,20 @@ def test_matmul(): "B_COMPUTE_DIM0": 4, "B_COMPUTE_DIM1": 4, # Must equal A_COMPUTE_DIM0 }, - # # # Change window to full size - { - **DEFAULT_CONFIG, - "A_COMPUTE_DIM0": 4, - "A_COMPUTE_DIM1": 4, - "B_COMPUTE_DIM0": 4, - "B_COMPUTE_DIM1": 4, - }, + # Change window to full size + # { + # **DEFAULT_CONFIG, + # "A_COMPUTE_DIM0": 4, + # "A_COMPUTE_DIM1": 4, + # "B_COMPUTE_DIM0": 4, + # "B_COMPUTE_DIM1": 4, + # }, # # Dimensions # *generate_random_dimension_cfg([DEFAULT_CONFIG]), ], trace=True, jobs=12, - extra_build_args=["--trace-depth", "5"], + # sim="questa", ) diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.ipynb b/src/mase_components/linear_layers/mxint_operators/test/test.ipynb deleted file mode 100644 index fabc4c551..000000000 --- a/src/mase_components/linear_layers/mxint_operators/test/test.ipynb +++ /dev/null @@ -1,52 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "#!/usr/bin/env python3\n", - "\n", - "# This script tests the fixed point linear\n", - "import os, logging\n", - "\n", - "import cocotb\n", - "from cocotb.log import SimLog\n", - "from cocotb.triggers import *\n", - "\n", - "from mase_cocotb.testbench import Testbench\n", - "from mase_cocotb.interfaces.streaming import (\n", - " MultiSignalStreamDriver,\n", - " MultiSignalStreamMonitor,\n", - ")\n", - "\n", - "from mase_cocotb.runner import mase_runner\n", - "from utils import mxint_quantize\n", - "\n", - "import torch\n", - "from math import ceil, log2\n", - "import random\n", - "\n", - "logger = logging.getLogger(\"testbench\")\n", - "logger.setLevel(logging.DEBUG)\n", - "\n", - "torch.manual_seed(10)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mase", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/src/mase_components/linear_layers/mxint_operators/test/test.py b/src/mase_components/linear_layers/mxint_operators/test/test.py deleted file mode 100644 index f58f382cf..000000000 --- a/src/mase_components/linear_layers/mxint_operators/test/test.py +++ /dev/null @@ -1,64 +0,0 @@ -from utils import mxint_quantize -import torch - -block_size = 10 -torch.manual_seed(0) -data = torch.rand(10) -w = torch.rand(10, 10) -d_man_width = 12 -w_man_width = 8 -e_width = 4 -(data_in, mdata_in, edata_in) = mxint_quantize( - data, - d_man_width, - e_width, -) -(weight, mweight, eweight) = mxint_quantize( - w, - w_man_width, - e_width, -) -linear = torch.nn.Linear(10, 10, bias=False) -linear.weight = torch.nn.Parameter(weight) -target_x = linear(data_in) -linear.weight = torch.nn.Parameter(mweight) -hardware_out = linear(mdata_in) -print(hardware_out * (2 ** (edata_in + eweight))) -# software knows -print(target_x) - - -# hardware quant back -def hardware_quant(hardware_in, be_value, e_width, width): - from math import ceil, log2 - - result = ceil(log2(max(hardware_in))) - exponent_bias = 2 ** (e_width - 1) - 1 - - # exponent - exponent_max = 2**e_width - 1 - exponent_bias - exponent_min = -exponent_bias - exponent = ( - torch.ceil(torch.log2(hardware_in.abs().max())) + be_value - exponent_bias - ) - exponent = torch.clamp(exponent, exponent_min, exponent_max) - int_min = -(2 ** (width - 1)) - int_max = 2 ** (width - 1) - 1 - mantissa = hardware_in / 2 ** (exponent - be_value) - breakpoint() - mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - - msfp_x = (2**exponent) * mantissa - return msfp_x, mantissa, exponent - - -new_man_width = 8 -new_e_width = 4 -qout, qmout, qeout = hardware_quant( - hardware_out, (edata_in + eweight), new_e_width, new_man_width -) -out, mout, eout = mxint_quantize(target_x, new_man_width, new_e_width) -breakpoint() -# def hardware_quant_back(): -# hardware_out.max().log2()+ hardware_exp -# clamp((log2(max(hardware_out))+hardware_exp),target_width) diff --git a/src/mase_components/linear_layers/mxint_operators/test/test_lint_fixed_arithmetic.py b/src/mase_components/linear_layers/mxint_operators/test/test_lint_fixed_arithmetic.py deleted file mode 100644 index b73af101e..000000000 --- a/src/mase_components/linear_layers/mxint_operators/test/test_lint_fixed_arithmetic.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from mase_components.linter import run_lint - - -@pytest.mark.dev -def test_lint_fixed_arithmetic(): - run_lint("fixed_arithmetic") - - -if __name__ == "__main__": - test_lint_fixed_arithmetic() diff --git a/src/mase_components/linear_layers/mxint_operators/test/test_synth_fixed_arithmetic.py b/src/mase_components/linear_layers/mxint_operators/test/test_synth_fixed_arithmetic.py deleted file mode 100644 index 19ea87d12..000000000 --- a/src/mase_components/linear_layers/mxint_operators/test/test_synth_fixed_arithmetic.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from mase_components.synth_runner import run_synth - - -@pytest.mark.vivado -def test_synth_fixed_arithmetic(): - run_synth("fixed_arithmetic") - - -if __name__ == "__main__": - test_synth_fixed_arithmetic() diff --git a/src/mase_components/linear_layers/mxint_operators/test/utils.py b/src/mase_components/linear_layers/mxint_operators/test/utils.py index 7edb9f6ed..0d5e0087b 100644 --- a/src/mase_components/linear_layers/mxint_operators/test/utils.py +++ b/src/mase_components/linear_layers/mxint_operators/test/utils.py @@ -29,18 +29,20 @@ def mxint_quantize(x, width: int = 12, exponent_width: int = 6, exponent: int = # exponent if exponent == None: - exponent = torch.ceil(torch.log2(x.abs().max())) - exponent_bias + exponent = torch.ceil(torch.log2(x.abs().max())) exponent = torch.clamp(exponent, exponent_min, exponent_max) # mantissa int_min = -(2 ** (width - 1)) int_max = 2 ** (width - 1) - 1 - mantissa = x / 2**exponent + mantissa = x * (2 ** (width - 1)) / 2**exponent + # print(mantissa, int_min, int_max) mantissa = torch.clamp(mantissa.floor(), int_min, int_max) - msfp_x = (2**exponent) * mantissa - return msfp_x, mantissa, exponent + q_x = (2**exponent) * mantissa / ((2 ** (width - 1))) + return q_x, mantissa, exponent def block_mxint_quant(tensor, q_config, parallelism): + print(tensor) original_shape = tensor.shape if len(tensor.shape) == 1: tensor = tensor.unsqueeze(0) @@ -175,14 +177,459 @@ def forward(self, x: Tensor) -> Tensor: else: x, mx, ex = self.x_quantizer(x) w, mw, ew = self.w_quantizer(self.weight) - print((mx @ mw.transpose(0, 1)).int()) if self.bias is not None: bias, mb, eb = self.b_quantizer(self.bias) else: bias = None - breakpoint() out = F.linear(x, w, bias) # print(f"mout = {F.linear(mx, mw, mb*2**(ex+ew - eb).floor())}") if self.out_quantizer is None: return out return self.out_quantizer(out) + + +class MXIntLinearHardware(_LinearBase): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + config=None, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype) + assert config is not None, "config is None!" + self.in_features = in_features + self.out_features = out_features + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + # establish quantizer + w_width, w_exponent_width = ( + config["weight_width"], + config["weight_exponent_width"], + ) + w_p1, w_p0 = ( + config["weight_parallelism"][0], + config["weight_parallelism"][1], + ) + x_width, x_exponent_width = ( + config["data_in_width"], + config["data_in_exponent_width"], + ) + x_p1, x_p0 = ( + config["data_in_parallelism"][0], + config["data_in_parallelism"][1], + ) + # check bias quantizer, if not, use weight quantizer + b_width, b_exponent_width = config["bias_width"], config["bias_exponent_width"] + b_p1, b_p0 = ( + config["bias_parallelism"][0], + config["bias_parallelism"][1], + ) + base_quantizer = block_mxint_quant + out_width, out_exponent_width = ( + config["data_out_width"], + config["data_out_exponent_width"], + ) + out_p1, out_p0 = ( + config["data_out_parallelism"][0], + config["data_out_parallelism"][1], + ) + self.out_quantizer = partial( + base_quantizer, + q_config={"width": out_width, "exponent_width": out_exponent_width}, + parallelism=[out_p1, out_p0], + ) + self.w_quantizer = partial( + base_quantizer, + q_config={"width": w_width, "exponent_width": w_exponent_width}, + parallelism=[w_p1, w_p0], + ) + self.x_quantizer = partial( + base_quantizer, + q_config={"width": x_width, "exponent_width": x_exponent_width}, + parallelism=[x_p1, x_p0], + ) + self.b_quantizer = partial( + base_quantizer, + q_config={"width": b_width, "exponent_width": b_exponent_width}, + parallelism=[b_p1, b_p0], + ) + + def forward(self, x: Tensor) -> Tensor: + print(x) + x, mx, ex = self.x_quantizer(x) + in_x = (mx, ex) + w, mw, ew = self.w_quantizer(self.weight) + in_w = (mw, ew) + if self.bias is not None: + bias, mbias, ebias = self.b_quantizer(self.bias) + in_bias = (mbias, ebias) + else: + bias = None + in_bias = None + + out = wrapped_mxint_linear_hardware( + in_x, in_w, in_bias, self.in_features, self.out_features, self.config + ) + + return out + + +def wrapped_mxint_linear_hardware(x, w, bias, in_features, out_features, config): + mx = x[0] + n = mx.reshape(-1, in_features).shape[0] + in_config = { + "x_config": { + "width": config["data_in_width"], + "exponent_width": config["data_in_exponent_width"], + "parallism_dim_0": config["data_in_parallelism"][1], + "parallism_dim_1": config["data_in_parallelism"][0], + "depth_dim_0": in_features // config["data_in_parallelism"][1], + "depth_dim_1": n // config["data_in_parallelism"][0], + "dim_0": in_features, + "dim_1": n, + }, + "w_config": { + "width": config["weight_width"], + "exponent_width": config["weight_exponent_width"], + "parallism_dim_0": config["weight_parallelism"][1], + "parallism_dim_1": config["weight_parallelism"][0], + "depth_dim_0": in_features // config["weight_parallelism"][1], + "depth_dim_1": out_features // config["weight_parallelism"][0], + "dim_0": in_features, + "dim_1": out_features, + }, + "bias_config": { + "width": config["bias_width"], + "exponent_width": config["bias_exponent_width"], + "parallism_dim_0": config["bias_parallelism"][1], + "parallism_dim_1": 1, + "depth_dim_0": out_features // config["bias_parallelism"][1], + "depth_dim_1": 1, + "dim_0": out_features, + "dim_1": 1, + }, + "out_config": { + "width": config["data_out_width"], + "exponent_width": config["data_out_exponent_width"], + "parallism_dim_0": config["data_out_parallelism"][1], + "parallism_dim_1": config["data_out_parallelism"][0], + "depth_dim_0": out_features // config["data_out_parallelism"][1], + "depth_dim_1": n // config["data_out_parallelism"][0], + "dim_0": out_features, + "dim_1": n, + }, + } + mout, eout = mxint_linear_hardware(x, w, bias, in_config) + out_config = in_config["out_config"] + reshaped_mout = mout.reshape( + out_config["depth_dim_1"], + out_config["parallism_dim_1"], + out_config["depth_dim_0"], + out_config["parallism_dim_0"], + ).permute(0, 2, 1, 3) + reshaped_out = reshaped_mout * 2 ** ( + eout[:, :, None, None] - config["data_out_width"] + 1 + ) + out = reshaped_out.reshape( + out_config["depth_dim_1"], + out_config["depth_dim_0"], + out_config["parallism_dim_1"], + out_config["parallism_dim_0"], + ).permute(0, 2, 1, 3) + out = out.reshape(out_config["dim_1"], out_config["dim_0"]) + + return out + + +def mxint_linear_hardware(x, w, bias, config): + """ + assume 2 dimensional input + config = { + "x_config":{ + "width": , + "exponent_width" , + "parallism_dim_0", + "parallism_dim_1", + "depth_dim_0", + "depth_dim_1", + "dim_0", + "dim_1", + }, + "w_config": { + ... + }, + "bias_config": { + ... + }, + "out_config": { + ... + }, + } + """ + mx, ex = x + mw, ew = w + x_config = config["x_config"] + w_config = config["w_config"] + out_config = config["out_config"] + from math import ceil, log2 + + def DotProductCore(man_x, exp_x, man_y, exp_y): + return man_x @ man_y.transpose(0, 1), exp_x + exp_y + + def block_wise_reshape_tensor(x, x_config): + reshaped_x = x.reshape( + x_config["depth_dim_1"], + x_config["parallism_dim_1"], + x_config["depth_dim_0"], + x_config["parallism_dim_0"], + ).permute(0, 2, 1, 3) + reshaped_x = reshaped_x.reshape( + x_config["depth_dim_1"] * x_config["depth_dim_0"], + x_config["parallism_dim_1"], + x_config["parallism_dim_0"], + ) + return reshaped_x + + # assume 2 dimensional input + assert ( + x_config["depth_dim_0"] == w_config["depth_dim_0"] + ), "need to check the setting of dim" + assert ( + x_config["parallism_dim_0"] == w_config["parallism_dim_0"] + ), "need to check the setting of dim" + reshaped_ex = ex.reshape(-1) + reshaped_mx = block_wise_reshape_tensor(mx, x_config) + reshaped_ew = ew.reshape(-1) + reshaped_mw = block_wise_reshape_tensor(mw, w_config) + man_out = torch.zeros( + x_config["depth_dim_1"], + w_config["depth_dim_1"], + x_config["parallism_dim_1"] * w_config["parallism_dim_1"], + ) + exp_out = torch.zeros(x_config["depth_dim_1"], w_config["depth_dim_1"]) + for i in range(x_config["depth_dim_1"]): + for j in range(w_config["depth_dim_1"]): + partial_man_out = torch.zeros( + w_config["depth_dim_0"], + x_config["parallism_dim_1"], + w_config["parallism_dim_1"], + ) + partial_exp_out = torch.zeros(w_config["depth_dim_0"]) + for k in range(x_config["depth_dim_0"]): + mx_block = reshaped_mx[i * x_config["depth_dim_0"] + k] + ex_block = reshaped_ex[i * x_config["depth_dim_0"] + k] + mw_block = reshaped_mw[j * w_config["depth_dim_0"] + k] + ew_block = reshaped_ew[j * w_config["depth_dim_0"] + k] + partial_man_out[k], partial_exp_out[k] = DotProductCore( + mx_block, ex_block, mw_block, ew_block + ) + acc_man_out, acc_exp_out = MxIntAccumulator( + partial_man_out.reshape(w_config["depth_dim_0"], -1), partial_exp_out + ) + if bias != None: + bias_config = config["bias_config"] + mbias, ebias = bias + reshaped_mbias = mbias.reshape( + w_config["depth_dim_1"], w_config["parallism_dim_1"] + ) + reshaped_ebias = ebias.reshape(w_config["depth_dim_1"]) + shifted_value = ( + reshaped_ebias[j] + - acc_exp_out + + x_config["width"] + + w_config["width"] + - 2 + - (bias_config["width"] - 1) + ) + shifted_bias = reshaped_mbias[j].repeat( + x_config["parallism_dim_1"] + ) * 2 ** (shifted_value) + print(reshaped_mbias[j]) + print(shifted_value) + acc_man_out = shifted_bias + acc_man_out + print("shfited_bias", shifted_bias) + man_out[i][j], exp_out[i][j] = MxIntCast( + acc_man_out, + acc_exp_out, + { + "in_width": x_config["width"] + + w_config["width"] + + ceil(log2(x_config["dim_0"])), + "in_frac_width": x_config["width"] + w_config["width"] - 2, + "in_exponent_width": max( + x_config["exponent_width"], w_config["exponent_width"] + ) + + 1, + "out_width": out_config["width"], + "out_exponent_width": out_config["exponent_width"], + }, + ) + man_out = ( + man_out.reshape( + x_config["depth_dim_1"], + w_config["depth_dim_1"], + x_config["parallism_dim_1"], + w_config["parallism_dim_1"], + ) + .permute(0, 2, 1, 3) + .reshape(x_config["dim_1"], w_config["dim_1"]) + ) + return man_out, exp_out + + +def MXIntMatmulHardware(man_x, exp_x, man_y, exp_y, x_config, y_config, out_config): + """ + assume 2 dimensional input + config = { + "width": , + "exponent_width" , + "parallism_dim_0", + "parallism_dim_1", + "depth_dim_0", + "depth_dim_1", + "dim_0", + "dim_1", + } + man.shape = [dim_1 * dim_0] + exp.shape = [depth_dim_1, depth_dim_0] + """ + from math import ceil, log2 + + def MatmulCore(man_x, exp_x, man_y, exp_y): + return man_x @ man_y, exp_x + exp_y + + # assume 2 dimensional input + assert ( + x_config["depth_dim_0"] == y_config["depth_dim_1"] + ), "need to check the setting of dim" + + def block_wise_reshape_tensor(x, x_config): + reshaped_x = x.reshape( + x_config["depth_dim_1"], + x_config["parallism_dim_1"], + x_config["depth_dim_0"], + x_config["parallism_dim_0"], + ).permute(0, 2, 1, 3) + reshaped_x = reshaped_x.reshape( + x_config["depth_dim_1"] * x_config["depth_dim_0"], + x_config["parallism_dim_1"], + x_config["parallism_dim_0"], + ) + return reshaped_x + + reshaped_exp_x = exp_x.reshape(-1) + reshaped_man_x = block_wise_reshape_tensor(man_x, x_config) + reshaped_exp_y = exp_y.reshape(-1) + reshaped_man_y = block_wise_reshape_tensor(man_y, y_config) + man_out = torch.zeros( + x_config["depth_dim_1"], + y_config["depth_dim_0"], + x_config["parallism_dim_1"] * y_config["parallism_dim_0"], + ) + exp_out = torch.zeros(x_config["depth_dim_1"], y_config["depth_dim_0"]) + for i in range(x_config["depth_dim_1"]): + for j in range(y_config["depth_dim_0"]): + partial_man_out = torch.zeros( + y_config["depth_dim_1"], + x_config["parallism_dim_1"], + y_config["parallism_dim_0"], + ) + partial_exp_out = torch.zeros(y_config["depth_dim_1"]) + for k in range(y_config["depth_dim_1"]): + man_x_block = reshaped_man_x[i * x_config["depth_dim_0"] + k] + exp_x_block = reshaped_exp_x[i * x_config["depth_dim_0"] + k] + man_y_block = reshaped_man_y[k * y_config["depth_dim_0"] + j] + exp_y_block = reshaped_exp_y[k * y_config["depth_dim_0"] + j] + partial_man_out[k], partial_exp_out[k] = MatmulCore( + man_x_block, exp_x_block, man_y_block, exp_y_block + ) + acc_man_out, acc_exp_out = MxIntAccumulator( + partial_man_out.reshape(y_config["depth_dim_1"], -1), partial_exp_out + ) + man_out[i][j], exp_out[i][j] = MxIntCast( + acc_man_out, + acc_exp_out, + { + "in_width": x_config["width"] + + y_config["width"] + + ceil(log2(x_config["dim_0"])), + "in_frac_width": x_config["width"] + y_config["width"] - 2, + "in_exponent_width": max( + x_config["exponent_width"], y_config["exponent_width"] + ) + + 1, + "out_width": out_config["width"], + "out_exponent_width": out_config["exponent_width"], + }, + ) + man_out = ( + man_out.reshape( + x_config["depth_dim_1"], + y_config["depth_dim_0"], + x_config["parallism_dim_1"], + x_config["parallism_dim_0"], + ) + .permute(0, 2, 1, 3) + .reshape(x_config["dim_1"], y_config["dim_0"]) + ) + return man_out, exp_out + + +def MxIntCast(man_in, exp_in, param): + # In Man Width + max_in = torch.ceil(torch.log2(man_in.abs().max())) + out_width = param["out_width"] + out_exponent_width = param["out_exponent_width"] + in_width = param["in_width"] + in_frac_width = param["in_frac_width"] + in_exponent_width = param["in_exponent_width"] + + out_exponent_max = 2 ** (out_exponent_width - 1) - 1 + out_exponent_min = -(2 ** (out_exponent_width - 1)) + + out_min = -(2 ** (out_width - 1)) + out_max = 2 ** (out_width - 1) - 1 + lma_in = torch.ceil(torch.log2(man_in.abs().max() + 1e-3)) + out_exp_full = lma_in + exp_in - in_frac_width + out_exp = torch.clamp(out_exp_full, out_exponent_min, out_exponent_max) + out_man = man_in // 2 ** (in_frac_width - exp_in + out_exp - (out_width - 1)) + out_man = torch.clamp(out_man, out_min, out_max) + + return out_man, out_exp + + +# def MxIntAccumulator(man, exp, clamp_width = 15): +# IN_DEPTH, BLOCK_SIZE = man.shape[0],man.shape[1] +# min_exp = torch.Tensor([64]) +# mout = torch.zeros(BLOCK_SIZE) +# out_exp = torch.Tensor([64]) +# for i in range(IN_DEPTH): +# min_exp = exp[i] if exp[i] max_exp else max_exp + mout = mout // 2 ** (max_exp - out_exp) + out_exp = max_exp + shifted_man = man[i] // 2 ** (max_exp - exp[i]) + mout = mout + shifted_man + + return mout, out_exp