Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/chop/nn/quantized/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def linearInteger(
config["data_out_frac_width"],
)

floor = config.get("floor", False)
floor = config.get("floor", True)
base_quantizer = integer_floor_quantizer if floor else integer_quantizer
w_quantizer = partial(base_quantizer, width=w_width, frac_width=w_frac_width)
x_quantizer = partial(base_quantizer, width=x_width, frac_width=x_frac_width)
Expand Down
8 changes: 4 additions & 4 deletions src/chop/nn/quantizers/quantizers_for_hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import torch.nn.functional as F
from torch import Tensor

# from .quantizers import integer_quantizer
from .integer import integer_quantizer, integer_floor_quantizer
from .utils import block, my_clamp, my_round, unblock, my_floor


def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int):
def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int, floor=False):
thresh = 2 ** (width - 1)
scale = 2**frac_width

fixed_point_value = my_clamp(my_round(x.mul(scale)), -thresh, thresh - 1)
base_quantizer = integer_floor_quantizer if floor else integer_quantizer
fixed_point_value = base_quantizer(x, width, frac_width) * scale
fixed_point_value = fixed_point_value.to(torch.int)
fixed_point_value = fixed_point_value % (2**width)
return fixed_point_value
Expand Down
79 changes: 38 additions & 41 deletions src/chop/passes/graph/transforms/verilog/emit_bram.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from chop.passes.graph.utils import vf, v2p, get_module_by_name, init_project
from chop.nn.quantizers import integer_quantizer_for_hw, integer_floor_quantizer_for_hw
from chop.nn.quantizers import integer_quantizer_for_hw

logger = logging.getLogger(__name__)
from pathlib import Path
Expand Down Expand Up @@ -49,7 +49,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"
]
)
out_depth = int((total_size + out_size - 1) / out_size)
out_depth = int(total_size / out_size)
out_width = int(
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["precision"][
0
Expand Down Expand Up @@ -84,7 +84,7 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
logic [DWIDTH-1:0] q0_t1;

initial begin
$readmemh("{data_name}", ram);
$readmemb("{data_name}", ram);
end

assign q0 = q0_t1;
Expand Down Expand Up @@ -119,14 +119,14 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):

`timescale 1ns / 1ps
module {node_param_name}_source #(
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = 32,
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = 1,
parameter {_cap(verilog_param_name)}_PRECISION_0 = 16,
parameter {_cap(verilog_param_name)}_PRECISION_1 = 3,

parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = 1,
parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = 1,
parameter OUT_DEPTH = (({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 + {_cap(verilog_param_name)}_PARALLELISM_DIM_0 - 1) / {_cap(verilog_param_name)}_PARALLELISM_DIM_0) * (({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 + {_cap(verilog_param_name)}_PARALLELISM_DIM_1 - 1) / {_cap(verilog_param_name)}_PARALLELISM_DIM_1)
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = -1,
parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = -1,
parameter {_cap(verilog_param_name)}_PRECISION_0 = -1,
parameter {_cap(verilog_param_name)}_PRECISION_1 = -1,

parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = -1,
parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = -1,
parameter OUT_DEPTH = ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0) * ({_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 / {_cap(verilog_param_name)}_PARALLELISM_DIM_1)
) (
input clk,
input rst,
Expand All @@ -138,7 +138,6 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
// 1-bit wider so IN_DEPTH also fits.
localparam COUNTER_WIDTH = $clog2(OUT_DEPTH);
logic [COUNTER_WIDTH:0] counter;

always_ff @(posedge clk)
if (rst) counter <= 0;
else begin
Expand All @@ -147,7 +146,6 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name):
else counter <= counter + 1;
end
end

logic [1:0] clear;
always_ff @(posedge clk)
if (rst) clear <= 0;
Expand Down Expand Up @@ -205,30 +203,39 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"
]
)
out_depth = int((total_size + out_size - 1) / out_size)
out_depth = int(total_size / out_size)

data_buff = ""
param_data = node.meta["mase"].module.get_parameter(param_name).data
param_meta = node.meta["mase"].parameters["hardware"]["verilog_param"]
if node.meta["mase"].parameters["hardware"]["interface"][verilog_param_name][
"transpose"
]:
raise NotImplementedError("only support linear with not tranposed weight")
else:
assert (
param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1"]
% param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"]
== 0
) and (
param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0"]
% param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"]
== 0
), "The parallesim parameter must be divisible by the tensor size parameter."
param_data = torch.reshape(
param_data,
(
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_OUT_0_SIZE"
],
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_IN_0_DEPTH"
],
node.meta["mase"].parameters["hardware"]["verilog_param"][
"DATA_IN_0_SIZE"
],
-1,
param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1"]
// param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"],
param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1"],
param_meta[f"{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0"]
// param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"],
param_meta[f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0"],
),
)
param_data = torch.transpose(param_data, 0, 1)
param_data = param_data.permute(0, 1, 3, 2, 4)
param_data = torch.flatten(param_data).tolist()

if (
node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"]
== "fixed"
Expand All @@ -240,31 +247,21 @@ def emit_parameters_in_dat_internal(node, param_name, file_name):
"precision"
][1]

if node.meta["mase"].module.config.get("floor", False):
base_quantizer = integer_floor_quantizer_for_hw
else:
base_quantizer = integer_quantizer_for_hw

scale = 2**frac_width
thresh = 2**width
for i in range(0, out_depth):
line_buff = ""
for j in range(0, out_size):
if i * out_size + out_size - 1 - j >= len(param_data):
value = 0
else:
value = param_data[i * out_size + out_size - 1 - j]

# TODO: please clear this up later
value = base_quantizer(torch.tensor(value), width, frac_width).item()
value = param_data[i * out_size + j]
value = integer_quantizer_for_hw(
torch.tensor(value), width, frac_width, floor=True
).item()
value = str(bin(value))
value_bits = value[value.find("0b") + 2 :]
value_bits = "0" * (width - len(value_bits)) + value_bits
assert len(value_bits) == width
value_bits = hex(int(value_bits, 2))
value_bits = value_bits[value_bits.find("0x") + 2 :]
value_bits = "0" * (width // 4 - len(value_bits)) + value_bits
line_buff = value_bits + line_buff
line_buff += value_bits
hex_buff = hex(int(line_buff, 2))

data_buff += line_buff + "\n"
else:
Expand Down
4 changes: 2 additions & 2 deletions src/chop/passes/graph/transforms/verilog/emit_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from pathlib import Path

torch.manual_seed(0)

import cocotb
from mase_cocotb.testbench import Testbench
from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor
Expand Down Expand Up @@ -147,6 +145,7 @@ def load_drivers(self, in_tensors):
self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"),
self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"),
],
floor=True,
)

else:
Expand Down Expand Up @@ -177,6 +176,7 @@ def load_monitors(self, expectation):
self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"),
self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"),
],
floor=True,
)

# Set expectation for each monitor
Expand Down
36 changes: 32 additions & 4 deletions src/mase_cocotb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def product_dict(**kwargs):
yield dict(zip(keys, instance))


def fixed_preprocess_tensor(tensor: Tensor, q_config: dict, parallelism: list) -> list:
def fixed_preprocess_tensor(
tensor: Tensor, q_config: dict, parallelism: list, floor=False
) -> list:
"""Preprocess a tensor before driving it into the DUT.
1. Quantize to requested fixed-point precision using floor rounding.
1. Quantize to requested fixed-point precision.
2. Convert to integer format to be compatible with Cocotb drivers.
3. Split into blocks according to parallelism in each dimension.

Expand All @@ -125,11 +127,13 @@ def fixed_preprocess_tensor(tensor: Tensor, q_config: dict, parallelism: list) -
tensor = tensor.view((-1, tensor.shape[-1]))

# Quantize
quantizer = partial(integer_floor_quantizer, **q_config)
base_quantizer = integer_floor_quantizer if floor else integer_quantizer
quantizer = partial(base_quantizer, **q_config)
q_tensor = quantizer(tensor)

# breakpoint()
# Convert to integer format
q_tensor = (q_tensor * 2 ** q_config["frac_width"]).int()
# q_tensor = signed_to_unsigned(q_tensor, bits=q_config["width"])

# Split into chunks according to parallelism in each dimension
# parallelism[0]: along rows, parallelism[1]: along columns
Expand Down Expand Up @@ -174,3 +178,27 @@ def fixed_cast(val, in_width, in_frac_width, out_width, out_frac_width):
val = val
# val = int(val % (1 << out_width))
return val # << out_frac_width # treat data<out_width, out_frac_width> as data<out_width, 0>


async def check_signal(dut, log, signal_list):
# TODO: support count start
# TODO: support checking signal with different name in valid and ready signal
def handshake_signal_check(
dut, log, signal_base, valid=None, ready=None, count_start: dict = {}
):
data_valid = getattr(dut, f"{signal_base}_valid") if valid is None else valid
data_ready = getattr(dut, f"{signal_base}_ready") if ready is None else ready
data = getattr(dut, signal_base)
svalue = [i.signed_integer for i in data.value]
if data_valid.value & data_ready.value:
count_start[signal_base] = (
count_start[signal_base] + 1
if count_start.get(signal_base) is not None
else " "
)
log.debug(f"handshake {count_start[signal_base]} {signal_base} = {svalue}")

while True:
await RisingEdge(dut.clk)
for signal in signal_list:
handshake_signal_check(dut, log, signal)
11 changes: 6 additions & 5 deletions src/mase_components/cast/rtl/fixed_round.sv
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ module fixed_round #(
logic carry_in, input_sign;
assign input_sign = data_in[IN_WIDTH-1];
assign input_data = (input_sign) ? ~(data_in[IN_WIDTH-2:0] - 1) : data_in[IN_WIDTH-2:0];
/* verilator lint_off SELRANGE */
logic [IN_WIDTH + OUT_FRAC_WIDTH - 1:0] lsb_check;
assign lsb_check = {input_data, {(OUT_FRAC_WIDTH) {1'b0}}};
always_comb begin
lsb_below[2] = (IN_FRAC_WIDTH >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH] : 0;
lsb_below[1] = (IN_FRAC_WIDTH-1 >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-1] : 0;
// lsb_below[0] = (IN_FRAC_WIDTH-2 >= OUT_FRAC_WIDTH) ? |(input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-2:0]): 0;
lsb_below[0] = '0; // to do: fix
lsb_below[2] = (IN_FRAC_WIDTH >= OUT_FRAC_WIDTH) ? lsb_check[IN_FRAC_WIDTH] : 0;
lsb_below[1] = (IN_FRAC_WIDTH - 1 >= OUT_FRAC_WIDTH) ? lsb_check[IN_FRAC_WIDTH-1] : 0;
lsb_below[0] = (IN_FRAC_WIDTH - 2 >= OUT_FRAC_WIDTH) ? |(lsb_check[IN_FRAC_WIDTH-2:0]) : 0;
// lsb_below[0] = '0; // to do: fix
end
always_comb begin
if ((IN_FRAC_WIDTH - OUT_FRAC_WIDTH) >= 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ module fixed_linear #(
/* verilator lint_off UNUSEDPARAM */
parameter HAS_BIAS = 1,
parameter WEIGHTS_PRE_TRANSPOSED = 0,
parameter FIFO = 1,

parameter DATA_IN_0_PRECISION_0 = 16,
parameter DATA_IN_0_PRECISION_1 = 3,
Expand Down Expand Up @@ -113,6 +114,9 @@ module fixed_linear #(
logic add_bias_in_valid;
logic add_bias_in_ready;

logic [DATA_OUT_0_PRECISION_0 - 1:0] rounding_out [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0];
logic rounding_out_valid;
logic rounding_out_ready;
// * Instances
// * ---------------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -223,9 +227,9 @@ module fixed_linear #(
.data_in_valid(add_bias_in_valid),
.data_in_ready(add_bias_in_ready),

.data_out(data_out_0),
.data_out_valid(data_out_0_valid),
.data_out_ready(data_out_0_ready)
.data_out(rounding_out),
.data_out_valid(rounding_out_valid),
.data_out_ready(rounding_out_ready)
);
end

Expand All @@ -245,7 +249,7 @@ module fixed_linear #(

// * Add bias
if (HAS_BIAS == 1) begin
fixed_cast #(
fixed_rounding #(
.IN_SIZE (BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1),
.IN_WIDTH (BIAS_PRECISION_0),
.IN_FRAC_WIDTH (BIAS_PRECISION_1),
Expand Down Expand Up @@ -275,10 +279,45 @@ module fixed_linear #(
.OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1)
) output_cast (
.data_in (matmul_out),
.data_out(data_out_0)
.data_out(rounding_out)
);
assign data_out_0_valid = matmul_out_valid;
assign matmul_out_ready = data_out_0_ready;
assign rounding_out_valid = matmul_out_valid;
assign matmul_out_ready = rounding_out_ready;
end

if (FIFO == 1) begin
localparam FIFO_DEPTH = DATA_OUT_0_TENSOR_SIZE_DIM_0 / DATA_OUT_0_PARALLELISM_DIM_0;

fifo_for_autogen #(
.DATA_IN_0_PRECISION_0(DATA_OUT_0_PRECISION_0), // = 8
.DATA_IN_0_PRECISION_1(DATA_OUT_0_PRECISION_1), // = 4
.DATA_IN_0_TENSOR_SIZE_DIM_0(DATA_OUT_0_TENSOR_SIZE_DIM_0), // = 20
.DATA_IN_0_PARALLELISM_DIM_0(DATA_OUT_0_PARALLELISM_DIM_0), // = 2
.DATA_IN_0_TENSOR_SIZE_DIM_1(DATA_OUT_0_TENSOR_SIZE_DIM_1), // = 4
.DATA_IN_0_PARALLELISM_DIM_1(DATA_OUT_0_PARALLELISM_DIM_1), // = 2
.DEPTH(FIFO_DEPTH),
.DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0),
.DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1),
.DATA_OUT_0_TENSOR_SIZE_DIM_0(DATA_OUT_0_TENSOR_SIZE_DIM_0),
.DATA_OUT_0_PARALLELISM_DIM_0(DATA_OUT_0_PARALLELISM_DIM_0),
.DATA_OUT_0_TENSOR_SIZE_DIM_1(DATA_OUT_0_TENSOR_SIZE_DIM_1),
.DATA_OUT_0_PARALLELISM_DIM_1(DATA_OUT_0_PARALLELISM_DIM_1)
) fifo_1_inst (
.clk(clk),
.rst(rst),

.data_in_0(rounding_out),
.data_in_0_valid(rounding_out_valid),
.data_in_0_ready(rounding_out_ready),
.data_out_0(data_out_0),
.data_out_0_valid(data_out_0_valid),
.data_out_0_ready(data_out_0_ready)
);
end else begin
always_comb begin
data_out_0 = rounding_out;
data_out_0_valid = rounding_out_valid;
rounding_out_ready = data_out_0_ready;
end
end
endmodule
Loading