diff --git a/.gitignore b/.gitignore index 645649057..5c6f4d6cf 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ coverage.xml # Hardware tests build/ +obj_dir/ # Translations *.mo @@ -143,6 +144,16 @@ venv.bak/ # Data files for hardware simulation *.dat +# Vivado +*.xpr +*.jou +*.cache +*.runs +*.srcs +*.hw +vivado_pid* +synth/ + # Generated SV LUT files *_lut.sv diff --git a/Makefile b/Makefile index 4729b121c..5c1b5afc0 100644 --- a/Makefile +++ b/Makefile @@ -9,14 +9,15 @@ VIVADO_AVAILABLE := $(shell command -v vivado 2> /dev/null) ifeq ($(GPU_AVAILABLE),) PLATFORM := cpu else - PLATFORM := cuda + PLATFORM := gpu endif # * Mount Vivado HLS path only if Vivado is available (to avoid path not found errors) +# Include shared folder containing board files etc ifeq ($(VIVADO_AVAILABLE),) DOCKER_RUN_EXTRA_ARGS= else - DOCKER_RUN_EXTRA_ARGS=-v $(vhls):$(vhls) + DOCKER_RUN_EXTRA_ARGS= -v /mnt/applications/:/mnt/applications -v $(vhls):$(vhls) -v /$(USER_PREFIX)/$(shell whoami)/shared:/root/shared endif # * Set docker image according to local flag @@ -42,6 +43,11 @@ hw_test_dir = src/mase_components/ NUM_WORKERS ?= 1 +sw_test_dir = machop/test/ +hw_test_dir = machop/mase_components/ + +NUM_WORKERS ?= 1 + # Make sure the repo is up to date sync: git submodule sync @@ -70,8 +76,8 @@ shell: -w /workspace \ -v /$(USER_PREFIX)/$(shell whoami)/.gitconfig:/root/.gitconfig \ -v /$(USER_PREFIX)/$(shell whoami)/.ssh:/root/.ssh \ - -v /$(USER_PREFIX)/$(shell whoami)/.mase:/root/.mase:z \ - -v $(shell pwd):/workspace:z \ + -v /$(USER_PREFIX)/$(shell whoami)/.mase:/root/.mase \ + -v $(shell pwd):/workspace \ $(DOCKER_RUN_EXTRA_ARGS) \ $(img) /bin/bash diff --git a/docs/source/modules/documentation/tutorials.rst b/docs/source/modules/documentation/tutorials.rst index 926629f65..384bbe9f6 100644 --- a/docs/source/modules/documentation/tutorials.rst +++ b/docs/source/modules/documentation/tutorials.rst @@ -1,6 +1,12 @@ Tutorials ============================= +.. toctree:: + :maxdepth: 1 + :caption: Common Usecases + + tutorials/common/bert_emit + .. toctree:: :maxdepth: 1 :caption: The Command Line System with Machop diff --git a/docs/source/modules/documentation/tutorials/common/bert_emit.md b/docs/source/modules/documentation/tutorials/common/bert_emit.md new file mode 100644 index 000000000..eced9a3ec --- /dev/null +++ b/docs/source/modules/documentation/tutorials/common/bert_emit.md @@ -0,0 +1,2 @@ +# Deploy a BERT model on an FPGA + diff --git a/docs/source/modules/hardware/arithmetic/mac.md b/docs/source/modules/hardware/arithmetic/mac.md index fd5141635..3690115c9 100644 --- a/docs/source/modules/hardware/arithmetic/mac.md +++ b/docs/source/modules/hardware/arithmetic/mac.md @@ -1,3 +1,3 @@ # Multiply-Accumulate (MAC) Unit -![MAC](https://raw.githubusercontent.com/DeepWok/mase/main/docs/source/imgs/hardware/mac.png) \ No newline at end of file +![MAC](https://raw.githubusercontent.com/DeepWok/mase/main/docs/source/imgs/hardware/mac.png) diff --git a/docs/tutorials/emit_verilog_bert.ipynb b/docs/tutorials/emit_verilog_bert.ipynb new file mode 100644 index 000000000..8d366f5ee --- /dev/null +++ b/docs/tutorials/emit_verilog_bert.ipynb @@ -0,0 +1,410 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to generate an FPGA accelerator for a quantized Bert model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "In this tutorial, we'll see how to load a Bert model from the Mase model library, optimize it by quantizing the weights, then emit the SystemVerilog code for a custom dataflow accelerator, ready to be deployed on an Intel or Xilinx FPGA. This involves using generating a computation graph for the model, then invoking several Mase compiler passes. First, we go through this in detail, discussing the steps required. Then, we show how to use the `chop.pipelines` pass managers to encapsulate all this functionality within a single function call. Finally, we'll run the generated [Cocotb](https://www.cocotb.org/) testbench to evaluate the throughput and latency of the emitted accelerator.\n", + "\n", + "This tutorial assumes you have a working Mase environment. Follow the instructions [here](https://deepwok.github.io/mase/modules/documentation/getting_started.html) to get started using Conda or Docker. You will also need a working Questa installation to run the testbench of the accelerator. If you don't have Questa available, you can also use Verilator, however the runtime may be very large." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up a logger\n", + "from chop.tools import get_logger\n", + "\n", + "logger = get_logger(__name__)\n", + "logger.setLevel(\"INFO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import and quantize the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's import the Bert model from Mase's [patched model library](https://github.com/DeepWok/mase/tree/main/src/chop/models). We'll define a small configuration with 3 layers and a hidden size of 96. We'll also define a quantization configuration which specifies the fixed-point precision we want to run the model with." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from chop.models.patched.bert import BertConfig, BertModel\n", + "\n", + "config = BertConfig()\n", + "config.num_hidden_layers = 3\n", + "config.hidden_size = 96\n", + "config.intermediate_size = 384\n", + "\n", + "q_config = {\n", + " \"data_in_width\": 8,\n", + " \"data_in_frac_width\": 3,\n", + " \"weight_width\": 8,\n", + " \"weight_frac_width\": 3,\n", + " \"bias_width\": 8,\n", + " \"bias_frac_width\": 3,\n", + " \"data_out_width\": 8,\n", + " \"data_out_frac_width\": 3,\n", + "}\n", + "\n", + "model = BertModel(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the model is defined, we are ready to quantize it by writing a module-level pass. This simply iterates through the modules in the Pytorch model and replaces the relevant ones with their quantized equivalents. In the Bert model, the relevant modules that need to be quantized are:\n", + "1. The self attention layer\n", + "2. Linear layers\n", + "3. Layer normalization layer\n", + "4. GELU activation layer\n", + "\n", + "You can see that Mase has a library of quantized neural network layers under the `chop.nn.quantized` API. See [here](https://github.com/DeepWok/mase/tree/main/src/chop/nn) for a full reference of the available modules." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "from transformers.activations import GELUActivation\n", + "from chop.models.patched.bert.modeling_bert import BertSelfAttention\n", + "from chop.nn.quantized import (\n", + " BertSelfAttentionInteger,\n", + " LinearInteger,\n", + " LayerNormInteger,\n", + " GELUInteger,\n", + ")\n", + "from chop.passes.graph.utils import deepsetattr\n", + "\n", + "\n", + "def bert_module_level_quantize(model, model_config, q_config):\n", + " for module in model.named_modules():\n", + " if isinstance(module[1], BertSelfAttention):\n", + " new_module = BertSelfAttentionInteger(\n", + " model_config, q_config, output_tensor_only=True\n", + " )\n", + " elif isinstance(module[1], nn.Linear):\n", + " new_module = LinearInteger(\n", + " in_features=module[1].in_features,\n", + " out_features=module[1].out_features,\n", + " bias=module[1].bias is not None,\n", + " config=q_config,\n", + " )\n", + " elif isinstance(module[1], nn.LayerNorm):\n", + " new_module = LayerNormInteger(\n", + " normalized_shape=module[1].normalized_shape,\n", + " eps=module[1].eps,\n", + " config=q_config,\n", + " )\n", + " elif isinstance(module[1], GELUActivation):\n", + " new_module = GELUInteger(config=q_config)\n", + " else:\n", + " continue\n", + " logger.info(f\"Replacing module: {module[0]}\")\n", + " deepsetattr(model, module[0], new_module)\n", + " return model\n", + "\n", + "\n", + "model = bert_module_level_quantize(model, config, q_config)\n", + "logger.info(f\"Quantized BERT model: {model}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Emit SystemVerilog code for the accelerator: step-by-step" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that the model is quantized, we are ready to run Mase's FX compiler flow. This involves extracting a computation graph from the Pytorch model leveraging Pytorch FX (see details [here](https://pytorch.org/docs/stable/fx.html)), then running a few analysis and transformation passes on this graph until it's ready for emitting the Verilog code for the dataflow accelerator. First, we'll do this step-by-step, then we'll see how to automate all these operations with a single function call, using the `chop.pipelines` API. \n", + "\n", + "In either case, we start by generating the computation graph through a process called symbolic tracing. As discussed in the `torch.fx` documentation, this involves running a forward pass of the model using dedicated `fx.Proxy` objects as the arguments, instead of the regular `torch.Tensor`s. These proxies record every operation executed on them, which is then used to generate the computation graph. Each node in the generated graph can be a single sublayer, such as `nn.Linear`, or fine-grained function call such as `torch.matmul`. For the emit verilog flow, we require the graph to be at layer granularity, meaning the internal function calls of each layer are hidden in the graph. To achieve this, we pass a `custom_ops` dictionary to the MaseGraph constructor, which instructs the FX tracer to skip this layer during FX tracing. We also provide the desired implementation for the self attention layer, which is available in the [Mase Components](https://github.com/DeepWok/mase/tree/main/src/mase_components) library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from chop.ir import MaseGraph\n", + "from mase_components import get_module_dependencies\n", + "\n", + "BERT_CUSTOM_OPS = {\n", + " \"modules\": {\n", + " BertSelfAttentionInteger: {\n", + " \"args\": {\n", + " \"hidden_states\": \"data_in\",\n", + " \"attention_mask\": None,\n", + " \"head_mask\": None,\n", + " \"encoder_hidden_states\": None,\n", + " \"encoder_attention_mask\": None,\n", + " \"past_key_value\": None,\n", + " \"output_attentions\": \"config\",\n", + " },\n", + " \"toolchain\": \"INTERNAL_RTL\",\n", + " \"module\": \"fixed_self_attention_single_precision_wrapper\",\n", + " \"dependence_files\": get_module_dependencies(\n", + " \"attention/fixed_self_attention_single_precision_wrapper\"\n", + " ),\n", + " },\n", + " },\n", + " \"functions\": {},\n", + "}\n", + "\n", + "mg = MaseGraph(model, custom_ops=BERT_CUSTOM_OPS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the bert model graph is generated, we can start with the analysis passes, which annotate the graph with relevant information, without changing the topology of the nodes and edges. The `add_common_metadata_analysis_pass` performs shape propagation, i.e. running a forward pass on the model to annotate each node with tensor metadata for each of the operator's input and output tensors. `add_hardware_metadata_analysis_pass` builds on top of this, annotating each node with the verilog parameters which will later be used by the pass that emits the SystemVerilog code. One crucial aspect is the `max_parallelism` parameter, which corresponds to the number of arithmetic cores in each hardware submodule, affecting the resource consumption and latency performance of the resulting hardware. The `patch_metadata_transform_pass` pass annotates the fixed-point precision according to the quantiation configuration for a subset of nodes which are relevant for the control flow of the generated hardware. For more information about each pass, see the [pass API documentation](https://deepwok.github.io/mase/modules/api/passes.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import chop.passes as passes\n", + "\n", + "# Redefine some configuration parameters\n", + "CONFIG_BATCH_SIZE = 1\n", + "CONFIG_SEQUENCE_LENGTH = 4\n", + "MAX_PARALLELISM = 4\n", + "WAIT_COUNT = 15\n", + "WAIT_UNIT = \"ms\"\n", + "\n", + "\n", + "mg, _ = passes.init_metadata_analysis_pass(mg)\n", + "\n", + "# * Add metadata analysis passes\n", + "mg, _ = passes.add_common_metadata_analysis_pass(\n", + " mg,\n", + " pass_args={\n", + " \"dummy_in\": {\n", + " \"input_ids\": torch.randn(\n", + " (CONFIG_BATCH_SIZE, CONFIG_SEQUENCE_LENGTH, config.hidden_size)\n", + " )\n", + " },\n", + " \"add_value\": False,\n", + " },\n", + ")\n", + "\n", + "mg, _ = passes.patch_metadata_transform_pass(\n", + " mg,\n", + " pass_args={\n", + " \"precision\": \"fixed\",\n", + " \"q_config\": q_config,\n", + " },\n", + ")\n", + "\n", + "mg, _ = passes.add_hardware_metadata_analysis_pass(\n", + " mg,\n", + " pass_args={\n", + " \"max_parallelism\": [MAX_PARALLELISM] * 4,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At this stage, we are ready to execute the graph transformation passes, which use the annotated metadata to change the topology of the graph such that it is ready for verilog emit. The `emit_verilog_top_transform_pass` generates the SystemVerilog top-level file, while `emit_internal_rtl_transform_pass` copies the relevant submodules from the [Mase Components](https://github.com/DeepWok/mase/tree/main/src/mase_components) SystemVerilog library to the user's workarea. The `emit_bram_transform_pass` pass emits the BRAM modules which store the weights and biases on the FPGA for each layer in the model. A Cocotb testbench is generated in the `emit_cocotb_transform_pass`, which can be used for testing the generated hardware using real Pytorch datasets. Finally, `emit_vivado_project_transform_pass` prepares a Vivado project containing the emitted Verilog code, making it ready for Synthesis and Implementation on the FPGA board." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the timeout time for the generated testbench\n", + "WAIT_COUNT = 15\n", + "WAIT_UNIT = \"ms\"\n", + "\n", + "mg, _ = passes.emit_verilog_top_transform_pass(mg)\n", + "mg, _ = passes.emit_bram_transform_pass(mg)\n", + "mg, _ = passes.emit_internal_rtl_transform_pass(mg)\n", + "mg, _ = passes.emit_cocotb_transform_pass(\n", + " mg,\n", + " pass_args={\n", + " \"wait_time\": WAIT_COUNT,\n", + " \"wait_unit\": WAIT_UNIT,\n", + " },\n", + ")\n", + "mg, _ = passes.emit_vivado_project_transform_pass(mg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Hoorah!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Emit SystemVerilog code for the accelerator: with automation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we've seen everything Mase does under the hood, but we don't want to write that much code each time we generate Verilog for a new model. Luckily, the workflow for every model is very similar, and can be abstracted into a pass manager, which runs a default set of passes. This is achieved through the AutoPipeline API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from chop import AutoPipelineForEmitVerilog\n", + "\n", + "# Redefine some configuration parameters\n", + "CONFIG_BATCH_SIZE = 1\n", + "CONFIG_SEQUENCE_LENGTH = 4\n", + "WAIT_COUNT = 15\n", + "WAIT_UNIT = \"ms\"\n", + "MAX_PARALLELISM = 4\n", + "\n", + "mg = MaseGraph(model, custom_ops=BERT_CUSTOM_OPS)\n", + "\n", + "pipeline = AutoPipelineForEmitVerilog()\n", + "mg = pipeline(\n", + " mg,\n", + " pass_args={\n", + " \"add_common_metadata_analysis_pass\": {\n", + " \"dummy_in\": {\n", + " \"input_ids\": torch.randn(\n", + " (\n", + " CONFIG_BATCH_SIZE,\n", + " CONFIG_SEQUENCE_LENGTH,\n", + " config.hidden_size,\n", + " )\n", + " )\n", + " },\n", + " \"add_value\": False,\n", + " },\n", + " \"patch_metadata_transform_pass\": {\n", + " \"q_config\": q_config,\n", + " },\n", + " \"add_hardware_metadata_analysis_pass\": {\n", + " \"max_parallelism\": [MAX_PARALLELISM] * 4,\n", + " },\n", + " \"report_node_meta_param_analysis_pass\": {\n", + " \"which\": [\"common\", \"hardware\"],\n", + " \"save_path\": \"llama_graph_meta_params.txt\",\n", + " },\n", + " \"emit_cocotb_transform_pass\": {\n", + " \"wait_time\": WAIT_COUNT,\n", + " \"wait_unit\": WAIT_UNIT,\n", + " },\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate the generated accelerator\n", + "\n", + "Now everything is ready, and the generated Verilog files can be found under `~/.mase/top/hardware/rtl`. You can inspect the `top.sv` file to see how data is propagated from the inputs of the module through every layer in the original model. You can also find the emitted Cocotb test under `~/.mase/top/hardware/test.py`. Note that the Cocotb testbench class is not emitted as a text file, but rather pickled and stored as a .dill file, which is a compressed way of sharing the testbench. This is then unpickled and instantiated in the `test.py` file which is executed by the Cocotb runner. Now, simply run the `simulate` action to obtain the latency for a single batch inference pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import chop.actions as actions\n", + "\n", + "os.environ[\"COCOTB_RESOLVE_X\"] = \"ZEROS\"\n", + "actions.simulate(\n", + " skip_build=False, skip_test=False, gui=False, waves=False, simulator=\"questa\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this tutorial, we demonstrated the process of generating an FPGA accelerator for a quantized BERT model using the Mase framework. We began by loading a BERT model and defining its configuration and quantization parameters, then proceeded to quantize the model at the module level. Next, we walked through the detailed steps of emitting SystemVerilog code for the accelerator, which included generating a computation graph using Torch FX, performing various metadata analysis passes, and transforming the graph to be ready for Verilog emission. We showed how to automate these steps using the chop.pipelines API, greatly simplifying the workflow. Finally, we ran the generated Cocotb testbench to evaluate the performance of the accelerator, obtaining throughput and latency metrics.\n", + "\n", + "By following this tutorial, you should now have a solid understanding of how to optimize transformer models for FPGA deployment using Mase, from quantization to hardware code generation and performance evaluation. If you are interested in experimenting further, we propose the following suggested exercises.\n", + "\n", + "1. Re-run the flow by changing the q_config dictionary to try different fixed-point precisions. In each case, open the generated Vivado project and launch the synthesis flow to compare the resource consumption of the generated hardware. Create a plot of the LUT, FF and DSP utilization statistics for a range of fixed-point precisions.\n", + "\n", + "2. Repeat exercise 1, but this time experiment with the maximum parallelism parameter. Again, compare the resource consumption for a range of parallelism parameters. This time, also run the Cocotb testbench in each iteration to see how the parallelism affects the inference latency. Based on this analysis, can you suggest an optimal design point that trades off resource consumption with inference latency?\n", + "\n", + "If you are interested in contributing to the Mase project, we suggest the following extension task.\n", + "\n", + "3. Try to support this flow for a new model, such as Llama, Mistral or GPT. Follow the steps in the [documentation]() to import a new model into Mase from the HuggingFace hub, and try running the `AutoPipelineForVerilogEmit`. If that doesn't work directly, see the hints in the [debugging guide]() to support the new model, now you know the steps required." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nas_bert.py b/nas_bert.py new file mode 100644 index 000000000..cb6b56108 --- /dev/null +++ b/nas_bert.py @@ -0,0 +1,193 @@ +import os + +import time +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from transformers import ( + Trainer, + TrainingArguments, + AutoTokenizer, + DataCollatorWithPadding, +) +from transformers.models.bert.configuration_bert import BertConfig +from datasets import load_dataset +import evaluate + +import nni +import nni.nas.strategy as strategy +from nni.nas.evaluator import FunctionalEvaluator +from nni.nas.experiment import NasExperiment, NasExperimentConfig + +from chop.actions.search.search_space import NasBertSpace + +# * Config +# * ------------------------------------------ + +NUM_TRIALS = 10 +TRIAL_CONCURRENCY = 3 +EPOCHS_PER_TRIAL = 5 +NUM_LATENCY_EVALUATION_ITERATIONS = 10 +checkpoint = "bert-base-uncased" +os.environ["WANDB_DISABLED"] = "true" + +# * Utils +# * ------------------------------------------ + + +def get_datasets(): + raw_datasets = load_dataset("glue", "sst2") + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + + def tokenize_function(example): + return tokenizer(example["sentence"], truncation=True) + + tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + return tokenized_datasets, data_collator, tokenizer + + +# * Evaluator +# * ------------------------------------------ + + +def fit(model): + """ + Train the model using HuggingFace trainer, the call onto evaluate library to get the accuracy + """ + tokenized_datasets, data_collator, tokenizer = get_datasets() + + training_args = TrainingArguments("test-trainer", report_to=None) + training_args.num_train_epochs = EPOCHS_PER_TRIAL + training_args.learning_rate = 1e-4 + training_args.save_strategy = "no" + + trainer = Trainer( + model, + training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["validation"], + data_collator=data_collator, + tokenizer=tokenizer, + ) + + trainer.train() + + # calculate time taken for prediction + avg_time = 0 + for _ in range(NUM_LATENCY_EVALUATION_ITERATIONS): + start_time = time.time() + predictions = trainer.predict(tokenized_datasets["validation"]) + end_time = time.time() + avg_time += end_time - start_time + + avg_time /= NUM_LATENCY_EVALUATION_ITERATIONS + + # Get the accuracy from the last prediction + preds = np.argmax(predictions.predictions, axis=-1) + metric = evaluate.load("glue", "sst2") + results = metric.compute(predictions=preds, references=predictions.label_ids) + nni.report_final_result( + { + "accuracy": results["accuracy"], + "default": results["accuracy"], + "average_latency": avg_time, + "average_tps": len(tokenized_datasets["validation"]) / avg_time, + } + ) + + +# * Define and run experiment +# * ------------------------------------------ + +cf = BertConfig.from_pretrained(checkpoint) +cf._attn_implementation = "eager" + +# Full model parameters +cf.num_hidden_layers = 3 +cf.space_hidden_size = [128, 256, 512, 768, 1024] + +# Per layer +cf.space_self_attention_implementation = ["attention", "linear", "feedthrough"] +cf.space_self_attention_layer_norm = ["layer_norm", "identity"] +cf.space_output_layer_norm = ["layer_norm", "identity"] +cf.space_intermediate_size = [192, 384, 768, 1536, 3072] +cf.space_num_attention_heads = [2, 4, 8, 16] + +space = NasBertSpace(cf) + +evaluator = FunctionalEvaluator(fit) +strat = strategy.TPE() + +experiment_config = NasExperimentConfig.default(space, evaluator, strat) +experiment_config.max_trial_number = NUM_TRIALS # spawn 3 trials at most +experiment_config.trial_concurrency = TRIAL_CONCURRENCY # will run 1 trial concurrently +experiment_config.trial_gpu_number = 1 # use 1 GPU for each trial +experiment_config.training_service.use_active_gpu = True + +experiment = NasExperiment(space, evaluator, strat, config=experiment_config) +experiment.start(port=8081) + + +def dump_experiment_results(data): + df = pd.DataFrame( + columns=["accuracy", "average_latency", "default", "average_tps"], index=[] + ) + for trial in data: + df.loc[trial.trialJobId] = trial.value + df.to_csv(f"experiment_{experiment.id}_results.csv") + return df + + +while True: + if experiment.get_status() == "DONE": + data = experiment.export_data() + df = dump_experiment_results(data) + break + +# create random dataframe with 100 trials where x-y follows an exponential distribution + +# DEBUG +# np.random.seed(0) +# times = np.linspace(0.5, 100, 100) +# df = pd.DataFrame( +# { +# "accuracy": np.log(times) + np.random.normal(0, 0.5, 100), +# "average_latency": times, +# "average_tps": times, +# } +# ) + + +def plot_pareto(df): + plt.figure() + plt.scatter(df["average_latency"], df["accuracy"]) + + # Plot pareto frontier + df = df.sort_values("average_latency") + pareto = df["accuracy"].cummax() + plt.plot(df["average_latency"], pareto, color="red") + + plt.xlabel("Latency [ms]") + plt.ylabel("Accuracy [%]") + plt.title("Accuracy/Latency Pareto Frontier") + plt.savefig(f"experiment_{experiment.id}_results.png") + + +# plt a figure with two subplots, one for latency and one for tps +# def plot_pareto(df): +# fig, ax = plt.subplots(1, 2) +# ax[0].scatter(df["average_latency"], df["accuracy"]) +# ax[0].set_xlabel("Latency [ms]") +# ax[0].set_ylabel("Accuracy [%]") +# ax[0].set_title("Accuracy/Latency Pareto Frontier") + +# ax[1].scatter(df["average_tps"], df["accuracy"]) +# ax[1].set_xlabel("TPS") +# ax[1].set_ylabel("Accuracy [%]") +# ax[1].set_title("Accuracy/TPS Pareto Frontier") +# fig.savefig(f"experiment_{0}_results.png") + + +plot_pareto(df) diff --git a/requirements.txt b/requirements.txt index ebbc92747..84b62f38d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ einops sphinx sphinx-rtd-theme myst-parser +bitstring absl-py scipy sphinx-glpi-theme @@ -62,4 +63,4 @@ bitstring # pynvml # pycuda # cuda-python -# pytorch-quantization # this needs installing from source \ No newline at end of file +# pytorch-quantization # this needs installing from source diff --git a/setup.py b/setup.py index 250652a0e..77063c850 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def is_cuda_available(): "torch-tensorRT; platform_system == 'Linux'", "tensorRT; platform_system == 'Linux'", "cuda-python; platform_system == 'Linux'", - "pytorch-quantization; platform_system == 'Linux'", + "pytorch-quantization @ https://pypi.nvidia.com/pytorch-quantization/pytorch_quantization-2.1.2-py3-none-any.whl ", ] setup( diff --git a/src/chop/__init__.py b/src/chop/__init__.py index 0202dc9a1..b17af66eb 100644 --- a/src/chop/__init__.py +++ b/src/chop/__init__.py @@ -4,3 +4,5 @@ from .ir.onnx.mase_onnx_graph import MaseOnnxGraph from . import passes + +from .pipelines import AutoPipelineForEmitVerilog diff --git a/src/chop/actions/search/search_space/__init__.py b/src/chop/actions/search/search_space/__init__.py index c4f11c154..22ece38b4 100644 --- a/src/chop/actions/search/search_space/__init__.py +++ b/src/chop/actions/search/search_space/__init__.py @@ -6,6 +6,8 @@ from .systolic import SystolicMappingSearchSpace from .base import SearchSpaceBase +from .nas_bert import NasBertSpace + SEARCH_SPACE_MAP = { "graph/quantize/mixed_precision_ptq": GraphSearchSpaceMixedPrecisionPTQ, "module/manual_hf/quantize/llm_mixed_precision_ptq": ManualHFModuleSearchSpaceMixedPrecisionPTQ, diff --git a/src/chop/actions/search/search_space/nas_bert.py b/src/chop/actions/search/search_space/nas_bert.py new file mode 100644 index 000000000..6dcda0912 --- /dev/null +++ b/src/chop/actions/search/search_space/nas_bert.py @@ -0,0 +1,1501 @@ +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + SequenceClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + ModelOutput, + get_torch_version, + logging, +) +from transformers.models.bert.configuration_bert import BertConfig + +import nni +from nni.nas.nn.pytorch import ModelSpace, LayerChoice, ParametrizedModule + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(ParametrizedModule): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, hidden_size_choice, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, hidden_size_choice + ) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, hidden_size_choice + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(hidden_size_choice, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + input_shape[0], seq_length + ) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(ParametrizedModule): + def __init__( + self, + config, + position_embedding_type=None, + num_attention_heads_choice=12, + hidden_size_choice=768, + ): + super().__init__() + if hidden_size_choice % num_attention_heads_choice != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size ({hidden_size_choice}) is not a multiple of the number of attention " + f"heads ({num_attention_heads_choice})" + ) + + self.num_attention_heads = num_attention_heads_choice + self.attention_head_size = int(hidden_size_choice / num_attention_heads_choice) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size_choice, self.all_head_size) + self.key = nn.Linear(hidden_size_choice, self.all_head_size) + self.value = nn.Linear(hidden_size_choice, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor( + key_length - 1, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + else: + position_ids_l = torch.arange( + query_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + key_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSdpaSelfAttention(BertSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse( + get_torch_version() + ) < version.parse("2.2.0") + + # Adapted from BertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if ( + self.position_embedding_type != "absolute" + or output_attentions + or head_mask is not None + ): + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = ( + encoder_attention_mask if is_cross_attention else attention_mask + ) + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if ( + is_cross_attention + and past_key_value + and past_key_value[0].shape[2] == current_states.shape[1] + ): + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if ( + self.require_contiguous_qkv + and query_layer.device.type == "cuda" + and attention_mask is not None + ): + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True + if self.is_decoder and attention_mask is None and tgt_len > 1 + else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(ParametrizedModule): + def __init__(self, config, choice="layer_norm", hidden_size_choice=768): + super().__init__() + self.dense = nn.Linear(hidden_size_choice, hidden_size_choice) + + if choice == "layer_norm": + self.LayerNorm = nn.LayerNorm(hidden_size_choice, eps=config.layer_norm_eps) + elif choice == "identity": + self.LayerNorm = BertIdentity() + else: + raise ValueError(f"Unrecognized choice: {choice}") + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +BERT_SELF_ATTENTION_CLASSES = { + "eager": BertSelfAttention, + "sdpa": BertSdpaSelfAttention, +} + + +class BertLinear(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.dense = nn.Linear(in_features, out_features) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + return self.dense(hidden_states) + + +class BertIdentity(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + return hidden_states + + +class BertAttention(ParametrizedModule): + def __init__( + self, + config, + position_embedding_type=None, + self_attention_choice="attention", + self_attention_layer_norm_choice="attention", + hidden_size_choice=768, + num_attention_heads_choice=12, + ): + super().__init__() + if self_attention_choice == "attention": + self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, + position_embedding_type=position_embedding_type, + num_attention_heads_choice=num_attention_heads_choice, + hidden_size_choice=hidden_size_choice, + ) + elif self_attention_choice == "linear": + self.self = BertLinear(hidden_size_choice, hidden_size_choice) + elif self_attention_choice == "feedthrough": + self.self = BertIdentity() + else: + raise ValueError(f"Unrecognized choice: {self_attention_choice}") + + self.output = BertSelfOutput( + config, + choice=self_attention_layer_norm_choice, + hidden_size_choice=hidden_size_choice, + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + if isinstance(self_outputs, tuple): + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + elif isinstance(self_outputs, torch.Tensor): + outputs = (attention_output, self_outputs) + else: + raise ValueError(f"Unrecognized output type: {type(self_outputs)}") + + return outputs + + +class BertIntermediate(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768, intermediate_size=3072): + super().__init__() + self.dense = nn.Linear(hidden_size_choice, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(ParametrizedModule): + def __init__( + self, + config, + hidden_size_choice=768, + intermediate_size=3072, + output_layer_norm_choice="layer_norm", + ): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size_choice) + # self.LayerNorm = nn.LayerNorm(hidden_size_choice, eps=config.layer_norm_eps) + + if output_layer_norm_choice == "layer_norm": + self.LayerNorm = nn.LayerNorm(hidden_size_choice, eps=config.layer_norm_eps) + elif output_layer_norm_choice == "identity": + self.LayerNorm = BertIdentity() + else: + raise ValueError(f"Unrecognized choice: {output_layer_norm_choice}") + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(ParametrizedModule): + def __init__( + self, + config, + self_attention_choice="attention", + self_attention_layer_norm_choice="layer_norm", + output_layer_norm_choice="layer_norm", + hidden_size_choice=768, + intermediate_size_choice=3072, + num_attention_heads_choice=12, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention( + config, + self_attention_choice=self_attention_choice, + self_attention_layer_norm_choice=self_attention_layer_norm_choice, + hidden_size_choice=hidden_size_choice, + num_attention_heads_choice=num_attention_heads_choice, + ) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError( + f"{self} should be used as a decoder model if cross attention is added" + ) + self.crossattention = BertAttention( + config, + position_embedding_type="absolute", + hidden_size_choice=hidden_size_choice, + num_attention_heads_choice=num_attention_heads_choice, + ) + self.intermediate = BertIntermediate( + config, + intermediate_size=intermediate_size_choice, + hidden_size_choice=hidden_size_choice, + ) + self.output = BertOutput( + config, + intermediate_size=intermediate_size_choice, + output_layer_norm_choice=output_layer_norm_choice, + hidden_size_choice=hidden_size_choice, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = ( + past_key_value[-2:] if past_key_value is not None else None + ) + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.config = config + + self_attention_choice = [ + nni.choice( + f"layer_{layer}_attn_impl_choice", + config.space_self_attention_implementation, + ) + for layer in range(config.num_hidden_layers) + ] + + self_attention_layer_norm_choice = [ + nni.choice( + f"layer_{layer}_attn_layer_norm_choice", + config.space_self_attention_layer_norm, + ) + for layer in range(config.num_hidden_layers) + ] + + output_layer_norm_choice = [ + nni.choice( + f"layer_{layer}_output_layer_norm_choice", + config.space_output_layer_norm, + ) + for layer in range(config.num_hidden_layers) + ] + + intermediate_size_choice = [ + nni.choice( + f"layer_{layer}_intermediate_size_choice", + self.config.space_intermediate_size, + ) + for layer in range(config.num_hidden_layers) + ] + + self_attention_num_heads_choice = [ + nni.choice( + f"layer_{layer}_num_attention_heads_choice", + self.config.space_num_attention_heads, + ) + for layer in range(config.num_hidden_layers) + ] + + self.layer = nn.ModuleList( + [ + BertLayer( + config, + self_attention_choice=self_attention_choice[layer], + self_attention_layer_norm_choice=self_attention_layer_norm_choice[ + layer + ], + output_layer_norm_choice=output_layer_norm_choice[layer], + intermediate_size_choice=intermediate_size_choice[layer], + hidden_size_choice=hidden_size_choice, + num_attention_heads_choice=self_attention_num_heads_choice[layer], + ) + for layer in range(config.num_hidden_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.dense = nn.Linear(hidden_size_choice, hidden_size_choice) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.dense = nn.Linear(hidden_size_choice, hidden_size_choice) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + + self.LayerNorm = nn.LayerNorm(hidden_size_choice, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(hidden_size_choice, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertPreTrainingHeads(ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(hidden_size_choice, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +class BertForPreTrainingOutput(ModelOutput): + """ + Output type of [`BertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: torch.FloatTensor = None + seq_relationship_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class BertModel(BertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _no_split_modules = ["BertEmbeddings", "BertLayer"] + + def __init__(self, config, add_pooling_layer=True, hidden_size_choice=768): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config, hidden_size_choice=hidden_size_choice) + self.encoder = BertEncoder(config, hidden_size_choice=hidden_size_choice) + + self.pooler = ( + BertPooler(config, hidden_size_choice=hidden_size_choice) + if add_pooling_layer + else None + ) + + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand( + batch_size, seq_length + ) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device + ) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length + past_key_values_length), device=device + ) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = ( + encoder_hidden_states.size() + ) + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertForPreTraining(BertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, BertForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output + ) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + next_sentence_loss = loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertForSequenceClassification(BertPreTrainedModel, ParametrizedModule): + def __init__(self, config, hidden_size_choice=768): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.hidden_size_choice = hidden_size_choice + + self.bert = BertModel(config, hidden_size_choice=self.hidden_size_choice) + classifier_dropout = ( + config.classifier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = BertLinear(self.hidden_size_choice, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class NasBertSpace(ModelSpace): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size_choice = nni.choice( + "hidden_size_choice", config.space_hidden_size + ) + self.model = BertForSequenceClassification( + config, hidden_size_choice=self.hidden_size_choice + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + # Note: Trainer uses the model signature to determined which dataset columns to remove when processing the dataset, + # so the signature here needs to match the model signature + return self.model( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + inputs_embeds, + labels, + output_attentions, + output_hidden_states, + return_dict, + ) diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py index 9bf6d191c..749279ff2 100644 --- a/src/chop/actions/simulate.py +++ b/src/chop/actions/simulate.py @@ -1,20 +1,26 @@ +import sys from os import getenv, PathLike + import torch from pathlib import Path +import time +import warnings +from cocotb.runner import get_runner, get_results + +from chop.tools import get_logger import mase_components from mase_components import get_modules - from .emit import emit -import warnings +import glob, os warnings.filterwarnings( "ignore", category=UserWarning, message="Python runners and associated APIs are an experimental feature and subject to change.", ) -from cocotb.runner import get_runner, get_results - +logger = get_logger(__name__) +logger.setLevel("DEBUG") def simulate( model: torch.nn.Module = None, @@ -27,8 +33,12 @@ def simulate( run_emit: bool = False, skip_build: bool = False, skip_test: bool = False, + trace_depth: int = 3, + gui: bool = False, + waves: bool = False, + simulator: str = "verilator" ): - SIM = getenv("SIM", "verilator") + SIM = getenv("SIM", simulator) runner = get_runner(SIM) project_dir = Path.home() / ".mase" / "top" @@ -38,33 +48,60 @@ def simulate( if not skip_build: # To do: extract from mz checkpoint - sources = [ - project_dir / "hardware" / "rtl" / "top.sv", + if (simulator == "questa"): + sources = glob.glob(os.path.join(project_dir / "hardware" / "rtl", "*.sv")) + build_args = [] + + elif (simulator == "verilator"): + sources = ["../../../top.sv"] + build_args = [ + "-Wno-fatal", + "-Wno-lint", + "-Wno-style", + "--trace-fst", + "--trace-structs", + "--trace-depth", + str(trace_depth), + ] + + else: + raise ValueError(f"Unrecognized simulator: {simulator}") + + includes = [ + project_dir / "hardware" / "rtl", + ] + [ + Path(mase_components.__file__).parent / module / "rtl" + for module in get_modules() ] + build_start = time.time() + runner.build( verilog_sources=sources, - includes=[ - project_dir / "hardware" / "rtl", - ] - # Include all mase components - + [ - Path(mase_components.__file__).parent / module / "rtl" - for module in get_modules() - ], + includes=includes, hdl_toplevel="top", - build_args=["-Wno-fatal", "-Wno-lint", "-Wno-style", "--trace"], + build_args=build_args, parameters=[], # use default parameters, ) + build_end = time.time() + logger.info(f"Build finished. Time taken: {build_end - build_start:.2f}s") + + if not skip_test: # Add tb file to python path - import sys sys.path.append(str(project_dir / "hardware" / "test")) + test_start = time.time() runner.test( - hdl_toplevel="top", test_module="mase_top_tb", hdl_toplevel_lang="verilog" + hdl_toplevel="top", + test_module="mase_top_tb", + hdl_toplevel_lang="verilog", + gui=gui, + waves=waves ) + test_end = time.time() + logger.info(f"Test finished. Time taken: {test_end - test_start:.2f}s") # num_tests, fail = get_results("build/results.xml") # return num_tests, fail diff --git a/src/chop/distributed/__init__.py b/src/chop/distributed/__init__.py new file mode 100644 index 000000000..cc807c9ad --- /dev/null +++ b/src/chop/distributed/__init__.py @@ -0,0 +1,2 @@ + +from .launcher import MaseLauncher \ No newline at end of file diff --git a/src/chop/distributed/debug.py b/src/chop/distributed/debug.py new file mode 100644 index 000000000..e771fbb78 --- /dev/null +++ b/src/chop/distributed/debug.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +from typing import List, Sequence, Tuple + +import numpy as np + +from torch._prims_common import ShapeType +from torch.distributed._tensor import DeviceMesh + +from torch.distributed._tensor.placement_types import Placement, Shard + + +def _mesh_to_coordinate(mesh, device_type): + """ + Given a n-dimensional list of device mesh, this function creates a map of + device and its coordinate + """ + # Convert the n-dimensional list to a NumPy array + np_mesh = np.array(mesh.mesh.tolist()) + + # Create a dictionary to map each value to its coordinate + device_to_coordinate_map = {} + for coord, value in np.ndenumerate(np_mesh): + # device is unique in device_mesh + device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord) + + return device_to_coordinate_map + + +def _convert_offset_to_ranges(all_offsets): + """ + Using tabulate package to create a table is easier when we specify row and col ranges + This function converts offsets to ranges. + """ + converted_blocks = [] + + for offset in all_offsets: + shape, offset, value = offset + + # Calculate row_range and column_range + row_range = (offset[0], offset[0] + shape[0] - 1) + column_range = (offset[1], offset[1] + shape[1] - 1) + + # Convert value to string to match your desired format + converted_block = { + "row_range": row_range, + "column_range": column_range, + "value": str(value), + } + converted_blocks.append(converted_block) + + return converted_blocks + + +def _create_table(blocks): + """ + Creates a tabulate table given row and column ranges with device name + """ + try: + from tabulate import tabulate + except ImportError as e: + raise ImportError("tabulate package is required to visualize sharding") from e + + # Extract unique row and column ranges + row_ranges = sorted({block["row_range"] for block in blocks}) + col_ranges = sorted({block["column_range"] for block in blocks}) + + # Create a matrix initialized with empty strings + matrix = [["" for _ in col_ranges] for _ in row_ranges] + + # Fill the matrix with values + for block in blocks: + row_index = row_ranges.index(block["row_range"]) + col_index = col_ranges.index(block["column_range"]) + if matrix[row_index][col_index] == "": + matrix[row_index][col_index] = block["value"] + else: + matrix[row_index][col_index] += ", " + block["value"] + + # Prepare headers + row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges] + col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges] + + return tabulate(matrix, headers=col_headers, showindex=row_headers) + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + my_coordinate: List[int], +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but + with custom my_coordinate input. This is the modified implementation for visualize_sharding. + """ + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) + + +def visualize_sharding(dtensor, header=""): + """ + Visualizes sharding in 1D-2D dtensors + Requires tabulate, install with `pip install tabulate` + + note: no sharding info will be printed for empty tensors + """ + if dtensor.numel() == 0: # we do not print for empty dtensors + return + + if len(dtensor.shape) >= 3: + raise RuntimeError( + "visualize sharding is only implemented for 1D or 2D dtensor" + ) + placements = dtensor.placements + device_mesh = dtensor.device_mesh + device_type = dtensor.device_mesh.device_type + + if device_mesh.get_coordinate() is None: # current rank is not in the mesh + return + + # Only display the visualization once for each DTensor, on the rank whose + # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh, + # we will only print on rank 0. + local_rank_zero_on_all_dim = all( + device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim) + ) + if not local_rank_zero_on_all_dim: + return + + device_map = _mesh_to_coordinate(device_mesh, device_type) + all_offsets = [] + for device in device_map: + local_shape, global_offset = compute_local_shape_and_global_offset( + dtensor.shape, device_mesh, placements, device_map[device] + ) + all_offsets.append([local_shape, global_offset, device]) + + # Convert offsets to blocks with row_ranges for tabulate + blocks = _convert_offset_to_ranges(all_offsets) + + # Print the table + print(header) + print(_create_table(blocks)) \ No newline at end of file diff --git a/src/chop/distributed/launcher.py b/src/chop/distributed/launcher.py new file mode 100644 index 000000000..a3d18fdb1 --- /dev/null +++ b/src/chop/distributed/launcher.py @@ -0,0 +1,85 @@ +import os +from functools import partial +from time import time + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp + +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + Replicate, + Shard, +) + +from chop.distributed.utils import rlog +from ..tools import get_logger +from .utils import placement_from_sharding_config + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def dist_model_fn( + name: str, module: nn.Module, device_mesh: DeviceMesh, rank: int, module_map={} +) -> None: + """ + This function gets called by torch.distributed._tensor.distribute_module on each module in the model. + Each tensor in each module is distributed according to the sharding configuration in module_map. + """ + if module in module_map: + for parameter, sharding_config in module_map[module].items(): + if parameter in ["data_in_0", "output", "data_out_0"]: + continue + if not hasattr(module, parameter): + rlog(logger, rank, f"Module {module} does not have parameter {parameter}", level="warning") + continue + + placement = placement_from_sharding_config(sharding_config) + + rlog(logger, rank, f"Distributing parameter {parameter} of module {module} to {placement}", level="debug") + try: + distributed_tensor = distribute_tensor(getattr(module, parameter), device_mesh, placement) + setattr(module, parameter, torch.nn.Parameter(distributed_tensor)) + except Exception as e: + rlog(logger, rank, f"Error distributing parameter {parameter} of module {module}: {e}", level="error") + + +def device_fn(rank, world_size, model=None, device_mesh=None, module_map={}, inputs=[]): + """ + This function gets called on each GPU device to set up the distributed environment and distribute the model, + following the SPMD model. + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + os.environ["RANK"] = str(rank) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + + mesh = DeviceMesh("cuda", mesh=device_mesh) + rlog(logger, rank, f"Distributing module parameters...", level="info") + start = time() + model = distribute_module( + model, mesh, partial(dist_model_fn, rank=rank, module_map=module_map), input_fn=None, output_fn=None + ) + end = time() + rlog(logger, rank, f"Module distribution done. Time taken: {end - start} seconds.") + + inputs = [distribute_tensor(in_tensor, mesh, [Replicate(), Replicate()]) for in_tensor in inputs] + out = model(*inputs) + + dist.destroy_process_group() + +class MaseLauncher(): + def __init__(self, mase_graph, world_size = None, device_mesh=None): + self.mg = mase_graph + self.model = mase_graph.model + self.world_size = world_size + self.device_mesh = device_mesh + + def run(self, module_map = {}, inputs=[]): + logger.info(f"Launching model with world size {self.world_size}.") + mp.spawn(partial(device_fn, model=self.model, device_mesh=self.device_mesh, module_map=module_map, inputs=inputs), args=(self.world_size,), nprocs=self.world_size, join=True) \ No newline at end of file diff --git a/src/chop/distributed/utils.py b/src/chop/distributed/utils.py new file mode 100644 index 000000000..55104edef --- /dev/null +++ b/src/chop/distributed/utils.py @@ -0,0 +1,35 @@ + +from torch.distributed._tensor import ( + Replicate, + Shard, +) + +from chop.passes.graph.analysis.autosharding.common import SpmdShard + +def placement_from_sharding_config(sharding_config): + """ + Sharding config is given as a tuple such as (R, S_0) where a symbol S_x at index i indicates + that tensor dimension i is sharded along the x-th dimension of the device mesh. However, + the distribute_tensor API expects a tuple of Shard() and Replicate() objects where a Shard(x) + at index i indicates that tensor dimension x is sharded along device mesh dimension i. + """ + placement = [Replicate(), Replicate()] + for shard_type in [SpmdShard.S_0, SpmdShard.S_1]: + if shard_type in sharding_config: + idx = sharding_config.index(shard_type) + # Preserve batch dimension + if (len(sharding_config) > 2): + idx = idx - (len(sharding_config) - 2) + placement[shard_type.value] = Shard(idx) + + if placement == [Shard(1), Shard(1)]: + print(f"Warning: Invalid sharding config {sharding_config}") + return placement + +def rlog(logger, rank, msg, level="info"): + """ + Only log on rank 0 to avoid repeated messages. + """ + log_fn = getattr(logger, level, logger.info) + if rank == 0: + log_fn(msg) \ No newline at end of file diff --git a/src/chop/ir/graph/mase_graph.py b/src/chop/ir/graph/mase_graph.py index 39f1ab2d3..21ddfe5a3 100644 --- a/src/chop/ir/graph/mase_graph.py +++ b/src/chop/ir/graph/mase_graph.py @@ -142,9 +142,16 @@ def is_leaf_module( is_hf_built_in_leaf_module = hf_is_leaf_module( self, m, module_qualified_name ) - is_custom_module = isinstance(m, custom_modules) - return is_hf_built_in_leaf_module or is_custom_module + is_mase_leaf_layer = isinstance(m, MASE_LEAF_LAYERS) + + return any( + ( + is_hf_built_in_leaf_module, + is_custom_module, + is_mase_leaf_layer, + ) + ) return is_leaf_module @@ -178,13 +185,13 @@ def is_leaf_module( custom_leaf_functions += tuple(patched_nodes["functions"]) custom_leaf_layers += tuple(patched_nodes["layers"]) - tracer = MaseTracer( + self.tracer = MaseTracer( custom_leaf_modules=custom_leaf_modules, custom_leaf_functions=custom_leaf_functions, custom_leaf_layers=custom_leaf_layers, ) - graph_module = fx.GraphModule(model, tracer.trace(model, cf_args)) + graph_module = fx.GraphModule(model, self.tracer.trace(model, cf_args)) if patched_nodes is not None: graph_module.patched_op_names = [ diff --git a/src/chop/models/manual/rms_norm.py b/src/chop/models/manual/rms_norm.py deleted file mode 100644 index ca4a771f7..000000000 --- a/src/chop/models/manual/rms_norm.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from torch import nn, Tensor - - -def _rms_norm(x: Tensor, eps, scale: Tensor | None): - mean_squares = x.square().mean(dim=(1, 2, 3), keepdim=True) - rms_x = mean_squares.sqrt() - x_normed = x / (rms_x + eps) - if scale != None: - return scale * x_normed - else: - return x_normed - - -class RMSNorm(nn.Module): - """Root Mean Square Layer Normalization""" - - def __init__( - self, - normalized_shape, - eps: float = 1e-8, - elementwise_affine: bool = False, - device=None, - dtype=None, - ): - super().__init__() - self.eps = eps - self.normalized_shape = normalized_shape - self.elementwise_affine = elementwise_affine - - factory_kwargs = {"device": device, "dtype": dtype} - if self.elementwise_affine: - self.weight = nn.Parameter( - torch.ones(self.normalized_shape, **factory_kwargs) - ) - else: - self.register_parameter("weight", None) - - def forward(self, x: Tensor): - return _rms_norm(x, self.eps, self.weight) diff --git a/src/chop/models/patched/bert/configuration_bert.py b/src/chop/models/patched/bert/configuration_bert.py index 6ef356b67..a3f82a8d2 100644 --- a/src/chop/models/patched/bert/configuration_bert.py +++ b/src/chop/models/patched/bert/configuration_bert.py @@ -25,11 +25,6 @@ logger = logging.get_logger(__name__) -from transformers.models.deprecated._archive_maps import ( - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, -) # noqa: F401, E402 - - class BertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to diff --git a/src/chop/models/patched/bert/modeling_bert.py b/src/chop/models/patched/bert/modeling_bert.py index 10c7088a9..88d705079 100644 --- a/src/chop/models/patched/bert/modeling_bert.py +++ b/src/chop/models/patched/bert/modeling_bert.py @@ -80,9 +80,9 @@ _SEQ_CLASS_EXPECTED_LOSS = 0.01 -from transformers.models.deprecated._archive_maps import ( - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, -) # noqa: F401, E402 +@torch.fx.wrap +def df_split(x): + return (x, x) def load_tf_weights_in_bert(model, config, tf_checkpoint_path): @@ -184,7 +184,9 @@ def __init__(self, config): # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False + ) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.position_embedding_type = getattr( @@ -419,7 +421,9 @@ class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False + ) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -473,8 +477,9 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: + in_1, in_2 = df_split(hidden_states) self_outputs = self.self( - hidden_states, + in_1, attention_mask, head_mask, encoder_hidden_states, @@ -482,7 +487,7 @@ def forward( past_key_value, output_attentions, ) - attention_output = self.output(self_outputs, hidden_states) + attention_output = self.output(self_outputs, in_2) # outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them # return outputs return attention_output @@ -507,7 +512,9 @@ class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False + ) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -611,8 +618,11 @@ def forward( # return outputs def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) + # ! TO DO: automate this + # from chop.nn.functional.splitter import splitter + att_1, att_2 = df_split(attention_output) + intermediate_output = self.intermediate(att_1) + layer_output = self.output(intermediate_output, att_2) return layer_output @@ -738,7 +748,9 @@ def __init__(self, config): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) @@ -823,8 +835,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + if module.weight is not None: + module.weight.data.fill_(1.0) @dataclass @@ -948,7 +962,11 @@ def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config - self.embeddings = BertEmbeddings(config) + # self.embeddings = BertEmbeddings(config) + def passthrough_function(input_ids): + return input_ids + + self.embeddings = passthrough_function self.encoder = BertEncoder(config) # self.pooler = BertPooler(config) if add_pooling_layer else None @@ -1104,7 +1122,7 @@ def forward( # sequence_output = encoder_outputs[0] sequence_output = encoder_outputs pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None + self.pooler(sequence_output) if self.pooler is not None else sequence_output ) return pooled_output diff --git a/src/chop/models/patched/llama/__init__.py b/src/chop/models/patched/llama/__init__.py new file mode 100644 index 000000000..96df0f27e --- /dev/null +++ b/src/chop/models/patched/llama/__init__.py @@ -0,0 +1,2 @@ +from .configuration_llama import LlamaConfig +from .modeling_llama import LlamaModel diff --git a/src/chop/models/patched/llama/configuration_llama.py b/src/chop/models/patched/llama/configuration_llama.py new file mode 100644 index 000000000..3189948ac --- /dev/null +++ b/src/chop/models/patched/llama/configuration_llama.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if ( + rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0 + ): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}" + ) diff --git a/src/chop/models/patched/llama/modeling_llama.py b/src/chop/models/patched/llama/modeling_llama.py new file mode 100644 index 000000000..af004253b --- /dev/null +++ b/src/chop/models/patched/llama/modeling_llama.py @@ -0,0 +1,1872 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +@torch.fx.wrap +def df_split(x): + return (x, x) + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) + / self.dim + ) + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=config.mlp_bias + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [ + F.linear(x, gate_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + up_proj = torch.cat( + [ + F.linear(x, up_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ], + dim=-1, + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + x1, x2 = df_split(x) + down_proj = self.down_proj( + self.act_fn(self.gate_proj(x1)) * self.up_proj(x2) + ) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.config.pretraining_tp, dim=2 + ) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + # return attn_output, attn_weights, past_key_value + return attn_output + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + # =============================================================== + # Moved here from model scope so this is hidden in tracing + if cache_position is None: + past_seen_tokens = ( + past_key_value.get_seq_length() if past_key_value is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + hidden_states.shape[1], + device=hidden_states.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + # =============================================================== + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + # return attn_output, None, past_key_value + return attn_output + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + # residual = hidden_states + hidden_states, residual = df_split(hidden_states) + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + # residual = hidden_states + hidden_states, residual = df_split(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + # if output_attentions: + # outputs += (self_attn_weights,) + + # if use_cache: + # outputs += (present_key_value,) + + return outputs[0] + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # self.embed_tokens = nn.Embedding( + # config.vocab_size, config.hidden_size, self.padding_idx + # ) + def passthrough_function(input_ids): + return input_ids + + self.embed_tokens = passthrough_function + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + # if use_cache and not isinstance( + # past_key_values, Cache + # ): # kept for BC (non `Cache` `past_key_values` inputs) + # return_legacy_cache = True + # past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + # causal_mask = self._update_causal_mask( + # attention_mask, + # inputs_embeds, + # cache_position, + # past_key_values, + # output_attentions, + # ) + causal_mask = attention_mask + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError( + "Custom 4D attention mask should be passed in inverted form with max==0`" + ) + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand( + input_tensor.shape[0], 1, -1, -1 + ) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + + attention_mask[:, None, None, :] + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = ( + cache_position[0] + if cache_position is not None + else past_key_values.get_seq_length() + ) + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length + if max_cache_length is None + else torch.min(max_cache_length, past_length) + ) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = ( + position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + ) + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + input_length, device=input_ids.device + ) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ) + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/chop/models/patched/mistral/__init__.py b/src/chop/models/patched/mistral/__init__.py new file mode 100644 index 000000000..211f2ea9f --- /dev/null +++ b/src/chop/models/patched/mistral/__init__.py @@ -0,0 +1,2 @@ +from .configuration_mistral import MistralConfig +from .modeling_mistral import MistralModel diff --git a/src/chop/models/patched/mistral/configuration_mistral.py b/src/chop/models/patched/mistral/configuration_mistral.py new file mode 100644 index 000000000..8cb536285 --- /dev/null +++ b/src/chop/models/patched/mistral/configuration_mistral.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mistral model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class MistralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an + Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. + + [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) + [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MistralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mistral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import MistralModel, MistralConfig + + >>> # Initializing a Mistral 7B style configuration + >>> configuration = MistralConfig() + + >>> # Initializing a model from the Mistral 7B style configuration + >>> model = MistralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/chop/models/patched/mistral/modeling_mistral.py b/src/chop/models/patched/mistral/modeling_mistral.py new file mode 100644 index 000000000..494caecfc --- /dev/null +++ b/src/chop/models/patched/mistral/modeling_mistral.py @@ -0,0 +1,1531 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mistral model.""" + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_mistral import MistralConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward with Gemma->Mistral + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, +} + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + # Omit tokens covered by past_key_values + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # crop the attention_mask to sliding window size during decode phase if using SlidingWindowCache + if ( + past_length > 0 + and attention_mask is not None + and isinstance(past_key_values, SlidingWindowCache) + and attention_mask.shape[1] > past_key_values.max_cache_len + ): + attention_mask = attention_mask[:, -past_key_values.max_cache_len :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForTokenClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/src/chop/nn/functional/__init__.py b/src/chop/nn/functional/__init__.py index e69de29bb..5dce0ef23 100644 --- a/src/chop/nn/functional/__init__.py +++ b/src/chop/nn/functional/__init__.py @@ -0,0 +1 @@ +from .softermax import softermax diff --git a/src/chop/nn/functional/softermax.py b/src/chop/nn/functional/softermax.py new file mode 100644 index 000000000..e6fe5bde1 --- /dev/null +++ b/src/chop/nn/functional/softermax.py @@ -0,0 +1,19 @@ +from torch import Tensor + + +def softermax(input: Tensor, dim: int) -> Tensor: + """Softermax implementation, according to "Softermax: Hardware/Software Co-Design of an Efficient Softmax for Transformers" paper (https://arxiv.org/abs/2103.09301). + + Args: + input (Tensor): Input tensor + + Returns: + Tensor: Output tensor + """ + input = input.squeeze() + out = input - input.max(dim=1).values.floor() + out = 2**out + row_sum = out.sum(dim=1).reshape((-1, 1)).expand(input.shape) + # Elementwise division + out = out / row_sum + return out diff --git a/src/chop/nn/functional/splitter.py b/src/chop/nn/functional/splitter.py new file mode 100644 index 000000000..37229d1dc --- /dev/null +++ b/src/chop/nn/functional/splitter.py @@ -0,0 +1,6 @@ + +import torch + +@torch.fx.wrap +def splitter(x): + return (x, x) diff --git a/src/chop/nn/quantized/__init__.py b/src/chop/nn/quantized/__init__.py index e79e43c38..1d02146e4 100644 --- a/src/chop/nn/quantized/__init__.py +++ b/src/chop/nn/quantized/__init__.py @@ -1,2 +1,12 @@ -from .modules import quantized_module_map -from .functional import quantized_func_map +from .modules import ( + quantized_module_map, + BertSelfAttentionInteger, + BertSelfAttentionHeadInteger, + LlamaSdpaAttentionInteger, + LinearInteger, + LayerNormInteger, + GELUInteger, + SiLUInteger, + RMSNormInteger, +) +from .functional import quantized_func_map, fixed_softermax diff --git a/src/chop/nn/quantized/functional/__init__.py b/src/chop/nn/quantized/functional/__init__.py index b9fa26e15..65ccb1915 100644 --- a/src/chop/nn/quantized/functional/__init__.py +++ b/src/chop/nn/quantized/functional/__init__.py @@ -1,3 +1,5 @@ +from .softermax import fixed_softermax + from .add import ( add_block_fp, add_block_log, diff --git a/src/chop/nn/quantized/functional/matmul.py b/src/chop/nn/quantized/functional/matmul.py index f8548c4fc..d06eb1ece 100644 --- a/src/chop/nn/quantized/functional/matmul.py +++ b/src/chop/nn/quantized/functional/matmul.py @@ -9,6 +9,7 @@ block_log_quantizer, block_minifloat_quantizer, integer_quantizer, + integer_floor_quantizer, log_quantizer, minifloat_denorm_quantizer, minifloat_ieee_quantizer, @@ -20,22 +21,37 @@ matmul_mapping = {"matmul": torch.matmul, "bmm": torch.bmm} -def generic_matmul_integer(x, y, config, style="matmul"): +def generic_matmul_integer(x, y, config, style="matmul", out_config=None, floor=False): bypass = config.get("bypass", False) matmul = matmul_mapping[style] + if bypass: return matmul(x, y) else: + base_quantizer = integer_quantizer + x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] - x_quantizer = partial(integer_quantizer, width=x_width, frac_width=x_frac_width) y_width, y_frac_width = config["weight_width"], config["weight_frac_width"] - y_quantizer = partial(integer_quantizer, width=y_width, frac_width=y_frac_width) + + x_quantizer = partial(base_quantizer, width=x_width, frac_width=x_frac_width) + y_quantizer = partial(base_quantizer, width=y_width, frac_width=y_frac_width) + + if out_config is not None: + out_width, out_frac_width = ( + out_config["data_out_width"], + out_config["data_out_frac_width"], + ) + out_quantizer = partial( + integer_floor_quantizer, width=out_width, frac_width=out_frac_width + ) x = x_quantizer(x) y = y_quantizer(y) - # y = x_quantizer(y) - return matmul(x, y) + if out_config is not None: + return out_quantizer(matmul(x, y)) + else: + return matmul(x, y) def generic_matmul_binary(x, y, config, style="matmul"): @@ -329,8 +345,8 @@ def generic_matmul_block_log(x, y, config, style="matmul"): return matmul(x, y) -def matmul_integer(x, y, config): - return generic_matmul_integer(x, y, config, "matmul") +def matmul_integer(x, y, config, out_config=None, floor=False): + return generic_matmul_integer(x, y, config, "matmul", out_config, floor) def matmul_binary(x, y, config): diff --git a/src/chop/nn/quantized/functional/softermax.py b/src/chop/nn/quantized/functional/softermax.py new file mode 100644 index 000000000..b5829f1ae --- /dev/null +++ b/src/chop/nn/quantized/functional/softermax.py @@ -0,0 +1,30 @@ +from torch import Tensor +from math import ceil + +from chop.nn.quantizers import ( + integer_quantizer, + integer_floor_quantizer, +) +from chop.nn.functional import softermax + + +def fixed_softermax( + input: Tensor, q_config: dict = None, out_q_config: dict = None, dim: int = 0 +) -> Tensor: + """Fixed-point softermax implementation, according to "Softermax: Hardware/Software Co-Design of an Efficient Softmax for Transformers" paper (https://arxiv.org/abs/2103.09301). + + Args: + input (Tensor): Input tensor + + Returns: + Tensor: Output tensor + """ + if q_config is not None: + input = integer_quantizer(input, **q_config) + + out = softermax(input, dim=dim) + + if out_q_config is not None: + out = integer_floor_quantizer(out, **out_q_config) + + return out diff --git a/src/chop/nn/quantized/modules/__init__.py b/src/chop/nn/quantized/modules/__init__.py index 42713ff31..64a971aa8 100644 --- a/src/chop/nn/quantized/modules/__init__.py +++ b/src/chop/nn/quantized/modules/__init__.py @@ -1,3 +1,6 @@ +from .attention_head import BertSelfAttentionHeadInteger +from .attention import BertSelfAttentionInteger, LlamaSdpaAttentionInteger + # from .add import AddInteger from .conv1d import ( Conv1dBlockFP, @@ -67,9 +70,7 @@ from .group_norm import GroupNormInteger from .instance_norm2d import InstanceNorm2dInteger -# from .rms_norm import ( -# RMSNormInteger -# ) +from .rms_norm import RMSNormInteger from .selu import ( SELUBlockFP, @@ -83,6 +84,18 @@ SELUTernary, ) +from .silu import ( + SiLUBlockFP, + SiLUBlockMinifloat, + SiLUInteger, + SiLULog, + SiLUBlockLog, + SiLUMinifloatDenorm, + SiLUMinifloatIEEE, + SiLUBinary, + SiLUTernary, +) + from .tanh import ( TanhBlockFP, TanhBlockMinifloat, @@ -188,7 +201,7 @@ "layer_norm_integer": LayerNormInteger, "group_norm_integer": GroupNormInteger, "instance_norm2d_integer": InstanceNorm2dInteger, - # "rms_norm_integer": RMSNormInteger, + "rms_norm_integer": RMSNormInteger, "selu_block_minifloat": SELUBlockMinifloat, "selu_integer": SELUInteger, "selu_fixed": SELUInteger, @@ -199,6 +212,16 @@ "selu_block_fp": SELUBlockFP, "selu_binary": SELUBinary, "selu_ternary": SELUTernary, + "silu_block_minifloat": SiLUBlockMinifloat, + "silu_integer": SiLUInteger, + "silu_fixed": SiLUInteger, + "silu_log": SiLULog, + "silu_block_log": SiLUBlockLog, + "silu_minifloat_ieee": SiLUMinifloatIEEE, + "silu_minifloat_denorm": SiLUMinifloatDenorm, + "silu_block_fp": SiLUBlockFP, + "silu_binary": SiLUBinary, + "silu_ternary": SiLUTernary, "tanh_block_minifloat": TanhBlockMinifloat, "tanh_integer": TanhInteger, "tanh_fixed": TanhInteger, @@ -241,4 +264,7 @@ "softplus_ternary": SoftplusTernary, "batch_norm1d_fixed": BatchNorm1dInteger, "batch_norm1d_linear": BatchNorm1dInteger, + "bert_self_attention_head_integer": BertSelfAttentionHeadInteger, + "bert_self_attention_integer": BertSelfAttentionInteger, + "llama_sdpa_attention_integer": LlamaSdpaAttentionInteger, } diff --git a/src/chop/nn/quantized/modules/attention.py b/src/chop/nn/quantized/modules/attention.py new file mode 100644 index 000000000..b769647a5 --- /dev/null +++ b/src/chop/nn/quantized/modules/attention.py @@ -0,0 +1,204 @@ +from functools import partial + +import torch +from torch import Tensor +from torch.nn import functional as F + +from transformers.models.bert.modeling_bert import BertSelfAttention +from chop.models.patched.llama.modeling_llama import LlamaSdpaAttention +from chop.models.patched.llama.configuration_llama import LlamaConfig + +from chop.nn.quantized.modules.linear import ( + LinearInteger, +) +from chop.nn.quantized.functional import fixed_softermax +from chop.nn.quantized.functional import matmul_integer + +from typing import Optional, Tuple + + +class _BertSelfAttentionBase(BertSelfAttention): + def __init__( + self, + config, + q_config: dict = None, + out_q_config: dict = None, + position_embedding_type=None, + bias=True, + output_tensor_only=False, + ) -> None: + super().__init__(config, position_embedding_type) + self.bypass = False + self.q_config = q_config + self.out_q_config = out_q_config + self.bias = bias + self.output_tensor_only = output_tensor_only + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + out = super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + if self.output_tensor_only: + return out[0] + return out + + +class _LlamaSdpaAttentionBase(LlamaSdpaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + q_config: dict = None, + out_q_config: dict = None, + output_tensor_only=False, + ): + super().__init__(config, layer_idx) + self.bypass = False + self.q_config = q_config + self.out_q_config = out_q_config + self.output_tensor_only = output_tensor_only + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[int] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + out = super().forward( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + ) + if self.output_tensor_only: + return out[0] + return out + + +class BertSelfAttentionInteger(_BertSelfAttentionBase): + def __init__( + self, + config, + q_config: dict = None, + out_q_config: dict = None, + position_embedding_type=None, + bias=True, + floor=False, + output_tensor_only=False, + ) -> None: + super().__init__( + config, + q_config, + out_q_config, + position_embedding_type, + bias=bias, + output_tensor_only=output_tensor_only, + ) + self.query = LinearInteger( + config.hidden_size, + config.hidden_size, + config=q_config, + out_config=out_q_config, + bias=bias, + floor=floor, + ) + self.key = LinearInteger( + config.hidden_size, + config.hidden_size, + config=q_config, + out_config=out_q_config, + bias=bias, + floor=floor, + ) + self.value = LinearInteger( + config.hidden_size, + config.hidden_size, + config=q_config, + out_config=out_q_config, + bias=bias, + floor=floor, + ) + # * Matmul is used for Q @ K^T and Scores @ V where the input values have already + # * been casted to the output precision, so we provide the output precision to the + # * software model + self.matmul = partial( + matmul_integer, + config={ + "data_in_width": self.q_config["data_out_width"], + "data_in_frac_width": self.q_config["data_out_frac_width"], + "weight_width": self.q_config["data_out_width"], + "weight_frac_width": self.q_config["data_out_frac_width"], + }, + out_config=out_q_config, + floor=floor, + ) + + +class LlamaSdpaAttentionInteger(_LlamaSdpaAttentionBase): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + q_config: dict = None, + out_q_config: dict = None, + output_tensor_only=False, + ): + super().__init__( + config, + layer_idx, + q_config, + out_q_config, + output_tensor_only=output_tensor_only, + ) + self.q_proj = LinearInteger( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + config=q_config, + out_config=out_q_config, + ) + self.k_proj = LinearInteger( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + config=q_config, + out_config=out_q_config, + ) + self.v_proj = LinearInteger( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + config=q_config, + out_config=out_q_config, + ) + self.o_proj = LinearInteger( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + config=q_config, + out_config=out_q_config, + ) diff --git a/src/chop/nn/quantized/modules/attention_head.py b/src/chop/nn/quantized/modules/attention_head.py new file mode 100644 index 000000000..8f9ea5969 --- /dev/null +++ b/src/chop/nn/quantized/modules/attention_head.py @@ -0,0 +1,91 @@ +import torch +from torch import Tensor +import torch.nn as nn +import math + +from typing import Optional, Tuple +from functools import partial + +from chop.nn.quantized.functional.matmul import ( + generic_matmul_integer, +) +from chop.nn.quantizers.integer import integer_quantizer + + +class _BertSelfAttentionHeadBase(torch.nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + # ! TO DO: replace these with quantized functions? + self.matmul = torch.matmul + self.softmax = nn.functional.softmax + + def self_attention_head( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tensor: + attention_scores = self.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = self.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = self.matmul(attention_probs, value_layer) + return context_layer + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tensor: + return self.self_attention_head( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + ) + + +class BertSelfAttentionHeadInteger(_BertSelfAttentionHeadBase): + def __init__(self, config, q_config: dict = None) -> None: + super().__init__(config) + + self.query_quantizer = partial( + integer_quantizer, + **q_config, + ) + self.key_quantizer = partial(integer_quantizer, **q_config) + self.value_quantizer = partial(integer_quantizer, **q_config) + + def forward( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tensor: + query_layer = self.query_quantizer(query_layer) + key_layer = self.key_quantizer(key_layer) + value_layer = self.value_quantizer(value_layer) + + return self.self_attention_head( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + ) diff --git a/src/chop/nn/quantized/modules/gelu.py b/src/chop/nn/quantized/modules/gelu.py index 1eb4583ba..074cf4df8 100644 --- a/src/chop/nn/quantized/modules/gelu.py +++ b/src/chop/nn/quantized/modules/gelu.py @@ -22,6 +22,7 @@ class _GELUBase(torch.nn.GELU): def __init__(self, inplace: bool = False): super().__init__(inplace) + self.inplace = inplace self.bypass = False self.x_quantizer = None @@ -30,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor: return F.gelu(x) else: x = self.x_quantizer(x) - return F.gelu(x, self.inplace) + return F.gelu(x) def get_quantized_output(self, x: Tensor) -> Tensor: x = self.x_quantizer(x) @@ -43,6 +44,7 @@ class GELUInteger(_GELUBase): def __init__(self, inplace: bool = False, config: dict = None): super().__init__(inplace) assert config is not None, "config is None!" + self.config = config self.bypass = config.get("bypass", False) if self.bypass: diff --git a/src/chop/nn/quantized/modules/linear.py b/src/chop/nn/quantized/modules/linear.py index 5ceae8bbe..00d6099f1 100644 --- a/src/chop/nn/quantized/modules/linear.py +++ b/src/chop/nn/quantized/modules/linear.py @@ -11,6 +11,7 @@ block_log_quantizer, block_minifloat_quantizer, integer_quantizer, + integer_floor_quantizer, log_quantizer, minifloat_denorm_quantizer, minifloat_ieee_quantizer, @@ -53,36 +54,19 @@ def __init__( self.x_quantizer = None self.w_quantizer = None self.b_quantizer = None + self.out_quantizer = lambda x: x self.pruning_masks = None def forward(self, x: Tensor) -> Tensor: if self.bypass: - # if bypss, there is no quantization + # if bypass, there is no quantization return F.linear(x, self.weight, self.bias) else: x = self.x_quantizer(x) w = self.w_quantizer(self.weight) bias = self.b_quantizer(self.bias) if self.bias is not None else None - return F.linear(x, w, bias) - - # TODO: implement these as passes - # def get_quantized_weight(self) -> Tensor: - # return self.w_quantizer(self.weight) - - # def get_quantized_weights_with_inputs(self, x: Tensor) -> Tensor: - # x = self.x_quantizer(x) - # w = self.w_quantizer(self.weight) - # bias = self.b_quantizer(self.bias) if self.bias is not None else None - # y = F.linear(x, w, bias) - # return { - # "x": x, - # "w": w, - # "bias": bias, - # "y": y, - # } - - # def get_output_bitwidth(self) -> dict: - # raise NotImplementedError() + out = F.linear(x, w, bias) + return self.out_quantizer(out) class LinearInteger(_LinearBase): @@ -94,10 +78,13 @@ def __init__( device=None, dtype=None, config=None, + out_config=None, + floor=False, ) -> None: super().__init__(in_features, out_features, bias, device, dtype) assert config is not None, "config is None!" self.config = config + self.out_config = out_config self.bypass = config.get("bypass", False) if self.bypass: return @@ -106,6 +93,12 @@ def __init__( x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] # check bias quantizer, if not, use weight quantizer b_width, b_frac_width = config["bias_width"], config["bias_frac_width"] + if out_config is not None: + out_width, out_frac_width = ( + out_config["data_out_width"], + out_config["data_out_frac_width"], + ) + base_quantizer = integer_floor_quantizer if floor else integer_quantizer self.w_quantizer = partial( integer_quantizer, width=w_width, frac_width=w_frac_width ) @@ -115,26 +108,10 @@ def __init__( self.b_quantizer = partial( integer_quantizer, width=b_width, frac_width=b_frac_width ) - - # def get_output_bitwidth(self): - # config = self.config - # w_width, w_frac = config["weight_width"], config["weight_frac_width"] - # x_width, x_frac = config["data_in_width"], config["data_in_frac_width"] - # bias_width = config["bias_width"] - - # ops = self.in_features - # product_width = w_width + x_width - # product_frac_width = w_frac + x_frac - # # *: + 1 for bias - # output_width = max(bias_width, product_width + ceil(log2(ops))) + 1 - # output_frac_width = product_frac_width - - # o_bitwidth = {} - # o_bitwidth["data_out_width"] = output_width - # o_bitwidth["data_out_frac_width"] = output_frac_width - # # o_bitwidth["product_width"] = product_width - # # o_bitwidth["product_frac_width"] = product_frac_width - # return o_bitwidth + if out_config is not None: + self.out_quantizer = partial( + integer_floor_quantizer, width=out_width, frac_width=out_frac_width + ) class LinearMinifloatDenorm(_LinearBase): @@ -247,20 +224,6 @@ def __init__( exponent_bias=b_exponent_bias, ) - # def get_output_bitwidth(self) -> dict: - # num_ops = self.in_features - # product_bitwidth = self.w_width + self.x_width - # product_frac = self.w_frac_width + self.x_frac_width - - # addition_bitwidth = math.ceil(math.log(num_ops)) - # output_bitwidth = product_bitwidth + addition_bitwidth - # return { - # "output_width": output_bitwidth, - # "output_frac_width": product_frac, - # "product_width": product_bitwidth, - # "product_frac_width": product_frac, - # } - class LinearLog(_LinearBase): def __init__( diff --git a/src/chop/nn/quantized/modules/rms_norm.py b/src/chop/nn/quantized/modules/rms_norm.py index 134e67ae8..91dd9d9d6 100644 --- a/src/chop/nn/quantized/modules/rms_norm.py +++ b/src/chop/nn/quantized/modules/rms_norm.py @@ -8,10 +8,47 @@ from chop.nn.quantizers import ( integer_quantizer, ) -import chop.models.manual.rms_norm as rms -class _RMSNormBase(rms.RMSNorm): +def _rms_norm(x: Tensor, eps, scale: Tensor | None): + mean_squares = x.square().mean(-1, keepdim=True) + rms_x = mean_squares.sqrt() + x_normed = x / (rms_x + eps) + if scale != None: + return scale * x_normed + else: + return x_normed + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization""" + + def __init__( + self, + normalized_shape, + eps: float = 1e-8, + elementwise_affine: bool = False, + device=None, + dtype=None, + ): + super().__init__() + self.eps = eps + self.normalized_shape = normalized_shape + self.elementwise_affine = elementwise_affine + + factory_kwargs = {"device": device, "dtype": dtype} + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.ones(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x: Tensor): + return _rms_norm(x, self.eps, self.weight) + + +class _RMSNormBase(RMSNorm): def __init__( self, normalized_shape, @@ -26,7 +63,7 @@ def __init__( self.w_quantizer = None def forward(self, x: Tensor): - return rms._rms_norm(x, self.eps, self.weight) + return _rms_norm(x, self.eps, self.weight) class RMSNormInteger(_RMSNormBase): diff --git a/src/chop/nn/quantized/modules/silu.py b/src/chop/nn/quantized/modules/silu.py new file mode 100644 index 000000000..2a3446ba9 --- /dev/null +++ b/src/chop/nn/quantized/modules/silu.py @@ -0,0 +1,294 @@ +from functools import partial + +import torch +from torch import Tensor +from torch.nn import functional as F + +from .utils import get_stats, quantiser_passthrough + +from chop.nn.quantizers import ( + block_fp_quantizer, + block_log_quantizer, + block_minifloat_quantizer, + integer_quantizer, + log_quantizer, + minifloat_denorm_quantizer, + minifloat_ieee_quantizer, +) + + +class _SiLUBase(torch.nn.SiLU): + def __init__(self, inplace: bool = False): + super().__init__(inplace) + self.bypass = False + self.x_quantizer = None + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.silu(x) + else: + x = self.x_quantizer(x) + return F.silu(x, self.inplace) + + def get_quantized_output(self, x: Tensor) -> Tensor: + x = self.x_quantizer(x) + return {"x": x} + + +class SiLUInteger(_SiLUBase): + bypass = None + + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + # establish quantizers + x_width, x_frac_width = config["data_in_width"], config["data_in_frac_width"] + self.x_quantizer = partial( + integer_quantizer, width=x_width, frac_width=x_frac_width, is_signed=False + ) + self.config = config + self.x_width = x_width + self.x_frac_width = x_frac_width + + +class SiLUMinifloatDenorm(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_width, x_exponent_bias = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_exponent_bias"], + ) + self.x_quantizer = partial( + minifloat_denorm_quantizer, + width=x_width, + exponent_width=x_exponent_width, + exponent_bias=x_exponent_bias, + ) + + +class SiLUMinifloatIEEE(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_width, x_exponent_bias = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_exponent_bias"], + ) + self.x_quantizer = partial( + minifloat_ieee_quantizer, + width=x_width, + exponent_width=x_exponent_width, + exponent_bias=x_exponent_bias, + ) + + +class SiLULog(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_bias = ( + config["data_in_width"], + config["data_in_exponent_bias"], + ) + self.x_quantizer = partial( + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, + ) + + +class SiLULog(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_bias = ( + config["data_in_width"], + config["data_in_exponent_bias"], + ) + self.x_quantizer = partial( + log_quantizer, + width=x_width, + exponent_bias=x_exponent_bias, + ) + + +class SiLUBlockFP(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_width, x_exponent_bias, x_block_size = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_exponent_bias"], + config["data_in_block_size"], + ) + self.x_quantizer = partial( + block_fp_quantizer, + width=x_width, + exponent_width=x_exponent_width, + exponent_bias=x_exponent_bias, + block_size=x_block_size, + skip_first_dim=True, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.silu(x) + else: + x_shape = [i for i in x.shape] + if x.ndim > 2: + x = torch.flatten(x, 0, -3) + x = self.x_quantizer(x) + x = torch.reshape(x, x_shape) + return F.silu(x, self.inplace) + + +class SiLUBlockMinifloat(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_width, x_exponent_bias_width, x_block_size = ( + config["data_in_width"], + config["data_in_exponent_width"], + config["data_in_exponent_bias_width"], + config["data_in_block_size"], + ) + self.x_quantizer = partial( + block_minifloat_quantizer, + width=x_width, + exponent_width=x_exponent_width, + exponent_bias_width=x_exponent_bias_width, + block_size=x_block_size, + skip_first_dim=True, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.silu(x) + else: + x_shape = [i for i in x.shape] + if x.ndim > 2: + x = torch.flatten(x, 0, -3) + x = self.x_quantizer(x) + x = torch.reshape(x, x_shape) + return F.silu(x, self.inplace) + + +class SiLUBlockLog(_SiLUBase): + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + + x_width, x_exponent_bias_width, x_block_size = ( + config["data_in_width"], + config["data_in_exponent_bias_width"], + config["data_in_block_size"], + ) + self.x_quantizer = partial( + block_log_quantizer, + width=x_width, + exponent_bias_width=x_exponent_bias_width, + block_size=x_block_size, + skip_first_dim=True, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.bypass: + return F.silu(x) + else: + x_shape = [i for i in x.shape] + if x.ndim > 2: + x = torch.flatten(x, 0, -3) + x = self.x_quantizer(x) + x = torch.reshape(x, x_shape) + return F.silu(x, self.inplace) + + +class SiLUBinary(_SiLUBase): + bypass = None + + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + # establish quantizers + x_stochastic = config["data_in_stochastic"] + x_bipolar = config["data_in_bipolar"] + self.x_quantizer = quantiser_passthrough + # self.x_quantizer = partial( + # binary_quantizer, stochastic=x_stochastic, bipolar=x_bipolar + # ) + self.config = config + # self.x_width = x_width + # self.x_frac_width = x_frac_width + + +class SiLUTernary(_SiLUBase): + bypass = None + + def __init__(self, inplace: bool = False, config: dict = None): + super().__init__(inplace) + assert config is not None, "config is None!" + self.config = config + self.bypass = config.get("bypass", False) + if self.bypass: + return + # establish quantisers + x_scaling_factor = config["data_in_scaling_factor"] + x_mean = get_stats(config, "data_in_mean") + x_median = get_stats(config, "data_in_median") + x_max = get_stats(config, "data_in_max") + self.x_quantizer = quantiser_passthrough + # self.x_quantizer = partial( + # ternary_quantizer, + # scaling_factor=x_scaling_factor, + # median=x_median, + # maximum=x_max, + # mean=x_mean, + # ) + self.config = config diff --git a/src/chop/nn/quantizers/__init__.py b/src/chop/nn/quantizers/__init__.py index 68c320689..f4af2f05d 100644 --- a/src/chop/nn/quantizers/__init__.py +++ b/src/chop/nn/quantizers/__init__.py @@ -1,7 +1,7 @@ from .block_fp import block_fp_quantizer from .block_log import block_log_quantizer from .block_minifloat import block_minifloat_quantizer -from .integer import integer_quantizer +from .integer import integer_quantizer, integer_floor_quantizer from .binary import binary_quantizer, residual_sign_quantizer from .ternary import ternary_quantizer from .log import log_quantizer diff --git a/src/chop/nn/quantizers/quantizers_for_hw.py b/src/chop/nn/quantizers/quantizers_for_hw.py index a8d705aa1..ac0575614 100644 --- a/src/chop/nn/quantizers/quantizers_for_hw.py +++ b/src/chop/nn/quantizers/quantizers_for_hw.py @@ -6,9 +6,6 @@ # from .quantizers import integer_quantizer from .utils import block, my_clamp, my_round, unblock, my_floor -from mase_cocotb.utils import sign_extend - - def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) scale = 2**frac_width @@ -19,6 +16,16 @@ def integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): return fixed_point_value +def unsigned_integer_quantizer_for_hw(x: Tensor, width: int, frac_width: int): + thresh = 2**width - 1 + scale = 2**frac_width + + fixed_point_value = my_clamp(my_floor(x.mul(scale)), 0, thresh) + fixed_point_value = fixed_point_value.to(torch.int) + fixed_point_value = fixed_point_value % (2**width) + return fixed_point_value + + def integer_floor_quantizer_for_hw(x: Tensor, width: int, frac_width: int): thresh = 2 ** (width - 1) scale = 2**frac_width diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index b6eec63df..f0af404ed 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -31,14 +31,18 @@ emit_bram_transform_pass, emit_internal_rtl_transform_pass, emit_cocotb_transform_pass, + emit_vivado_project_transform_pass, raise_granularity_transform_pass, tensorrt_calibrate_transform_pass, tensorrt_fine_tune_transform_pass, tensorrt_fake_quantize_transform_pass, + patch_metadata_transform_pass, ) from .module.analysis import calculate_avg_bits_module_analysis_pass -from .module.transforms import quantize_module_transform_pass +from .module.transforms import quantize_module_transform_pass, resharding_transform_pass from .onnx.analysis import ( export_fx_graph_analysis_pass, ) + +from .graph.analysis.autosharding import autosharding_analysis_pass diff --git a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py index 53785b105..ae9e5e008 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_common_metadata.py @@ -121,9 +121,6 @@ def graph_iterator_for_mase_ops(graph): mase_op = "softshrink" elif isinstance(module, nn.LogSigmoid): mase_op = "logsigmoid" - # TODO: temporary. Support all patched attention layers - elif "attention" in module.__name__.lower(): - mase_op = "attention" else: mase_op = None for module_cls in graph.model.custom_ops["modules"].keys(): diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 397199915..83493a57d 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -8,7 +8,7 @@ get_input_nodes, get_output_nodes, ) -from chop.passes.graph.utils import get_mase_op +from chop.passes.graph.utils import get_mase_op, deepgetattr from torch import nn @@ -34,7 +34,18 @@ def add_component_source(node): node.meta["mase"]["hardware"]["interface"] = {} mase_op = node.meta["mase"]["common"]["mase_op"] - if mase_op in INTERNAL_COMP.keys(): + if mase_op == "user_defined_module": + for custom_op, op_info in node.meta["mase"].model.custom_ops["modules"].items(): + if isinstance( + deepgetattr(node.meta["mase"].model, node.target), + custom_op, + ): + node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" + node.meta["mase"]["hardware"]["module"] = op_info["module"] + node.meta["mase"]["hardware"]["dependence_files"] = op_info[ + "dependence_files" + ] + elif mase_op in INTERNAL_COMP.keys(): node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" # take the first ip in the component list by default node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"] @@ -380,7 +391,7 @@ def add_hardware_metadata_analysis_pass(graph, pass_args=None): # * Fix max parallelism to small value to enable verilator simulation # ! TO DO: enable this to be overriden by user for node in graph.nodes: - node.meta["mase"]["hardware"]["max_parallelism"] = [4, 4, 4, 4] + node.meta["mase"]["hardware"]["max_parallelism"] = pass_args.get("max_parallelism", [4] * 4) # Add hardware parameters for node in graph.nodes: diff --git a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py index a5e30a04d..55a589a5f 100644 --- a/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/common_metadata_layers.py @@ -5,6 +5,7 @@ import inspect from chop.tools.utils import to_numpy_if_tensor as to_numpy from chop.passes.graph.utils import vf, get_node_by_name +from chop.passes.graph.patching import MASE_LEAF_FUNCTIONS, MASE_LEAF_LAYERS import traceback from functools import reduce @@ -114,6 +115,8 @@ "sigmoid": {"input": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.argmax.html "argmax": {"input": "data_in"}, + # dataflow_split + "df_split": {"x": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.split.html "split": {"input": "data_in", "split_size_or_sections": "config", "dim": "config"}, # https://pytorch.org/docs/stable/generated/torch.logical_not.html @@ -153,11 +156,12 @@ # https://pytorch.org/docs/stable/generated/torch.full.html "full": {"size": "config", "fill_value": "data_in"}, # get item - "getitem": {"a": "data_in", "b": "data_in"}, + "getitem": {"in": "data_in", "select": "config"}, # getattr "getattr": {"a": "data_in", "b": "data_in"}, # https://pytorch.org/docs/stable/generated/torch.ones.html "ones": {"size": "config", "device": "config"}, + "finfo": {"dtype": "config"}, } module_data = { @@ -207,6 +211,7 @@ "silu": {"input": "data_in"}, "elu": {"input": "data_in"}, "softmax": {"input": "data_in"}, + "gelu": {"input": "data_in"}, } @@ -247,26 +252,63 @@ "transpose": {"dim_0": "config", "dim_1": "config"}, # https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous "contiguous": {}, + "masked_fill": {"mask": "data_in", "value": "data_in"}, } +def get_type_and_precision(meta): + # * Fetch type and precision from q_config for quantized modules + if isinstance(meta.module, MASE_LEAF_LAYERS): + cf = ( + meta.module.q_config + if hasattr(meta.module, "q_config") + else meta.module.config + ) + arg_type = "fixed" + arg_precision = [ + cf["data_in_width"], + cf["data_in_frac_width"], + ] + else: + arg_type = "float" + arg_precision = [32] + return arg_type, arg_precision + + def match_args_and_kwargs(meta, args, kwargs, data, add_value): ordered_func_data = [(k, v) for k, v in data.items()] meta.parameters["common"]["args"] = {} meta_kwargs = {} + + arg_type, arg_precision = get_type_and_precision(meta) + + # * Assign metadata for each argument j = 0 for i, x in enumerate(args): if isinstance(x, torch.Tensor) and ordered_func_data[i][1] == "data_in": arg_meta = { "shape": list(x.shape), "torch_dtype": x.dtype, - "type": "float", - "precision": [32], + "type": arg_type, + "precision": arg_precision, } if add_value: arg_meta["value"] = x meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta j += 1 + # check if it's a tuple of tensors + elif isinstance(x, tuple) and all([isinstance(x, torch.Tensor) for x in x]): + for k, x in enumerate(x): + arg_meta = { + "shape": list(x.shape), + "torch_dtype": x.dtype, + "type": arg_type, + "precision": arg_precision, + } + if add_value: + arg_meta["value"] = x + meta.parameters["common"]["args"][f"data_in_{j}"] = arg_meta + j += 1 else: # this is not an data_in, but just actually an named arg n, vtype = ordered_func_data[i] @@ -288,8 +330,8 @@ def get_shape(x): arg_meta = { "shape": get_shape(v), "torch_dtype": v.dtype if isinstance(v, torch.Tensor) else type(v), - "type": "float", - "precision": [32], + "type": arg_type, + "precision": arg_precision, } if add_value: arg_meta["value"] = v @@ -306,21 +348,39 @@ def get_shape(x): def analyse_result(meta, result, add_value): # deal with results meta.parameters["common"]["results"] = {} + + result_type, result_precision = get_type_and_precision(meta) + if isinstance(result, torch.Tensor): meta.parameters["common"]["results"]["data_out_0"] = { - "type": "float", - "precision": [32], + "type": result_type, + "precision": result_precision, "shape": list(result.shape), "torch_dtype": result.dtype, } if add_value: meta.parameters["common"]["results"]["data_out_0"]["value"] = result + + # check if it's a tuple of tensors + elif isinstance(result, tuple) and all( + [isinstance(x, torch.Tensor) for x in result] + ): + for i, x in enumerate(result): + meta.parameters["common"]["results"][f"data_out_{i}"] = { + "type": result_type, + "precision": result_precision, + "shape": list(x.shape), + "torch_dtype": x.dtype, + } + if add_value: + meta.parameters["common"]["results"][f"data_out_{i}"]["value"] = x else: meta.parameters["common"]["results"]["data_out_0"] = { "type": type(result), "shape": [1], "value": result, } + return meta @@ -392,11 +452,19 @@ def analyse_common_parameters_module(meta, result, args, kwargs, add_value=True) module_args = module_data[mase_op] meta = match_args_and_kwargs(meta, args, kwargs, module_args, add_value) + + arg_type, arg_precision = get_type_and_precision(meta) + for name, parameter in meta.module.named_parameters(): + name = name.replace(".", "_") meta.parameters["common"]["args"][name] = { - "type": "float", - "precision": [32], - "shape": list(parameter.shape), + "type": arg_type, + "precision": arg_precision, + "shape": ( + list(parameter.shape) + if len(parameter.shape) > 1 + else list(parameter.unsqueeze(dim=0).shape) + ), "from": None, } if add_value: diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py index e7d05c949..25d3d4673 100644 --- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py @@ -45,6 +45,7 @@ "fixed_arithmetic/rtl/fixed_adder_tree.sv", "fixed_arithmetic/rtl/fixed_adder_tree_layer.sv", "fixed_arithmetic/rtl/fixed_mult.sv", + "common/rtl/unpacked_repeat_circular_buffer.sv", "common/rtl/register_slice.sv", "common/rtl/join2.sv", "common/rtl/skid_buffer.sv", @@ -147,7 +148,7 @@ "name": "fixed_gelu", "dependence_files": [ "activations/rtl/fixed_gelu.sv", - "arithmetic/rtl/fixed_mult.sv", + "activations/rtl/gelu_lut.sv", ], }, ], @@ -156,7 +157,7 @@ "name": "fixed_softsign", "dependence_files": [ "activations/rtl/fixed_softsign.sv", - "arithmetic/rtl/fixed_mult.sv", + "fixed_arithmetic/rtl/fixed_mult.sv", ], }, ], @@ -168,4 +169,34 @@ ], }, ], + "add": [ + { + "name": "fixed_adder", + "dependence_files": [ + "fixed_arithmetic/rtl/fixed_adder.sv", + ], + } + ], + "mul": [ + { + "name": "fixed_elementwise_multiplier", + "dependence_files": [ + "fixed_arithmetic/rtl/fixed_elementwise_multiplier.sv", + ], + } + ], + "df_split": [ + { + "name": "df_split", + "dependence_files": ["common/rtl/df_split.sv", "common/rtl/split2.sv"], + } + ], + "getitem": [ + { + "name": "buffer", + "dependence_files": [ + "common/rtl/buffer.sv", + ], + } + ], } diff --git a/src/chop/passes/graph/analysis/autosharding/__init__.py b/src/chop/passes/graph/analysis/autosharding/__init__.py new file mode 100644 index 000000000..2e53f199b --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/__init__.py @@ -0,0 +1,2 @@ + +from .autosharding import autosharding_analysis_pass \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/alpa.py b/src/chop/passes/graph/analysis/autosharding/alpa.py new file mode 100644 index 000000000..924aa55c9 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/alpa.py @@ -0,0 +1,151 @@ + +import functools + +import torch.nn as nn +import numpy as np +import cvxpy as cp + +from chop.tools import get_logger + +from .common import SpmdShard +from .alpa_layers import ALPA_LAYERS +from .alpa_cost_modelling import get_resharding_matrix + +logger = get_logger(__name__) +import sys, pdb, traceback + +def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + +# Set the custom exception hook +sys.excepthook = excepthook + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default + +def get_node_target(node): + if isinstance(node.target, str): + return deepgetattr(node.meta["mase"].model, node.target, None) + else: + return node.target + +def assign_default_sharding(node): + rank = len(node.meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": [{"data_in_0": (SpmdShard.R,) * rank}], + "valid_output_shardings": [(SpmdShard.R,) * rank], + "compute_cost_vector": [0], + "communication_cost_vector": [0], + "opt_var": np.array([1]), + } + +def alpa_intra_op_sharding_pass(mg, mesh): + """ + Intra-operator auto parallelization pass. + """ + + module_map = {} + + # Setup for the ILP optimization + expr = 0 + constr = [] + + # Write cost vectors into metadata for each operator + # This will later be used to solve the ILP optimization + for node in mg.fx_graph.nodes: + + target = get_node_target(node) + target_cls = type(target) + num_params = len([i for i in target.parameters()]) if isinstance(target, nn.Module) else 0 + + if node.op != "call_module" or num_params == 0: + assign_default_sharding(node) + + elif target_cls in ALPA_LAYERS.keys(): + # Enumerate shardings and costs for this operator + ( + input_shardings, + output_shardings, + compute_cost_vector, + communication_cost_vector, + ) = ALPA_LAYERS[target_cls](node.meta, mesh, target) + + # Formulate optimization variable and consider compute/communication cost + opt_var = cp.Variable(len(input_shardings), boolean=True) + constr += [ + cp.sum(opt_var) == 1, + ] + expr += opt_var.T @ (compute_cost_vector + communication_cost_vector) + + # Write into metadata + node.meta["mase"]["software"]["autosharding"] = { + "valid_input_shardings": input_shardings, + "valid_output_shardings": output_shardings, + "compute_cost_vector": compute_cost_vector, + "communication_cost_vector": communication_cost_vector, + "opt_var": opt_var, + } + + # Consider resharding cost + for in_node in node.all_input_nodes: + in_opt_var = in_node.meta["mase"]["software"]["autosharding"]["opt_var"] + + resharding_costs = get_resharding_matrix( + mesh, + src_shardings = in_node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"], + dest_shardings = [sharding["data_in_0"] for sharding in node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"]], + dest_node_meta = node.meta["mase"] + ).flatten() + + # Formulate resharding cost term with linearized variable + e_var = cp.Variable(opt_var.shape[0] * in_opt_var.shape[0], boolean=True) + expr += e_var.T @ resharding_costs + constr += [ + cp.sum(e_var) == 1, + ] + + # Scalar construction of the inequality constraints for the linearized variable + for i in range(e_var.shape[0]): + constr += [ + e_var[i] <= opt_var[i // in_opt_var.shape[0]], + e_var[i] <= in_opt_var[i % in_opt_var.shape[0]], + e_var[i] >= opt_var[i // in_opt_var.shape[0]] + in_opt_var[i % in_opt_var.shape[0]] - 1 + ] + + # No sharding algorithm found for this operator, but this has parameter attributes + # (i.e. not an elementwise or implicit function) + elif (len([i for i in target.parameters()]) > 0): + logger.warning(f"No sharding algorithm found for operator: {target_cls}, but the parameter count is non-zero.") + logger.warning(f" MaseLauncher will fully replicate the parameters of this module.") + + else: + logger.debug(f"Skipping implicit/elementwise operator: {target_cls}") + + # Solve the ILP problem + prob = cp.Problem(cp.Minimize(expr), constr) + prob.solve() + + for node in mg.fx_graph.nodes: + chosen_idx = 0 if isinstance(node.meta["mase"]["software"]["autosharding"]["opt_var"], np.ndarray) else np.where(node.meta["mase"]["software"]["autosharding"]["opt_var"].value == 1)[0][0] + node.meta["mase"]["software"]["autosharding"]["input_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_input_shardings"][chosen_idx] + node.meta["mase"]["software"]["autosharding"]["output_sharding"] = node.meta["mase"]["software"]["autosharding"]["valid_output_shardings"][chosen_idx] + + # Write into module map (used by distributed launcher) + target = get_node_target(node) + if node.op == "call_module" and target is not None: + module_map[target] = { + key: node.meta["mase"]["software"]["autosharding"]["input_sharding"][key] for key in node.meta["mase"]["software"]["autosharding"]["input_sharding"].keys() + } + module_map[target]["output"] = node.meta["mase"]["software"]["autosharding"]["output_sharding"] + + return mg, module_map + +def alpa_autosharding_pass(mg, mesh): + mg, module_map = alpa_intra_op_sharding_pass(mg, mesh) + return mg, module_map \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py new file mode 100644 index 000000000..7e051cb5b --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/alpa_cost_modelling.py @@ -0,0 +1,101 @@ + +import numpy as np +from functools import lru_cache + +from chop.ir.graph import MaseMetadata + +from .common import SpmdShard +from .mesh_model import MeshModel + +BYTES_PER_ELEMENT = 4 + +def get_communication_cost(sharding: tuple, node_meta: MaseMetadata, mesh: MeshModel): + assert sharding[0][-1] == sharding[1][-2], f"Inconsistent sharding for node: {node_meta.node}" + inner_dim_sharding = sharding[1][0] + + out_shape = node_meta["common"]["results"]["data_out_0"]["shape"] + + if inner_dim_sharding == SpmdShard.R: + return 0 + + else: + ar_dim = inner_dim_sharding.value # 0 for S_0, 1 for S_1 + return mesh.all_reduce_cost(num_bytes = BYTES_PER_ELEMENT * np.prod(out_shape), mesh_dim = ar_dim) + +@lru_cache(maxsize=None) +def get_resharding_cost(mesh: MeshModel, src: tuple, dest: tuple, dest_node_meta: MaseMetadata): + """ + Obtain the resharding cost given a source and destination sharding profile for a tensor. + The mesh object is assumed to have been initialized with alpha, beta parameters so that + the communication cost can be estimated for each MPI operator. + """ + + + # If original sharding is fully replicated, no resharding is required + if src == dest or all(i == SpmdShard.R for i in src): + return 0 + + num_bytes = BYTES_PER_ELEMENT * np.prod(dest_node_meta["common"]["args"]["data_in_0"]["shape"]) + + # No cost (simple split along given mesh dimension) + if ( + # Keep dim 0, split dim 1 + # E.g. (R, R) -> (R, S_0), (S_0, R) -> (S_0, S_1) + (src[0] == dest[0]) and (src[1] == SpmdShard.R) and (dest[1] in [SpmdShard.S_0, SpmdShard.S_1]) + # Split dim 0, keep dim 1 + # E.g. (R, R) -> (S_1, R), (R, S_1) -> (S_0, S_1) + or (src[1] == dest[1]) and (src[0] == SpmdShard.R) and (dest[0] in [SpmdShard.S_0, SpmdShard.S_1]) + ): + return 0 + + # Split -> Replicate (All Gather) + elif ( + # Keep dim 0, gather along dim 1 + # E.g. (S_1, S_0) -> (S_1, R) + (src[0] == dest[0]) and (src[1] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[1] == SpmdShard.R) + # Gather along dim 0, keep dim 1 + # E.g. (S_0, S_1) -> (R, S_1) + or (src[1] == dest[1]) and (src[0] in [SpmdShard.S_0, SpmdShard.S_1]) and (dest[0] == SpmdShard.R) + ): + ag_dim = 1 if src[0] == dest[0] else 0 + return mesh.all_gather_cost( + num_bytes = num_bytes, + mesh_dim = ag_dim, + ) + + # All-to-all + # E.g. (R, S_0) -> (S_0, R), (S_1, R) -> (R, S_1) + elif (src[0] == dest[1] and src[1] == dest[0] and (SpmdShard.R in src)): + # all to all + a2a_dim = src[0].value if src[0] != SpmdShard.R else src[1].value + try: + return mesh.all_to_all_cost( + num_bytes = num_bytes, + mesh_dim = a2a_dim, + ) + except: + breakpoint() + + # Two-stage resharding: when the resharding cannot be resolved with a single split, all-gather or all-to-all, + # must first gather along the first non-replicated dimension, then recursively compute the cost for the + # reduced sharding + else: + # Reduce one dimension and re-compute + if (src[0] != SpmdShard.R): + new_src = (SpmdShard.R, src[1]) + ag_dim = src[0].value + else: + new_src = (SpmdShard.R, SpmdShard.R) + ag_dim = src[1].value + + return mesh.all_gather_cost( + num_bytes = num_bytes, + mesh_dim = ag_dim + ) + get_resharding_cost(mesh, new_src, dest, dest_node_meta) + +def get_resharding_matrix(mesh, src_shardings, dest_shardings, dest_node_meta): + mat = np.zeros((len(dest_shardings), len(src_shardings))) + for src_idx, src in enumerate(src_shardings): + for dest_idx, dest in enumerate(dest_shardings): + mat[dest_idx, src_idx] = get_resharding_cost(mesh, src, dest, dest_node_meta) + return mat diff --git a/src/chop/passes/graph/analysis/autosharding/alpa_layers.py b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py new file mode 100644 index 000000000..74651a624 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/alpa_layers.py @@ -0,0 +1,94 @@ +import itertools +import numpy as np +import torch.nn as nn + +from chop.tools import get_logger +from chop.models.patched.bert.modeling_bert import BertSelfAttention + +from .common import SpmdShard, VALID_2D_TENSOR_SHARDINGS +from .alpa_cost_modelling import get_communication_cost + + +logger = get_logger(__name__) + +def is_valid_2d_sharding(sharding): + if len(sharding) > 2: + return sharding[1:] in VALID_2D_TENSOR_SHARDINGS + else: + return sharding in VALID_2D_TENSOR_SHARDINGS + +def is_valid_sharding_pair(sharding_pair): + return sharding_pair[0][-1] == sharding_pair[1][-2] + +def is_fully_replicated(sharding_pair): + return all(all(dimp == SpmdShard.R for dimp in subp) for subp in sharding_pair) + +def get_valid_2d_shardings(node_meta, mesh, module): + """ + Return every valid combination of shardings for the input tensors. For an operator + sharding to be valid, the inner dimension must have the same sharding. + E.g. ((R, S_0), (S_0, R)) are valid, but ((R, S_0), (S_1, R)) is not. + """ + input_shardings = [] + output_shardings = [] + compute_cost_vector = [] + communication_cost_vector = [] + + out_rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + + for perm in itertools.product(VALID_2D_TENSOR_SHARDINGS, repeat=2): + if out_rank > 2: + perm = tuple((SpmdShard.R,) * (out_rank - 2) + p for p in perm) + output_sharding = tuple((SpmdShard.R,) * (out_rank - 2) + (perm[0][-2], perm[1][-1])) + if not is_fully_replicated(perm) and is_valid_sharding_pair(perm) and is_valid_2d_sharding(output_sharding): + input_shardings.append({ + "data_in_0": perm[0], + "weight": perm[1] + }) + output_shardings.append(output_sharding) + + compute_cost_vector.append(0) + communication_cost_vector.append(get_communication_cost(perm, node_meta["mase"], mesh)) + + return ( + input_shardings, + output_shardings, + np.array(compute_cost_vector), + np.array(communication_cost_vector), + ) + +def get_valid_linear_shardings(node_meta, mesh, module): + return get_valid_2d_shardings(node_meta, mesh, module) + +def get_valid_layernorm_shardings(node_meta, mesh, module): + rank = len(node_meta["mase"]["common"]["results"]["data_out_0"]["shape"]) + valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * rank}] + valid_output_shardings = [(SpmdShard.R,) * rank] + compute_cost_vector = [0] + communication_cost_vector = [0] + return ( + valid_input_shardings, + valid_output_shardings, + np.array(compute_cost_vector), + np.array(communication_cost_vector), + ) + +def get_valid_embedding_shardings(node_meta, mesh, module): + weight_rank = len(module.weight.shape) + data_in_rank = len(node_meta["mase"]["common"]["args"]["data_in_0"]["shape"]) + valid_input_shardings = [{"data_in_0": (SpmdShard.R,) * data_in_rank, "weight": (SpmdShard.R,) * weight_rank}] + valid_output_shardings = [(SpmdShard.R,) * data_in_rank] + compute_cost_vector = [0] + communication_cost_vector = [0] + return ( + valid_input_shardings, + valid_output_shardings, + np.array(compute_cost_vector), + np.array(communication_cost_vector), + ) + +ALPA_LAYERS = { + nn.Linear: get_valid_linear_shardings, + nn.LayerNorm: get_valid_layernorm_shardings, + nn.Embedding: get_valid_embedding_shardings, +} diff --git a/src/chop/passes/graph/analysis/autosharding/autosharding.py b/src/chop/passes/graph/analysis/autosharding/autosharding.py new file mode 100644 index 000000000..3f6ae4b60 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/autosharding.py @@ -0,0 +1,45 @@ + +import numpy as np +import cvxpy as cp +from time import time + +from chop.tools import get_logger + +from .mesh_model import MeshModel +from .alpa import alpa_autosharding_pass + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def autosharding_analysis_pass(mg, pass_args: dict = {}): + """ + A lightweight implementation of the core algorithm from the Alpa paper: https://arxiv.org/abs/2201.12023 + """ + + assert "mesh_shape" in pass_args, "Logical description for device cluster was not specified." + assert "inter_node_bandwidth" in pass_args, "Inter-node bandwidth not specified" + assert "intra_node_bandwidth" in pass_args, "Intra-node bandwidth not specified" + + # Timing + start_time = time() + + # Initialize device mesh model, used for cost estimation + mesh = MeshModel(pass_args["mesh_shape"]) + + algo = pass_args.get("sharding_algo", "alpa") + + # Communication cost model depends + mesh.set_cost_model_parameters( + intra_node_bandwidth=pass_args["intra_node_bandwidth"], + inter_node_bandwidth=pass_args["inter_node_bandwidth"], + backend = pass_args.get("communications_backend", "default") + ) + + # Run intra-operator pass + if algo == "alpa": + mg, module_map = alpa_autosharding_pass(mg, mesh) + + end_time = time() + logger.info(f"Autosharding pass complete. Time taken: {end_time - start_time} seconds.") + + return mg, module_map diff --git a/src/chop/passes/graph/analysis/autosharding/common.py b/src/chop/passes/graph/analysis/autosharding/common.py new file mode 100644 index 000000000..e0b98001a --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/common.py @@ -0,0 +1,25 @@ +from enum import Enum + +class SpmdShard(Enum): + S_0 = 0 + S_1 = 1 + R = 3 + + def __repr__(self): + return self.name + + def __gt__(self, other): + if self.__class__ is other.__class__: + return self.value > other.value + return NotImplemented + + +VALID_2D_TENSOR_SHARDINGS = [ + (SpmdShard.R, SpmdShard.R), + (SpmdShard.R, SpmdShard.S_0), + (SpmdShard.R, SpmdShard.S_1), + (SpmdShard.S_0, SpmdShard.R), + (SpmdShard.S_0, SpmdShard.S_1), + (SpmdShard.S_1, SpmdShard.R), + (SpmdShard.S_1, SpmdShard.S_0), +] \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/autosharding/mesh_model.py b/src/chop/passes/graph/analysis/autosharding/mesh_model.py new file mode 100644 index 000000000..79aa93ce9 --- /dev/null +++ b/src/chop/passes/graph/analysis/autosharding/mesh_model.py @@ -0,0 +1,53 @@ +import torch +import numpy as np + +class MeshModel(): + def __init__(self, mesh_shape, mesh_alpha = None, mesh_beta = None): + self.mesh_shape = mesh_shape + + num_devices = np.prod(mesh_shape) + self.id_mesh = torch.arange(0, num_devices).reshape(mesh_shape) + + # Alpha/beta model is used to estimate communication cost between devices + self.mesh_alpha = [0] * 2 if mesh_alpha is None else mesh_alpha + self.mesh_beta = [None] * 2 if mesh_beta is None else mesh_beta + + def set_cost_model_parameters(self, intra_node_bandwidth: int, inter_node_bandwidth:int, backend:str = "default"): + # Assign differently depending if backend is NVLink, Infiniband, etc + if (backend == "default"): + # Assuming a setup with ethernet-connected nodes and devices connected through + # PCIe within each node + self.mesh_beta = [ + 1 / inter_node_bandwidth, + 1 / intra_node_bandwidth + ] + + def all_gather_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices * num_bytes + 0.1) + + def all_reduce_cost(self, num_bytes, mesh_dim, num_devices = None): + """ + The term multiplied by beta represents the total number of bytes + transferred over the full transaction. For the ring implementation + of all reduce there are 2 rounds of (n-1) transfers, hence 2(n-1). + In each case num_bytes/num_devices is transferred, where num_bytes + is the number of bytes for the full tensor on each device. + """ + if num_devices is None: + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * + (num_devices - 1) / num_devices * num_bytes + 0.01) + + def reduce_scatter_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices * num_bytes + 0.001) + + def all_to_all_cost(self, num_bytes, mesh_dim): + num_devices = self.id_mesh.shape[mesh_dim] + penalty_factor = num_devices / 2.0 + return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * + (num_devices - 1) / num_devices / num_devices * num_bytes * + penalty_factor + 0.001) \ No newline at end of file diff --git a/src/chop/passes/graph/analysis/report/report_graph.py b/src/chop/passes/graph/analysis/report/report_graph.py index 6a617febd..990416b22 100644 --- a/src/chop/passes/graph/analysis/report/report_graph.py +++ b/src/chop/passes/graph/analysis/report/report_graph.py @@ -1,4 +1,5 @@ import logging +from tabulate import tabulate logger = logging.getLogger(__name__) @@ -15,8 +16,19 @@ def report_graph_analysis_pass(graph, pass_args={"file_name": None}): :rtype: tuple(MaseGraph, dict) """ file_name = pass_args.get("file_name") - buff = "" - buff += str(graph.fx_graph) + buff = """ +Graph Analysis Report + +===================== Graph Summary ===================== + + """ + node_specs = [ + [n.op, n.name, n.target, n.args, n.kwargs] for n in graph.fx_graph.nodes + ] + + buff += str( + tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) + ) count = { "placeholder": 0, "get_attr": 0, @@ -33,8 +45,17 @@ def report_graph_analysis_pass(graph, pass_args={"file_name": None}): for node in graph.fx_graph.nodes: count[node.op] += 1 - buff += f"""\nNetwork overview: + buff += f""" + +===================== Graph Syntax ===================== + +{str(graph.fx_graph)} + +===================== Graph Overview ===================== + +Network overview: {count} + Layer types: {layer_types}""" if file_name is None: diff --git a/src/chop/passes/graph/common.py b/src/chop/passes/graph/common.py index 8c441b83a..56cdc917d 100644 --- a/src/chop/passes/graph/common.py +++ b/src/chop/passes/graph/common.py @@ -31,7 +31,6 @@ "where", "_assert", "getattr", - "getitem", "long", "type_as", "clamp", @@ -48,6 +47,8 @@ "full", "ones", "dim", + "finfo", + "masked_fill" ] MASE_MODULE_RELATED_FUNCS = [ @@ -115,6 +116,8 @@ "log", "range", "gelu", + "df_split", + "getitem", ] diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py index 3f5ae8f58..4966ec744 100644 --- a/src/chop/passes/graph/transforms/__init__.py +++ b/src/chop/passes/graph/transforms/__init__.py @@ -7,6 +7,7 @@ emit_mlir_hls_transform_pass, emit_cocotb_transform_pass, emit_verilog_top_transform_pass, + emit_vivado_project_transform_pass, ) from .utils import ( conv_bn_fusion_transform_pass, @@ -23,3 +24,5 @@ tensorrt_fine_tune_transform_pass, tensorrt_fake_quantize_transform_pass, ) + +from .patching import patch_metadata_transform_pass diff --git a/src/chop/passes/graph/transforms/patching/__init__.py b/src/chop/passes/graph/transforms/patching/__init__.py new file mode 100644 index 000000000..b59c4b822 --- /dev/null +++ b/src/chop/passes/graph/transforms/patching/__init__.py @@ -0,0 +1 @@ +from .patch_metadata import patch_metadata_transform_pass diff --git a/src/chop/passes/graph/transforms/patching/patch_metadata.py b/src/chop/passes/graph/transforms/patching/patch_metadata.py new file mode 100644 index 000000000..17210864b --- /dev/null +++ b/src/chop/passes/graph/transforms/patching/patch_metadata.py @@ -0,0 +1,62 @@ +import operator + +PYTHON_NATIVE_FUNCTIONS = [ + operator.add, + operator.mul, + operator.getitem, +] + + +def patch_metadata_transform_pass(mg, pass_args: dict = {}): + """ + Typically, metadata such as precision and type are inferred from each node during the add_common_metadata_analysis_pass. + However, for call_function nodes where the target is a Python-native function, some metadata is not correctly inferred since + we avoid overriding Python native functions with mase-specific primitives. Hence, in this pass we manually patch the metadata + for these nodes according to the requested payloads. + """ + + precision = pass_args.get("precision", "fixed") + + for node in mg.fx_graph.nodes: + # Update args + if ( + node.target in PYTHON_NATIVE_FUNCTIONS + or node.meta["mase"]["common"]["mase_op"] == "df_split" + ): + node.meta["mase"]["common"]["args"]["data_in_0"]["type"] = precision + node.meta["mase"]["common"]["args"]["data_in_0"]["precision"] = [ + pass_args["q_config"]["data_in_width"], + pass_args["q_config"]["data_in_frac_width"], + ] + if "data_in_1" in node.meta["mase"]["common"]["args"]: + node.meta["mase"]["common"]["args"]["data_in_1"]["type"] = precision + node.meta["mase"]["common"]["args"]["data_in_1"]["precision"] = [ + pass_args["q_config"]["data_in_width"], + pass_args["q_config"]["data_in_frac_width"], + ] + + # Update results + if ( + node.target in PYTHON_NATIVE_FUNCTIONS + or node.meta["mase"]["common"]["mase_op"] == "df_split" + or node.op == "placeholder" + or node.op == "output" + ): + node.meta["mase"]["common"]["results"]["data_out_0"]["type"] = precision + node.meta["mase"]["common"]["results"]["data_out_0"]["precision"] = [ + pass_args["q_config"]["data_out_width"], + pass_args["q_config"]["data_out_frac_width"], + ] + if "data_out_1" in node.meta["mase"]["common"]["results"]: + node.meta["mase"]["common"]["results"]["data_out_1"]["type"] = precision + node.meta["mase"]["common"]["results"]["data_out_1"]["precision"] = [ + pass_args["q_config"]["data_out_width"], + pass_args["q_config"]["data_out_frac_width"], + ] + + # Set one of the args to none according to the select value + if node.target == operator.getitem: + select = 0 if node.args[1] == 1 else 1 + node.meta["mase"]["common"]["args"][f"data_in_{select}"] = None + + return mg, {} diff --git a/src/chop/passes/graph/transforms/verilog/__init__.py b/src/chop/passes/graph/transforms/verilog/__init__.py index 57903beac..262e7905f 100644 --- a/src/chop/passes/graph/transforms/verilog/__init__.py +++ b/src/chop/passes/graph/transforms/verilog/__init__.py @@ -4,3 +4,4 @@ from .emit_hls import emit_mlir_hls_transform_pass from .emit_internal import emit_internal_rtl_transform_pass from .emit_logicnets import emit_logicnets_transform_pass +from .emit_vivado_project import emit_vivado_project_transform_pass diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py index 56ee07969..8aeeb663f 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py @@ -35,25 +35,30 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): """ # ! TO DO: currently emitting too many parameters + verilog_param_name = param_name.replace(".", "_") total_size = math.prod( - node.meta["mase"].parameters["common"]["args"][param_name]["shape"] + node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"] ) # TO DO: change setting parallelism for weight in metadata # node.meta["mase"].parameters["hardware"]["verilog_param"][f"{_cap(param_name)}_PARALLELISM_DIM_1"] out_size = int( node.meta["mase"].parameters["hardware"]["verilog_param"][ - f"{_cap(param_name)}_PARALLELISM_DIM_0" + f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0" + ] + * node.meta["mase"].parameters["hardware"]["verilog_param"][ + f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1" ] - * 4 ) out_depth = int(total_size / out_size) out_width = int( - node.meta["mase"].parameters["common"]["args"][param_name]["precision"][0] + node.meta["mase"].parameters["common"]["args"][verilog_param_name]["precision"][ + 0 + ] ) addr_width = clog2(out_depth) + 1 - node_param_name = f"{vf(node.name)}_{param_name}" + node_param_name = f"{vf(node.name)}_{verilog_param_name}" rom_verilog = f""" // ===================================== @@ -114,19 +119,19 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): `timescale 1ns / 1ps module {node_param_name}_source #( - parameter {_cap(param_name)}_TENSOR_SIZE_DIM_0 = 32, - parameter {_cap(param_name)}_TENSOR_SIZE_DIM_1 = 1, - parameter {_cap(param_name)}_PRECISION_0 = 16, - parameter {_cap(param_name)}_PRECISION_1 = 3, - - parameter {_cap(param_name)}_PARALLELISM_DIM_0 = 1, - parameter {_cap(param_name)}_PARALLELISM_DIM_1 = 1, - parameter OUT_DEPTH = {_cap(param_name)}_TENSOR_SIZE_DIM_0 / {_cap(param_name)}_PARALLELISM_DIM_0 + parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 = 32, + parameter {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_1 = 1, + parameter {_cap(verilog_param_name)}_PRECISION_0 = 16, + parameter {_cap(verilog_param_name)}_PRECISION_1 = 3, + + parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_0 = 1, + parameter {_cap(verilog_param_name)}_PARALLELISM_DIM_1 = 1, + parameter OUT_DEPTH = {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0 / {_cap(verilog_param_name)}_PARALLELISM_DIM_0 ) ( input clk, input rst, - output logic [{_cap(param_name)}_PRECISION_0-1:0] data_out [{_cap(param_name)}_PARALLELISM_DIM_0 * {_cap(param_name)}_PARALLELISM_DIM_1-1:0], + output logic [{_cap(verilog_param_name)}_PRECISION_0-1:0] data_out [{_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1-1:0], output data_out_valid, input data_out_ready ); @@ -146,9 +151,9 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): logic ce0; assign ce0 = 1; - logic [{_cap(param_name)}_PRECISION_0*{_cap(param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector; + logic [{_cap(verilog_param_name)}_PRECISION_0*{_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0-1:0] data_vector; {node_param_name} #( - .DATA_WIDTH({_cap(param_name)}_PRECISION_0 * {_cap(param_name)}_TENSOR_SIZE_DIM_0), + .DATA_WIDTH({_cap(verilog_param_name)}_PRECISION_0 * {_cap(verilog_param_name)}_TENSOR_SIZE_DIM_0), .ADDR_RANGE(OUT_DEPTH) ) {node_param_name}_mem ( .clk(clk), @@ -160,8 +165,8 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): // Cocotb/verilator does not support array flattening, so // we need to manually add some reshaping process. - for (genvar j = 0; j < {_cap(param_name)}_TENSOR_SIZE_DIM_0; j++) - assign data_out[j] = data_vector[{_cap(param_name)}_PRECISION_0*j+{_cap(param_name)}_PRECISION_0-1:{_cap(param_name)}_PRECISION_0*j]; + for (genvar j = 0; j < {_cap(verilog_param_name)}_PARALLELISM_DIM_0 * {_cap(verilog_param_name)}_PARALLELISM_DIM_1; j++) + assign data_out[j] = data_vector[{_cap(verilog_param_name)}_PRECISION_0*j+{_cap(verilog_param_name)}_PRECISION_0-1:{_cap(verilog_param_name)}_PRECISION_0*j]; assign data_out_valid = 1; @@ -170,42 +175,39 @@ def emit_parameters_in_mem_internal(node, param_name, file_name, data_name): with open(file_name, "w", encoding="utf-8") as outf: outf.write(rom_verilog) - logger.debug(f"ROM module {param_name} successfully written into {file_name}") + logger.debug( + f"ROM module {verilog_param_name} successfully written into {file_name}" + ) assert os.path.isfile(file_name), "ROM Verilog generation failed." - os.system(f"verible-verilog-format --inplace {file_name}") + # os.system(f"verible-verilog-format --inplace {file_name}") def emit_parameters_in_dat_internal(node, param_name, file_name): """ Emit initialised data for the ROM block. Each element must be in 8 HEX digits. """ + verilog_param_name = param_name.replace(".", "_") total_size = math.prod( - node.meta["mase"].parameters["common"]["args"][param_name]["shape"] + node.meta["mase"].parameters["common"]["args"][verilog_param_name]["shape"] ) - if "IN_DEPTH" in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): - if param_name == "bias": - out_depth = 1 - else: - out_depth = node.meta["mase"].parameters["hardware"]["verilog_param"][ - "IN_DEPTH" - ] - else: - out_depth = total_size - - out_size = iceil(total_size / out_depth) - # The depth of parameters must match with the input depth of data - assert ( - total_size % out_depth == 0 - ), f"Cannot partition imperfect size for now {node.name}.{param_name} = {total_size} / {out_depth}." - # Assume the first index is the total width - out_width = node.meta["mase"].parameters["common"]["args"][param_name]["precision"][ - 0 - ] + # TO DO: change setting parallelism for weight in metadata + # node.meta["mase"].parameters["hardware"]["verilog_param"][f"{_cap(param_name)}_PARALLELISM_DIM_1"] + out_size = int( + node.meta["mase"].parameters["hardware"]["verilog_param"][ + f"{_cap(verilog_param_name)}_PARALLELISM_DIM_0" + ] + * node.meta["mase"].parameters["hardware"]["verilog_param"][ + f"{_cap(verilog_param_name)}_PARALLELISM_DIM_1" + ] + ) + out_depth = int(total_size / out_size) data_buff = "" param_data = node.meta["mase"].module.get_parameter(param_name).data - if node.meta["mase"].parameters["hardware"]["interface"][param_name]["transpose"]: + if node.meta["mase"].parameters["hardware"]["interface"][verilog_param_name][ + "transpose" + ]: param_data = torch.reshape( param_data, ( @@ -223,11 +225,14 @@ def emit_parameters_in_dat_internal(node, param_name, file_name): param_data = torch.transpose(param_data, 0, 1) param_data = torch.flatten(param_data).tolist() - if node.meta["mase"].parameters["common"]["args"][param_name]["type"] == "fixed": - width = node.meta["mase"].parameters["common"]["args"][param_name]["precision"][ - 0 - ] - frac_width = node.meta["mase"].parameters["common"]["args"][param_name][ + if ( + node.meta["mase"].parameters["common"]["args"][verilog_param_name]["type"] + == "fixed" + ): + width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][ + "precision" + ][0] + frac_width = node.meta["mase"].parameters["common"]["args"][verilog_param_name][ "precision" ][1] @@ -328,15 +333,22 @@ def emit_bram_handshake(node, rtl_dir): """ node_name = vf(node.name) for param_name, parameter in node.meta["mase"].module.named_parameters(): + param_verilog_name = param_name.replace(".", "_") if ( - node.meta["mase"].parameters["hardware"]["interface"][param_name]["storage"] + node.meta["mase"].parameters["hardware"]["interface"][param_verilog_name][ + "storage" + ] == "BRAM" ): logger.debug( - f"Emitting DAT file for node: {node_name}, parameter: {param_name}" + f"Emitting DAT file for node: {node_name}, parameter: {param_verilog_name}" + ) + verilog_name = os.path.join( + rtl_dir, f"{node_name}_{param_verilog_name}_source.sv" + ) + data_name = os.path.join( + rtl_dir, f"{node_name}_{param_verilog_name}_rom.dat" ) - verilog_name = os.path.join(rtl_dir, f"{node_name}_{param_name}_source.sv") - data_name = os.path.join(rtl_dir, f"{node_name}_{param_name}_rom.dat") emit_parameters_in_mem_internal(node, param_name, verilog_name, data_name) emit_parameters_in_dat_internal(node, param_name, data_name) else: @@ -432,7 +444,7 @@ def emit_parameters_in_mem_hls(node, param_name, file_name, data_name): outf.write(rom_verilog) logger.debug(f"ROM module {param_name} successfully written into {file_name}") assert os.path.isfile(file_name), "ROM Verilog generation failed." - os.system(f"verible-verilog-format --inplace {file_name}") + # os.system(f"verible-verilog-format --inplace {file_name}") def emit_bram_hls(node, rtl_dir): diff --git a/src/chop/passes/graph/transforms/verilog/emit_internal.py b/src/chop/passes/graph/transforms/verilog/emit_internal.py index 01574f2ea..63e27563b 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/src/chop/passes/graph/transforms/verilog/emit_internal.py @@ -38,6 +38,7 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") + and isinstance(node.meta["mase"].module.config, dict) and node.meta["mase"].module.config.get("name", "") == "logicnets" ): # LogicNets hardware is generated programmatically from a mase node @@ -50,18 +51,13 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): files = include_ip_to_project(node) rtl_dependencies = _append(rtl_dependencies, files) elif "INTERNAL_HLS" in node.meta["mase"].parameters["hardware"]["toolchain"]: - assert False, "Intenral HLS not implemented yet." + assert False, "Internal HLS not implemented yet." else: # QOL change to log a warning. May be useful for adding future modules to mase hardware. logger.warning(f"Node {node.name} has no toolchain specified. Skipping...") - hardware_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "..", - "..", - "..", - "..", - "..", - "mase_components", - ) + + import mase_components + + hardware_dir = mase_components.__path__[0] for f in rtl_dependencies: shutil.copy(os.path.join(hardware_dir, f), rtl_dir) diff --git a/src/chop/passes/graph/transforms/verilog/emit_tb.py b/src/chop/passes/graph/transforms/verilog/emit_tb.py index 948417394..ee4c05189 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_tb.py +++ b/src/chop/passes/graph/transforms/verilog/emit_tb.py @@ -1,4 +1,6 @@ -import math, time, os, logging, torch, glob, shutil +import logging, torch +from pathlib import Path +from textwrap import indent from chop.passes.graph.utils import vf, v2p, init_project from chop.nn.quantizers import ( @@ -13,12 +15,28 @@ import cocotb from mase_cocotb.testbench import Testbench from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor -from mase_cocotb.z_qlayers.tensor_cast import quantize_to_int + import dill import inspect +def _cap(name): + """ + capitalize a string + """ + return str(name).upper() + + +def _emit_cocotb_test(graph, pass_args={}): + + wait_time = pass_args.get("wait_time", 2) + wait_unit = pass_args.get("wait_units", "ms") + batch_size = pass_args.get("batch_size", 1) + + test_template = f""" +import cocotb + @cocotb.test() async def test(dut): from pathlib import Path @@ -31,21 +49,13 @@ async def test(dut): await tb.initialize() - in_tensors = tb.generate_inputs(batches=3) + in_tensors = tb.generate_inputs(batches={batch_size}) exp_out = tb.model(*list(in_tensors.values())) tb.load_drivers(in_tensors) tb.load_monitors(exp_out) - await Timer(10000, units="us") - tb.end_checks() - - -def _emit_cocotb_test(graph): - test_template = f""" -import cocotb - -{inspect.getsource(test)} + await tb.wait_end(timeout={wait_time}, timeout_unit="{wait_unit}") """ tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb" @@ -60,26 +70,34 @@ def __init__(self, dut, fail_on_checks=True): super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks) # Instantiate as many drivers as required inputs to the model - for arg in graph.meta["mase"]["common"]["args"].keys(): - self.input_drivers.append( - StreamDriver( + self.input_drivers = {} + self.output_monitors = {} + + for node in graph.nodes_in: + for arg in node.meta["mase"]["common"]["args"].keys(): + if "data_in" not in arg: + continue + self.input_drivers[arg] = StreamDriver( dut.clk, getattr(dut, arg), getattr(dut, f"{arg}_valid"), getattr(dut, f"{arg}_ready"), ) - ) + self.input_drivers[arg].log.setLevel(logging.DEBUG) # Instantiate as many monitors as required outputs - for result in graph.meta["mase"]["common"]["results"].keys(): - self.output_monitors.append( - StreamMonitor( + for node in graph.nodes_out: + for result in node.meta["mase"]["common"]["results"].keys(): + if "data_out" not in result: + continue + self.output_monitors[result] = StreamMonitor( dut.clk, getattr(dut, result), getattr(dut, f"{result}_valid"), getattr(dut, f"{result}_ready"), + check=False, ) - ) + self.output_monitors[result].log.setLevel(logging.DEBUG) self.model = graph.model @@ -98,39 +116,79 @@ def generate_inputs(self, batches): :return: a dictionary of input arguments and their corresponding tensors :rtype: Dict """ + # ! TO DO: iterate through graph.args instead to generalize inputs = {} - for arg, arg_info in graph.meta["mase"]["common"]["args"].items(): - # Batch dimension always set to 1 in metadata - inputs[arg] = torch.rand(([batches] + arg_info["shape"][1:])) + for node in graph.nodes_in: + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): + # Batch dimension always set to 1 in metadata + if "data_in" not in arg: + continue + # print(f"Generating data for node {node}, arg {arg}: {arg_info}") + inputs[f"{arg}"] = torch.rand(([batches] + arg_info["shape"][1:])) return inputs def load_drivers(self, in_tensors): - for arg_idx, arg_batches in enumerate(in_tensors.values()): + for arg, arg_batches in in_tensors.items(): # Quantize input tensor according to precision if len(self.input_precision) > 1: - arg_batches = integer_quantizer( - arg_batches, - width=self.input_precision[0], - frac_width=self.input_precision[1], + from mase_cocotb.utils import fixed_preprocess_tensor + + in_data_blocks = fixed_preprocess_tensor( + tensor=arg_batches, + q_config={ + "width": self.get_parameter(f"{_cap(arg)}_PRECISION_0"), + "frac_width": self.get_parameter( + f"{_cap(arg)}_PRECISION_1" + ), + }, + parallelism=[ + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"), + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"), + ], ) - # Convert to integer equivalent of fixed point representation - arg_batches = (arg_batches * (2 ** self.input_precision[1])).int() - # Convert to input data blocks by reshaping to parallelism - in_data_blocks = arg_batches.reshape((-1, 4)).tolist() else: # TO DO: convert to integer equivalent of floating point representation pass # Append all input blocks to input driver + # ! TO DO: generalize + block_size = self.get_parameter( + "DATA_IN_0_PARALLELISM_DIM_0" + ) * self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1") + for block in in_data_blocks: - self.input_drivers[arg_idx].append(block) + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.input_drivers[arg].append(block) def load_monitors(self, expectation): - # TO DO: reshape according to output parallelism - output_blocks = expectation.reshape(-1, 4) + from mase_cocotb.utils import fixed_preprocess_tensor + + # Process the expectation tensor + output_blocks = fixed_preprocess_tensor( + tensor=expectation, + q_config={ + "width": self.get_parameter(f"DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter(f"DATA_OUT_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"), + ], + ) + + # Set expectation for each monitor for block in output_blocks: - self.output_monitors[-1].expect(block.tolist()) + # ! TO DO: generalize to multi-output models + if len(block) < self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"): + block = block + [0] * ( + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0") - len(block) + ) + self.output_monitors["data_out_0"].expect(block) + + # Drive the in-flight flag for each monitor + self.output_monitors["data_out_0"].in_flight = True # Serialize testbench object to be instantiated within test by cocotb runner cls_obj = MaseGraphTB @@ -155,15 +213,18 @@ def emit_cocotb_transform_pass(graph, pass_args={}): - pass_args - project_dir -> str : the directory of the project + - trace -> bool : trace waves in the simulation """ logger.info("Emitting testbench...") project_dir = ( - pass_args["project_dir"] if "project_dir" in pass_args.keys() else "top" + pass_args["project_dir"] + if "project_dir" in pass_args.keys() + else Path.home() / ".mase" / "top" ) init_project(project_dir) - _emit_cocotb_test(graph) + _emit_cocotb_test(graph, pass_args=pass_args) _emit_cocotb_tb(graph) return graph, None diff --git a/src/chop/passes/graph/transforms/verilog/emit_top.py b/src/chop/passes/graph/transforms/verilog/emit_top.py index 16c6897f8..9e27bfb82 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_top.py +++ b/src/chop/passes/graph/transforms/verilog/emit_top.py @@ -5,6 +5,7 @@ import time from multiprocessing import Process, Queue +import torch.fx as fx from chop.passes.graph.utils import vf, v2p, init_project import mase_components.activations.test.generate_memory as gen_lut import torch.nn as nn @@ -58,6 +59,17 @@ def param_needs_signals(node, param, value, qualifier="data_in"): ) +def is_real_input_arg(node, arg_idx): + return ( + # module parameter arguments are appended after fx args + arg_idx < len(node.args) + # Drop None arguments + and isinstance(node.args[arg_idx], fx.Node) + # Drop arguments that are inputs to this node, but not the whole graph + and node.args[arg_idx].op == "placeholder" + ) + + # ============================================================================= # Verilog parameters # ============================================================================= @@ -113,16 +125,19 @@ def emit(self, graph, parameter_map): i = 0 for node in nodes_in: node_name = vf(node.name) - for arg in node.meta["mase"].parameters["common"]["args"].keys(): - if "data_in" in arg: + for arg_idx, arg in enumerate( + node.meta["mase"].parameters["common"]["args"].keys() + ): + if is_real_input_arg(node, arg_idx): + # if "data_in" in arg: arg_name = _cap(arg) parallelism_params = [ param for param in parameter_map - if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + if param.startswith(f"{arg_name}_PARALLELISM_DIM") ] interface += f""" - input [{node_name}_{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0], + input [{arg_name}_PRECISION_0-1:0] data_in_{i} [{'*'.join(parallelism_params)}-1:0], input data_in_{i}_valid, output data_in_{i}_ready,""" i += 1 @@ -136,10 +151,10 @@ def emit(self, graph, parameter_map): parallelism_params = [ param for param in parameter_map - if f"{node_name}_{result_name}_PARALLELISM_DIM" in param + if param.startswith(f"{result_name}_PARALLELISM_DIM") ] interface += f""" - output [{node_name}_{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0], + output [{result_name}_PRECISION_0-1:0] data_out_{i} [{'*'.join(parallelism_params)}-1:0], output data_out_{i}_valid, input data_out_{i}_ready,""" i += 1 @@ -178,6 +193,12 @@ def _emit_signals_top_internal(self, node, parameter_map): for param in parameter_map if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param ] + + # Getitem argument always get mapped to port 0 irrespective of + # actual argument index + if node.meta["mase"]["common"]["mase_op"] == "getitem": + arg = "data_in_0" + signals += f""" logic [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0]; logic {node_name}_{arg}_valid; @@ -302,7 +323,7 @@ def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): parameters = "" for param in node.meta["mase"].parameters["hardware"]["verilog_param"].keys(): if f"{_cap(key)}_" in param: - parameters += f".{param}({node_name}_{param}),\n" + parameters += f" .{param}({node_name}_{param}),\n" parameters = _remove_last_comma(parameters) return f""" @@ -317,6 +338,25 @@ def _emit_module_parameters_top_internal(self, key, value, node, parameter_map): ); """ + def _emit_getitem_signals(self, node): + """ + Getitem nodes have arg list like (None, None, None, Arg, None, None) + where the meaningful arg is at an arbitrary index, but always maps to + data_in_0 interface of the hardware + """ + + node_name = vf(node.name) + + return f""" + .data_in_0 ({node_name}_data_in_0), + .data_in_0_valid ({node_name}_data_in_0_valid), + .data_in_0_ready ({node_name}_data_in_0_ready), + + .data_out_0 ({node_name}_data_out_0), + .data_out_0_valid ({node_name}_data_out_0_valid), + .data_out_0_ready ({node_name}_data_out_0_ready), + """ + def emit(self, node, parameter_map): node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] @@ -327,30 +367,39 @@ def emit(self, node, parameter_map): for key, value in ( node.meta["mase"].parameters["hardware"]["verilog_param"].items() ): + if value is None: + continue key_value = parameter_map[f"{node_name}_{key}"] debug_info = f"// = {key_value}" parameters += f""" .{key}({node_name}_{key}), {debug_info}\n""" parameters = _remove_last_comma(parameters) - # Emit component instantiation input signals - for key, value in node.meta["mase"].parameters["common"]["args"].items(): - if "data" not in key: - continue - signals += f""" + # Handle getitem nodes separately since an arbitrary argument index + # will always be mapped to data_in_0 interface of the hardware + if node.meta["mase"]["common"]["mase_op"] == "getitem": + signals += self._emit_getitem_signals(node) + + # All other node types + else: + # Emit component instantiation input signals + for key, value in node.meta["mase"].parameters["common"]["args"].items(): + if "inplace" in key or not isinstance(value, dict): + continue + signals += f""" .{key}({node_name}_{key}), .{key}_valid({node_name}_{key}_valid), .{key}_ready({node_name}_{key}_ready), - """ + """ - # Emit component instantiation output signals - for key, value in node.meta["mase"].parameters["common"]["results"].items(): - if "data" not in key: - continue - signals += f""" + # Emit component instantiation output signals + for key, value in node.meta["mase"].parameters["common"]["results"].items(): + signals += f""" .{key}({node_name}_{key}), .{key}_valid({node_name}_{key}_valid), .{key}_ready({node_name}_{key}_ready), - """ + """ + + # Remove final comma in signal list signals = _remove_last_comma(signals) # Combine component instantiation @@ -534,12 +583,14 @@ def _emit_top_wires(self): i = 0 for node in nodes_in: node_name = vf(node.name) - for arg in node.meta["mase"].parameters["common"]["args"].keys(): - if "data_in" in arg: + for arg_idx, arg in enumerate( + node.meta["mase"].parameters["common"]["args"].keys() + ): + if is_real_input_arg(node, arg_idx): wires += f""" - assign data_in_{i}_ready = {node_name}_{arg}_ready; - assign {node_name}_{arg}_valid = data_in_{i}_valid; - assign {node_name}_{arg} = data_in_{i}; +assign data_in_{i}_ready = {node_name}_{arg}_ready; +assign {node_name}_{arg}_valid = data_in_{i}_valid; +assign {node_name}_{arg} = data_in_{i}; """ i += 1 i = 0 @@ -548,9 +599,9 @@ def _emit_top_wires(self): for result in node.meta["mase"].parameters["common"]["results"].keys(): if "data_out" in result: wires += f""" - assign data_out_{i}_valid = {node_name}_{result}_valid; - assign {node_name}_{result}_ready = data_out_{i}_ready; - assign data_out_{i} = {node_name}_{result}; +assign data_out_{i}_valid = {node_name}_{result}_valid; +assign {node_name}_{result}_ready = data_out_{i}_ready; +assign data_out_{i} = {node_name}_{result}; """ i += 1 @@ -558,27 +609,49 @@ def _emit_top_wires(self): return wires + def _emit_getitem_wires(self, node): + """ + Getitem nodes may receive an output from an arbitrary index of the parent node, + which is always driven to port 0 of the getitem node + """ + + from_name = vf(node.args[0].name) + to_name = vf(node.name) + select = node.args[1] + + return f""" +assign {from_name}_data_out_{select}_ready = {to_name}_data_in_0_ready; +assign {to_name}_data_in_0_valid = {from_name}_data_out_{select}_valid; +assign {to_name}_data_in_0 = {from_name}_data_out_{select}; +""" + def _emit_node2node_wires(self): nodes_in = self.graph.nodes_in - # Ignore the input of the input nodes - # (as they are already connected by the previous process) - # For each other explicit node, emit the edge of their inputs. - # Assume all the node has only one output. wires = "" for node in self.graph.fx_graph.nodes: - if node.meta["mase"].parameters["hardware"]["is_implicit"]: + + if ( + # Skip implicit nodes + node.meta["mase"].parameters["hardware"]["is_implicit"] + # Input nodes were already connected by the previous process + or node in nodes_in + ): continue - if node in nodes_in: + + # Getitem nodes are handled separately + if node.meta["mase"]["common"]["mase_op"] == "getitem": + wires += self._emit_getitem_wires(node) continue to_name = vf(node.name) + for i, node_in in enumerate(node.all_input_nodes): from_name = vf(node_in.name) wires += f""" - assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready; - assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid; - assign {to_name}_data_in_{i} = {from_name}_data_out_0; +assign {from_name}_data_out_0_ready = {to_name}_data_in_{i}_ready; +assign {to_name}_data_in_{i}_valid = {from_name}_data_out_0_valid; +assign {to_name}_data_in_{i} = {from_name}_data_out_0; """ return wires diff --git a/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py new file mode 100644 index 000000000..09e6ef3f7 --- /dev/null +++ b/src/chop/passes/graph/transforms/verilog/emit_vivado_project.py @@ -0,0 +1,111 @@ +import os +import subprocess +from pathlib import Path + +from chop.passes.graph.utils import init_project +from chop.tools import get_logger, set_logging_verbosity +import mase_components +from mase_components.deps import MASE_HW_DEPS + +logger = get_logger(f"emit_vivado_project") +set_logging_verbosity("debug") + +COMPONENTS_PATH = Path(mase_components.__file__).parents[0] + + +def generate_tcl_script(top_name, vivado_project_path, include_groups, project_dir): + logger.info( + f"Writing Vivado project generation script: {vivado_project_path}/build.tcl" + ) + + tcl_script_template = f""" +set_param board.repoPaths {{{str(Path.home())}/shared/board-files}} +create_project {top_name}_build_project {vivado_project_path} -part xcu280-fsvh2892-2L-e +set_property board_part xilinx.com:au280:part0:1.1 [current_project] +""" + for include_group in include_groups: + tcl_script_template += f"""\nadd_files {include_group}""" + + tcl_script_template += f"\n\nset_property top top [current_fileset]" + + tcl_script_template += f""" +update_compile_order -fileset sources_1 +""" + + # * Package IP + tcl_script_template += f""" +ipx::package_project -root_dir {project_dir}/hardware/ip_repo -vendor user.org -library user -taxonomy /UserIP -import_files +ipx::create_xgui_files [ipx::current_core] +ipx::update_checksums [ipx::current_core] +ipx::check_integrity [ipx::current_core] +ipx::save_core [ipx::current_core] +set_property ip_repo_paths {project_dir}/hardware/ip_repo [current_project] +update_ip_catalog +""" + + with open(f"{vivado_project_path}/build.tcl", "w") as file: + file.write(tcl_script_template) + + +def emit_vivado_project_transform_pass(graph, pass_args={}): + """Emit the Vivado project containing the generated Verilog and all required IPs + + :param graph: a MaseGraph + :type graph: MaseGraph + :param pass_args: this pass requires additional arguments which is explained below, defaults to {} + :type pass_args: _type_, optional + :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) + :rtype: tuple(MaseGraph, Dict) + + + - pass_args + - project_dir -> str : the directory of the project for cosimulation + - top_name -> str : top-level name + """ + + # * Check if Vivado is available by running the command + try: + subprocess.run(["vivado", "-version"], capture_output=True, text=True) + except: + logger.warning( + "Vivado is not available, skipping emit_vivado_project_transform_pass." + ) + return graph, {} + + logger.info("Emitting Vivado project...") + + # Create project directory, and the verilog is emmited to {project_name}/hardware/rtl + project_dir = ( + pass_args["project_dir"] + if "project_dir" in pass_args.keys() + else Path.home() / ".mase" / "top" + ) + top_name = pass_args["top_name"] if "top_name" in pass_args.keys() else "top" + init_project(project_dir) + vivado_project_path = os.path.join( + project_dir, "hardware", f"{top_name}_build_project" + ) + os.makedirs(vivado_project_path, exist_ok=True) + + # * List include files + include_groups = [ + f"{COMPONENTS_PATH / group / 'rtl'}" + for group in mase_components.get_modules() + if group != "vivado" + ] + [project_dir / "hardware" / "rtl"] + + generate_tcl_script(top_name, vivado_project_path, include_groups, project_dir) + + logger.info(f"Emitting Vivado project at: {vivado_project_path}") + cmd = [ + "vivado", + "-mode", + "batch", + "-log", + f"{vivado_project_path}/project_build.log", + "-source", + f"{vivado_project_path}/build.tcl", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + + return graph, {} diff --git a/src/chop/passes/graph/transforms/verilog/util.py b/src/chop/passes/graph/transforms/verilog/util.py index 5ddbdee6a..35abe9560 100644 --- a/src/chop/passes/graph/transforms/verilog/util.py +++ b/src/chop/passes/graph/transforms/verilog/util.py @@ -15,6 +15,8 @@ def get_verilog_parameters(graph): for key, value in ( node.meta["mase"].parameters["hardware"]["verilog_param"].items() ): + if value is None: + continue if not isinstance(value, (int, float, complex, bool)): value = '"' + value + '"' assert ( @@ -22,6 +24,14 @@ def get_verilog_parameters(graph): ), f"{node_name}_{key} already exists in the parameter map" parameter_map[f"{node_name}_{key}"] = value + # * Return graph level parameters + for node in graph.nodes_in + graph.nodes_out: + for key, value in ( + node.meta["mase"].parameters["hardware"]["verilog_param"].items() + ): + if "DATA_IN" in key or "DATA_OUT" in key: + parameter_map[key] = value + return parameter_map @@ -30,7 +40,4 @@ def include_ip_to_project(node): Copy internal files to the project """ mase_op = node.meta["mase"].parameters["common"]["mase_op"] - assert ( - mase_op in INTERNAL_COMP - ), f"Cannot find mase op {mase_op} in internal components" - return INTERNAL_COMP[mase_op][0]["dependence_files"] + return node.meta["mase"].parameters["hardware"]["dependence_files"] diff --git a/src/chop/passes/graph/utils.py b/src/chop/passes/graph/utils.py index aeda4402a..ebeceba03 100644 --- a/src/chop/passes/graph/utils.py +++ b/src/chop/passes/graph/utils.py @@ -6,6 +6,7 @@ import torch from pathlib import Path +from functools import reduce def check_func_type(node, my_func): @@ -156,3 +157,25 @@ def init_project(project_dir): Path(hardware_dir / "test" / "mase_top_tb").mkdir(parents=True, exist_ok=True) Path(hardware_dir / "test").mkdir(parents=True, exist_ok=True) Path(hardware_dir / "hls").mkdir(parents=True, exist_ok=True) + + +def sign_extend(value: int, bits: int): + sign_bit = 1 << (bits - 1) + return (value & (sign_bit - 1)) - (value & sign_bit) + + +def deepgetattr(obj, attr): + """Recurses through an attribute chain to get the ultimate value.""" + return reduce(getattr, attr.split("."), obj) + + +def deepsetattr(obj, attr_str, value): + """Recurses through an attribute chain to set the ultimate value.""" + attrs = attr_str.split(".") + if len(attrs) == 1: + setattr(obj, attrs[0], value) + else: + first_attr = attrs.pop(0) + if not hasattr(obj, first_attr): + setattr(obj, first_attr, {}) + deepsetattr(getattr(obj, first_attr), ".".join(attrs), value) diff --git a/src/chop/passes/module/__init__.py b/src/chop/passes/module/__init__.py index 400d5122d..946827339 100644 --- a/src/chop/passes/module/__init__.py +++ b/src/chop/passes/module/__init__.py @@ -1,5 +1,5 @@ from .analysis import calculate_avg_bits_module_analysis_pass -from .transforms import quantize_module_transform_pass +from .transforms import quantize_module_transform_pass, resharding_transform_pass ANALYSIS_PASSES = ["calculate_avg_bits_module_analysis_pass"] diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 3fcc8c5b3..754c4f0ce 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1 +1,2 @@ from .quantize import quantize_module_transform_pass +from .autosharding import resharding_transform_pass \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/__init__.py b/src/chop/passes/module/transforms/autosharding/__init__.py new file mode 100644 index 000000000..95de26f74 --- /dev/null +++ b/src/chop/passes/module/transforms/autosharding/__init__.py @@ -0,0 +1,2 @@ + +from .resharding import resharding_transform_pass \ No newline at end of file diff --git a/src/chop/passes/module/transforms/autosharding/resharding.py b/src/chop/passes/module/transforms/autosharding/resharding.py new file mode 100644 index 000000000..ecfc8ad78 --- /dev/null +++ b/src/chop/passes/module/transforms/autosharding/resharding.py @@ -0,0 +1,75 @@ + +import functools + +import torch +import torch.nn as nn + +from torch.distributed._tensor import ( + DeviceMesh, +) + +from torch.distributed._tensor.api import Redistribute + +from chop.distributed.utils import placement_from_sharding_config, rlog +from chop.tools import get_logger + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +def deepsetattr(obj, attr, value): + """Recurses through an attribute chain to set the ultimate value.""" + attrs = attr.split(".") + if len(attrs) > 1: + deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) + else: + setattr(obj, attr, value) + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default + +class ReshardingWrapper(nn.Module): + def __init__(self, device_mesh, module, resharding_config): + super().__init__() + self.module = module + self.resharding_config = resharding_config + self.device_mesh = device_mesh + + def forward(self, x): + rank = torch.distributed.get_rank() + device_mesh = DeviceMesh("cuda", self.device_mesh) + + required_placement = placement_from_sharding_config(self.resharding_config["data_in_0"]) + if (x.placements != required_placement): + rlog(logger, rank, f"For module {self.module}, resharding tensor x from {x.placements} to {required_placement}", level="debug") + x = Redistribute.apply(x, device_mesh, required_placement) + + return self.module(x) + +def resharding_transform_pass(mg, pass_args={}): + """ + This pass inserts a wrapper around each module in the graph to handle resharding + activation tensors when the output of the previous module has a different sharding + profile to the one assigned to the current module. + """ + + module_map = pass_args.get("module_map", None) + device_mesh = pass_args.get("device_mesh", None) + if module_map is None or device_mesh is None: + raise ValueError("module_map and device_mesh are required for resharding_transform_pass") + + for node in mg.fx_graph.nodes: + if node.op != "call_module": + continue + module = deepgetattr(mg.model, node.target, None) + if module is not None: + resharding_config = module_map[module] + logger.info(f"Inserting resharding wrapper around node: {node}") + deepsetattr(mg.model, node.target, ReshardingWrapper(device_mesh, module, resharding_config)) + + mg.model.recompile() + + return mg, {} \ No newline at end of file diff --git a/src/chop/pipelines/__init__.py b/src/chop/pipelines/__init__.py new file mode 100644 index 000000000..8ae093986 --- /dev/null +++ b/src/chop/pipelines/__init__.py @@ -0,0 +1,2 @@ +from .auto_pipeline import AutoPipeline +from .emit_verilog import AutoPipelineForEmitVerilog diff --git a/src/chop/pipelines/auto_pipeline.py b/src/chop/pipelines/auto_pipeline.py new file mode 100644 index 000000000..d62f15720 --- /dev/null +++ b/src/chop/pipelines/auto_pipeline.py @@ -0,0 +1,18 @@ +from chop.ir import MaseGraph +from chop.tools.logger import get_logger + +logger = get_logger(__name__) + + +class AutoPipeline: + def __init__(self, pass_list=[]) -> None: + self.pass_list = pass_list + + def __call__(self, mg: MaseGraph, pass_args: dict, skip_passes: list = []): + for pass_fn in self.pass_list: + if pass_fn in skip_passes: + logger.debug(f"Skipping pass: {pass_fn.__name__}") + continue + logger.debug(f"Running pass: {pass_fn.__name__}") + args = pass_args.get(pass_fn.__name__, {}) + mg, _ = pass_fn(mg, pass_args=args) diff --git a/src/chop/pipelines/emit_verilog.py b/src/chop/pipelines/emit_verilog.py new file mode 100644 index 000000000..437526a49 --- /dev/null +++ b/src/chop/pipelines/emit_verilog.py @@ -0,0 +1,23 @@ +import chop.passes as passes + +from .auto_pipeline import AutoPipeline + + +class AutoPipelineForEmitVerilog(AutoPipeline): + def __init__(self) -> None: + + pass_list = [ + passes.init_metadata_analysis_pass, + passes.report_graph_analysis_pass, + passes.add_common_metadata_analysis_pass, + passes.patch_metadata_transform_pass, + passes.add_hardware_metadata_analysis_pass, + passes.report_node_meta_param_analysis_pass, + passes.emit_verilog_top_transform_pass, + passes.emit_bram_transform_pass, + passes.emit_internal_rtl_transform_pass, + passes.emit_cocotb_transform_pass, + passes.emit_vivado_project_transform_pass, + ] + + super().__init__(pass_list) diff --git a/src/chop/tools/__init__.py b/src/chop/tools/__init__.py index 5d7787b4d..47b22b690 100644 --- a/src/chop/tools/__init__.py +++ b/src/chop/tools/__init__.py @@ -9,3 +9,16 @@ from .logger import root_logger, set_logging_verbosity, get_logger from .get_input import get_cf_args, get_dummy_input + +from .utils import ( + set_excepthook, + deepsetattr, + deepgetattr, + get_checkpoint_file, + copy_weights, + to_numpy, + to_numpy_if_tensor, + to_tensor, + to_tensor_if_numpy, + is_tensor, +) diff --git a/src/chop/tools/utils.py b/src/chop/tools/utils.py index 7dd292e82..64922b3d0 100644 --- a/src/chop/tools/utils.py +++ b/src/chop/tools/utils.py @@ -1,9 +1,8 @@ import numpy as np import os -import pickle import torch +import functools -import colorlog import torch import subprocess @@ -256,3 +255,29 @@ def parse_accelerator(accelerator: str): else: raise RuntimeError(f"Unsupported accelerator {accelerator}") return device + + +def set_excepthook(): + import sys, pdb, traceback + + def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + + sys.excepthook = excepthook + +def deepsetattr(obj, attr, value): + """Recurses through an attribute chain to set the ultimate value.""" + attrs = attr.split(".") + if len(attrs) > 1: + deepsetattr(getattr(obj, attrs[0]), '.'.join(attrs[1:]), value) + else: + setattr(obj, attr, value) + +def deepgetattr(obj, attr, default=None): + """Recurses through an attribute chain to get the ultimate value.""" + try: + return functools.reduce(getattr, attr.split("."), obj) + except AttributeError: + return default \ No newline at end of file diff --git a/src/mase_cocotb/interfaces/streaming.py b/src/mase_cocotb/interfaces/streaming.py index 9f9ab6419..325783407 100644 --- a/src/mase_cocotb/interfaces/streaming.py +++ b/src/mase_cocotb/interfaces/streaming.py @@ -7,7 +7,10 @@ from mase_cocotb.driver import Driver from mase_cocotb.monitor import Monitor -from mase_cocotb.utils import sign_extend + +def _sign_extend(value: int, bits: int): + sign_bit = 1 << (bits - 1) + return (value & (sign_bit - 1)) - (value & sign_bit) class StreamDriver(Driver): @@ -23,17 +26,28 @@ def set_valid_prob(self, prob): assert prob >= 0.0 and prob <= 1.0 self.valid_prob = prob - async def _driver_send(self, data) -> None: + async def _driver_send(self, transaction) -> None: while True: await RisingEdge(self.clk) - self.data.value = data + if type(self.data) == tuple: + # Drive multiple data bus + for wire, val in zip(self.data, transaction): + wire.value = val + else: + # Drive single data + self.data.value = transaction if random.random() > self.valid_prob: self.valid.value = 0 continue # Try roll random valid again at next clock self.valid.value = 1 await ReadOnly() if self.ready.value == 1: - self.log.debug("Sent %s" % data) + if type(self.data) == tuple: + # Drive multiple data bus + for t in transaction: + self.log.debug("Sent %s" % t) + else: + self.log.debug("Sent %s" % transaction) break if self.send_queue.empty(): @@ -42,7 +56,7 @@ async def _driver_send(self, data) -> None: class StreamMonitor(Monitor): - def __init__(self, clk, data, valid, ready, check=True, name=None): + def __init__(self, clk, data, valid, ready, check=True, name=None, unsigned=False): super().__init__(clk, check=check, name=name) self.clk = clk self.data = data @@ -50,18 +64,39 @@ def __init__(self, clk, data, valid, ready, check=True, name=None): self.ready = ready self.check = check self.name = name + self.unsigned = unsigned def _trigger(self): + if "x" in self.valid.value.binstr or "x" in self.ready.value.binstr: + return False return self.valid.value == 1 and self.ready.value == 1 def _recv(self): - if type(self.data.value) == list: - return [int(x) for x in self.data.value] - elif type(self.data.value) == BinaryValue: - return int(self.data.value) + + def _get_sig_value(sig): + + if type(sig.value) == list: + if self.unsigned: + return [x.integer for x in sig.value] + else: + return [x.signed_integer for x in sig.value] + + elif type(sig.value) == BinaryValue: + if self.unsigned: + return int(sig.value.integer) + else: + return int(sig.value.signed_integer) + + if type(self.data) == tuple: + # Multiple synchronised data signals + return tuple(_get_sig_value(s) for s in self.data) + else: + # Single data signal + return _get_sig_value(self.data) def _check(self, got, exp): - if self.check: + + def _check_sig(got, exp): if not np.equal(got, exp).all(): self.log.error( "%s: \nGot \n%s, \nExpected \n%s" @@ -71,7 +106,26 @@ def _check(self, got, exp): exp, ) ) - raise TestFailure("\nGot \n%s, \nExpected \n%s" % (got, exp)) + assert False, "Test Failed!" + else: + self.log.debug( + "Passed | %s: \nGot \n%s, \nExpected \n%s" + % ( + self.name if self.name != None else "Unnamed StreamMonitor", + got, + exp, + ) + ) + + if self.check: + if type(self.data) == tuple: + assert type(got) == tuple + assert type(exp) == tuple + assert len(got) == len(exp), "Got & Exp Tuples are different length" + for g, e in zip(got, exp): + _check_sig(g, e) + else: + _check_sig(got, exp) class StreamMonitorFloat(StreamMonitor): @@ -111,6 +165,7 @@ def __init__( self.signed = signed self.error_bits = error_bits self.error_log = [] if log_error else None + self.recv_log = [] if log_error else None self.log_error = log_error self.log.setLevel("INFO") @@ -124,11 +179,12 @@ def _check(self, got, exp): g = np.array(got) e = np.array(exp) if self.signed: - g = sign_extend(g, self.width) - e = sign_extend(e, self.width) + g = _sign_extend(g, self.width) + e = _sign_extend(e, self.width) err = np.abs(g - e) if self.log_error: self.error_log.append(err) + self.recv_log.append(got) max_biterr = np.full_like(err, self.error_bits) if not (err <= max_biterr).all(): self.log.error("Failed | Got: %20s Exp: %20s Err: %14s" % (g, e, err)) @@ -138,14 +194,18 @@ def _check(self, got, exp): elif type(got) == int: g, e = got, exp if self.signed: - g = sign_extend(g, self.width) - e = sign_extend(e, self.width) + g = _sign_extend(g, self.width) + e = _sign_extend(e, self.width) err = abs(g - e) if self.log_error: self.error_log.append(err) + self.recv_log.append(got) if not err <= self.error_bits: self.log.error("Failed | Got: %20s Exp: %20s Err: %10s" % (g, e, err)) assert fail, "Test Failed!" return - self.log.debug("Passed | Got: %20s Exp: %20s Err: %10s" % (g, e, err)) + else: + g, e = got, exp + err = np.abs(g - e) + self.log.debug("Passed | Got: %20s Exp: %20s Err: %10s" % (g, e, err)) diff --git a/src/mase_cocotb/matrix_tools.py b/src/mase_cocotb/matrix_tools.py index 6c500fa13..3a05b3db4 100644 --- a/src/mase_cocotb/matrix_tools.py +++ b/src/mase_cocotb/matrix_tools.py @@ -84,6 +84,7 @@ def matrix_mult_model( out_symmetric, a_input, b_input, + debug=False, ): A = rebuild_matrix( a_input, a_total_dim0, a_total_dim1, a_compute_dim0, a_compute_dim1 @@ -95,12 +96,13 @@ def matrix_mult_model( B_signed = sign_extend_t(B, b_width) C_signed = torch.matmul(A_signed, B_signed) - logger.debug("Matrix A") - logger.debug(A_signed) - logger.debug("Matrix B") - logger.debug(B_signed) - logger.debug("Matrix C") - logger.debug(C_signed) + if debug: + logger.debug("Matrix A") + logger.debug(A_signed) + logger.debug("Matrix B") + logger.debug(B_signed) + logger.debug("Matrix C") + logger.debug(C_signed) # Floor rounding acc_frac_width = a_frac_width + b_frac_width @@ -111,14 +113,16 @@ def matrix_mult_model( max_val = (2 ** (out_width - 1)) - 1 C_clamped = torch.clamp(C_signed, min_val, max_val) - logger.debug("Matrix C (clamp)") - logger.debug(C_clamped) + if debug: + logger.debug("Matrix C (clamp)") + logger.debug(C_clamped) # Changed to unsigned number C_unsigned_rep = signed_to_unsigned(C_clamped, out_width) - logger.debug("Matrix C (clamp -> unsigned)") - logger.debug(C_unsigned_rep) + if debug: + logger.debug("Matrix C (clamp -> unsigned)") + logger.debug(C_unsigned_rep) # Split into lists of ints return split_matrix( diff --git a/src/mase_cocotb/monitor.py b/src/mase_cocotb/monitor.py index 3b09afbe4..682dd2f6e 100644 --- a/src/mase_cocotb/monitor.py +++ b/src/mase_cocotb/monitor.py @@ -15,6 +15,7 @@ def __init__(self, clk, check=True, name=None): self.exp_queue = Queue() self.check = check self.name = name + self.in_flight = False if not hasattr(self, "log"): self.log = SimLog( @@ -42,13 +43,23 @@ async def _recv_thread(self): self.recv_queue.put(tr) if self.exp_queue.empty(): - raise TestFailure( - "\nGot \n%s,\nbut we did not expect anything." - % self.recv_queue.get() + assert False, ( + "Got %s but we did not expect anything." % self.recv_queue.get() ) self._check(self.recv_queue.get(), self.exp_queue.get()) + # * If the monitor is in-flight (expectation queue has been populated) + # * and the expectation queue is now empty (after running the check), + # * the test is finished + if ( + self.in_flight == True + and self.recv_queue.empty() + and self.exp_queue.empty() + ): + self.in_flight = False + self.log.info(f"Monitor has been drained.") + def _trigger(self): raise NotImplementedError() diff --git a/src/mase_cocotb/runner.py b/src/mase_cocotb/runner.py index fe03715db..ac61715d3 100644 --- a/src/mase_cocotb/runner.py +++ b/src/mase_cocotb/runner.py @@ -2,11 +2,16 @@ import logging from shutil import rmtree from pathlib import Path +from copy import deepcopy import re import inspect from typing import Any +from concurrent.futures import ProcessPoolExecutor, as_completed +from time import time + import torch +import cocotb from cocotb.runner import get_runner, get_results from mase_components.deps import MASE_HW_DEPS @@ -14,50 +19,30 @@ logger.setLevel("INFO") -def mase_runner( - module_param_list: list[dict[str, Any]] = [dict()], +def _single_test( + i: int, # id + deps: list[str], + module: str, + module_params: dict, + module_path: Path, + comp_path: Path, + test_work_dir: Path, extra_build_args: list[str] = [], - trace: bool = False, seed: int = None, + trace: bool = False, + skip_build: bool = False, ): - assert type(module_param_list) == list, "Need to pass in a list of dicts!" - - # Get file which called this function - # Should be of form components//test/_tb.py - test_filepath = inspect.stack()[1].filename - matches = re.search(r"mase_components/(\w*)/test/(\w*)_tb\.py", test_filepath) - assert matches != None, "Function only works when called from test" - group, module = matches.groups() - - # Group path is components/ - group_path = Path(test_filepath).parent.parent + print("# ---------------------------------------") + print(f"# Test {i}") + print("# ---------------------------------------") + print(f"# Parameters:") + print(f"# - {'Test Index'}: {i}") + for k, v in module_params.items(): + print(f"# - {k}: {v}") + print("# ---------------------------------------") - # Components path is components/ - comp_path = group_path.parent - - # Try to find RTL file: - # components//rtl/.py - module_path = group_path.joinpath("rtl").joinpath(f"{module}.sv") - assert path.exists(module_path), f"{module_path} does not exist." - - SIM = getenv("SIM", "verilator") - - deps = MASE_HW_DEPS[f"{group}/{module}"] - - total_tests = 0 - total_fail = 0 - - for i, module_params in enumerate(module_param_list): - print("# ---------------------------------------") - print(f"# Test {i+1}/{len(module_param_list)}") - print("# ---------------------------------------") - print(f"# Parameters:") - print(f"# - {'Test Index'}: {i}") - for k, v in module_params.items(): - print(f"# - {k}: {v}") - print("# ---------------------------------------") - test_work_dir = group_path.joinpath(f"test/build/{module}/test_{i}") - runner = get_runner(SIM) + runner = get_runner(getenv("SIM", "verilator")) + if not skip_build: runner.build( verilog_sources=[module_path], includes=[str(comp_path.joinpath(f"{d}/rtl/")) for d in deps], @@ -72,6 +57,7 @@ def mase_runner( # Simulation Optimisation "-Wno-UNOPTFLAT", "-prof-c", + "--assert", "--stats", # Signal trace in dump.fst *(["--trace-fst", "--trace-structs"] if trace else []), @@ -86,21 +72,159 @@ def mase_runner( parameters=module_params, build_dir=test_work_dir, ) - results_file = runner.test( + try: + runner.test( hdl_toplevel=module, - test_module=f"mase_components.{group}.test.{module}_tb", + hdl_toplevel_lang="verilog", + test_module=module + "_tb", seed=seed, results_xml="results.xml", + build_dir=test_work_dir, ) - logger.info(f"Results are at {results_file}") - num_tests, fail = get_results(results_file) - total_tests += num_tests - total_fail += fail + num_tests, fail = get_results(test_work_dir.joinpath("results.xml")) + except Exception as e: + print(f"Error occured while running Verilator simulation: {e}") + num_tests = fail = 1 + + return { + "num_tests": num_tests, + "failed_tests": fail, + "params": module_params, + } - print("TEST RESULTS") - print(" PASSED: %d" % (total_tests - total_fail)) - print(" FAILED: %d" % (total_fail)) - print(" NUM TESTS: %d" % (total_tests)) + +def mase_runner( + module_param_list: list[dict[str, Any]] = [dict()], + extra_build_args: list[str] = [], + trace: bool = False, + seed: int = None, + jobs: int = 1, + skip_build: bool = False, +): + assert type(module_param_list) == list, "Need to pass in a list of dicts!" + + start_time = time() + + # Get file which called this function + # Should be of form components//test/_tb.py + test_filepath = inspect.stack()[1].filename + matches = re.search(r"mase_components/(\w*)/test/(\w*)_tb\.py", test_filepath) + assert ( + matches != None + ), "Did not find file that matches _tb.py in the test folder!" + group, module = matches.groups() + + # Group path is components/ + group_path = Path(test_filepath).parent.parent + + # Components path is components/ + comp_path = group_path.parent + + # Try to find RTL file: + # components//rtl/.py + module_path = group_path.joinpath("rtl").joinpath(f"{module}.sv") + assert path.exists(module_path), f"{module_path} does not exist." + + deps = MASE_HW_DEPS[f"{group}/{module}"] + + total_tests = 0 + total_fail = 0 + passed_cfgs = [] + failed_cfgs = [] + + # Single threaded run + if jobs == 1: + + for i, module_params in enumerate(module_param_list): + test_work_dir = group_path.joinpath(f"test/build/{module}/test_{i}") + results = _single_test( + i=i, + deps=deps, + module=module, + module_params=module_params, + module_path=module_path, + comp_path=comp_path, + test_work_dir=test_work_dir, + extra_build_args=extra_build_args, + seed=seed, + trace=trace, + skip_build=skip_build, + ) + total_tests += results["num_tests"] + total_fail += results["failed_tests"] + if results["failed_tests"]: + failed_cfgs.append((i, module_params)) + else: + passed_cfgs.append((i, module_params)) + + # Multi threaded run + else: + with ProcessPoolExecutor(max_workers=jobs) as executor: + # TODO: add timeout + future_to_job_meta = {} + for i, module_params in enumerate(module_param_list): + test_work_dir = group_path.joinpath(f"test/build/{module}/test_{i}") + future = executor.submit( + _single_test, + i=i, + deps=deps, + module=module, + module_params=module_params, + module_path=module_path, + comp_path=comp_path, + test_work_dir=test_work_dir, + extra_build_args=extra_build_args, + seed=seed, + trace=trace, + ) + future_to_job_meta[future] = { + "id": i, + "params": deepcopy(module_params), + } + + # Wait for futures to complete + for future in as_completed(future_to_job_meta): + meta = future_to_job_meta[future] + id = meta["id"] + params = meta["params"] + try: + result = future.result() + except Exception as exc: + print("Test %r generated an exception: %s" % (id, exc)) + else: + print("Test %r is done. Result: %s" % (id, result)) + total_tests += result["num_tests"] + total_fail += result["failed_tests"] + if result["failed_tests"]: + failed_cfgs.append((id, params)) + else: + passed_cfgs.append((id, params)) + + print("# ---------------------------------------") + print("# Test Results") + print("# ---------------------------------------") + print("# - Time elapsed: %.2f seconds" % (time() - start_time)) + print("# - Jobs: %d" % (jobs)) + print("# - Passed: %d" % (total_tests - total_fail)) + print("# - Failed: %d" % (total_fail)) + print("# - Total : %d" % (total_tests)) + print("# ---------------------------------------") + + if len(passed_cfgs): + passed_cfgs = sorted(passed_cfgs, key=lambda t: t[0]) + print(f"# Passed Configs") + print("# ---------------------------------------") + for i, params in passed_cfgs: + print(f"# - test_{i}: {params}") + print("# ---------------------------------------") + + if len(failed_cfgs): + failed_cfgs = sorted(failed_cfgs, key=lambda t: t[0]) + print(f"# Failed Configs") + print("# ---------------------------------------") + for i, params in failed_cfgs: + print(f"# - test_{i}: {params}") + print("# ---------------------------------------") return total_fail @@ -135,6 +259,7 @@ def simulate_pass( *(["--trace-fst", "--trace-structs"] if trace else []), "-prof-c", "--stats", + "--assert", "-O2", "-build-jobs", "8", diff --git a/src/mase_cocotb/testbench.py b/src/mase_cocotb/testbench.py index 0ca86d307..fa76983e5 100644 --- a/src/mase_cocotb/testbench.py +++ b/src/mase_cocotb/testbench.py @@ -1,6 +1,7 @@ import cocotb from cocotb.triggers import * from cocotb.clock import Clock +from cocotb.utils import get_sim_time class Testbench: @@ -11,23 +12,31 @@ def __init__(self, dut, clk=None, rst=None, fail_on_checks=True) -> None: self.clk = clk self.rst = rst - self.input_drivers = [] - self.output_monitors = [] + self.input_drivers = {} + self.output_monitors = {} self.input_precision = [32] self.fail_on_checks = fail_on_checks - if self.clk != None: + if self.clk is not None: self.clock = Clock(self.clk, 20, units="ns") cocotb.start_soon(self.clock.start()) def assign_self_params(self, attrs): for att in attrs: - setattr(self, att, getattr(self.dut, att).value) + setattr(self, att, int(getattr(self.dut, att).value)) + + def get_parameter(self, parameter_name): + parameter = getattr(self.dut, parameter_name) + return int(parameter) + + def get_parameter(self, parameter_name): + parameter = getattr(self.dut, parameter_name) + return int(parameter) async def reset(self, active_high=True): - if self.rst == None: + if self.rst is None: raise Exception( "Cannot reset. Either a reset wire was not provided or " + "the module does not have a reset." @@ -43,7 +52,7 @@ async def initialize(self): await self.reset() # Set all monitors ready - for monitor in self.output_monitors: + for monitor in self.output_monitors.values(): monitor.ready.value = 1 def generate_inputs(self, batches=1): @@ -55,10 +64,22 @@ def load_drivers(self, in_tensors): def load_monitors(self, expectation): raise NotImplementedError - def end_checks(self): - if self.fail_on_checks: - for monitor in self.output_monitors: - assert monitor.exp_queue.empty() + async def wait_end(self, timeout=1, timeout_unit="ms"): + while True: + await RisingEdge(self.clk) - for driver in self.input_drivers: - assert driver.send_queue.empty() + # ! TODO: check if this slows down test significantly + if get_sim_time(timeout_unit) > timeout: + raise TimeoutError("Timed out waiting for test to end.") + + if all( + [ + monitor.in_flight == False + for monitor in self.output_monitors.values() + ] + ): + break + + if self.fail_on_checks: + for driver in self.input_drivers.values(): + assert driver.send_queue.empty(), "Driver still has data to send." diff --git a/src/mase_cocotb/utils.py b/src/mase_cocotb/utils.py index cbdea58bb..62c811ebe 100644 --- a/src/mase_cocotb/utils.py +++ b/src/mase_cocotb/utils.py @@ -1,5 +1,6 @@ import random from copy import copy +import itertools from cocotb.triggers import RisingEdge import torch @@ -9,6 +10,16 @@ sys.path.append("../") from mase_cocotb.z_qlayers import quantize_to_int +from functools import partial +from chop.nn.quantizers import integer_quantizer + + +# Apparently this function only exists in Python 3.12 ... +def batched(iterable, n=1): + l = len(iterable) + for ndx in range(0, l, n): + yield iterable[ndx : min(ndx + n, l)] + # Apparently this function only exists in Python 3.12 ... def batched(iterable, n=1): @@ -83,6 +94,56 @@ def verilator_str_param(s): return f'"{s}"' +def product_dict(**kwargs): + keys = kwargs.keys() + for instance in itertools.product(*kwargs.values()): + yield dict(zip(keys, instance)) + + +def fixed_preprocess_tensor(tensor: Tensor, q_config: dict, parallelism: list) -> list: + """Preprocess a tensor before driving it into the DUT. + 1. Quantize to requested fixed-point precision. + 2. Convert to integer format to be compatible with Cocotb drivers. + 3. Split into blocks according to parallelism in each dimension. + + Args: + tensor (Tensor): Input tensor + q_config (dict): Quantization configuration. + parallelism (list): Parallelism in each dimension. + + Returns: + list: Processed blocks in nested list format. + """ + if len(tensor.shape) == 1: + tensor = tensor.unsqueeze(0) + + if len(parallelism) == 1: + parallelism = [1, parallelism[0]] + + # * Flatten batch dimension + tensor = tensor.view((-1, tensor.shape[-1])) + + # Quantize + quantizer = partial(integer_quantizer, **q_config) + q_tensor = quantizer(tensor) + + # Convert to integer format + q_tensor = (q_tensor * 2 ** q_config["frac_width"]).int() + q_tensor = signed_to_unsigned(q_tensor, bits=q_config["width"]) + + # Split into chunks according to parallelism in each dimension + # parallelism[0]: along rows, parallelism[1]: along columns + dim_0_split = q_tensor.split(parallelism[0], dim=0) + dim_1_split = [x.split(parallelism[1], dim=1) for x in dim_0_split] + blocks = [] + # Flatten the list of blocks + for i in range(len(dim_1_split)): + for j in range(len(dim_1_split[i])): + blocks.append(dim_1_split[i][j].flatten().tolist()) + + return blocks + + def large_num_generator(large_num_thres=127, large_num_limit=500, large_num_prob=0.1): """ Generator large numbers & small numbers with a given probability distribution. diff --git a/src/mase_components/ViT/test/test_synth_ViT.py b/src/mase_components/ViT/test/test_synth_ViT.py new file mode 100644 index 000000000..ad5a860cd --- /dev/null +++ b/src/mase_components/ViT/test/test_synth_ViT.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_ViT(): + run_synth("ViT") + + +if __name__ == "__main__": + test_synth_ViT() diff --git a/src/mase_components/__init__.py b/src/mase_components/__init__.py index 3dee93b96..de4c54db6 100644 --- a/src/mase_components/__init__.py +++ b/src/mase_components/__init__.py @@ -13,3 +13,23 @@ def get_modules(): if "__pycache__" in mods: mods.remove("__pycache__") return mods + + +def get_group_files(group): + current_dir = os.path.dirname(os.path.abspath(__file__)) + group_dir = os.path.join(current_dir, group, "rtl") + files = [ + f"{group}/rtl/{f}" + for f in os.listdir(group_dir) + if os.path.isfile(os.path.join(group_dir, f)) and "__init__" not in f + ] + return files + + +def get_module_dependencies(module): + group, mod = module.split("/") + group_deps = MASE_HW_DEPS.get(module, []) + file_deps = [] + for group_dep in group_deps: + file_deps += get_group_files(group_dep) + return file_deps diff --git a/src/mase_components/activations/rtl/fixed_elu.sv b/src/mase_components/activations/rtl/fixed_elu.sv index b753e5a3b..63fe10025 100644 --- a/src/mase_components/activations/rtl/fixed_elu.sv +++ b/src/mase_components/activations/rtl/fixed_elu.sv @@ -63,7 +63,7 @@ module fixed_elu #( if (STRAIGHT_THROUGH) begin unpacked_register_slice_quick #( .DATA_WIDTH(DATA_IN_0_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) ) single_roll ( .clk(clk), .rst(rst), diff --git a/src/mase_components/activations/rtl/fixed_gelu.sv b/src/mase_components/activations/rtl/fixed_gelu.sv index 40fb76156..5490795ce 100644 --- a/src/mase_components/activations/rtl/fixed_gelu.sv +++ b/src/mase_components/activations/rtl/fixed_gelu.sv @@ -1,164 +1,113 @@ `timescale 1ns / 1ps - - module fixed_gelu #( /* verilator lint_off UNUSEDPARAM */ parameter DATA_IN_0_PRECISION_0 = 8, - parameter DATA_IN_0_PRECISION_1 = 3, - parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_IN_0_PRECISION_1 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 10, parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_IN_0_PARALLELISM_DIM_0 = 8, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + parameter IN_0_DEPTH = $rtoi($ceil(DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0)), - parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + 2*DATA_IN_0_PRECISION_1 + APPROXIMATION_PRECISION-1, - parameter DATA_OUT_0_PRECISION_1 = 2 * DATA_IN_0_PRECISION_1 + APPROXIMATION_PRECISION, - parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 1, + parameter DATA_OUT_0_PRECISION_0 = 8, + parameter DATA_OUT_0_PRECISION_1 = 4, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 10, parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, - parameter DATA_OUT_0_PARALLELISM_DIM_0 = 8, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1, - - parameter APPROXIMATION_PRECISION = 16, - parameter APPROXIMATION_N = 8 + parameter DATA_OUT_0_PARALLELISM_DIM_2 = 1 ) ( /* verilator lint_off UNUSEDSIGNAL */ - input rst, input clk, - input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0], - output logic [DATA_IN_0_PRECISION_0 + APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_1 -1 :0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0], - + input rst, + input logic data_in_0_valid, output logic data_in_0_ready, + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, - input logic data_out_0_ready + input logic data_out_0_ready, + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0] ); - logic [DATA_IN_0_PRECISION_0-1:0] data_in_0_delayed0 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1 -1:0]; - logic [DATA_IN_0_PRECISION_0-1:0] data_in_0_delayed1 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1 -1:0]; - logic data_out_0_valid_delayed0; - logic data_out_0_valid_delayed1; + logic [DATA_IN_0_PRECISION_0-1:0] ff_data[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0]; + logic [DATA_IN_0_PRECISION_0-1:0] roll_data[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION-1 : 0] coefficient_a [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION-1 : 0] coefficient_b [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_0 - 1 :0] coefficient_c [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_0 - 1 :0] product_bx_scaled [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_0 - 1 :0] coefficient_c_scaled [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_0 - 1 :0] product_ax2 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [2*DATA_IN_0_PRECISION_0 - 1 :0] product_x2 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [APPROXIMATION_PRECISION + DATA_IN_0_PRECISION_0 - 1 :0] product_bx [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; + logic ff_data_valid; + logic ff_data_ready; - logic [APPROXIMATION_PRECISION*3 - 1:0] coefficients [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [DATA_IN_0_PRECISION_0 + APPROXIMATION_PRECISION + 2*DATA_IN_0_PRECISION_1 -1 :0] sum [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; - logic [2:0] index[DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0]; + logic roll_data_valid; + logic roll_data_ready; + unpacked_fifo #( + .DEPTH(IN_0_DEPTH), + .DATA_WIDTH(DATA_IN_0_PRECISION_0), + .IN_NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) + ) roller_buffer ( + .clk(clk), + .rst(rst), + .data_in(data_in_0), + .data_in_valid(data_in_0_valid), + .data_in_ready(data_in_0_ready), // write enable + .data_out(ff_data), + .data_out_valid(ff_data_valid), + .data_out_ready(ff_data_ready) // read enable + ); - genvar i; - generate - for ( - i = 0; i <= (DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1 - 1); i = i + 1 - ) begin + localparam STRAIGHT_THROUGH = (DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1 == DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1); - - const - logic [APPROXIMATION_PRECISION*3 - 1:0] - lut[0:APPROXIMATION_N-1] = { - 48'b111111111010111111111101100100011111101101001100, - 48'b111111011010010111110001101010001110100111100001, - 48'b111111110110010011110110100011111110110011001001, - 48'b000100110010100100011100101010011111111110101111, - 48'b000100110101101100100011001001011111111110111001, - 48'b111111111000000001001001000110111110110100001001, - 48'b111111011001110101001110011111011110100110110100, - 48'b111111111010110101000010011111111111101100110001 - }; - - fixed_mult #( - .IN_A_WIDTH(DATA_IN_0_PRECISION_0), - .IN_B_WIDTH(DATA_IN_0_PRECISION_0) - ) MX_multiplier_x2 ( - .data_a (data_in_0_delayed1[i]), - .data_b (data_in_0_delayed1[i]), - .product(product_x2[i]) - ); - - fixed_mult #( - .IN_A_WIDTH(2 * DATA_IN_0_PRECISION_0), - .IN_B_WIDTH(APPROXIMATION_PRECISION) - ) MX_multiplier_ax2 ( - .data_a (product_x2[i]), - .data_b (coefficient_a[i]), - .product(product_ax2[i]) + generate + if (STRAIGHT_THROUGH) begin + unpacked_register_slice_quick #( + .DATA_WIDTH(DATA_IN_0_PRECISION_0), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) + ) single_roll ( + .clk(clk), + .rst(rst), + .in_data(ff_data), + .in_valid(ff_data_valid), + .in_ready(ff_data_ready), + .out_data(roll_data), + .out_valid(roll_data_valid), + .out_ready(roll_data_ready) ); - fixed_mult #( - .IN_A_WIDTH(DATA_IN_0_PRECISION_0), - .IN_B_WIDTH(APPROXIMATION_PRECISION) - ) MX_multiplier_bx ( - .data_a (data_in_0_delayed1[i]), - .data_b (coefficient_b[i]), - .product(product_bx[i]) + end else begin + + roller #( + .DATA_WIDTH(DATA_IN_0_PRECISION_0), + .NUM(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .ROLL_NUM(DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1) + ) roller_inst ( + .clk(clk), + .rst(rst), + .data_in(ff_data), + .data_in_valid(ff_data_valid), + .data_in_ready(ff_data_ready), + .data_out(roll_data), + .data_out_valid(roll_data_valid), + .data_out_ready(roll_data_ready) ); - - always_ff @(posedge clk) begin - coefficients[i] <= lut[index[i]]; - end - - - always_ff @(posedge clk) begin - if (data_in_0_valid) begin - data_in_0_delayed0[i] <= data_in_0[i]; - data_in_0_delayed1[i] <= data_in_0_delayed0[i]; - end - end - - always_comb begin - if ($signed(data_in_0[i][DATA_IN_0_PRECISION_0-1 : DATA_IN_0_PRECISION_1]) >= 4) - data_out_0[i] = ($signed( - data_in_0_delayed1[i] - )) <<< APPROXIMATION_PRECISION - 2 + DATA_IN_0_PRECISION_1; - else if ($signed(data_in_0[i][DATA_IN_0_PRECISION_0-1 : DATA_IN_0_PRECISION_1]) <= -4) - data_out_0[i] = 0; - else data_out_0[i] = $signed(sum[i]); - end - - always_comb sum[i] = product_ax2[i] + product_bx_scaled[i] + coefficient_c_scaled[i]; - - always_comb - index[i] = (data_in_0_delayed0[i][DATA_IN_0_PRECISION_1+2 : DATA_IN_0_PRECISION_1]) + 4; - - always_comb - coefficient_c_scaled[i] = ($signed(coefficient_c[i])) <<< (2 * DATA_IN_0_PRECISION_1); - - always_comb product_bx_scaled[i] = ($signed(product_bx[i])) <<< (DATA_IN_0_PRECISION_1); - - assign coefficient_c[i] = $signed(coefficients[i][APPROXIMATION_PRECISION-1 : 0]); - - assign - coefficient_b[i] = coefficients[i] [((APPROXIMATION_PRECISION)*2-1) : APPROXIMATION_PRECISION]; - - assign - coefficient_a[i] = coefficients[i] [((APPROXIMATION_PRECISION*3)-1) : APPROXIMATION_PRECISION*2]; - - end endgenerate - - - always_ff @(posedge clk) begin - if (rst) data_out_0_valid_delayed1 <= 0; - else if (data_out_0_ready && !data_in_0_valid) begin - data_out_0_valid_delayed0 <= 0; - data_out_0_valid_delayed1 <= data_out_0_valid_delayed0; - end else if (data_in_0_valid) begin - data_out_0_valid_delayed0 <= 1; - data_out_0_valid_delayed1 <= data_out_0_valid_delayed0; - end else data_out_0_valid_delayed1 <= data_out_0_valid_delayed1; + for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++) begin : elu + gelu_lut #( + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + .DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1) + ) elu_map ( + .data_in_0 (roll_data[i]), + .data_out_0(data_out_0[i]) + ); end - assign data_in_0_ready = 1; - - assign data_out_0_valid = data_out_0_valid_delayed1; - + assign data_out_0_valid = roll_data_valid; + assign roll_data_ready = data_out_0_ready; endmodule diff --git a/src/mase_components/activations/rtl/fixed_hardswish.sv b/src/mase_components/activations/rtl/fixed_hardswish.sv index 7e076763a..e10b97d52 100644 --- a/src/mase_components/activations/rtl/fixed_hardswish.sv +++ b/src/mase_components/activations/rtl/fixed_hardswish.sv @@ -52,7 +52,7 @@ module fixed_hardswish #( tmp_0[i] = 3 <<< DATA_IN_0_PRECISION_1; // 3 in the same fx tmp_1[i] = data_in_0[i] + tmp_0[i]; // x + 3 tmp_2[i] = (tmp_1[i] >>> 3) + (tmp_1[i] >>> 4); // tmp/8 + tmp/16 ~ tmp/6 - assign data_out_0[i] = tmp_3[i]; // dout = x(x+3) * 3/16 [Original HardSwish is x(x+3)/6] + data_out_0[i] = tmp_3[i]; // dout = x(x+3) * 3/16 [Original HardSwish is x(x+3)/6] end end diff --git a/src/mase_components/activations/rtl/fixed_logsigmoid.sv b/src/mase_components/activations/rtl/fixed_logsigmoid.sv index 02ff77a14..9e70cec92 100644 --- a/src/mase_components/activations/rtl/fixed_logsigmoid.sv +++ b/src/mase_components/activations/rtl/fixed_logsigmoid.sv @@ -62,7 +62,7 @@ module fixed_logsigmoid #( if (STRAIGHT_THROUGH) begin unpacked_register_slice_quick #( .DATA_WIDTH(DATA_IN_0_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) ) single_roll ( .clk(clk), .rst(rst), diff --git a/src/mase_components/activations/rtl/fixed_relu.sv b/src/mase_components/activations/rtl/fixed_relu.sv index 8b9426668..2ceb0a242 100644 --- a/src/mase_components/activations/rtl/fixed_relu.sv +++ b/src/mase_components/activations/rtl/fixed_relu.sv @@ -1,3 +1,5 @@ +`timescale 1ns / 1ps + module fixed_relu #( /* verilator lint_off UNUSEDPARAM */ parameter DATA_IN_0_PRECISION_0 = 8, diff --git a/src/mase_components/activations/rtl/fixed_sigmoid.sv b/src/mase_components/activations/rtl/fixed_sigmoid.sv index b828d02d2..618a4c8d6 100644 --- a/src/mase_components/activations/rtl/fixed_sigmoid.sv +++ b/src/mase_components/activations/rtl/fixed_sigmoid.sv @@ -62,7 +62,7 @@ module fixed_sigmoid #( if (STRAIGHT_THROUGH) begin unpacked_register_slice_quick #( .DATA_WIDTH(DATA_IN_0_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) ) single_roll ( .clk(clk), .rst(rst), diff --git a/src/mase_components/activations/rtl/fixed_silu.sv b/src/mase_components/activations/rtl/fixed_silu.sv index 27ee82bee..7fab48b42 100644 --- a/src/mase_components/activations/rtl/fixed_silu.sv +++ b/src/mase_components/activations/rtl/fixed_silu.sv @@ -62,7 +62,7 @@ module fixed_silu #( if (STRAIGHT_THROUGH) begin unpacked_register_slice_quick #( .DATA_WIDTH(DATA_IN_0_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) ) single_roll ( .clk(clk), .rst(rst), diff --git a/src/mase_components/activations/rtl/fixed_softermax.sv b/src/mase_components/activations/rtl/fixed_softermax.sv new file mode 100644 index 000000000..3d875e0aa --- /dev/null +++ b/src/mase_components/activations/rtl/fixed_softermax.sv @@ -0,0 +1,100 @@ +/* +Module : fixed_softermax +Description : This module implements softermax. + https://arxiv.org/abs/2103.09301 + + It depends on the "softermax_local_window" and + "softermax_global_norm" modules. +*/ +`timescale 1ns / 1ps +module fixed_softermax #( + parameter DATA_IN_0_PRECISION_0 = 8, + parameter DATA_IN_0_PRECISION_1 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 10, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + + parameter DATA_OUT_0_PRECISION_0 = 8, + parameter DATA_OUT_0_PRECISION_1 = 4, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 10, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1 +) ( + input logic clk, + input logic rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + + // * Declarations + // * ================================================================= + + logic [DATA_IN_0_PRECISION_0-1:0] in_data_unflattened [DATA_IN_0_PARALLELISM_DIM_1-1:0] [DATA_IN_0_PARALLELISM_DIM_0-1:0]; + logic [DATA_OUT_0_PRECISION_0-1:0] out_data_unflattened [DATA_OUT_0_PARALLELISM_DIM_1-1:0] [DATA_OUT_0_PARALLELISM_DIM_0-1:0]; + + logic [DATA_IN_0_PARALLELISM_DIM_1-1:0] in_data_valid; + logic [DATA_IN_0_PARALLELISM_DIM_1-1:0] in_data_ready; + logic [DATA_IN_0_PARALLELISM_DIM_1-1:0] out_data_valid; + logic [DATA_IN_0_PARALLELISM_DIM_1-1:0] out_data_ready; + + // * Instances + // * ================================================================= + + // * Split handshake signals into the rows + split_n #( + .N(DATA_IN_0_PARALLELISM_DIM_1) + ) split_n_i ( + .data_in_valid (data_in_0_valid), + .data_in_ready (data_in_0_ready), + .data_out_valid(in_data_valid), + .data_out_ready(in_data_ready) + ); + + // * Softermax 1d instance for each row + for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_1; i++) begin + + assign in_data_unflattened [i] = data_in_0 [(i + 1) * DATA_IN_0_PARALLELISM_DIM_0 - 1 : i * DATA_IN_0_PARALLELISM_DIM_0]; + + fixed_softermax_1d #( + .TOTAL_DIM (DATA_IN_0_TENSOR_SIZE_DIM_0), + .PARALLELISM (DATA_IN_0_PARALLELISM_DIM_0), + .IN_WIDTH (DATA_IN_0_PRECISION_0), + .IN_FRAC_WIDTH (DATA_IN_0_PRECISION_1), + .POW2_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH (DATA_OUT_0_PRECISION_1) + ) fixed_softermax_1d_i ( + .clk, + .rst, + + .in_data (in_data_unflattened[i]), + .in_valid(in_data_valid[i]), + .in_ready(in_data_ready[i]), + + .out_data (out_data_unflattened[i]), + .out_valid(out_data_valid[i]), + .out_ready(out_data_ready[i]) + ); + + assign data_out_0 [(i + 1) * DATA_IN_0_PARALLELISM_DIM_0 - 1 : i * DATA_IN_0_PARALLELISM_DIM_0] = out_data_unflattened[i]; + end + + // * Join handshake signals from all the rows + join_n #( + .NUM_HANDSHAKES(DATA_OUT_0_PARALLELISM_DIM_1) + ) join_n_i ( + .data_in_valid (out_data_valid), + .data_in_ready (out_data_ready), + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) + ); + +endmodule diff --git a/src/mase_components/activations/rtl/fixed_softermax_1d.sv b/src/mase_components/activations/rtl/fixed_softermax_1d.sv new file mode 100644 index 000000000..4a2b7c0b1 --- /dev/null +++ b/src/mase_components/activations/rtl/fixed_softermax_1d.sv @@ -0,0 +1,93 @@ +/* +Module : softermax +Description : This module implements softermax. + https://arxiv.org/abs/2103.09301 + + It depends on the "softermax_local_window" and + "softermax_global_norm" modules. +*/ +`timescale 1ns / 1ps +module fixed_softermax_1d #( + // Shape Parameters + parameter TOTAL_DIM = 16, + parameter PARALLELISM = 4, + + // Width Parameters + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter POW2_WIDTH = 16, + // POW2_FRAC_WIDTH should always be POW2_WIDTH - 1, since local values are + // two to the power of a number in the range of (-inf, 0]. + localparam POW2_FRAC_WIDTH = 15, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 7 +) ( + input logic clk, + input logic rst, + + input logic [IN_WIDTH-1:0] in_data [PARALLELISM-1:0], + input logic in_valid, + output logic in_ready, + + output logic [OUT_WIDTH-1:0] out_data [PARALLELISM-1:0], + output logic out_valid, + input logic out_ready +); + + // ----- + // Params + // ----- + + localparam MAX_WIDTH = IN_WIDTH - IN_FRAC_WIDTH; + + + // ----- + // Wires + // ----- + + logic [ MAX_WIDTH-1:0] local_max; + logic [POW2_WIDTH-1:0] local_values[PARALLELISM-1:0]; + logic local_window_valid, local_window_ready; + + // ----- + // Modules + // ----- + + softermax_local_window #( + .PARALLELISM(PARALLELISM), + .IN_WIDTH(IN_WIDTH), + .IN_FRAC_WIDTH(IN_FRAC_WIDTH), + .OUT_WIDTH(POW2_WIDTH), + .OUT_FRAC_WIDTH(POW2_FRAC_WIDTH) + ) local_window_inst ( + .clk(clk), + .rst(rst), + .in_data(in_data), + .in_valid(in_valid), + .in_ready(in_ready), + .out_values(local_values), + .out_max(local_max), + .out_valid(local_window_valid), + .out_ready(local_window_ready) + ); + + softermax_global_norm #( + .TOTAL_DIM(TOTAL_DIM), + .PARALLELISM(PARALLELISM), + .IN_VALUE_WIDTH(POW2_WIDTH), + .IN_MAX_WIDTH(MAX_WIDTH), + .OUT_WIDTH(OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) global_norm_inst ( + .clk(clk), + .rst(rst), + .in_values(local_values), + .in_max(local_max), + .in_valid(local_window_valid), + .in_ready(local_window_ready), + .out_data(out_data), + .out_valid(out_valid), + .out_ready(out_ready) + ); + +endmodule diff --git a/src/mase_components/activations/rtl/fixed_softermax_2d.sv b/src/mase_components/activations/rtl/fixed_softermax_2d.sv new file mode 100644 index 000000000..96d492c4a --- /dev/null +++ b/src/mase_components/activations/rtl/fixed_softermax_2d.sv @@ -0,0 +1,82 @@ +/* +Module : fixed_softermax_2d +Description : This module extends the softermax operation into an additional + parallelism dimension by instantiating N fixed_softermax modules. + + !! This module only does softmax across DIM0 !! +*/ + +`timescale 1ns / 1ps + +module fixed_softermax_2d #( + // Shape Parameters + parameter TOTAL_DIM0 = 16, + parameter TOTAL_DIM1 = 16, + parameter COMPUTE_DIM0 = 4, + parameter COMPUTE_DIM1 = 4, + + // Width Parameters + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter POW2_WIDTH = 16, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 7 +) ( + input logic clk, + input logic rst, + + input logic [IN_WIDTH-1:0] in_data [COMPUTE_DIM0*COMPUTE_DIM1-1:0], + input logic in_valid, + output logic in_ready, + + output logic [OUT_WIDTH-1:0] out_data [COMPUTE_DIM0*COMPUTE_DIM1-1:0], + output logic out_valid, + input logic out_ready +); + + // ----- + // Wires + // ----- + + // Shape is (COMPUTE_DIM1, COMPUTE_DIM0, IN_WIDTH) + logic [IN_WIDTH-1:0] softermax_in_rows[COMPUTE_DIM1-1:0][COMPUTE_DIM0-1:0]; + logic softermax_in_ready[COMPUTE_DIM1-1:0]; + + logic [OUT_WIDTH-1:0] softermax_out_rows[COMPUTE_DIM1-1:0][COMPUTE_DIM0-1:0]; + logic softermax_out_valid[COMPUTE_DIM1-1:0]; + + + // ----- + // Modules + // ----- + + for (genvar i = 0; i < COMPUTE_DIM1; i++) begin : softermax_row + + assign softermax_in_rows[i] = in_data[(i+1)*COMPUTE_DIM1-1:i*COMPUTE_DIM1]; + assign out_data[(i+1)*COMPUTE_DIM1-1:i*COMPUTE_DIM1] = softermax_out_rows[i]; + + fixed_softermax #( + .TOTAL_DIM (TOTAL_DIM0), + .PARALLELISM (COMPUTE_DIM0), + .IN_WIDTH (IN_WIDTH), + .IN_FRAC_WIDTH (IN_FRAC_WIDTH), + .POW2_WIDTH (POW2_WIDTH), + .OUT_WIDTH (OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) softermax_inst ( + .clk (clk), + .rst (rst), + .in_data (softermax_in_rows[i]), + .in_valid (in_valid), + .in_ready (softermax_in_ready[i]), + .out_data (softermax_out_rows[i]), + .out_valid(softermax_out_valid[i]), + .out_ready(out_ready) + ); + + end + + assign in_ready = softermax_in_ready[0]; + assign out_valid = softermax_out_valid[0]; + +endmodule diff --git a/src/mase_components/activations/rtl/fixed_softshrink.sv b/src/mase_components/activations/rtl/fixed_softshrink.sv index fedb1c447..7b8182b3d 100644 --- a/src/mase_components/activations/rtl/fixed_softshrink.sv +++ b/src/mase_components/activations/rtl/fixed_softshrink.sv @@ -64,7 +64,7 @@ module fixed_softshrink #( if (STRAIGHT_THROUGH) begin unpacked_register_slice_quick #( .DATA_WIDTH(DATA_IN_0_PRECISION_0), - .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1), + .IN_SIZE(DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1) ) single_roll ( .clk(clk), .rst(rst), diff --git a/src/mase_components/activations/rtl/softermax_global_norm.sv b/src/mase_components/activations/rtl/softermax_global_norm.sv new file mode 100644 index 000000000..0cd55a0bd --- /dev/null +++ b/src/mase_components/activations/rtl/softermax_global_norm.sv @@ -0,0 +1,417 @@ +/* +Module : softermax_global_norm +Description : This module implements the second section of the softermax compute + pipeline which calculates renormalizes all local values and + calculates the final. + + Refer to bottom half of Fig. 4.a) and 4.b) in the Softermax Paper. + https://arxiv.org/abs/2103.09301 +*/ + +`timescale 1ns / 1ps + +module softermax_global_norm #( + // Input shape dimensions + parameter TOTAL_DIM = 16, + parameter PARALLELISM = 4, + + // Widths + parameter IN_VALUE_WIDTH = 16, + // IN_VALUE_FRAC_WIDTH should always be (IN_VALUE_WIDTH-1) since it is an + // unsigned fixed-point number in range [0, 2) + localparam IN_VALUE_FRAC_WIDTH = IN_VALUE_WIDTH - 1, + parameter IN_MAX_WIDTH = 5, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 7 +) ( + input logic clk, + input logic rst, + + // in_values: Unsigned fixed-point in range [0, 2) + input logic [IN_VALUE_WIDTH-1:0] in_values[PARALLELISM-1:0], + // in_max: Signed integers + input logic [ IN_MAX_WIDTH-1:0] in_max, + input logic in_valid, + output logic in_ready, + + output logic [OUT_WIDTH-1:0] out_data [PARALLELISM-1:0], + output logic out_valid, + input logic out_ready +); + + // ----- + // Parameters + // ----- + + localparam DEPTH = TOTAL_DIM / PARALLELISM; + + // Max is integer only + localparam SUBTRACT_WIDTH = IN_MAX_WIDTH + 1; + + localparam ADDER_TREE_IN_WIDTH = 1 + IN_VALUE_WIDTH; // Pad single zero for unsigned + localparam ADDER_TREE_OUT_WIDTH = $clog2(PARALLELISM) + ADDER_TREE_IN_WIDTH; + localparam ADDER_TREE_FRAC_WIDTH = IN_VALUE_FRAC_WIDTH; + + localparam ACC_WIDTH = $clog2(DEPTH) + ADDER_TREE_OUT_WIDTH; + localparam ACC_FRAC_WIDTH = ADDER_TREE_FRAC_WIDTH; + + // TODO: Maybe set this at top level? + localparam RECIP_WIDTH = 2 * ACC_WIDTH; + localparam RECIP_FRAC_WIDTH = 2 * ACC_FRAC_WIDTH; + + localparam MULT_WIDTH = IN_VALUE_WIDTH + RECIP_WIDTH; + localparam MULT_FRAC_WIDTH = IN_VALUE_FRAC_WIDTH + RECIP_FRAC_WIDTH; + + + initial begin + assert (TOTAL_DIM > 1); + assert (DEPTH * PARALLELISM == TOTAL_DIM); + + // Sanity Check + assert (ADDER_TREE_OUT_WIDTH >= ADDER_TREE_FRAC_WIDTH); + assert (ADDER_TREE_IN_WIDTH >= ADDER_TREE_FRAC_WIDTH); + assert (ACC_WIDTH >= ACC_FRAC_WIDTH); + assert (RECIP_WIDTH >= RECIP_FRAC_WIDTH); + assert (MULT_WIDTH >= MULT_FRAC_WIDTH); + end + + + // ----- + // Wires + // ----- + + logic [IN_MAX_WIDTH-1:0] local_max_out; + logic local_max_in_valid, local_max_in_ready; + logic local_max_out_valid, local_max_out_ready; + + logic [IN_MAX_WIDTH-1:0] global_max_out; + logic global_max_in_valid, global_max_in_ready; + logic global_max_out_valid, global_max_out_ready; + + logic [IN_MAX_WIDTH-1:0] repeat_global_max_out; + logic repeat_global_max_out_valid, repeat_global_max_out_ready; + + logic [IN_VALUE_WIDTH-1:0] local_values_out[PARALLELISM-1:0]; + logic local_values_in_valid, local_values_in_ready; + logic local_values_out_valid, local_values_out_ready; + + logic [SUBTRACT_WIDTH-1:0] subtract_in_data; + logic [SUBTRACT_WIDTH-1:0] subtract_out_data; + logic subtract_in_valid, subtract_in_ready; + logic subtract_out_valid, subtract_out_ready; + + logic [IN_VALUE_WIDTH-1:0] shift_in_data[PARALLELISM-1:0]; + logic [IN_VALUE_WIDTH-1:0] shift_out_data[PARALLELISM-1:0]; + logic shift_in_valid; + logic shift_in_ready[PARALLELISM-1:0]; + logic shift_out_valid[PARALLELISM-1:0]; + logic shift_out_ready; + + logic [IN_VALUE_WIDTH-1:0] adjusted_values_in_data[PARALLELISM-1:0]; + logic [IN_VALUE_WIDTH-1:0] adjusted_values_out_data[PARALLELISM-1:0]; + logic adjusted_values_in_valid, adjusted_values_in_ready; + logic adjusted_values_out_valid, adjusted_values_out_ready; + + logic [ ADDER_TREE_IN_WIDTH-1:0] adder_tree_in_data [PARALLELISM-1:0]; + logic [ADDER_TREE_OUT_WIDTH-1:0] adder_tree_out_data; + logic adder_tree_in_valid, adder_tree_in_ready; + logic adder_tree_out_valid, adder_tree_out_ready; + + logic [ACC_WIDTH-1:0] acc_out_data; + logic acc_out_valid, acc_out_ready; + + logic [RECIP_WIDTH-1:0] norm_recip_data; + logic norm_recip_valid, norm_recip_ready; + + logic [RECIP_WIDTH-1:0] repeat_norm_recip_data; + logic repeat_norm_recip_valid, repeat_norm_recip_ready; + + logic [MULT_WIDTH-1:0] mult_in_data[PARALLELISM-1:0]; + logic [MULT_WIDTH-1:0] mult_out_data[PARALLELISM-1:0]; + logic mult_in_valid; + logic mult_in_ready[PARALLELISM-1:0]; + logic mult_out_valid[PARALLELISM-1:0]; + logic mult_out_ready[PARALLELISM-1:0]; + + logic [OUT_WIDTH-1:0] mult_cast_data[PARALLELISM-1:0]; + + logic [OUT_WIDTH-1:0] out_reg_data[PARALLELISM-1:0]; + logic out_reg_valid[PARALLELISM-1:0]; + logic out_reg_ready; + + + // ----- + // Modules + // ----- + + split_n #( + .N(3) + ) input_split ( + .data_in_valid (in_valid), + .data_in_ready (in_ready), + .data_out_valid({local_max_in_valid, global_max_in_valid, local_values_in_valid}), + .data_out_ready({local_max_in_ready, global_max_in_ready, local_values_in_ready}) + ); + + fifo #( + .DATA_WIDTH(IN_MAX_WIDTH), + .DEPTH (4 * DEPTH) // TODO: resize +1 ? + ) local_max_buffer ( + .clk(clk), + .rst(rst), + .in_data(in_max), + .in_valid(local_max_in_valid), + .in_ready(local_max_in_ready), + .out_data(local_max_out), + .out_valid(local_max_out_valid), + .out_ready(local_max_out_ready), + .empty(), + .full() + ); + + comparator_accumulator #( + .DATA_WIDTH(IN_MAX_WIDTH), + .DEPTH(DEPTH), + .MAX1_MIN0(1), + .SIGNED(1) + ) global_max_accumulator ( + .clk(clk), + .rst(rst), + .in_data(in_max), + .in_valid(global_max_in_valid), + .in_ready(global_max_in_ready), + .out_data(global_max_out), + .out_valid(global_max_out_valid), + .out_ready(global_max_out_ready) + ); + + single_element_repeat #( + .DATA_WIDTH(IN_MAX_WIDTH), + .REPEAT(DEPTH) + ) global_max_repeater ( + .clk(clk), + .rst(rst), + .in_data(global_max_out), + .in_valid(global_max_out_valid), + .in_ready(global_max_out_ready), + .out_data(repeat_global_max_out), + .out_valid(repeat_global_max_out_valid), + .out_ready(repeat_global_max_out_ready) + ); + + matrix_fifo #( + .DATA_WIDTH(IN_VALUE_WIDTH), + .DIM0(PARALLELISM), + .DIM1(1), + .FIFO_SIZE(4 * DEPTH) // TODO: resize? + ) local_values_buffer ( + .clk(clk), + .rst(rst), + .in_data(in_values), + .in_valid(local_values_in_valid), + .in_ready(local_values_in_ready), + .out_data(local_values_out), + .out_valid(local_values_out_valid), + .out_ready(local_values_out_ready) + ); + + + join2 subtract_join ( + .data_in_valid ({repeat_global_max_out_valid, local_max_out_valid}), + .data_in_ready ({repeat_global_max_out_ready, local_max_out_ready}), + .data_out_valid(subtract_in_valid), + .data_out_ready(subtract_in_ready) + ); + + assign subtract_in_data = $signed(repeat_global_max_out) - $signed(local_max_out); + + skid_buffer #( + .DATA_WIDTH(SUBTRACT_WIDTH) + ) sub_reg ( + .clk(clk), + .rst(rst), + .data_in(subtract_in_data), + .data_in_valid(subtract_in_valid), + .data_in_ready(subtract_in_ready), + .data_out(subtract_out_data), + .data_out_valid(subtract_out_valid), + .data_out_ready(subtract_out_ready) + ); + + join2 shift_join ( + .data_in_valid ({local_values_out_valid, subtract_out_valid}), + .data_in_ready ({local_values_out_ready, subtract_out_ready}), + .data_out_valid(shift_in_valid), + .data_out_ready(shift_in_ready[0]) + ); + + + // Batched shift + for (genvar i = 0; i < PARALLELISM; i++) begin : shift + assign shift_in_data[i] = local_values_out[i] >> subtract_out_data; + + skid_buffer #( + .DATA_WIDTH(IN_VALUE_WIDTH) + ) shift_reg ( + .clk(clk), + .rst(rst), + .data_in(shift_in_data[i]), + .data_in_valid(shift_in_valid), + .data_in_ready(shift_in_ready[i]), + .data_out(shift_out_data[i]), + .data_out_valid(shift_out_valid[i]), + .data_out_ready(shift_out_ready) + ); + end + + split2 norm_split ( + .data_in_valid (shift_out_valid[0]), + .data_in_ready (shift_out_ready), + .data_out_valid({adjusted_values_in_valid, adder_tree_in_valid}), + .data_out_ready({adjusted_values_in_ready, adder_tree_in_ready}) + ); + + assign adjusted_values_in_data = shift_out_data; + for (genvar i = 0; i < PARALLELISM; i++) begin : unsigned_hack + assign adder_tree_in_data[i] = {1'b0, shift_out_data[i]}; + end + + matrix_fifo #( + .DATA_WIDTH(IN_VALUE_WIDTH), + .DIM0(PARALLELISM), + .DIM1(1), + .FIFO_SIZE(4 * DEPTH) // TODO: resize? + ) adjusted_values_buffer ( + .clk(clk), + .rst(rst), + .in_data(adjusted_values_in_data), + .in_valid(adjusted_values_in_valid), + .in_ready(adjusted_values_in_ready), + .out_data(adjusted_values_out_data), + .out_valid(adjusted_values_out_valid), + .out_ready(adjusted_values_out_ready) + ); + + + generate + if (PARALLELISM == 1) begin : gen_skip_adder_tree + assign adder_tree_out_data = adder_tree_in_data[0]; + assign adder_tree_out_valid = adder_tree_in_valid; + assign adder_tree_in_ready = adder_tree_out_ready; + end else begin : gen_adder_tree + fixed_adder_tree #( + .IN_SIZE (PARALLELISM), + .IN_WIDTH(ADDER_TREE_IN_WIDTH) + ) adder_tree ( + .clk(clk), + .rst(rst), + .data_in(adder_tree_in_data), + .data_in_valid(adder_tree_in_valid), + .data_in_ready(adder_tree_in_ready), + .data_out(adder_tree_out_data), + .data_out_valid(adder_tree_out_valid), + .data_out_ready(adder_tree_out_ready) + ); + end + endgenerate + + fixed_accumulator #( + .IN_DEPTH(DEPTH), + .IN_WIDTH(ADDER_TREE_OUT_WIDTH) + ) norm_accumulator ( + .clk(clk), + .rst(rst), + .data_in(adder_tree_out_data), + .data_in_valid(adder_tree_out_valid), + .data_in_ready(adder_tree_out_ready), + .data_out(acc_out_data), + .data_out_valid(acc_out_valid), + .data_out_ready(acc_out_ready) + ); + + softermax_lpw_reciprocal #( + .ENTRIES(8), + .IN_WIDTH(ACC_WIDTH), + .IN_FRAC_WIDTH(ACC_FRAC_WIDTH), + .OUT_WIDTH(RECIP_WIDTH), + .OUT_FRAC_WIDTH(RECIP_FRAC_WIDTH) + ) norm_recip ( + .clk(clk), + .rst(rst), + .in_data(acc_out_data), + .in_valid(acc_out_valid), + .in_ready(acc_out_ready), + .out_data(norm_recip_data), + .out_valid(norm_recip_valid), + .out_ready(norm_recip_ready) + ); + + single_element_repeat #( + .DATA_WIDTH(RECIP_WIDTH), + .REPEAT(DEPTH) + ) repeat_norm_recip ( + .clk(clk), + .rst(rst), + .in_data(norm_recip_data), + .in_valid(norm_recip_valid), + .in_ready(norm_recip_ready), + .out_data(repeat_norm_recip_data), + .out_valid(repeat_norm_recip_valid), + .out_ready(repeat_norm_recip_ready) + ); + + join2 mult_join ( + .data_in_valid ({adjusted_values_out_valid, repeat_norm_recip_valid}), + .data_in_ready ({adjusted_values_out_ready, repeat_norm_recip_ready}), + .data_out_valid(mult_in_valid), + .data_out_ready(mult_in_ready[0]) + ); + + // Batched mult & cast + for (genvar i = 0; i < PARALLELISM; i++) begin : output_mult_cast + assign mult_in_data[i] = adjusted_values_out_data[i] * repeat_norm_recip_data; + + skid_buffer #( + .DATA_WIDTH(MULT_WIDTH) + ) mult_reg ( + .clk(clk), + .rst(rst), + .data_in(mult_in_data[i]), + .data_in_valid(mult_in_valid), + .data_in_ready(mult_in_ready[i]), + .data_out(mult_out_data[i]), + .data_out_valid(mult_out_valid[i]), + .data_out_ready(mult_out_ready[0]) + ); + + fixed_signed_cast #( + .IN_WIDTH(MULT_WIDTH), + .IN_FRAC_WIDTH(MULT_FRAC_WIDTH), + .OUT_WIDTH(OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH), + .SYMMETRIC(0), + .ROUND_FLOOR(1) + ) output_cast ( + .in_data (mult_out_data[i]), + .out_data(mult_cast_data[i]) + ); + + skid_buffer #( + .DATA_WIDTH(MULT_WIDTH) + ) out_reg ( + .clk(clk), + .rst(rst), + .data_in(mult_cast_data[i]), + .data_in_valid(mult_out_valid[i]), + .data_in_ready(mult_out_ready[i]), + .data_out(out_reg_data[i]), + .data_out_valid(out_reg_valid[i]), + .data_out_ready(out_reg_ready) + ); + end + + assign out_data = out_reg_data; + assign out_valid = out_reg_valid[0]; + assign out_reg_ready = out_ready; + +endmodule diff --git a/src/mase_components/activations/rtl/softermax_local_window.sv b/src/mase_components/activations/rtl/softermax_local_window.sv new file mode 100644 index 000000000..4a7cc5871 --- /dev/null +++ b/src/mase_components/activations/rtl/softermax_local_window.sv @@ -0,0 +1,269 @@ +/* +Module : softermax_local_window +Description : This module implements the first section of the softermax compute + pipeline which calculates the local maximum and power of 2. + + Refer to the top half of Fig. 4.a) in the Softermax Paper. + https://arxiv.org/abs/2103.09301 +*/ + +`timescale 1ns / 1ps + +module softermax_local_window #( + // Input shape independent of TOTAL_DIM since this module only + // operates on local windows + parameter PARALLELISM = 4, + + // Widths + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 7, + + // Derived params + localparam MAX_WIDTH = IN_WIDTH - IN_FRAC_WIDTH +) ( + input logic clk, + input logic rst, + + // Input streaming interface + input logic [IN_WIDTH-1:0] in_data [PARALLELISM-1:0], + input logic in_valid, + output logic in_ready, + + // Output streaming interface with pow of 2 values & max + output logic [OUT_WIDTH-1:0] out_values[PARALLELISM-1:0], + output logic [MAX_WIDTH-1:0] out_max, + output logic out_valid, + input logic out_ready +); + + // ----- + // Parameters + // ----- + + localparam MAX_TREE_DEPTH = $clog2(PARALLELISM); + + localparam SUBTRACT_WIDTH = IN_WIDTH + 1; + localparam SUBTRACT_FRAC_WIDTH = IN_FRAC_WIDTH; + + initial begin + assert (PARALLELISM > 1); + assert (IN_WIDTH > IN_FRAC_WIDTH); + assert (OUT_WIDTH > OUT_FRAC_WIDTH); + end + + // ----- + // Wires + // ----- + + logic [MAX_WIDTH-1:0] rounded_int_in_data[PARALLELISM-1:0]; + logic [MAX_WIDTH-1:0] rounded_int_out_data[PARALLELISM-1:0]; + logic rounded_int_in_valid; + logic rounded_int_in_ready[PARALLELISM-1:0]; + logic rounded_int_out_valid[PARALLELISM-1:0]; + logic rounded_int_out_ready; + + // logic [MAX_WIDTH-1:0] max_tree_in_data [PARALLELISM-1:0]; + logic [MAX_WIDTH-1:0] max_tree_out_data; + // logic max_tree_in_valid, max_tree_in_ready; + logic max_tree_out_valid, max_tree_out_ready; + + logic [IN_WIDTH-1:0] input_fifo_in_data [PARALLELISM-1:0]; + logic [IN_WIDTH-1:0] input_fifo_out_data[PARALLELISM-1:0]; + logic input_fifo_in_valid, input_fifo_in_ready; + logic input_fifo_out_valid, input_fifo_out_ready; + + logic [MAX_WIDTH-1:0] max_fifo_in_data, max_fifo_out_data; + logic max_fifo_in_valid, max_fifo_in_ready; + logic max_fifo_out_valid, max_fifo_out_ready; + + logic [MAX_WIDTH-1:0] sub_max_in_data; + logic [MAX_WIDTH-1:0] sub_max_out_data; + logic sub_max_in_valid, sub_max_in_ready; + logic sub_max_out_valid, sub_max_out_ready; + + logic [SUBTRACT_WIDTH-1:0] subtract_in_data[PARALLELISM-1:0]; + logic [SUBTRACT_WIDTH-1:0] subtract_out_data[PARALLELISM-1:0]; + logic subtract_in_valid; + logic subtract_in_ready[PARALLELISM-1:0]; + logic subtract_out_valid[PARALLELISM-1:0]; + logic subtract_out_ready[PARALLELISM-1:0]; + + logic [OUT_WIDTH-1:0] lpw_out_data[PARALLELISM-1:0]; + logic lpw_out_valid[PARALLELISM-1:0]; + logic lpw_out_ready; + + // ----- + // Modules + // ----- + + split2 input_split ( + .data_in_valid (in_valid), + .data_in_ready (in_ready), + .data_out_valid({rounded_int_in_valid, input_fifo_in_valid}), + .data_out_ready({rounded_int_in_ready[0], input_fifo_in_ready}) + ); + + for (genvar i = 0; i < PARALLELISM; i++) begin : rounding + fixed_signed_cast #( + .IN_WIDTH(IN_WIDTH), + .IN_FRAC_WIDTH(IN_FRAC_WIDTH), + .OUT_WIDTH(MAX_WIDTH), + .OUT_FRAC_WIDTH(0), // No fraction + .SYMMETRIC(0), + .ROUND_FLOOR(1) + ) rounding_inst ( + .in_data (in_data[i]), + .out_data(rounded_int_in_data[i]) + ); + + // Required even though cast is only doing slice since theres another split + skid_buffer #( + .DATA_WIDTH(MAX_WIDTH) + ) cast_buff ( + .clk(clk), + .rst(rst), + .data_in(rounded_int_in_data[i]), + .data_in_valid(rounded_int_in_valid), + .data_in_ready(rounded_int_in_ready[i]), + .data_out(rounded_int_out_data[i]), + .data_out_valid(rounded_int_out_valid[i]), + .data_out_ready(rounded_int_out_ready) + ); + end + + assign input_fifo_in_data = in_data; + + comparator_tree #( + .SIZE(PARALLELISM), + .DATA_WIDTH(MAX_WIDTH), + .MAX1_MIN0(1), // MAX + .SIGNED(1) + ) max_tree ( + .clk(clk), + .rst(rst), + .in_data(rounded_int_out_data), + .in_valid(rounded_int_out_valid[0]), + .in_ready(rounded_int_out_ready), + .out_data(max_tree_out_data), + .out_valid(max_tree_out_valid), + .out_ready(max_tree_out_ready) + ); + + matrix_fifo #( + .DATA_WIDTH(IN_WIDTH), + .DIM0(PARALLELISM), + .DIM1(1), + .FIFO_SIZE(16) + ) input_buffer ( + .clk(clk), + .rst(rst), + .in_data(input_fifo_in_data), + .in_valid(input_fifo_in_valid), + .in_ready(input_fifo_in_ready), + .out_data(input_fifo_out_data), + .out_valid(input_fifo_out_valid), + .out_ready(input_fifo_out_ready) + ); + + assign max_fifo_in_data = max_tree_out_data; + + fifo #( + .DATA_WIDTH(MAX_WIDTH), + // Should be enough. Subtract takes 1 cycle and LPW is max 3 cycles + .DEPTH(8) + ) max_buffer ( + .clk(clk), + .rst(rst), + .in_data(max_fifo_in_data), + .in_valid(max_fifo_in_valid), + .in_ready(max_fifo_in_ready), + .out_data(max_fifo_out_data), + .out_valid(max_fifo_out_valid), + .out_ready(max_fifo_out_ready), + .empty(), + .full() + ); + + split2 max_tree_split ( + .data_in_valid (max_tree_out_valid), + .data_in_ready (max_tree_out_ready), + .data_out_valid({sub_max_in_valid, max_fifo_in_valid}), + .data_out_ready({sub_max_in_ready, max_fifo_in_ready}) + ); + + assign sub_max_in_data = max_tree_out_data; + + skid_buffer #( + .DATA_WIDTH(MAX_WIDTH) + ) max_intermediate_buff ( + .clk(clk), + .rst(rst), + .data_in(sub_max_in_data), + .data_in_valid(sub_max_in_valid), + .data_in_ready(sub_max_in_ready), + .data_out(sub_max_out_data), + .data_out_valid(sub_max_out_valid), + .data_out_ready(sub_max_out_ready) + ); + + join2 subtract_join ( + .data_in_valid ({sub_max_out_valid, input_fifo_out_valid}), + .data_in_ready ({sub_max_out_ready, input_fifo_out_ready}), + .data_out_valid(subtract_in_valid), + .data_out_ready(subtract_in_ready[0]) + ); + + // Batched subtract & power of 2 + for (genvar i = 0; i < PARALLELISM; i++) begin : subtract_pow2 + // Need to pad the maxInt with fixed-point zeros in fraction + assign subtract_in_data[i] = $signed( + input_fifo_out_data[i] + ) - $signed( + {sub_max_out_data, {IN_FRAC_WIDTH{1'b0}}} + ); + + skid_buffer #( + .DATA_WIDTH(SUBTRACT_WIDTH) + ) sub_reg ( + .clk(clk), + .rst(rst), + .data_in(subtract_in_data[i]), + .data_in_valid(subtract_in_valid), + .data_in_ready(subtract_in_ready[i]), + .data_out(subtract_out_data[i]), + .data_out_valid(subtract_out_valid[i]), + .data_out_ready(subtract_out_ready[i]) + ); + + softermax_lpw_pow2 #( + .IN_WIDTH(SUBTRACT_WIDTH), + .IN_FRAC_WIDTH(SUBTRACT_FRAC_WIDTH), + .OUT_WIDTH(OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) lpw_pow2 ( + .clk(clk), + .rst(rst), + .in_data(subtract_out_data[i]), + .in_valid(subtract_out_valid[i]), + .in_ready(subtract_out_ready[i]), + .out_data(lpw_out_data[i]), + .out_valid(lpw_out_valid[i]), + .out_ready(lpw_out_ready) + ); + + end + + // Final synchronize + join2 output_sync ( + .data_in_valid ({max_fifo_out_valid, lpw_out_valid[0]}), + .data_in_ready ({max_fifo_out_ready, lpw_out_ready}), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + + assign out_values = lpw_out_data; + assign out_max = max_fifo_out_data; + +endmodule diff --git a/src/mase_components/activations/rtl/softermax_lpw_pow2.sv b/src/mase_components/activations/rtl/softermax_lpw_pow2.sv new file mode 100644 index 000000000..77f2b5f0d --- /dev/null +++ b/src/mase_components/activations/rtl/softermax_lpw_pow2.sv @@ -0,0 +1,269 @@ +/* +Module : softermax_lpw_pow2 +Description : This module implements 2^x with linear piecewise approximation. + + Uses 4 linear pieces between [0, 1) for the fraction then shifts + it by the integer part. + + TODO: need to support (-inf, 1) -> (0, 2) +*/ + +`timescale 1ns / 1ps + +module softermax_lpw_pow2 #( + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 4 +) ( + input logic clk, + input logic rst, + + input logic [IN_WIDTH-1:0] in_data, + input logic in_valid, + output logic in_ready, + + output logic [OUT_WIDTH-1:0] out_data, + output logic out_valid, + input logic out_ready +); + + // ----- + // Parameters + // ----- + + // Input: x + localparam INT_WIDTH = IN_WIDTH - IN_FRAC_WIDTH; + + // Slope: m + localparam SLOPE_FRAC_WIDTH = OUT_WIDTH; + localparam SLOPE_WIDTH = 2 + SLOPE_FRAC_WIDTH; + + // Mult: mx + localparam MULT_FRAC_WIDTH = IN_FRAC_WIDTH + SLOPE_FRAC_WIDTH; + localparam MULT_WIDTH = IN_WIDTH + SLOPE_WIDTH; + + // Intercept (need to match mx frac): c + localparam INTERCEPT_FRAC_WIDTH = MULT_FRAC_WIDTH; + localparam INTERCEPT_WIDTH = 2 + INTERCEPT_FRAC_WIDTH; + + // Output width: mx + c + localparam LPW_WIDTH = MULT_WIDTH + 1; + localparam LPW_FRAC_WIDTH = MULT_FRAC_WIDTH; // == INTERCEPT_FRAC_WIDTH + + // PARAMETERS BELOW ONLY USED IN 1/2-BIT CASE + // Output result of 2^[0,1] is in [1,2] which requires 2 integer bits + localparam LUT_WIDTH = IN_FRAC_WIDTH + 2; + + + initial begin + assert (INT_WIDTH > 0); // Untested for 0 int width + assert (OUT_WIDTH > OUT_FRAC_WIDTH); // Untested for 0 out frac width + assert (IN_WIDTH > 0); + assert (IN_FRAC_WIDTH >= 0); + end + + // Wires + logic [INT_WIDTH-1:0] in_data_int; // Q INT.0 + logic [IN_FRAC_WIDTH-1:0] in_data_frac; // Q 0.FRAC + + logic [OUT_WIDTH-1:0] result_data; + logic result_valid, result_ready; + + + // Function to generate LUT (Only used 1/2-bit case) + function automatic logic [OUT_WIDTH-1:0] pow2_func(real x); + real res, res_shifted; + bit [OUT_WIDTH-1:0] return_val; + res = 2.0 ** x; + + // Output cast + res_shifted = res * (2.0 ** OUT_FRAC_WIDTH); + return_val = logic'(res_shifted); + return return_val; + endfunction + + // Function to generate slope variable (m) + function automatic logic [SLOPE_WIDTH-1:0] slope(real x1, real x2); + real y1, y2, res, res_shifted; + bit [SLOPE_WIDTH-1:0] return_val; + y1 = 2.0 ** x1; + y2 = 2.0 ** x2; + res = (y2 - y1) / (x2 - x1); + + // Output cast + res_shifted = res * (2.0 ** SLOPE_FRAC_WIDTH); + return_val = logic'(res_shifted); + return return_val; + endfunction + + // Function to intercept variable (c) + function automatic logic [INTERCEPT_WIDTH-1:0] intercept(real x1, real x2); + real m, y1, y2, res, res_shifted; + bit [INTERCEPT_WIDTH-1:0] return_val; + y1 = 2.0 ** x1; + y2 = 2.0 ** x2; + m = (y2 - y1) / (x2 - x1); + res = y1 - (m * x1); + + // Output cast + res_shifted = res * (2 ** INTERCEPT_FRAC_WIDTH); + return_val = logic'(res_shifted); + return return_val; + endfunction + + // ----- + // Logic + // ----- + + assign {in_data_int, in_data_frac} = in_data; + + generate + if (IN_FRAC_WIDTH <= 1) begin : one_bit_frac + + logic [OUT_WIDTH-1:0] lookup_result; + always_comb begin + case (in_data_frac) + 1'b0: lookup_result = pow2_func(0.0); + 1'b1: lookup_result = pow2_func(0.5); + endcase + // TODO: Fix pipelining for the shifter + result_data = lookup_result >> -in_data_int; + result_valid = in_valid; + in_ready = result_ready; + end + + end else if (IN_FRAC_WIDTH == 2) begin : two_bit_frac + + logic [OUT_WIDTH-1:0] lookup_result; + always_comb begin + case (in_data_frac) + 2'b00: lookup_result = pow2_func(0.0); + 2'b01: lookup_result = pow2_func(0.25); + 2'b10: lookup_result = pow2_func(0.5); + 2'b11: lookup_result = pow2_func(0.75); + endcase + // TODO: Fix pipelining for the shifter + result_data = lookup_result >> -in_data_int; + result_valid = in_valid; + in_ready = result_ready; + end + + end else begin : lpw_approx + + // Split out the top two bits of the frac again to figure out which + // piecewise part if lies on + logic [1:0] frac_top_in, frac_top_out; + assign frac_top_in = in_data_frac[IN_FRAC_WIDTH-1:IN_FRAC_WIDTH-2]; + + logic [INT_WIDTH-1:0] in_data_int_buff[1:0]; + + logic [MULT_WIDTH-1:0] mult_in, mult_out; + logic mult_out_valid, mult_out_ready; + logic intercept_out_valid, intercept_out_ready; + + logic [LPW_WIDTH-1:0] lpw_int_in, lpw_int_out, lpw_result; + logic [LPW_WIDTH-1:0] lpw_out_data; + logic lpw_out_valid, lpw_out_ready; + + logic [OUT_WIDTH:0] lpw_cast_out; + + always_comb begin + // Multiplication Stage + case (frac_top_in) + 2'b00: mult_in = in_data_frac * slope(0.00, 0.25); + 2'b01: mult_in = in_data_frac * slope(0.25, 0.50); + 2'b10: mult_in = in_data_frac * slope(0.50, 0.75); + 2'b11: mult_in = in_data_frac * slope(0.75, 1.00); + endcase + end + + // Buffer multiplication, top frac bits, and int part + skid_buffer #( + .DATA_WIDTH(MULT_WIDTH + 2 + INT_WIDTH) + ) out_reg ( + .clk(clk), + .rst(rst), + .data_in({mult_in, frac_top_in, in_data_int}), + .data_in_valid(in_valid), + .data_in_ready(in_ready), + .data_out({mult_out, frac_top_out, in_data_int_buff[0]}), + .data_out_valid(mult_out_valid), + .data_out_ready(mult_out_ready) + ); + + // Add Intercept + always_comb begin + case (frac_top_out) + 2'b00: lpw_int_in = mult_out + intercept(0.00, 0.25); + 2'b01: lpw_int_in = mult_out + intercept(0.25, 0.50); + 2'b10: lpw_int_in = mult_out + intercept(0.50, 0.75); + 2'b11: lpw_int_in = mult_out + intercept(0.75, 1.00); + endcase + end + + skid_buffer #( + .DATA_WIDTH(LPW_WIDTH + INT_WIDTH) + ) intercept_reg ( + .clk(clk), + .rst(rst), + .data_in({lpw_int_in, in_data_int_buff[0]}), + .data_in_valid(mult_out_valid), + .data_in_ready(mult_out_ready), + .data_out({lpw_int_out, in_data_int_buff[1]}), + .data_out_valid(intercept_out_valid), + .data_out_ready(intercept_out_ready) + ); + + // TODO: Shift up for positive x + assign lpw_result = lpw_int_out >> -in_data_int_buff[1]; + + skid_buffer #( + .DATA_WIDTH(LPW_WIDTH) + ) lpw_reg ( + .clk(clk), + .rst(rst), + .data_in(lpw_result), + .data_in_valid(intercept_out_valid), + .data_in_ready(intercept_out_ready), + .data_out(lpw_out_data), + .data_out_valid(lpw_out_valid), + .data_out_ready(lpw_out_ready) + ); + + fixed_signed_cast #( + .IN_WIDTH(LPW_WIDTH + 1), + .IN_FRAC_WIDTH(LPW_FRAC_WIDTH), + .OUT_WIDTH(OUT_WIDTH + 1), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH), + .SYMMETRIC(0), + .ROUND_FLOOR(1) + ) fixed_cast ( + .in_data ({1'b0, lpw_out_data}), + .out_data(lpw_cast_out) + ); + + assign result_data = lpw_cast_out[OUT_WIDTH-1:0]; + assign result_valid = lpw_out_valid; + assign lpw_out_ready = result_ready; + + end + + endgenerate + + + // Output Register + skid_buffer #( + .DATA_WIDTH(OUT_WIDTH) + ) out_reg ( + .clk(clk), + .rst(rst), + .data_in(result_data), + .data_in_valid(result_valid), + .data_in_ready(result_ready), + .data_out(out_data), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + +endmodule diff --git a/src/mase_components/activations/rtl/softermax_lpw_reciprocal.sv b/src/mase_components/activations/rtl/softermax_lpw_reciprocal.sv new file mode 100644 index 000000000..a255b2c32 --- /dev/null +++ b/src/mase_components/activations/rtl/softermax_lpw_reciprocal.sv @@ -0,0 +1,299 @@ +/* +Module : softermax_lpw_reciprocal +Description : This module implements 1/x using linear piecewise approximation. + + The softermax module allows us to assume: + - Input is unsigned. (x >= 0) + - Therefore, the output is also unsigned. (y >= 0) + + This module calculates 1/x using linear piecewise approx. in the + domain: [1, 2). It will shift all numbers into that range and then + shift the number back once the 1/x calculation is done. + + Safer to use 8 entry LUT rather than 4. There will be a few more + bits of error with 4 entries, however 8 entries can achieve a + single-bit of error vs. software model. +*/ + +`timescale 1ns / 1ps + +module softermax_lpw_reciprocal #( + parameter ENTRIES = 8, // Must be power of 2 + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 5 +) ( + input logic clk, + input logic rst, + + // Input streaming interface + input logic [IN_WIDTH-1:0] in_data, + input logic in_valid, + output logic in_ready, + + output logic [OUT_WIDTH-1:0] out_data, + output logic out_valid, + input logic out_ready +); + + // ----- + // Parameters + // ----- + + let max(a, b) = (a > b) ? a : b; + + localparam ENTRIES_WIDTH = $clog2(ENTRIES); + + // Range reduced num: x + localparam RANGE_REDUCED_WIDTH = IN_WIDTH; + localparam RANGE_REDUCED_FRAC_WIDTH = IN_WIDTH - 1; + + // Slope: m + // IMPORTANT: This determines how precise the output of this module is. + // SLOPE_FRAC_WIDTH = OUT_WIDTH is maximum precision. + localparam SLOPE_FRAC_WIDTH = OUT_WIDTH; + localparam SLOPE_WIDTH = 1 + SLOPE_FRAC_WIDTH; + + // Mult: mx + localparam MULT_WIDTH = RANGE_REDUCED_WIDTH + SLOPE_WIDTH; + localparam MULT_FRAC_WIDTH = RANGE_REDUCED_FRAC_WIDTH + SLOPE_FRAC_WIDTH; + + // Intercept (need to match mx frac): c + localparam INTERCEPT_FRAC_WIDTH = MULT_FRAC_WIDTH; + localparam INTERCEPT_WIDTH = 2 + INTERCEPT_FRAC_WIDTH; // Needs 2 integer bits + + // Output width: mx + c + localparam LPW_WIDTH = MULT_WIDTH + 1; + localparam LPW_FRAC_WIDTH = MULT_FRAC_WIDTH; // == INTERCEPT_FRAC_WIDTH + + + // Recip width calculation: Need to pad extra 2 * max(intwidth, fracwidth) to + // make sure recip is not shifted out + localparam IN_INT_WIDTH = IN_WIDTH - IN_FRAC_WIDTH; + localparam EXTRA_WIDTH = max(IN_INT_WIDTH, IN_FRAC_WIDTH); + localparam RECIP_WIDTH = LPW_WIDTH + EXTRA_WIDTH; + localparam RECIP_FRAC_WIDTH = LPW_FRAC_WIDTH; + + // Shift num widths + localparam MSB_WIDTH = $clog2(IN_WIDTH); + localparam SHIFT_WIDTH = MSB_WIDTH + 1; + + initial begin + // Params + assert (ENTRIES >= 4); + assert (2 ** ENTRIES_WIDTH == ENTRIES); + assert (ENTRIES_WIDTH <= RANGE_REDUCED_FRAC_WIDTH); + assert (IN_WIDTH > IN_FRAC_WIDTH); + assert (IN_FRAC_WIDTH >= ENTRIES_WIDTH); + assert (OUT_WIDTH > OUT_FRAC_WIDTH); + assert (OUT_FRAC_WIDTH >= ENTRIES_WIDTH); + + // Sanity Asserts + assert (RANGE_REDUCED_WIDTH > RANGE_REDUCED_FRAC_WIDTH); + assert (SLOPE_WIDTH > SLOPE_FRAC_WIDTH); + assert (MULT_WIDTH > MULT_FRAC_WIDTH); + assert (INTERCEPT_WIDTH > INTERCEPT_FRAC_WIDTH); + assert (LPW_WIDTH > LPW_FRAC_WIDTH); + assert (RECIP_WIDTH > RECIP_FRAC_WIDTH); + end + + // ----- + // Wires + // ----- + + logic [RANGE_REDUCED_WIDTH-1:0] range_reduced_num[1:0]; + logic [MSB_WIDTH-1:0] msb[2:0]; + logic msb_not_found[4:0]; + logic range_reduce_out_valid, range_reduce_out_ready; + + logic [ENTRIES_WIDTH-1:0] frac_top_in, frac_top_out; + + logic [MULT_WIDTH-1:0] mult_in, mult_out; + logic mult_out_valid, mult_out_ready; + + logic [LPW_WIDTH-1:0] lpw_in_data, lpw_out_data; + logic lpw_out_valid, lpw_out_ready; + + logic [SHIFT_WIDTH-1:0] shift_amt_in, shift_amt_out; + + logic [RECIP_WIDTH-1:0] recip_in_data, recip_out_data; + logic recip_out_valid, recip_out_ready; + + logic [ OUT_WIDTH:0] cast_out_data; + + logic [OUT_WIDTH-1:0] output_reg_in_data; + + + // ----- + // Functions + // ----- + + // Function to generate slope variable (m) + function automatic logic [SLOPE_WIDTH-1:0] slope(real x1, real x2); + real y1, y2, res, res_shifted; + bit [SLOPE_WIDTH-1:0] return_val; + + // Calculate real result + y1 = 1.0 / x1; + y2 = 1.0 / x2; + res = (y2 - y1) / (x2 - x1); + + // Output cast + res_shifted = res * (2.0 ** SLOPE_FRAC_WIDTH); + return_val = logic'(res_shifted); + return return_val; + endfunction + + // Function to intercept variable (c) + function automatic logic [INTERCEPT_WIDTH-1:0] intercept(real x1, real x2); + real m, y1, y2, res, res_shifted; + bit [INTERCEPT_WIDTH-1:0] return_val; + + // Calculate real result + y1 = 1.0 / x1; + y2 = 1.0 / x2; + m = (y2 - y1) / (x2 - x1); + res = y1 - (m * x1); + + // Output cast + res_shifted = res * (2.0 ** INTERCEPT_FRAC_WIDTH); + return_val = logic'(res_shifted); + return return_val; + endfunction + + + // ----- + // Tables + // ----- + + logic [SLOPE_WIDTH-1:0] slope_lut[ENTRIES-1:0]; + logic [INTERCEPT_WIDTH-1:0] intercept_lut[ENTRIES-1:0]; + + initial begin + real step = 1.0 / ENTRIES; + for (int i = 0; i < ENTRIES; i++) begin + real start, stop; + start = 1.00 + (i * step); + stop = 1.00 + ((i + 1) * step); + slope_lut[i] = slope(start, stop); + intercept_lut[i] = intercept(start, stop); + end + end + + + // ----- + // Modules + // ----- + + fixed_range_reduction #( + .WIDTH(IN_WIDTH) + ) range_reduce ( + .data_a(in_data), + .data_out(range_reduced_num[0]), // This num is in the format Q1.(IN_WIDTH-1) + .msb_index(msb[0]), + .not_found(msb_not_found[0]) // if msb_not_found, then x = 0 + ); + + skid_buffer #( + .DATA_WIDTH(RANGE_REDUCED_WIDTH + MSB_WIDTH + 1) + ) range_reduce_reg ( + .clk(clk), + .rst(rst), + .data_in({range_reduced_num[0], msb[0], msb_not_found[0]}), + .data_in_valid(in_valid), + .data_in_ready(in_ready), + .data_out({range_reduced_num[1], msb[1], msb_not_found[1]}), + .data_out_valid(range_reduce_out_valid), + .data_out_ready(range_reduce_out_ready) + ); + + + // Multiplication Stage + assign frac_top_in = range_reduced_num[1][RANGE_REDUCED_WIDTH-2:RANGE_REDUCED_WIDTH-1-ENTRIES_WIDTH]; + assign mult_in = $signed({1'b0, range_reduced_num[1]}) * $signed(slope_lut[frac_top_in]); + + skid_buffer #( + .DATA_WIDTH(MULT_WIDTH + ENTRIES_WIDTH + MSB_WIDTH + 1) + ) mult_stage_reg ( + .clk(clk), + .rst(rst), + .data_in({mult_in, frac_top_in, msb[1], msb_not_found[1]}), + .data_in_valid(range_reduce_out_valid), + .data_in_ready(range_reduce_out_ready), + .data_out({mult_out, frac_top_out, msb[2], msb_not_found[2]}), + .data_out_valid(mult_out_valid), + .data_out_ready(mult_out_ready) + ); + + // Add Intercept to Mult + assign lpw_in_data = $signed(mult_out) + $signed({1'b0, intercept_lut[frac_top_out]}); + // Also convert MSB into a shift amount + assign shift_amt_in = IN_FRAC_WIDTH - msb[2]; + + skid_buffer #( + .DATA_WIDTH(LPW_WIDTH + SHIFT_WIDTH + 1) + ) lpw_stage_reg ( + .clk(clk), + .rst(rst), + .data_in({lpw_in_data, shift_amt_in, msb_not_found[2]}), + .data_in_valid(mult_out_valid), + .data_in_ready(mult_out_ready), + .data_out({lpw_out_data, shift_amt_out, msb_not_found[3]}), + .data_out_valid(lpw_out_valid), + .data_out_ready(lpw_out_ready) + ); + + always_comb begin + // Shift stage + if ($signed(shift_amt_out) >= 0) begin + recip_in_data = $signed(lpw_out_data) <<< shift_amt_out; + end else begin + recip_in_data = $signed(lpw_out_data) >>> -shift_amt_out; + end + end + + skid_buffer #( + .DATA_WIDTH(RECIP_WIDTH + 1) + ) recip_stage_reg ( + .clk(clk), + .rst(rst), + .data_in({recip_in_data, msb_not_found[3]}), + .data_in_valid(lpw_out_valid), + .data_in_ready(lpw_out_ready), + .data_out({recip_out_data, msb_not_found[4]}), + .data_out_valid(recip_out_valid), + .data_out_ready(recip_out_ready) + ); + + + // TODO: change to unsigned cast + fixed_signed_cast #( + .IN_WIDTH(RECIP_WIDTH + 1), + .IN_FRAC_WIDTH(LPW_FRAC_WIDTH), + .OUT_WIDTH(OUT_WIDTH + 1), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH), + .SYMMETRIC(0), + .ROUND_FLOOR(1) + ) signed_cast ( + .in_data ({1'b0, recip_out_data}), + .out_data(cast_out_data) + ); + + // Mux between INT_MAX and 1/x result (edge case for 1/0) + assign output_reg_in_data = (msb_not_found[4]) ? '1 : cast_out_data[OUT_WIDTH-1:0]; + + skid_buffer #( + .DATA_WIDTH(OUT_WIDTH) + ) output_reg ( + .clk(clk), + .rst(rst), + .data_in(output_reg_in_data), + .data_in_valid(recip_out_valid), + .data_in_ready(recip_out_ready), + .data_out(out_data), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + +endmodule diff --git a/src/mase_components/activations/test/fixed_softermax_1d_tb.py b/src/mase_components/activations/test/fixed_softermax_1d_tb.py new file mode 100644 index 000000000..e2809cd28 --- /dev/null +++ b/src/mase_components/activations/test/fixed_softermax_1d_tb.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +import os +from random import randint, choice + +import torch +import logging +from functools import partial + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import Timer + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, ErrorThresholdStreamMonitor +from mase_cocotb.runner import mase_runner +from mase_cocotb.utils import fixed_preprocess_tensor, bit_driver + +from chop.nn.quantized.functional import fixed_softermax + + +class SoftermaxTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.INFO) + + self.assign_self_params( + [ + "TOTAL_DIM", + "PARALLELISM", + "IN_WIDTH", + "IN_FRAC_WIDTH", + "POW2_WIDTH", + "POW2_FRAC_WIDTH", + "OUT_WIDTH", + "OUT_FRAC_WIDTH", + ] + ) + self.depth = self.TOTAL_DIM // self.PARALLELISM + + self.in_data_driver = StreamDriver( + dut.clk, dut.in_data, dut.in_valid, dut.in_ready + ) + + self.out_data_monitor = ErrorThresholdStreamMonitor( + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + width=self.OUT_WIDTH, + signed=False, + error_bits=1, + check=True, + log_error=True, + ) + + # Model + self.model = partial( + fixed_softermax, + q_config={ + "width": self.IN_WIDTH, + "frac_width": self.IN_FRAC_WIDTH, + }, + ) + + # Set verbosity of driver and monitor loggers to debug + # self.in_data_driver.log.setLevel(logging.DEBUG) + # self.out_data_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self, batches): + return torch.randn( + (batches, self.TOTAL_DIM), + ) + + async def run_test(self, batches, us): + await self.reset() + self.log.debug(f"Reset finished") + + inputs = self.generate_inputs(batches) + + for batch in inputs: + + exp_out = self.model(batch) + + # * Load the inputs driver + self.log.debug(f"Processing inputs: {batch}") + driver_input = fixed_preprocess_tensor( + tensor=batch, + q_config={ + "width": self.IN_WIDTH, + "frac_width": self.IN_FRAC_WIDTH, + }, + parallelism=[self.PARALLELISM], + ) + self.in_data_driver.load_driver(driver_input) + + # * Load the output monitor + self.log.debug(f"Processing outputs: {exp_out}") + outs = fixed_preprocess_tensor( + tensor=exp_out, + q_config={ + "width": self.OUT_WIDTH, + "frac_width": self.OUT_FRAC_WIDTH, + }, + parallelism=[self.PARALLELISM], + ) + self.out_data_monitor.load_monitor(outs) + + await Timer(us, units="us") + assert self.out_data_monitor.exp_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = SoftermaxTB(dut) + tb.out_data_monitor.ready.value = 1 + await tb.run_test(batches=1, us=10) + + +@cocotb.test() +async def stream(dut): + tb = SoftermaxTB(dut) + tb.out_data_monitor.ready.value = 1 + await tb.run_test(batches=1000, us=2000) + + +@cocotb.test() +async def valid_toggle(dut): + tb = SoftermaxTB(dut) + tb.in_data_driver.set_valid_prob(0.5) + tb.out_data_monitor.ready.value = 1 + await tb.run_test(batches=1000, us=2000) + + +@cocotb.test() +async def valid_backpressure_toggle(dut): + tb = SoftermaxTB(dut) + tb.in_data_driver.set_valid_prob(0.5) + cocotb.start_soon(bit_driver(tb.out_data_monitor.ready, dut.clk, 0.5)) + await tb.run_test(batches=1000, us=2000) + + +def get_fixed_softermax_config(kwargs={}): + config = { + "TOTAL_DIM": 20, + "PARALLELISM": 4, + "IN_WIDTH": 16, + "IN_FRAC_WIDTH": 6, + "POW2_WIDTH": 16, + "OUT_WIDTH": 16, + "OUT_FRAC_WIDTH": 6, + } + config.update(kwargs) + return config + + +def get_random_width(): + width = randint(2, 16) + frac_width = randint(1, width) + return width, frac_width + + +def get_random_softermax_config(): + parallelism = choice([2, 4, 8]) + depth = randint(2, 5) + in_width, in_frac_width = get_random_width() + out_width, out_frac_width = get_random_width() + config = { + "TOTAL_DIM": parallelism * depth, + "PARALLELISM": parallelism, + "IN_WIDTH": in_width, + "IN_FRAC_WIDTH": in_frac_width, + "POW2_WIDTH": 16, + "OUT_WIDTH": out_width, + "OUT_FRAC_WIDTH": out_frac_width, + } + return config + + +def test_fixed_softermax_1d_smoke(): + """ + Some quick tests to check if the module is working. + """ + mase_runner( + trace=True, + module_param_list=[ + get_fixed_softermax_config(), + *[get_random_softermax_config() for _ in range(50)], + ], + jobs=12, + # skip_build=True, + ) + + +if __name__ == "__main__": + test_fixed_softermax_1d_smoke() diff --git a/src/mase_components/activations/test/fixed_softermax_tb.py b/src/mase_components/activations/test/fixed_softermax_tb.py new file mode 100644 index 000000000..8c7e3be26 --- /dev/null +++ b/src/mase_components/activations/test/fixed_softermax_tb.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +import os + +import torch +import logging +from functools import partial + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import Timer + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.runner import mase_runner +from mase_cocotb.utils import fixed_preprocess_tensor + +from chop.nn.quantized.functional import fixed_softermax + + +class SoftermaxTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.DEBUG) + + self.in_data_driver = StreamDriver( + dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready + ) + + self.out_data_monitor = StreamMonitor( + dut.clk, + dut.data_out_0, + dut.data_out_0_valid, + dut.data_out_0_ready, + check=True, + ) + # Model + self.model = partial( + fixed_softermax, + q_config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + ) + + # Set verbosity of driver and monitor loggers to debug + self.in_data_driver.log.setLevel(logging.DEBUG) + self.out_data_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self): + return torch.randn( + ( + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_1"), + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"), + ) + ) + + async def run_test(self): + await self.reset() + self.log.info(f"Reset finished") + self.out_data_monitor.ready.value = 1 + + inputs = self.generate_inputs() + exp_out = self.model(inputs) + + # * Load the inputs driver + self.log.info(f"Processing inputs: {inputs}") + inputs = fixed_preprocess_tensor( + tensor=inputs, + q_config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], + ) + self.in_data_driver.load_driver(inputs) + + # * Load the output monitor + self.log.info(f"Processing outputs: {exp_out}") + outs = fixed_preprocess_tensor( + tensor=exp_out, + q_config={ + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), + ], + ) + self.out_data_monitor.load_monitor(outs) + + await Timer(1, units="ms") + assert self.out_data_monitor.exp_queue.empty() + + +@cocotb.test() +async def cocotb_test(dut): + tb = SoftermaxTB(dut) + await tb.run_test() + + +def get_fixed_softermax_config(kwargs={}): + config = { + "DATA_IN_0_PRECISION_0": 16, + "DATA_IN_0_PRECISION_1": 6, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, + "DATA_IN_0_TENSOR_SIZE_DIM_1": 20, + "DATA_IN_0_PARALLELISM_DIM_0": 2, + "DATA_IN_0_PARALLELISM_DIM_1": 2, + "DATA_OUT_0_PRECISION_0": 16, + "DATA_OUT_0_PRECISION_1": 6, + "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, + "DATA_OUT_0_TENSOR_SIZE_DIM_1": 20, + "DATA_OUT_0_PARALLELISM_DIM_0": 2, + "DATA_OUT_0_PARALLELISM_DIM_1": 2, + } + config.update(kwargs) + return config + + +def test_fixed_softermax_smoke(): + """ + Some quick tests to check if the module is working. + """ + mase_runner( + trace=True, + module_param_list=[ + get_fixed_softermax_config(), + ], + # skip_build=True, + ) + + +if __name__ == "__main__": + test_fixed_softermax_smoke() diff --git a/src/mase_components/activations/test/generate_memory.py b/src/mase_components/activations/test/generate_memory.py index 10df73ad7..d977876b7 100644 --- a/src/mase_components/activations/test/generate_memory.py +++ b/src/mase_components/activations/test/generate_memory.py @@ -21,6 +21,7 @@ def make_quantizer(data_width: int, f_width: int): "sigmoid": nn.Sigmoid(), "logsigmoid": nn.LogSigmoid(), "softshrink": nn.Softshrink(), + "gelu": nn.GELU(), "exp": torch.exp, "softmax": torch.exp, } @@ -270,21 +271,6 @@ def lookup_to_sv_file( print(f"SystemVerilog module generated and saved as {file_path}.") -# DEPRECATED -# def generate_mem(function_name, in_data_width, in_f_width, data_width, f_width): -# assert ( -# function_name in FUNCTION_TABLE -# ), f"Function {function_name} not found in FUNCTION_TABLE" -# lookup_to_file( -# in_data_width, -# in_f_width, -# data_width, -# f_width, -# function_name, -# f"/home/aw23/mase/machop/mase_components/activations/rtl/{function_name}_IN{in_data_width}_{in_f_width}_OUT{data_width}_{f_width}_map.mem", -# ) - - def generate_sv_lut( function_name, in_data_width, diff --git a/src/mase_components/activations/test/softermax.py b/src/mase_components/activations/test/softermax.py new file mode 100644 index 000000000..a1b3271f5 --- /dev/null +++ b/src/mase_components/activations/test/softermax.py @@ -0,0 +1,85 @@ +from math import exp + + +def softmax(l: list[float], pow2=False): + + max_num = max(l) + + norm = 0.0 + for x in l: + if pow2: + norm += 2 ** (x - max_num) + else: + norm += exp(x - max_num) + + if pow2: + out = [(2 ** (x - max_num)) / norm for x in l] + else: + out = [exp(x - max_num) / norm for x in l] + + assert abs(sum(out) - 1) < 1e-5, f"Sum is {sum(out)}" + + return out + + +def _softmax_model(l: list[int], parallelism: int, pow2=False): + """Model used to understand hardware implementation.""" + + assert len(l) % parallelism == 0 + iters = len(l) // parallelism + + # Calculate local max & local pow2 values + + local_values_buffer = [] + local_max_buffer = [] + + for i in range(iters): + local_window = l[i * parallelism : (i + 1) * parallelism] + local_max = max(local_window) + if pow2: + local_pow = [2 ** (x - local_max) for x in local_window] + else: + local_pow = [exp(x - local_max) for x in local_window] + local_max_buffer.append(local_max) + local_values_buffer.append(local_pow) + + # Calculate global max + + global_max = max(local_max_buffer) + local_max_diff = [global_max - x for x in local_max_buffer] + + adjusted_vals = [] + norm = 0.0 + + for diff, vals in zip(local_max_diff, local_values_buffer): + if pow2: + adj = [x * (2**-diff) for x in vals] + else: + adj = [x * exp(-diff) for x in vals] + norm += sum(adj) + adjusted_vals.append(adj) + + out = [] + for i in range(iters): + vals = adjusted_vals[i] + out.extend([x / norm for x in vals]) + + assert abs(sum(out) - 1) < 1e-5, f"Sum is {sum(out)}" + + return out + + +def check_lists(l1, l2, eps: float = 1e-5): + for x, y in zip(l1, l2): + assert abs(x - y) < eps + + +if __name__ == "__main__": + + LIST = [1, 2, 3, 4, 5, 6, 7, 8, 9] + PARALLELISM = 3 + + sw_softmax = softmax(LIST, pow2=True) + hw_softmax = _softmax_model(LIST, PARALLELISM, pow2=True) + check_lists(sw_softmax, hw_softmax) + print(hw_softmax) diff --git a/src/mase_components/activations/test/softermax_global_norm_tb.py b/src/mase_components/activations/test/softermax_global_norm_tb.py new file mode 100644 index 000000000..553a69e5a --- /dev/null +++ b/src/mase_components/activations/test/softermax_global_norm_tb.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 + +import logging +from math import ceil + +import torch +import numpy as np + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver, sign_extend_t, batched +from mase_cocotb.interfaces.streaming import StreamDriver, ErrorThresholdStreamMonitor + +import cocotb +from cocotb.triggers import * + +from chop.passes.graph.transforms.quantize.quantizers.integer import ( + integer_floor_quantizer, +) +from chop.passes.graph.transforms.quantize.quantizers.quantizers_for_hw import ( + unsigned_integer_quantizer_for_hw, +) + +logger = logging.getLogger("testbench") +logger.setLevel("INFO") + + +class SoftermaxGlobalNormTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + [ + "TOTAL_DIM", + "PARALLELISM", + "DEPTH", + "IN_VALUE_WIDTH", + "RECIP_WIDTH", + "RECIP_FRAC_WIDTH", + "IN_VALUE_FRAC_WIDTH", + "IN_MAX_WIDTH", + "OUT_WIDTH", + "OUT_FRAC_WIDTH", + ] + ) + + # Driver/Monitor + self.in_driver = StreamDriver( + dut.clk, (dut.in_values, dut.in_max), dut.in_valid, dut.in_ready + ) + + # Specify Error Threshold + self.percentage_error = 0.05 # 5% + self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) + + self.output_monitor = ErrorThresholdStreamMonitor( + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + width=self.OUT_WIDTH, + signed=False, + error_bits=self.error_threshold_bits, + log_error=True, + check=True, + ) + + def generate_inputs(self, batches=10): + # TODO: Take a look at all zero case again + local_vals = torch.randint( + 1, 2**self.IN_VALUE_WIDTH, size=(batches * self.DEPTH, self.PARALLELISM) + ) + local_max = torch.randint( + 0, 2**self.IN_MAX_WIDTH, size=(batches * self.DEPTH, 1) + ) + + logger.debug("local_vals: %s" % (local_vals)) + logger.debug( + "local_vals (float): %s" % (local_vals / (2**self.IN_VALUE_FRAC_WIDTH)) + ) + logger.debug("local_max: %s" % (local_max)) + logger.debug( + "local_max (signed): %s" % (sign_extend_t(local_max, self.IN_MAX_WIDTH)) + ) + + return local_vals.tolist(), local_max.flatten().tolist() + + def model(self, inputs): + batched_in = list(batched(inputs, self.DEPTH)) + exp_output = [] + + for batch in batched_in: + local_vals, local_max = list(zip(*batch)) + local_vals = torch.tensor(list(local_vals), dtype=torch.float) / ( + 2**self.IN_VALUE_FRAC_WIDTH + ) + local_max = torch.tensor(list(local_max), dtype=torch.float) + local_max = sign_extend_t( + torch.tensor(list(local_max), dtype=torch.float), self.IN_MAX_WIDTH + ) + + global_max = local_max.max() + adj_amt = global_max - local_max.reshape(self.DEPTH, 1) + adj_values = integer_floor_quantizer( + x=local_vals / (2**adj_amt), + width=self.IN_VALUE_WIDTH, + frac_width=self.IN_VALUE_FRAC_WIDTH, + is_signed=False, + ) + norm = adj_values.sum() + inv_norm = integer_floor_quantizer( + x=1 / (norm + 1e-10), + width=self.RECIP_WIDTH, + frac_width=self.RECIP_FRAC_WIDTH, + is_signed=False, + ) + softermax = adj_values * inv_norm + softermax_int = unsigned_integer_quantizer_for_hw( + softermax, self.OUT_WIDTH, self.OUT_FRAC_WIDTH + ) + + logger.debug("Values: %s" % (local_vals)) + logger.debug("Max: %s -> %s" % (local_max, global_max)) + logger.debug("Diff: %s" % (adj_amt)) + logger.debug("ADJ Values: %s" % (adj_values)) + logger.debug("norm: %s" % (norm)) + logger.debug("softermax: %s" % (softermax)) + logger.debug("softermax (int): %s" % (softermax_int)) + logger.debug("sanity sum: %s" % (softermax.sum().item())) + logger.debug("integer sum: %s" % (softermax_int.sum().item())) + + # logger.info(adj_values) + # logger.info(norm) + # logger.info(softermax) + + # assert abs(softermax.sum().item() - 1) < 0.1, f"Sum is {softermax.sum().item()}" + + exp_output.append(softermax_int) + + return torch.cat(exp_output, dim=0).tolist() + + async def run_test(self, batches, us): + inputs = self.generate_inputs(batches) + driver_inputs = list(zip(*inputs)) + exp_out = self.model(driver_inputs) + self.in_driver.load_driver(driver_inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, "us") + assert self.output_monitor.recv_queue.empty() + self._final_check() + + def _final_check(self): + if len(self.output_monitor.error_log) == 0: + logger.info("No Errors.") + # No errors + return + errors = np.stack(self.output_monitor.error_log) + max_bit_err = np.max(errors) + logger.info("Maximum bit-error: %d", max_bit_err) + if max_bit_err > self.error_threshold_bits: + assert False, ( + "Test failed due to high approximation error. Got %d bits of error!" + % max_bit_err + ) + + +@cocotb.test() +async def basic(dut): + tb = SoftermaxGlobalNormTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=1, us=2) + + +@cocotb.test() +async def stream(dut): + tb = SoftermaxGlobalNormTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=1000, us=2000) + + +@cocotb.test() +async def backpressure(dut): + tb = SoftermaxGlobalNormTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + await tb.reset() + await tb.run_test(batches=100, us=2000) + + +@cocotb.test() +async def valid(dut): + tb = SoftermaxGlobalNormTB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=100, us=2000) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = SoftermaxGlobalNormTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=1000, us=2000) + + +if __name__ == "__main__": + + DEFAULT = { + "TOTAL_DIM": 16, + "PARALLELISM": 4, + "IN_VALUE_WIDTH": 16, + "IN_MAX_WIDTH": 3, + "OUT_WIDTH": 8, + "OUT_FRAC_WIDTH": 7, + } + + def parallelism_cfgs(cfgs: list): + new_cfgs = [] + for cfg in cfgs: + for par in [1, 4, 8]: # parallelism + for depth in [2, 4, 8]: + total = depth * par + new_cfgs.append({**cfg, "TOTAL_DIM": total, "PARALLELISM": par}) + return new_cfgs + + def in_value_cfgs(cfgs: list): + new_cfgs = [] + for cfg in cfgs: + for in_width in [4, 7, 10]: + new_cfgs.append( + { + **cfg, + "IN_VALUE_WIDTH": in_width, + } + ) + return new_cfgs + + def in_max_cfgs(cfgs: list): + new_cfgs = [] + for cfg in cfgs: + for in_max in [2, 3, 4]: + new_cfgs.append( + { + **cfg, + "IN_MAX_WIDTH": in_max, + } + ) + return new_cfgs + + gen_cfgs = parallelism_cfgs([{}]) + gen_cfgs = in_value_cfgs(gen_cfgs) + gen_cfgs = in_max_cfgs(gen_cfgs) + + cfgs = [ + DEFAULT, + *gen_cfgs, + ] + + # cfgs = [{'TOTAL_DIM': 32, 'PARALLELISM': 4, 'IN_VALUE_WIDTH': 16, 'IN_MAX_WIDTH': 2}] + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=12, + ) diff --git a/src/mase_components/activations/test/softermax_local_window_tb.py b/src/mase_components/activations/test/softermax_local_window_tb.py new file mode 100644 index 000000000..41e4c20a1 --- /dev/null +++ b/src/mase_components/activations/test/softermax_local_window_tb.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 + +import logging +from random import randint + +import torch + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver, sign_extend_t, sign_extend, signed_to_unsigned +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + StreamMonitor, +) + +from mase_components.cast.test.fixed_signed_cast_tb import _fixed_signed_cast_model + +import cocotb +from cocotb.triggers import * + +logger = logging.getLogger("testbench") +logger.setLevel("INFO") + + +class SoftermaxLocalWindowTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + [ + "PARALLELISM", + "IN_WIDTH", + "IN_FRAC_WIDTH", + "OUT_WIDTH", + "OUT_FRAC_WIDTH", + "MAX_WIDTH", + "SUBTRACT_WIDTH", + "SUBTRACT_FRAC_WIDTH", + ] + ) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + self.output_monitor = StreamMonitor( + dut.clk, + (dut.out_values, dut.out_max), + dut.out_valid, + dut.out_ready, + check=True, + ) + + def generate_inputs(self, batches=10): + return [ + [randint(0, 2**self.IN_WIDTH - 1) for _ in range(self.PARALLELISM)] + for _ in range(batches) + ] + + # def _lpw_pow2_model(self, inputs): + # """ + # Copy over model for lpw_pow2 since verilator does not allow acessing + # signals/modules which are in generate block in cocotb. + # https://github.com/cocotb/cocotb/issues/1884 + # """ + # in_t = torch.tensor(inputs) + # num = sign_extend_t(in_t, self.SUBTRACT_WIDTH) / (2 ** self.SUBTRACT_FRAC_WIDTH) + # res = 2 ** num + # res = (res * 2**self.OUT_FRAC_WIDTH).int() + # res = torch.clamp(res, 0, 2**self.OUT_WIDTH-1) + # return res + + # def _comparator_tree_model(self, inputs): + # inputs = [[sign_extend(x, self.IN_WIDTH) for x in l] for l in inputs] + # exp_out = [max(l) for l in inputs] + # exp_out = [signed_to_unsigned(x, self.IN_WIDTH) for x in exp_out] + # return exp_out + + def model(self, inputs): + sign_ext = sign_extend_t( + torch.tensor(inputs, dtype=torch.float), bits=self.IN_WIDTH + ) + float_inputs = sign_ext / (2**self.IN_FRAC_WIDTH) + # float_inputs = torch.tensor([[-31.5, -32]]) + rounded_inputs_float, rounded_inputs_uint = _fixed_signed_cast_model( + float_inputs, self.MAX_WIDTH, 0, False, "floor" + ) + local_max = rounded_inputs_float.max(dim=1, keepdim=True).values + local_max_uint = signed_to_unsigned(local_max.int(), self.MAX_WIDTH) + + difference = float_inputs - local_max + pow2 = 2**difference + res = torch.clamp( + (pow2 * 2**self.OUT_FRAC_WIDTH).int(), 0, 2**self.OUT_WIDTH - 1 + ) + + logger.debug("float_inputs: %s" % float_inputs) + logger.debug("rounded_inputs_float: %s" % rounded_inputs_float) + logger.debug("local_max: %s" % local_max) + logger.debug("local_max_uint: %s" % local_max_uint) + logger.debug("difference: %s" % difference) + logger.debug("pow2: %s" % pow2) + logger.debug("res: %s" % res) + + exp_vals = res.tolist() + exp_max = local_max_uint.tolist() + + logger.debug("exp_vals: %s" % exp_vals) + logger.debug("exp_max: %s" % exp_max) + + return list(zip(exp_vals, exp_max)) + + async def run_test(self, batches, us): + inputs = self.generate_inputs(batches) + exp_out = self.model(inputs) + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, "us") + assert self.output_monitor.recv_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = SoftermaxLocalWindowTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=3, us=20) + + +@cocotb.test() +async def stream(dut): + tb = SoftermaxLocalWindowTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=1000, us=200) + + +@cocotb.test() +async def backpressure(dut): + tb = SoftermaxLocalWindowTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + await tb.reset() + await tb.run_test(batches=100, us=200) + + +@cocotb.test() +async def valid(dut): + tb = SoftermaxLocalWindowTB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=100, us=200) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = SoftermaxLocalWindowTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=1000, us=200) + + +if __name__ == "__main__": + + DEFAULT = { + "PARALLELISM": 4, + "IN_WIDTH": 8, + "IN_FRAC_WIDTH": 2, + "OUT_WIDTH": 8, + "OUT_FRAC_WIDTH": 7, + } + + def parallelism_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + for d in [2, 4, 16]: + out.append({**cfg, "PARALLELISM": d}) + return out + + cfgs = [DEFAULT] + cfgs = parallelism_cfgs(cfgs) + + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=4, + ) diff --git a/src/mase_components/activations/test/softermax_lpw_pow2_tb.py b/src/mase_components/activations/test/softermax_lpw_pow2_tb.py new file mode 100644 index 000000000..501243918 --- /dev/null +++ b/src/mase_components/activations/test/softermax_lpw_pow2_tb.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 + +import os, logging +from random import randint +from pathlib import Path + +import torch + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver, sign_extend_t +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + ErrorThresholdStreamMonitor, +) + +import cocotb +from cocotb.triggers import * + +import pandas as pd +import altair as alt + +logger = logging.getLogger("testbench") +logger.setLevel("DEBUG") + + +class LPW_Pow2TB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + ["IN_WIDTH", "IN_FRAC_WIDTH", "OUT_WIDTH", "OUT_FRAC_WIDTH"] + ) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + self.error_threshold_bits = 2 + self.output_monitor = ErrorThresholdStreamMonitor( + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + width=self.OUT_WIDTH, + log_error=True, + signed=True, + error_bits=self.error_threshold_bits, + check=False, + ) + + def generate_inputs(self): + negative_nums = torch.arange( + start=2 ** (self.IN_WIDTH - 1), end=2**self.IN_WIDTH, dtype=torch.int32 + ) + zero_to_one = torch.arange( + start=0, end=2**self.IN_FRAC_WIDTH, dtype=torch.int32 # one + ) + return torch.cat((negative_nums, zero_to_one)).tolist() + + def model(self, inputs): + in_t = torch.tensor(inputs) + num = sign_extend_t(in_t, self.IN_WIDTH) / (2**self.IN_FRAC_WIDTH) + res = 2**num + res = (res * 2**self.OUT_FRAC_WIDTH).int() + res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) + return res.tolist() + + async def run_test(self, us): + await self.reset() + inputs = self.generate_inputs() + # logger.debug(inputs) + exp_out = self.model(inputs) + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, "us") + assert self.output_monitor.exp_queue.empty() + self._final_check() + + def _final_check(self): + max_bit_err = max(self.output_monitor.error_log) + logger.info("Maximum bit-error: %d", max_bit_err) + if max_bit_err > self.error_threshold_bits: + assert False, ( + "Test failed due to high approximation error. Got %d bits of error!" + % max_bit_err + ) + + +@cocotb.test() +async def sweep(dut): + tb = LPW_Pow2TB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + inputs = tb.generate_inputs() + exp_out = tb.model(inputs) + tb.in_driver.load_driver(inputs) + tb.output_monitor.load_monitor(exp_out) + await Timer(20, "us") + assert tb.output_monitor.exp_queue.empty() + + # Graphing error + recv_log = tb.output_monitor.recv_log + assert len(exp_out) == len(recv_log) + + x = sign_extend_t(torch.tensor(inputs), tb.IN_WIDTH) / (2**tb.IN_FRAC_WIDTH) + ref = 2**x + ref *= 2**tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) + + data = pd.DataFrame( + { + "x": x.tolist(), + "reference": ref.tolist(), + "software": exp_out, + "hardware": recv_log, + } + ).melt( + id_vars="x", + value_vars=["reference", "software", "hardware"], + value_name="Value", + var_name="Type", + ) + + graph_id = f"{tb.IN_WIDTH}_{tb.IN_FRAC_WIDTH}_to_{tb.OUT_WIDTH}_{tb.OUT_FRAC_WIDTH}" + alt.Chart(data).mark_line().encode( + x="x", + y="Value", + color="Type", + ).properties( + width=600, + height=300, + ).save( + Path(__file__).parent / f"build/softermax_lpw_pow2/error_graph_{graph_id}.png", + scale_factor=3, + ) + + tb._final_check() + + +@cocotb.test() +async def backpressure(dut): + tb = LPW_Pow2TB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.6)) + await tb.run_test(us=100) + + +@cocotb.test() +async def valid(dut): + tb = LPW_Pow2TB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.5) + await tb.run_test(us=100) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = LPW_Pow2TB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.run_test(us=100) + + +if __name__ == "__main__": + + DEFAULT = {"IN_WIDTH": 16, "IN_FRAC_WIDTH": 3, "OUT_WIDTH": 16, "OUT_FRAC_WIDTH": 3} + + def width_cfgs(): + bitwidths = [2, 4, 8] + cfgs = [] + for in_width in bitwidths: + for in_frac_width in range(1, in_width): + for out_width in bitwidths: + for out_frac_width in range(1, out_width): + cfgs.append( + { + "IN_WIDTH": in_width, + "IN_FRAC_WIDTH": in_frac_width, + "OUT_WIDTH": out_width, + "OUT_FRAC_WIDTH": out_frac_width, + } + ) + return cfgs + + cfgs = width_cfgs() + # cfgs = [DEFAULT] + + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=8, + ) diff --git a/src/mase_components/activations/test/softermax_lpw_reciprocal_tb.py b/src/mase_components/activations/test/softermax_lpw_reciprocal_tb.py new file mode 100644 index 000000000..e0d2ee930 --- /dev/null +++ b/src/mase_components/activations/test/softermax_lpw_reciprocal_tb.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +import logging +from pathlib import Path +from random import randint +from math import ceil + +import torch + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + ErrorThresholdStreamMonitor, +) + +import cocotb +from cocotb.triggers import * + +import pandas as pd +import altair as alt + +logger = logging.getLogger("testbench") +logger.setLevel("DEBUG") + + +class LPW_Reciprocal2TB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + ["ENTRIES", "IN_WIDTH", "IN_FRAC_WIDTH", "OUT_WIDTH", "OUT_FRAC_WIDTH"] + ) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + + # Specify Error Threshold + self.percentage_error = 0.05 # 5% + self.error_threshold_bits = ceil(self.percentage_error * (2**self.OUT_WIDTH)) + + self.output_monitor = ErrorThresholdStreamMonitor( + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + width=self.OUT_WIDTH, + log_error=True, + signed=False, + error_bits=self.error_threshold_bits, + check=False, # We manually assert later + ) + + def generate_inputs(self, batches=100): + return [randint(0, 2**self.IN_WIDTH - 1) for _ in range(batches)] + + def sweep_input(self): + return list(range(2**self.IN_WIDTH)) + + def model(self, inputs): + in_t = torch.tensor(inputs) / (2**self.IN_FRAC_WIDTH) + recip = 1.0 / in_t + res = torch.floor(recip * 2**self.OUT_FRAC_WIDTH) + res = torch.nan_to_num(res) + res = torch.clamp(res, 0, 2**self.OUT_WIDTH - 1) + res = res.int() + return res.tolist() + + async def run_test(self, batches, us): + await self.reset() + inputs = self.generate_inputs(batches=batches) + exp_out = self.model(inputs) + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, "us") + assert self.output_monitor.exp_queue.empty() + self._final_check() + + def _final_check(self): + max_bit_err = max(self.output_monitor.error_log) + logger.info("Maximum bit-error: %d", max_bit_err) + if max_bit_err > self.error_threshold_bits: + assert False, ( + "Test failed due to high approximation error. Got %d bits of error!" + % max_bit_err + ) + + +@cocotb.test() +async def sweep(dut): + tb = LPW_Reciprocal2TB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + if tb.IN_WIDTH > 16: + logger.warning("Not doing full sweep due to large input bitwidth.") + return + else: + inputs = tb.sweep_input() + exp_out = tb.model(inputs) + tb.in_driver.load_driver(inputs) + tb.output_monitor.load_monitor(exp_out) + await Timer(4000, "us") + assert tb.output_monitor.exp_queue.empty() + + # Graphing error + recv_log = tb.output_monitor.recv_log + assert len(exp_out) == len(recv_log) + + x = torch.tensor(inputs) / (2**tb.IN_FRAC_WIDTH) + ref = 1.0 / x + ref *= 2**tb.OUT_FRAC_WIDTH # scale up + ref = torch.clamp(ref, 0, 2**tb.OUT_WIDTH - 1) + + data = pd.DataFrame( + { + "x": x.tolist(), + "reference": ref.tolist(), + "software": exp_out, + "hardware": recv_log, + } + ) + data["hw_error"] = (data["hardware"] - data["reference"]).abs() + data["sw_error"] = (data["software"] - data["reference"]).abs() + data["model_error"] = (data["hardware"] - data["software"]).abs() + + curve_data = data.melt( + id_vars="x", + value_vars=["reference", "software", "hardware"], + value_name="Value", + var_name="Type", + ) + curve = ( + alt.Chart(curve_data) + .mark_line() + .encode( + x="x", + y=alt.Y("Value").title("Curves"), + color="Type", + ) + .properties( + width=600, + height=300, + ) + ) + + error_data = data.melt( + id_vars="x", + value_vars=["hw_error", "sw_error"], + value_name="Value", + var_name="Type", + ) + error = ( + alt.Chart(error_data) + .mark_line() + .encode( + x="x", + y=alt.Y("Value").title("Error vs. Perfect Reference"), + color="Type", + ) + .properties( + width=600, + height=100, + ) + ) + + model_error_data = data.melt( + id_vars="x", value_vars=["model_error"], value_name="Value", var_name="Type" + ) + model_error = ( + alt.Chart(model_error_data) + .mark_line() + .encode( + x="x", + y=alt.Y("Value").title("Bit error vs software model"), + color="Type", + ) + .properties( + width=600, + height=100, + ) + ) + + graph_id = f"{tb.ENTRIES}e_{tb.IN_WIDTH}_{tb.IN_FRAC_WIDTH}_to_{tb.OUT_WIDTH}_{tb.OUT_FRAC_WIDTH}" + (curve & error & model_error).save( + Path(__file__).parent + / f"build/softermax_lpw_reciprocal/error_graph_{graph_id}.png", + scale_factor=3, + ) + + tb._final_check() + + +@cocotb.test() +async def backpressure(dut): + tb = LPW_Reciprocal2TB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.6)) + await tb.run_test(batches=1000, us=400) + + +@cocotb.test() +async def valid(dut): + tb = LPW_Reciprocal2TB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.5) + await tb.run_test(batches=1000, us=400) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = LPW_Reciprocal2TB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.run_test(batches=1000, us=400) + + +if __name__ == "__main__": + + DEFAULT = { + "ENTRIES": 8, + "IN_WIDTH": 8, + "IN_FRAC_WIDTH": 3, + "OUT_WIDTH": 8, + "OUT_FRAC_WIDTH": 7, + } + + def random_cfg(): + in_width = randint(4, 30) + out_width = randint(4, 30) + return { + "ENTRIES": 8, + "IN_WIDTH": in_width, + "IN_FRAC_WIDTH": randint(3, in_width - 1), + "OUT_WIDTH": out_width, + "OUT_FRAC_WIDTH": randint(3, out_width - 1), + } + + NUM_RANDOM_CFGS = 40 + random_cfgs = [random_cfg() for _ in range(NUM_RANDOM_CFGS)] + + cfgs = [ + DEFAULT, + { + "ENTRIES": 8, + "IN_WIDTH": 20, + "IN_FRAC_WIDTH": 10, + "OUT_WIDTH": 20, + "OUT_FRAC_WIDTH": 3, + }, + *random_cfgs, + ] + + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=12, + ) diff --git a/src/mase_components/activations/test/test_lint_activations.py b/src/mase_components/activations/test/test_lint_activations.py index a45bd89b4..db6031a01 100644 --- a/src/mase_components/activations/test/test_lint_activations.py +++ b/src/mase_components/activations/test/test_lint_activations.py @@ -1,7 +1,11 @@ from mase_components.linter import run_lint +from .generate_memory import generate_sv_lut, FUNCTION_TABLE + def test_lint_activations(): + for func, _ in FUNCTION_TABLE.items(): + generate_sv_lut(func, 8, 4, data_width=8, f_width=4, path_with_dtype=False) run_lint("activations") diff --git a/src/mase_components/activations/test/test_synth_activations.py b/src/mase_components/activations/test/test_synth_activations.py new file mode 100644 index 000000000..33c9aea5a --- /dev/null +++ b/src/mase_components/activations/test/test_synth_activations.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_activations(): + run_synth("activations") + + +if __name__ == "__main__": + test_synth_activations() diff --git a/src/mase_components/arbiters/__init__.py b/src/mase_components/arbiters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mase_components/arbiters/rtl/__init__.py b/src/mase_components/arbiters/rtl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mase_components/arbiters/rtl/find_first_arbiter.sv b/src/mase_components/arbiters/rtl/find_first_arbiter.sv new file mode 100644 index 000000000..b6d2729ab --- /dev/null +++ b/src/mase_components/arbiters/rtl/find_first_arbiter.sv @@ -0,0 +1,23 @@ +`timescale 1ns / 1ps +module find_first_arbiter #( + parameter NUM_REQUESTERS = 4 +) ( + input [ NUM_REQUESTERS - 1:0] request, + output logic [ NUM_REQUESTERS - 1:0] grant_oh, + output logic [$clog2(NUM_REQUESTERS)-1:0] grant_bin +); + + always_comb begin + grant_oh = '0; + grant_bin = '0; + + for (int i = 0; i < NUM_REQUESTERS; i++) begin + if (request[i]) begin + grant_oh = ({{(NUM_REQUESTERS - 1) {1'b0}}, 1'b1} << i); + grant_bin = i; + break; + end + end + end + +endmodule diff --git a/src/mase_components/arbiters/test/__init__.py b/src/mase_components/arbiters/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mase_components/arithmetic/test/test_synth_arithmetic.py b/src/mase_components/arithmetic/test/test_synth_arithmetic.py new file mode 100644 index 000000000..0f0da84b5 --- /dev/null +++ b/src/mase_components/arithmetic/test/test_synth_arithmetic.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_arithmetic(): + run_synth("arithmetic") + + +if __name__ == "__main__": + test_synth_arithmetic() diff --git a/src/mase_components/attention/rtl/fixed_att.sv b/src/mase_components/attention/rtl/fixed_att.sv deleted file mode 100644 index 5a3a80a2d..000000000 --- a/src/mase_components/attention/rtl/fixed_att.sv +++ /dev/null @@ -1,401 +0,0 @@ -`timescale 1ns / 1ps -/* -The first version without softmax -noted in order to make output dimension match -make w_parallelism = w_size - w_num_parallelism = in_depth - but the dimension constraint, shown in the testbench part but not here - data_in [IN_PARALLELISM * IN_NUM_PARALLELISM][IN_SIZE * IN_DEPTH] - weight_q[W_PARALLELISM * W_NUM_PARALLELISM][W_SIZE * IN_DEPTH] - weight_k[W_PARALLELISM * W_NUM_PARALLELISM][W_SIZE * IN_DEPTH] - weight_v[W_PARALLELISM * W_NUM_PARALLELISM][W_SIZE * IN_DEPTH] - - data_q [W_PARALLELISM * W_NUM_PARALLELISM][IN_PARALLELISM * IN_NUM_PARALLELISM] - data_k [W_PARALLELISM * W_NUM_PARALLELISM][IN_PARALLELISM * IN_NUM_PARALLELISM] - - data_v_t[IN_PARALLELISM * IN_NUM_PARALLELISM][W_PARALLELISM * W_NUM_PARALLELISM] - - data_s [W_PARALLELISM * W_NUM_PARALLELISM][W_PARALLELISM * W_NUM_PARALLELISM] - - data_z [W_PARALLELISM * W_NUM_PARALLELISM][IN_PARALLELISM * IN_NUM_PARALLELISM] - data_out[IN_PARALLELISM][W_PARALLELISM] - - realize the function - data_z = att(data_in) -*/ -module fixed_att #( - parameter DQIN_WIDTH = 8, - parameter DQIN_FRAC_WIDTH = 1, - parameter DKIN_WIDTH = 8, - parameter DKIN_FRAC_WIDTH = 1, - parameter DVIN_WIDTH = 8, - parameter DVIN_FRAC_WIDTH = 1, - - parameter WQ_WIDTH = 8, - parameter WQ_FRAC_WIDTH = 1, - parameter WK_WIDTH = 8, - parameter WK_FRAC_WIDTH = 1, - parameter WV_WIDTH = 8, - parameter WV_FRAC_WIDTH = 1, - - parameter BQ_WIDTH = 8, - parameter BQ_FRAC_WIDTH = 1, - parameter BK_WIDTH = 8, - parameter BK_FRAC_WIDTH = 1, - parameter BV_WIDTH = 8, - parameter BV_FRAC_WIDTH = 1, - - parameter DQ_WIDTH = 8, - parameter DQ_FRAC_WIDTH = 1, - parameter DK_WIDTH = 8, - parameter DK_FRAC_WIDTH = 1, - parameter DV_WIDTH = 8, - parameter DV_FRAC_WIDTH = 1, - - parameter DS_WIDTH = 8, - parameter DS_FRAC_WIDTH = 1, - parameter EXP_WIDTH = 8, - parameter EXP_FRAC_WIDTH = 4, - parameter DIV_WIDTH = 10, - parameter DS_SOFTMAX_WIDTH = 8, - parameter DS_SOFTMAX_FRAC_WIDTH = 7, - - parameter DZ_WIDTH = 8, - parameter DZ_FRAC_WIDTH = 1, - - parameter IN_PARALLELISM = 3, - parameter IN_NUM_PARALLELISM = 2, - - parameter IN_SIZE = 3, - //define for matrix multilication - parameter IN_DEPTH = 3, - - parameter W_PARALLELISM = 2, - parameter W_NUM_PARALLELISM = 3, - parameter W_SIZE = IN_SIZE, - - - parameter OUT_PARALLELISM = IN_PARALLELISM, - parameter OUT_SIZE = W_PARALLELISM -) ( - input clk, - input rst, - - input [WQ_WIDTH - 1:0] weight_q[W_PARALLELISM * W_SIZE -1 : 0], - input weight_q_valid, - output weight_q_ready, - - input [WK_WIDTH - 1:0] weight_k[W_PARALLELISM * W_SIZE -1 : 0], - input weight_k_valid, - output weight_k_ready, - - input [WV_WIDTH - 1:0] weight_v[W_PARALLELISM * W_SIZE -1 : 0], - input weight_v_valid, - output weight_v_ready, - - input [BQ_WIDTH - 1:0] bias_q[W_PARALLELISM -1 : 0], - input bias_q_valid, - output bias_q_ready, - - input [BK_WIDTH - 1:0] bias_k[W_PARALLELISM -1 : 0], - input bias_k_valid, - output bias_k_ready, - - input [BV_WIDTH - 1:0] bias_v[W_PARALLELISM -1 : 0], - input bias_v_valid, - output bias_v_ready, - - input [DQIN_WIDTH -1:0] data_in_q[IN_PARALLELISM * IN_SIZE - 1 : 0], - input data_in_q_valid, - output data_in_q_ready, - - input [DKIN_WIDTH -1:0] data_in_k[IN_PARALLELISM * IN_SIZE - 1 : 0], - input data_in_k_valid, - output data_in_k_ready, - - input [DVIN_WIDTH -1:0] data_in_v[IN_PARALLELISM * IN_SIZE - 1 : 0], - input data_in_v_valid, - output data_in_v_ready, - - output [DZ_WIDTH -1:0] data_out_0[OUT_PARALLELISM * OUT_SIZE - 1:0], - output data_out_0_valid, - input data_out_0_ready -); - - logic [DQIN_WIDTH-1:0] ff_data_in_q[IN_PARALLELISM * IN_SIZE - 1:0]; - logic [DKIN_WIDTH-1:0] ff_data_in_k[IN_PARALLELISM * IN_SIZE - 1:0]; - logic [DVIN_WIDTH-1:0] ff_data_in_v[IN_PARALLELISM * IN_SIZE - 1:0]; - logic ff_data_in_q_valid, ff_data_in_k_valid, ff_data_in_v_valid; - logic ff_data_in_q_ready, ff_data_in_k_ready, ff_data_in_v_ready; - // assign ff_data_in_qk_ready = ff_data_in_q_ready&&ff_data_in_k_ready; - // fifo for qk - unpacked_fifo #( - .DEPTH(IN_DEPTH * IN_NUM_PARALLELISM), - .DATA_WIDTH(DQIN_WIDTH), - .IN_NUM(IN_PARALLELISM * IN_SIZE) - ) fifo_q ( - .clk(clk), - .rst(rst), - .data_in(data_in_q), - .data_in_valid(data_in_q_valid), - .data_in_ready(data_in_q_ready), - .data_out(ff_data_in_q), - .data_out_valid(ff_data_in_q_valid), - .data_out_ready(ff_data_in_q_ready) - ); - - unpacked_fifo #( - .DEPTH(IN_DEPTH * IN_NUM_PARALLELISM), - .DATA_WIDTH(DKIN_WIDTH), - .IN_NUM(IN_PARALLELISM * IN_SIZE) - ) fifo_k ( - .clk(clk), - .rst(rst), - .data_in(data_in_k), - .data_in_valid(data_in_k_valid), - .data_in_ready(data_in_k_ready), - .data_out(ff_data_in_k), - .data_out_valid(ff_data_in_k_valid), - .data_out_ready(ff_data_in_k_ready) - ); - - logic [DQ_WIDTH - 1 : 0] data_q[IN_PARALLELISM * W_PARALLELISM - 1:0]; - logic data_q_valid, data_q_ready; - logic [DK_WIDTH - 1 : 0] data_k[IN_PARALLELISM * W_PARALLELISM - 1:0]; - logic data_k_valid, data_k_ready; - //matmul qk - fixed_2d_linear #( - .IN_WIDTH(DQIN_WIDTH), - .IN_FRAC_WIDTH(DQIN_FRAC_WIDTH), - .WEIGHT_WIDTH(WQ_WIDTH), - .WEIGHT_FRAC_WIDTH(WQ_FRAC_WIDTH), - .BIAS_WIDTH(BQ_WIDTH), - .BIAS_FRAC_WIDTH(BQ_FRAC_WIDTH), - .OUT_WIDTH(DQ_WIDTH), - .OUT_FRAC_WIDTH(DQ_FRAC_WIDTH), - .IN_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN_Y(IN_PARALLELISM), - .IN_X(IN_DEPTH * IN_SIZE), - .UNROLL_IN_X(IN_SIZE), - .W_Y(W_PARALLELISM * W_NUM_PARALLELISM), - .UNROLL_W_Y(W_PARALLELISM) - ) inst_fmmc_q ( - .clk(clk), - .rst(rst), - .data_in(ff_data_in_q), - .data_in_valid(ff_data_in_q_valid), - .data_in_ready(ff_data_in_q_ready), - .weight(weight_q), - .weight_valid(weight_q_valid), - .weight_ready(weight_q_ready), - .bias(bias_q), - .bias_valid(bias_q_valid), - .bias_ready(bias_q_ready), - .data_out(data_q), - .data_out_valid(data_q_valid), - .data_out_ready(data_q_ready) - ); - fixed_2d_linear #( - .IN_WIDTH(DKIN_WIDTH), - .IN_FRAC_WIDTH(DKIN_FRAC_WIDTH), - .WEIGHT_WIDTH(WK_WIDTH), - .WEIGHT_FRAC_WIDTH(WK_FRAC_WIDTH), - .BIAS_WIDTH(BK_WIDTH), - .BIAS_FRAC_WIDTH(BK_FRAC_WIDTH), - .OUT_WIDTH(DK_WIDTH), - .OUT_FRAC_WIDTH(DK_FRAC_WIDTH), - .IN_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN_Y(IN_PARALLELISM), - .IN_X(IN_DEPTH * IN_SIZE), - .UNROLL_IN_X(IN_SIZE), - .W_Y(W_PARALLELISM * W_NUM_PARALLELISM), - .UNROLL_W_Y(W_PARALLELISM) - ) inst_fmmc_k ( - .clk(clk), - .rst(rst), - .data_in(ff_data_in_k), - .data_in_valid(ff_data_in_k_valid), - .data_in_ready(ff_data_in_k_ready), - .weight(weight_k), - .weight_valid(weight_k_valid), - .weight_ready(weight_k_ready), - .bias(bias_k), - .bias_valid(bias_k_valid), - .bias_ready(bias_k_ready), - .data_out(data_k), - .data_out_valid(data_k_valid), - .data_out_ready(data_k_ready) - ); - logic [DS_WIDTH - 1 : 0] data_s[IN_PARALLELISM * IN_PARALLELISM - 1:0]; - logic data_s_valid, data_s_ready; - // matmul s - /* verilator lint_off PINMISSING */ - fixed_matmul #( - .IN1_WIDTH(DQ_WIDTH), - .IN1_FRAC_WIDTH(DQ_FRAC_WIDTH), - .IN2_WIDTH(DK_WIDTH), - .IN2_FRAC_WIDTH(DK_FRAC_WIDTH), - .OUT_WIDTH(DS_WIDTH), - .OUT_FRAC_WIDTH(DS_FRAC_WIDTH), - .IN1_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN1_Y(IN_PARALLELISM), - .IN1_X(W_PARALLELISM * W_NUM_PARALLELISM), - .UNROLL_IN1_X(W_PARALLELISM), - .IN2_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN2_Y(IN_PARALLELISM) - ) inst_fmmc_s ( - .clk(clk), - .rst(rst), - .data_in1(data_q), - .data_in1_valid(data_q_valid), - .data_in1_ready(data_q_ready), - .data_in2(data_k), - .data_in2_valid(data_k_valid), - .data_in2_ready(data_k_ready), - .data_out(data_s), - .data_out_valid(data_s_valid), - .data_out_ready(data_s_ready) - ); - logic [DS_SOFTMAX_WIDTH - 1:0] softmax_s[IN_PARALLELISM * IN_PARALLELISM - 1:0]; - logic softmax_s_valid, softmax_s_ready; - - hash_softmax #( - .IN_WIDTH(DS_WIDTH), - .IN_FRAC_WIDTH(DS_FRAC_WIDTH), - .EXP_WIDTH(EXP_WIDTH), - .EXP_FRAC_WIDTH(EXP_FRAC_WIDTH), - .DIV_WIDTH(DIV_WIDTH), - .OUT_WIDTH(DS_SOFTMAX_WIDTH), - .OUT_FRAC_WIDTH(DS_SOFTMAX_FRAC_WIDTH), - .IN_SIZE(IN_PARALLELISM * IN_PARALLELISM), - .OUT_SIZE(IN_PARALLELISM * IN_PARALLELISM), - .IN_DEPTH(IN_NUM_PARALLELISM) - ) softmax_inst ( - .data_in(data_s), - .data_in_valid(data_s_valid), - .data_in_ready(data_s_ready), - .data_out(softmax_s), - .data_out_valid(softmax_s_valid), - .data_out_ready(softmax_s_ready), - .* - ); - - logic [BV_WIDTH-1:0] bias_v_extend[IN_PARALLELISM * W_PARALLELISM - 1:0]; - logic [BV_WIDTH-1:0] ib_bias_v[IN_PARALLELISM * W_PARALLELISM - 1:0]; - logic ib_bias_v_valid, ib_bias_v_ready; - // bias_v require transpose here - for (genvar i = 0; i < W_PARALLELISM; i++) - for (genvar j = 0; j < IN_PARALLELISM; j++) assign bias_v_extend[i*IN_PARALLELISM+j] = bias_v[i]; - - input_buffer #( - .IN_WIDTH(BV_WIDTH), - .IN_PARALLELISM(W_PARALLELISM), - .IN_SIZE(IN_PARALLELISM), - .BUFFER_SIZE(1), - .REPEAT(IN_NUM_PARALLELISM) - ) bias_v_buffer ( - .clk(clk), - .rst(rst), - .data_in(bias_v_extend), - .data_in_valid(bias_v_valid), - .data_in_ready(bias_v_ready), - .data_out(ib_bias_v), - .data_out_valid(ib_bias_v_valid), - .data_out_ready(ib_bias_v_ready) - ); - unpacked_fifo #( - .DEPTH(IN_DEPTH * IN_NUM_PARALLELISM), - .DATA_WIDTH(DVIN_WIDTH), - .IN_NUM(IN_PARALLELISM * IN_SIZE) - ) fifo_v ( - .clk(clk), - .rst(rst), - .data_in(data_in_v), - .data_in_valid(data_in_v_valid), - .data_in_ready(data_in_v_ready), - .data_out(ff_data_in_v), - .data_out_valid(ff_data_in_v_valid), - .data_out_ready(ff_data_in_v_ready) - ); - //matmul_v - logic [DV_WIDTH - 1 : 0] data_v_t[W_PARALLELISM * IN_PARALLELISM - 1:0]; - logic data_v_t_valid, data_v_t_ready; - fixed_matmul #( - .IN1_WIDTH(WV_WIDTH), - .IN1_FRAC_WIDTH(WV_FRAC_WIDTH), - .IN2_WIDTH(DVIN_WIDTH), - .IN2_FRAC_WIDTH(DVIN_FRAC_WIDTH), - .HAS_BIAS(1), - .BIAS_WIDTH(BV_WIDTH), - .BIAS_FRAC_WIDTH(BV_FRAC_WIDTH), - .OUT_WIDTH(DV_WIDTH), - .OUT_FRAC_WIDTH(DV_FRAC_WIDTH), - .IN1_Y(W_PARALLELISM * W_NUM_PARALLELISM), - .UNROLL_IN1_Y(W_PARALLELISM), - .IN1_X(IN_SIZE * IN_DEPTH), - .UNROLL_IN1_X(IN_SIZE), - .IN2_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN2_Y(IN_PARALLELISM) - ) inst_fmmc_v ( - .clk(clk), - .rst(rst), - .data_in1(weight_v), - .data_in1_valid(weight_v_valid), - .data_in1_ready(weight_v_ready), - .data_in2(ff_data_in_v), - .data_in2_valid(ff_data_in_v_valid), - .data_in2_ready(ff_data_in_v_ready), - .bias(ib_bias_v), - .bias_valid(ib_bias_v_valid), - .bias_ready(ib_bias_v_ready), - .data_out(data_v_t), - .data_out_valid(data_v_t_valid), - .data_out_ready(data_v_t_ready) - ); - - logic [DZ_WIDTH - 1:0] data_z[IN_PARALLELISM * W_PARALLELISM - 1:0]; - logic data_z_valid, data_z_ready; - //z = s*v_t - always_ff @(posedge clk) $display("%b, %b, data_q", data_q_valid, data_q_ready); - always_ff @(posedge clk) $display("%b, %b, data_k", data_k_valid, data_k_ready); - always_ff @(posedge clk) $display("%b, %b, data_s", data_s_valid, data_s_ready); - always_ff @(posedge clk) $display("%b, %b, data_in_q", data_in_q_valid, data_in_q_ready); - always_ff @(posedge clk) $display("%b, %b, data_in_k", data_in_k_valid, data_in_k_ready); - always_ff @(posedge clk) $display("%b, %b, ff_in_k", ff_data_in_k_valid, ff_data_in_k_ready); - always_ff @(posedge clk) $display("%b, %b, ff_in_q", ff_data_in_q_valid, ff_data_in_q_ready); - always_ff @(posedge clk) $display(""); - fixed_matmul #( - .IN1_WIDTH(DS_SOFTMAX_WIDTH), - .IN1_FRAC_WIDTH(DS_SOFTMAX_FRAC_WIDTH), - .IN2_WIDTH(DV_WIDTH), - .IN2_FRAC_WIDTH(DV_FRAC_WIDTH), - .OUT_WIDTH(DZ_WIDTH), - .OUT_FRAC_WIDTH(DZ_FRAC_WIDTH), - - .IN1_Y(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN1_Y(IN_PARALLELISM), - .IN1_X(IN_PARALLELISM * IN_NUM_PARALLELISM), - .UNROLL_IN1_X(IN_PARALLELISM), - - .IN2_Y(W_PARALLELISM * W_NUM_PARALLELISM), - .UNROLL_IN2_Y(W_PARALLELISM) - ) inst_fmmc_z ( - .clk(clk), - .rst(rst), - .data_in1(softmax_s), - .data_in1_valid(softmax_s_valid), - .data_in1_ready(softmax_s_ready), - .data_in2(data_v_t), - .data_in2_valid(data_v_t_valid), - .data_in2_ready(data_v_t_ready), - .data_out(data_z), - .data_out_valid(data_z_valid), - .data_out_ready(data_z_ready) - ); - assign data_out_0 = data_z; - assign data_out_0_valid = data_z_valid; - assign data_z_ready = data_out_ready; - - -endmodule - diff --git a/src/mase_components/attention/rtl/fixed_gqa_head.sv b/src/mase_components/attention/rtl/fixed_gqa_head.sv new file mode 100644 index 000000000..1dd8618ba --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_gqa_head.sv @@ -0,0 +1,272 @@ +/* +Module : fixed_gqa_head + +Description : Implements an attention head which is used in group query + attention (GQA). It has no K, V matrix multiplications as this is + done outside of the heads due to the shared K, V weight matrices. + + The module has parameterised intermediate fixed-point widths and + Q has a linear layer where the embedding dimension of the input + can reduced from EMBEDDING_DIM to HEAD_DIM for this head. + + !! The K, V matrix multiply and K transpose is done outside of + this module. !! + + Dimensions of each input/output port is in the comments below. + +Dataflow : 1. Q get projected from EMBEDDING_DIM to HEAD_DIM + 2. QK^T matrix multiplication + 3. Softermax on attention scores: softermax(QK^T) + 4. Final matrix multiply to get attention: softermax(QK^T) * V + +Assumptions : 1. All activation inputs share same total and compute dimensions. + 2. All weight inputs share same total and compute dimensions. + 3. Activations and weights share same compute dimensions. + 4. The K input is transposed already. +*/ + +`timescale 1ns / 1ps + +module fixed_gqa_head #( + // Dimensions + parameter TOTAL_EMBEDDING_DIM = 32, + parameter TOTAL_HEAD_DIM = 16, + parameter TOTAL_SEQUENCE_DIM = 16, + parameter COMPUTE_EMBEDDING_DIM = 4, + parameter COMPUTE_HEAD_DIM = 4, + parameter COMPUTE_SEQUENCE_DIM = 4, + + // Input Port Widths + parameter Q_ACT_WIDTH = 8, + parameter Q_ACT_FRAC_WIDTH = 2, + parameter Q_WEIGHT_WIDTH = 8, + parameter Q_WEIGHT_FRAC_WIDTH = 2, + + parameter K_ACT_WIDTH = 8, + parameter K_ACT_FRAC_WIDTH = 2, + + parameter V_ACT_WIDTH = 8, + parameter V_ACT_FRAC_WIDTH = 2, + + // Output Port Widths + parameter OUT_ACT_WIDTH = 8, + parameter OUT_ACT_FRAC_WIDTH = 2, + + // Intermediate widths + // Output widths for query activation & weight multiplication + parameter Q_OUT_WIDTH = 16, + parameter Q_OUT_FRAC_WIDTH = 8, + // Output width for QK^T matrix multiplication + parameter QK_OUT_WIDTH = 16, + parameter QK_OUT_FRAC_WIDTH = 8, + // Widths for Softermax module + parameter SOFTERMAX_POW2_WIDTH = 16, + parameter SOFTERMAX_OUT_WIDTH = 16, + parameter SOFTERMAX_OUT_FRAC_WIDTH = 15 +) ( + input logic clk, + input logic rst, + + // Query Activation & Weight Matrices + // The multiplication between q_act and q_weight will reduce the embedding dimension + // from COMPUTE_EMBEDDING_DIM down to the COMPUTE_HEAD_DIM for this head. + + // Query Activation Input (dims = seq_dim x embedding_dim) + input logic [Q_ACT_WIDTH-1:0] q_act_data [COMPUTE_SEQUENCE_DIM*COMPUTE_EMBEDDING_DIM-1:0], + input logic q_act_valid, + output logic q_act_ready, + + // Query Weights for this Head (dims = embedding_dim x head_dim) + input logic [Q_WEIGHT_WIDTH-1:0] q_weight_data [COMPUTE_EMBEDDING_DIM*COMPUTE_HEAD_DIM-1:0], + input logic q_weight_valid, + output logic q_weight_ready, + + // Key & Value Matmul has been done outside of this module so they are already + // in the specified head embedding dim. + + // Pre-Calculated & Transposed Key Activation Matrix (dims = head_dim x seq_dim) + input logic [K_ACT_WIDTH-1:0] k_transposed_act_data[COMPUTE_HEAD_DIM*COMPUTE_SEQUENCE_DIM-1:0], + input logic k_transposed_act_valid, + output logic k_transposed_act_ready, + + // Pre-Calculated Value Activation Matrix (dims = seq_dim x head_dim) + input logic [V_ACT_WIDTH-1:0] v_act_data [COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0], + input logic v_act_valid, + output logic v_act_ready, + + // Output Activation Matrix (dims = seq_dim x head_dim) + output logic [OUT_ACT_WIDTH-1:0] out_act_data [COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0], + output logic out_act_valid, + input logic out_act_ready +); + + // ----- + // Params + // ----- + + localparam EMBEDDING_DEPTH = TOTAL_EMBEDDING_DIM / COMPUTE_EMBEDDING_DIM; + localparam HEAD_DEPTH = TOTAL_HEAD_DIM / COMPUTE_HEAD_DIM; + localparam SEQUENCE_DEPTH = TOTAL_SEQUENCE_DIM / COMPUTE_SEQUENCE_DIM; + + initial begin + // Check divisibility + assert (EMBEDDING_DEPTH * COMPUTE_EMBEDDING_DIM == TOTAL_EMBEDDING_DIM); + assert (HEAD_DEPTH * COMPUTE_HEAD_DIM == TOTAL_HEAD_DIM); + assert (SEQUENCE_DEPTH * COMPUTE_SEQUENCE_DIM == TOTAL_SEQUENCE_DIM); + end + + + // ----- + // Wires + // ----- + + // Output of q_act x q_weight (dims = seq_dim x head_dim) + logic [Q_OUT_WIDTH-1:0] q_out_data[COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0]; + logic q_out_valid, q_out_ready; + + // Output of q_out x k_transposed_act (dims = seq_dim x seq_dim) + logic [QK_OUT_WIDTH-1:0] qk_out_data[COMPUTE_SEQUENCE_DIM*COMPUTE_SEQUENCE_DIM-1:0]; + logic qk_out_valid, qk_out_ready; + + // Output of softermax(q_out x k_transposed_act) (dims = seq_dim x seq_dim) + logic [SOFTERMAX_OUT_WIDTH-1:0] softermax_out_data[COMPUTE_SEQUENCE_DIM*COMPUTE_SEQUENCE_DIM-1:0]; + logic [SOFTERMAX_OUT_WIDTH:0] softermax_unsigned_out_data [COMPUTE_SEQUENCE_DIM*COMPUTE_SEQUENCE_DIM-1:0]; + logic softermax_out_valid, softermax_out_ready; + + // ----- + // Modules + // ----- + + matmul #( + // Activations + .A_TOTAL_DIM0 (TOTAL_EMBEDDING_DIM), + .A_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .A_COMPUTE_DIM0(COMPUTE_EMBEDDING_DIM), + .A_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .A_WIDTH (Q_ACT_WIDTH), + .A_FRAC_WIDTH (Q_ACT_FRAC_WIDTH), + // Weights + .B_TOTAL_DIM0 (TOTAL_HEAD_DIM), + .B_TOTAL_DIM1 (TOTAL_EMBEDDING_DIM), + .B_COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .B_COMPUTE_DIM1(COMPUTE_EMBEDDING_DIM), + .B_WIDTH (Q_WEIGHT_WIDTH), + .B_FRAC_WIDTH (Q_WEIGHT_FRAC_WIDTH), + // Output + .OUT_WIDTH (Q_OUT_WIDTH), + .OUT_FRAC_WIDTH(Q_OUT_FRAC_WIDTH), + .OUT_SYMMETRIC (0) + ) q_matmul ( + .clk (clk), + .rst (rst), + .a_data (q_act_data), + .a_valid (q_act_valid), + .a_ready (q_act_ready), + .b_data (q_weight_data), + .b_valid (q_weight_valid), + .b_ready (q_weight_ready), + .out_data (q_out_data), + .out_valid(q_out_valid), + .out_ready(q_out_ready) + ); + + // TODO: Fix buffering problem + // Ideally, we want to buffer port A instead of port B because the critical path + // is on the second port anyway due to the transpose. This means that we need to + // insert a large fifo on the Q path to latency/cycle match the K path to + // prevent throughput issues and deadlocks. + matmul #( + // Port A: q_out + .A_TOTAL_DIM0 (TOTAL_HEAD_DIM), + .A_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .A_COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .A_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .A_WIDTH (Q_OUT_WIDTH), + .A_FRAC_WIDTH (Q_OUT_FRAC_WIDTH), + // Port B: k_transpose + .B_TOTAL_DIM0 (TOTAL_SEQUENCE_DIM), + .B_TOTAL_DIM1 (TOTAL_HEAD_DIM), + .B_COMPUTE_DIM0(COMPUTE_SEQUENCE_DIM), + .B_COMPUTE_DIM1(COMPUTE_HEAD_DIM), + .B_WIDTH (K_ACT_WIDTH), + .B_FRAC_WIDTH (K_ACT_FRAC_WIDTH), + // Output + .OUT_WIDTH (QK_OUT_WIDTH), + .OUT_FRAC_WIDTH(QK_OUT_FRAC_WIDTH), + .OUT_SYMMETRIC (0) + ) qk_matmul ( + .clk (clk), + .rst (rst), + .a_data (q_out_data), + .a_valid (q_out_valid), + .a_ready (q_out_ready), + .b_data (k_transposed_act_data), + .b_valid (k_transposed_act_valid), + .b_ready (k_transposed_act_ready), + .out_data (qk_out_data), + .out_valid(qk_out_valid), + .out_ready(qk_out_ready) + ); + + fixed_softermax_2d #( + .TOTAL_DIM0 (TOTAL_SEQUENCE_DIM), + .TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .COMPUTE_DIM0 (COMPUTE_SEQUENCE_DIM), + .COMPUTE_DIM1 (COMPUTE_SEQUENCE_DIM), + .IN_WIDTH (QK_OUT_WIDTH), + .IN_FRAC_WIDTH (QK_OUT_FRAC_WIDTH), + .POW2_WIDTH (SOFTERMAX_POW2_WIDTH), + .OUT_WIDTH (SOFTERMAX_OUT_WIDTH), + .OUT_FRAC_WIDTH(SOFTERMAX_OUT_FRAC_WIDTH) + ) qk_softermax ( + .clk (clk), + .rst (rst), + .in_data (qk_out_data), + .in_valid (qk_out_valid), + .in_ready (qk_out_ready), + .out_data (softermax_out_data), + .out_valid(softermax_out_valid), + .out_ready(softermax_out_ready) + ); + + // Unsigned pad 0 to softmax result + for ( + genvar i = 0; i < COMPUTE_SEQUENCE_DIM * COMPUTE_SEQUENCE_DIM; i++ + ) begin : gen_softermax_unsigned + assign softermax_unsigned_out_data[i] = {1'b0, softermax_out_data[i]}; + end + + matmul #( + // Port A: softermax attention scores + .A_TOTAL_DIM0 (TOTAL_SEQUENCE_DIM), + .A_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .A_COMPUTE_DIM0(COMPUTE_SEQUENCE_DIM), + .A_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .A_WIDTH (SOFTERMAX_OUT_WIDTH + 1), // Added 1 bit for unsigned + .A_FRAC_WIDTH (SOFTERMAX_OUT_FRAC_WIDTH), + // Port B: value matrix + .B_TOTAL_DIM0 (TOTAL_HEAD_DIM), + .B_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .B_COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .B_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .B_WIDTH (V_ACT_WIDTH), + .B_FRAC_WIDTH (V_ACT_FRAC_WIDTH), + // Output + .OUT_WIDTH (OUT_ACT_WIDTH), + .OUT_FRAC_WIDTH(OUT_ACT_FRAC_WIDTH), + .OUT_SYMMETRIC (0) + ) attn_matmul ( + .clk (clk), + .rst (rst), + .a_data (softermax_unsigned_out_data), + .a_valid (softermax_out_valid), + .a_ready (softermax_out_ready), + .b_data (v_act_data), + .b_valid (v_act_valid), + .b_ready (v_act_ready), + .out_data (out_act_data), + .out_valid(out_act_valid), + .out_ready(out_act_ready) + ); + +endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_att.sv b/src/mase_components/attention/rtl/fixed_self_att.sv deleted file mode 100644 index b5ec54990..000000000 --- a/src/mase_components/attention/rtl/fixed_self_att.sv +++ /dev/null @@ -1,154 +0,0 @@ -`timescale 1ns / 1ps -module fixed_self_att #( - parameter DATA_WIDTH = 8, - parameter DATA_FRAC_WIDTH = 1, - - parameter WEIGHT_Q_WIDTH = 8, - parameter WEIGHT_Q_FRAC_WIDTH = 1, - parameter WEIGHT_K_WIDTH = 8, - parameter WEIGHT_K_FRAC_WIDTH = 1, - parameter WEIGHT_V_WIDTH = 8, - parameter WEIGHT_V_FRAC_WIDTH = 1, - - parameter BIAS_Q_WIDTH = 8, - parameter BIAS_Q_FRAC_WIDTH = 1, - parameter BIAS_K_WIDTH = 8, - parameter BIAS_K_FRAC_WIDTH = 1, - parameter BIAS_V_WIDTH = 8, - parameter BIAS_V_FRAC_WIDTH = 1, - - parameter DQ_WIDTH = 8, - parameter DQ_FRAC_WIDTH = 1, - parameter DK_WIDTH = 8, - parameter DK_FRAC_WIDTH = 1, - parameter DV_WIDTH = 8, - parameter DV_FRAC_WIDTH = 1, - - parameter DS_WIDTH = 8, - parameter DS_FRAC_WIDTH = 1, - parameter EXP_WIDTH = 8, - parameter EXP_FRAC_WIDTH = 4, - parameter DIV_WIDTH = 10, - parameter DS_SOFTMAX_WIDTH = 8, - parameter DS_SOFTMAX_FRAC_WIDTH = 7, - - parameter DZ_WIDTH = 8, - parameter DZ_FRAC_WIDTH = 1, - - parameter IN_PARALLELISM = 3, - parameter IN_NUM_PARALLELISM = 2, - - parameter IN_SIZE = 3, - //define for matrix multilication - parameter IN_DEPTH = 3, - - parameter W_PARALLELISM = 3, - parameter W_NUM_PARALLELISM = 2, - parameter W_SIZE = IN_SIZE, - - parameter OUT_PARALLELISM = IN_PARALLELISM, - - parameter BIAS_Q_SIZE = 3, - parameter BIAS_K_SIZE = 3, - parameter BIAS_V_SIZE = 3, - parameter WEIGHT_Q_SIZE = 9, - parameter WEIGHT_K_SIZE = 9, - parameter WEIGHT_V_SIZE = 9, - - parameter OUT_SIZE = OUT_PARALLELISM * OUT_SIZE, - parameter OUT_WIDTH = DZ_WIDTH -) ( - input clk, - input rst, - - input [WEIGHT_Q_WIDTH - 1:0] weight_q[WEIGHT_Q_SIZE -1 : 0], - input weight_q_valid, - output weight_q_ready, - - input [WEIGHT_K_WIDTH - 1:0] weight_k[WEIGHT_K_SIZE -1 : 0], - input weight_k_valid, - output weight_k_ready, - - input [WEIGHT_V_WIDTH - 1:0] weight_v[WEIGHT_V_SIZE -1 : 0], - input weight_v_valid, - output weight_v_ready, - - input [BIAS_Q_WIDTH - 1:0] bias_q[BIAS_Q_SIZE -1 : 0], - input bias_q_valid, - output bias_q_ready, - - input [BIAS_K_WIDTH - 1:0] bias_k[BIAS_K_SIZE -1 : 0], - input bias_k_valid, - output bias_k_ready, - - input [BIAS_V_WIDTH - 1:0] bias_v[BIAS_V_SIZE -1 : 0], - input bias_v_valid, - output bias_v_ready, - - input [DATA_WIDTH -1:0] data_in_0[IN_PARALLELISM * IN_SIZE - 1 : 0], - input data_in_0_valid, - output data_in_0_ready, - - output [OUT_WIDTH -1:0] data_out_0[OUT_SIZE - 1:0], - output data_out_0_valid, - input data_out_0_ready -); - logic data_in_q_ready, data_in_k_ready, data_in_v_ready; - logic data_in_q_valid, data_in_k_valid, data_in_v_valid; - - assign data_in_q_valid = data_in_v_ready && data_in_k_ready && data_in_valid; - assign data_in_k_valid = data_in_q_ready && data_in_v_ready && data_in_valid; - assign data_in_v_valid = data_in_q_ready && data_in_k_ready && data_in_valid; - assign data_in_ready = data_in_q_ready && data_in_k_ready && data_in_v_ready; - fixed_att #( - .DQIN_WIDTH(DATA_WIDTH), - .DQIN_FRAC_WIDTH(DATA_FRAC_WIDTH), - .DKIN_WIDTH(DATA_WIDTH), - .DKIN_FRAC_WIDTH(DATA_FRAC_WIDTH), - .DVIN_WIDTH(DATA_WIDTH), - .DVIN_FRAC_WIDTH(DATA_FRAC_WIDTH), - - .WQ_WIDTH(WEIGHT_Q_WIDTH), - .WQ_FRAC_WIDTH(WEIGHT_Q_FRAC_WIDTH), - .WK_WIDTH(WEIGHT_K_WIDTH), - .WL_FRAC_WIDTH(WEIGHT_K_FRAC_WIDTH), - .WV_WIDTH(WEIGHT_V_WIDTH), - .WV_FRAC_WIDTH(WEIGHT_V_FRAC_WIDTH), - - .BQ_WIDTH(BIAS_Q_WIDTH), - .BQ_FRAC_WIDTH(BIAS_Q_FRAC_WIDTH), - .BK_WIDTH(BIAS_K_WIDTH), - .BK_FRAC_WIDTH(BIAS_K_FRAC_WIDTH), - .BV_WIDTH(BIAS_V_WIDTH), - .BV_FRAC_WIDTH(BIAS_V_FRAC_WIDTH), - - .DQ_WIDTH(DQ_WIDTH), - .DQ_FRAC_WIDTH(DQ_FRAC_WIDTH), - .DK_WIDTH(DK_WIDTH), - .DK_FRAC_WIDTH(DK_FRAC_WIDTH), - .DV_WIDTH(DV_WIDTH), - .DV_FRAC_WIDTH(DV_FRAC_WIDTH), - - .DS_WIDTH(DS_WIDTH), - .DS_FRAC_WIDTH(DS_FRAC_WIDTH), - .EXP_WIDTH(EXP_WIDTH), - .EXP_FRAC_WIDTH(EXP_FRAC_WIDTH), - .DIV_WIDTH(DIV_WIDTH), - .DS_SOFTMAX_WIDTH(DS_SOFTMAX_WIDTH), - .DS_SOFTMAX_FRAC_WIDTH(DS_SOFTMAX_FRAC_WIDTH), - - .DZ_WIDTH(DZ_WIDTH), - .DZ_FRAC_WIDTH(DZ_FRAC_WIDTH), - .IN_PARALLELISM(IN_PARALLELISM), - .IN_NUM_PARALLELISM(IN_NUM_PARALLELISM), - .IN_SIZE(IN_SIZE), - .IN_DEPTH(IN_DEPTH), - .W_PARALLELISM(W_PARALLELISM), - .W_NUM_PARALLELISM(W_NUM_PARALLELISM) - ) att_inst ( - .data_in_q(data_in_0), - .data_in_k(data_in_0), - .data_in_v(data_in_0), - .* - ); -endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_attention.sv b/src/mase_components/attention/rtl/fixed_self_attention.sv new file mode 100644 index 000000000..a54aec306 --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_self_attention.sv @@ -0,0 +1,282 @@ +`timescale 1ns / 1ps +module fixed_self_attention #( + parameter NUM_HEADS = 12, + parameter ACTIVATION = 0, + + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 768, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 20, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + + parameter WEIGHTS_PRE_TRANSPOSED = 0, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 768, + parameter WEIGHT_TENSOR_SIZE_DIM_1 = 768, + parameter WEIGHT_PARALLELISM_DIM_0 = 4, + parameter WEIGHT_PARALLELISM_DIM_1 = 4, + parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_1 = 3, + + parameter HAS_BIAS = 1, + parameter BIAS_TENSOR_SIZE_DIM_0 = 64, + parameter BIAS_TENSOR_SIZE_DIM_1 = 20, + parameter BIAS_PARALLELISM_DIM_0 = 4, + parameter BIAS_PARALLELISM_DIM_1 = 4, + parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_1 = 3, + + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = WEIGHT_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = WEIGHT_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0, + parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + +) ( + input logic clk, + input logic rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + // Query weights + input logic [WEIGHT_PRECISION_0-1:0] weight_query [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_query_valid, + output logic weight_query_ready, + + // Query bias + input logic [BIAS_PRECISION_0-1:0] bias_query [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_query_valid, + output logic bias_query_ready, + + // Key weights + input logic [WEIGHT_PRECISION_0-1:0] weight_key [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_key_valid, + output logic weight_key_ready, + + // Key bias + input logic [BIAS_PRECISION_0-1:0] bias_key [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_key_valid, + output logic bias_key_ready, + + // Value weights + input logic [WEIGHT_PRECISION_0-1:0] weight_value [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_value_valid, + output logic weight_value_ready, + + // Value bias + input logic [BIAS_PRECISION_0-1:0] bias_value [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_value_valid, + output logic bias_value_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + + // * Declarations + // * ================================================================= + + // Query + logic [DATA_OUT_0_PRECISION_0-1:0] query[DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0]; + logic joint_query_valid, joint_query_ready; + logic [NUM_HEADS-1:0] split_query_valid, split_query_ready; + + // Key + logic [DATA_OUT_0_PRECISION_0-1:0] key[DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0]; + logic joint_key_valid, joint_key_ready; + logic [NUM_HEADS-1:0] split_key_valid, split_key_ready; + + // Value + logic [DATA_OUT_0_PRECISION_0-1:0] value[DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0]; + logic joint_value_valid, joint_value_ready; + logic [NUM_HEADS-1:0] split_value_valid, split_value_ready; + + // Head output + logic [DATA_OUT_0_PRECISION_0-1:0] head_out [NUM_HEADS-1:0] [DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic [NUM_HEADS-1:0] head_out_valid; + logic [NUM_HEADS-1:0] head_out_ready; + + // * Instances + // * ================================================================= + + fixed_self_attention_input_block_batched #( + .DATA_IN_0_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .DATA_IN_0_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0(DATA_IN_0_PARALLELISM_DIM_0), + .DATA_IN_0_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1), + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + + .WEIGHTS_PRE_TRANSPOSED (WEIGHTS_PRE_TRANSPOSED), + .WEIGHT_TENSOR_SIZE_DIM_0(WEIGHT_TENSOR_SIZE_DIM_0), + .WEIGHT_TENSOR_SIZE_DIM_1(WEIGHT_TENSOR_SIZE_DIM_1), + .WEIGHT_PARALLELISM_DIM_0(WEIGHT_PARALLELISM_DIM_0), + .WEIGHT_PARALLELISM_DIM_1(WEIGHT_PARALLELISM_DIM_1), + .WEIGHT_PRECISION_0 (WEIGHT_PRECISION_0), + .WEIGHT_PRECISION_1 (WEIGHT_PRECISION_1), + + .HAS_BIAS (HAS_BIAS), + .BIAS_TENSOR_SIZE_DIM_0(BIAS_TENSOR_SIZE_DIM_0), + .BIAS_TENSOR_SIZE_DIM_1(BIAS_TENSOR_SIZE_DIM_1), + .BIAS_PARALLELISM_DIM_0(BIAS_PARALLELISM_DIM_0), + .BIAS_PARALLELISM_DIM_1(BIAS_PARALLELISM_DIM_1), + .BIAS_PRECISION_0 (BIAS_PRECISION_0), + .BIAS_PRECISION_1 (BIAS_PRECISION_1), + + .DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1) + ) batched_input_block_i ( + .clk(clk), + .rst(rst), + + .data_in_0(data_in_0), + .data_in_0_valid(data_in_0_valid), + .data_in_0_ready(data_in_0_ready), + + // Query parameters + .weight_query(weight_query), + .weight_query_valid(weight_query_valid), + .weight_query_ready(weight_query_ready), + + .bias_query(bias_query), + .bias_query_valid(bias_query_valid), + .bias_query_ready(bias_query_ready), + + // Key parameters + .weight_key(weight_key), + .weight_key_valid(weight_key_valid), + .weight_key_ready(weight_key_ready), + + .bias_key(bias_key), + .bias_key_valid(bias_key_valid), + .bias_key_ready(bias_key_ready), + + // Value parameters + .weight_value(weight_value), + .weight_value_valid(weight_value_valid), + .weight_value_ready(weight_value_ready), + + .bias_value(bias_value), + .bias_value_valid(bias_value_valid), + .bias_value_ready(bias_value_ready), + + // Query output + .data_out_query(query), + .data_out_query_valid(joint_query_valid), + .data_out_query_ready(joint_query_ready), + + // Key output + .data_out_key(key), + .data_out_key_valid(joint_key_valid), + .data_out_key_ready(joint_key_ready), + + // Value output + .data_out_value(value), + .data_out_value_valid(joint_value_valid), + .data_out_value_ready(joint_value_ready) + ); + + // * Scatter query, key, value + + self_attention_head_scatter #( + .NUM_HEADS(NUM_HEADS), + + .IN_DATA_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .IN_DATA_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .IN_DATA_PARALLELISM_DIM_0(WEIGHT_PARALLELISM_DIM_0), + .IN_DATA_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1) + + ) scatter_qkv_i ( + .clk, + .rst, + + .query_valid(joint_query_valid), + .query_ready(joint_query_ready), + + .key_valid(joint_key_valid), + .key_ready(joint_key_ready), + + .value_valid(joint_value_valid), + .value_ready(joint_value_ready), + + .split_query_valid(split_query_valid), + .split_query_ready(split_query_ready), + + .split_key_valid(split_key_valid), + .split_key_ready(split_key_ready), + + .split_value_valid(split_value_valid), + .split_value_ready(split_value_ready) + ); + + // * Heads + + for (genvar head = 0; head < NUM_HEADS; head++) begin + + fixed_self_attention_head #( + .IN_DATA_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0 / NUM_HEADS), + .IN_DATA_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .IN_DATA_PARALLELISM_DIM_0(DATA_IN_0_PARALLELISM_DIM_0), + .IN_DATA_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1), + .IN_DATA_PRECISION_0 (DATA_OUT_0_PRECISION_0), + .IN_DATA_PRECISION_1 (DATA_OUT_0_PRECISION_1), + + .OUT_DATA_TENSOR_SIZE_DIM_0(DATA_OUT_0_TENSOR_SIZE_DIM_0 / NUM_HEADS), + .OUT_DATA_TENSOR_SIZE_DIM_1(DATA_OUT_0_TENSOR_SIZE_DIM_1), + .OUT_DATA_PARALLELISM_DIM_0(DATA_OUT_0_PARALLELISM_DIM_0), + .OUT_DATA_PARALLELISM_DIM_1(DATA_OUT_0_PARALLELISM_DIM_1), + .OUT_DATA_PRECISION_0 (DATA_OUT_0_PRECISION_0), + .OUT_DATA_PRECISION_1 (DATA_OUT_0_PRECISION_1) + + ) head_i ( + .clk, + .rst, + + .query (query), + .query_valid(split_query_valid[head]), + .query_ready(split_query_ready[head]), + + .key (key), + .key_valid(split_key_valid[head]), + .key_ready(split_key_ready[head]), + + .value (value), + .value_valid(split_value_valid[head]), + .value_ready(split_value_ready[head]), + + .out (head_out[head]), + .out_valid(head_out_valid[head]), + .out_ready(head_out_ready[head]) + ); + + end + + // * Gather heads + + self_attention_head_gather #( + .NUM_HEADS(NUM_HEADS), + + .IN_DATA_TENSOR_SIZE_DIM_0(DATA_OUT_0_TENSOR_SIZE_DIM_0), + .IN_DATA_TENSOR_SIZE_DIM_1(DATA_OUT_0_TENSOR_SIZE_DIM_1), + .IN_DATA_PARALLELISM_DIM_0(DATA_OUT_0_PARALLELISM_DIM_0), + .IN_DATA_PARALLELISM_DIM_1(DATA_OUT_0_PARALLELISM_DIM_1), + .IN_DATA_PRECISION_0 (DATA_OUT_0_PRECISION_0), + .IN_DATA_PRECISION_1 (DATA_OUT_0_PRECISION_1) + + ) gather_qkv_i ( + .clk, + .rst, + + .split_head_out (head_out), + .split_head_out_valid(head_out_valid), + .split_head_out_ready(head_out_ready), + + .updated_tokens (data_out_0), + .updated_tokens_valid(data_out_0_valid), + .updated_tokens_ready(data_out_0_ready) + ); + +endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_attention_head.sv b/src/mase_components/attention/rtl/fixed_self_attention_head.sv new file mode 100644 index 000000000..845e9a3c7 --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_self_attention_head.sv @@ -0,0 +1,281 @@ +`timescale 1ns / 1ps +module fixed_self_attention_head #( + + // * Queries, keys and values are assumed to have the same + // * precision, dimensions and parallelism + parameter IN_DATA_TENSOR_SIZE_DIM_0 = 64, + parameter IN_DATA_TENSOR_SIZE_DIM_1 = 32, + parameter IN_DATA_PARALLELISM_DIM_0 = 4, + parameter IN_DATA_PARALLELISM_DIM_1 = 4, + parameter IN_DATA_PRECISION_0 = 16, + parameter IN_DATA_PRECISION_1 = 3, + + // * Output tokens are casted to requested precision + parameter OUT_DATA_TENSOR_SIZE_DIM_0 = 64, + parameter OUT_DATA_TENSOR_SIZE_DIM_1 = 32, + parameter OUT_DATA_PARALLELISM_DIM_0 = IN_DATA_PARALLELISM_DIM_0, + parameter OUT_DATA_PARALLELISM_DIM_1 = IN_DATA_PARALLELISM_DIM_1, + parameter OUT_DATA_PRECISION_0 = 16, + parameter OUT_DATA_PRECISION_1 = 3 + +) ( + input logic clk, + input logic rst, + + input logic [IN_DATA_PRECISION_0-1:0] query [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0], + input logic query_valid, + output logic query_ready, + + input logic [IN_DATA_PRECISION_0-1:0] key [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0], + input logic key_valid, + output logic key_ready, + + input logic [IN_DATA_PRECISION_0-1:0] value [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0], + input logic value_valid, + output logic value_ready, + + output logic [OUT_DATA_PRECISION_0-1:0] out [OUT_DATA_PARALLELISM_DIM_0*OUT_DATA_PARALLELISM_DIM_1-1:0], + output logic out_valid, + input logic out_ready +); + + initial begin + assert (OUT_DATA_TENSOR_SIZE_DIM_0 == IN_DATA_TENSOR_SIZE_DIM_0) + else + $fatal( + "Module incorrectly parametrized. OUT_DATA_TENSOR_SIZE_DIM_0 != IN_DATA_TENSOR_SIZE_DIM_0" + ); + + assert (OUT_DATA_TENSOR_SIZE_DIM_1 == IN_DATA_TENSOR_SIZE_DIM_1) + else + $fatal( + "Module incorrectly parametrized. OUT_DATA_TENSOR_SIZE_DIM_1 != IN_DATA_TENSOR_SIZE_DIM_1" + ); + + assert (OUT_DATA_PARALLELISM_DIM_0 == IN_DATA_PARALLELISM_DIM_0) + else + $fatal( + "Parallelism conversion not yet supported. OUT_DATA_PARALLELISM_DIM_0 != IN_DATA_PARALLELISM_DIM_0" + ); + + assert (OUT_DATA_PARALLELISM_DIM_1 == IN_DATA_PARALLELISM_DIM_1) + else + $fatal( + "Parallelism conversion not yet supported. OUT_DATA_PARALLELISM_DIM_1 != IN_DATA_PARALLELISM_DIM_1" + ); + end + + parameter IN_DATA_DEPTH_0 = IN_DATA_TENSOR_SIZE_DIM_0 / IN_DATA_PARALLELISM_DIM_0; + parameter IN_DATA_DEPTH_1 = IN_DATA_TENSOR_SIZE_DIM_1 / IN_DATA_PARALLELISM_DIM_1; + + // Query key transpose + parameter QUERY_TRANSPOSE_PRECISION_0 = 2 * IN_DATA_PRECISION_0 + $clog2( + IN_DATA_PARALLELISM_DIM_0 + ) + $clog2( + IN_DATA_DEPTH_1 + ); + parameter QUERY_TRANSPOSE_PRECISION_1 = 2 * IN_DATA_PRECISION_1; + + // Attention scores + // ! TO DO: check precision transformation post softmax + parameter ATTENTION_SCORES_PRECISION_0 = QUERY_TRANSPOSE_PRECISION_0; + parameter ATTENTION_SCORES_PRECISION_1 = QUERY_TRANSPOSE_PRECISION_1; + + parameter OUT_PRE_CAST_PRECISION_0 = IN_DATA_PRECISION_0 + ATTENTION_SCORES_PRECISION_0 + $clog2( + IN_DATA_PARALLELISM_DIM_1 + ) + $clog2( + IN_DATA_TENSOR_SIZE_DIM_1 / IN_DATA_PARALLELISM_DIM_1 + ); + parameter OUT_PRE_CAST_PRECISION_1 = IN_DATA_PRECISION_1 + ATTENTION_SCORES_PRECISION_1; + + // * Declarations + // * ================================================================= + + logic [IN_DATA_PRECISION_0-1:0] key_transpose [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0]; + logic key_transpose_valid; + logic key_transpose_ready; + + logic [OUT_DATA_PRECISION_0-1:0] query_key_transpose [IN_DATA_PARALLELISM_DIM_1 * IN_DATA_PARALLELISM_DIM_1-1:0]; + logic query_key_transpose_valid; + logic query_key_transpose_ready; + + logic [OUT_DATA_PRECISION_0-1:0] attention_scores [IN_DATA_PARALLELISM_DIM_1 * IN_DATA_PARALLELISM_DIM_1-1:0]; + logic attention_scores_valid; + logic attention_scores_ready; + + logic [OUT_DATA_PRECISION_0-1:0] out_pre_cast [OUT_DATA_PARALLELISM_DIM_0*OUT_DATA_PARALLELISM_DIM_1-1:0]; + logic [OUT_DATA_PRECISION_0-1:0] out_casted [OUT_DATA_PARALLELISM_DIM_0*OUT_DATA_PARALLELISM_DIM_1-1:0]; + logic out_cast_valid; + logic out_cast_ready; + + // * Instances + // * ================================================================= + + // * Transpose projected keys + + matrix_stream_transpose #( + .TOTAL_DIM0 (IN_DATA_TENSOR_SIZE_DIM_0), + .TOTAL_DIM1 (IN_DATA_TENSOR_SIZE_DIM_1), + .COMPUTE_DIM0(IN_DATA_PARALLELISM_DIM_0), + .COMPUTE_DIM1(IN_DATA_PARALLELISM_DIM_1), + + .DATA_WIDTH(IN_DATA_PRECISION_0) + ) key_transpose_i ( + .clk, + .rst, + + // In Matrix + .in_data (key), + .in_valid(key_valid), + .in_ready(key_ready), + + // Out Matrix + .out_data (key_transpose), + .out_valid(key_transpose_valid), + .out_ready(key_transpose_ready) + ); + + // * Query x Key^T + + matmul #( + .A_TOTAL_DIM0(IN_DATA_TENSOR_SIZE_DIM_0), + .A_TOTAL_DIM1(IN_DATA_TENSOR_SIZE_DIM_1), + + .B_TOTAL_DIM0(IN_DATA_TENSOR_SIZE_DIM_1), + .B_TOTAL_DIM1(IN_DATA_TENSOR_SIZE_DIM_0), + + .A_COMPUTE_DIM0(IN_DATA_PARALLELISM_DIM_0), + .A_COMPUTE_DIM1(IN_DATA_PARALLELISM_DIM_0), + .B_COMPUTE_DIM0(IN_DATA_PARALLELISM_DIM_1), + .B_COMPUTE_DIM1(IN_DATA_PARALLELISM_DIM_0), + + .A_WIDTH (IN_DATA_PRECISION_0), + .A_FRAC_WIDTH(IN_DATA_PRECISION_1), + + .B_WIDTH (IN_DATA_PRECISION_0), + .B_FRAC_WIDTH(IN_DATA_PRECISION_1), + + .OUT_WIDTH (OUT_DATA_PRECISION_0), + .OUT_FRAC_WIDTH(OUT_DATA_PRECISION_1) + + ) query_key_transpose_matmul_i ( + .clk, + .rst, + + .a_data (query), + .a_valid(query_valid), + .a_ready(query_ready), + + .b_data (key_transpose), + .b_valid(key_transpose_valid), + .b_ready(key_transpose_ready), + + .out_data (query_key_transpose), + .out_valid(query_key_transpose_valid), + .out_ready(query_key_transpose_ready) + ); + + // ! TO DO: normalize query_key_transpose + + // * Attention scores: softmax(Query x Key^T) + + fixed_softermax #( + .DATA_IN_0_PRECISION_0 (OUT_DATA_PRECISION_0), + .DATA_IN_0_PRECISION_1 (OUT_DATA_PRECISION_1), + .DATA_IN_0_TENSOR_SIZE_DIM_0(IN_DATA_TENSOR_SIZE_DIM_1), + .DATA_IN_0_TENSOR_SIZE_DIM_1(IN_DATA_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0(IN_DATA_PARALLELISM_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_1(IN_DATA_PARALLELISM_DIM_1), + + .DATA_OUT_0_PRECISION_0 (OUT_DATA_PRECISION_0), + .DATA_OUT_0_PRECISION_1 (OUT_DATA_PRECISION_1), + .DATA_OUT_0_TENSOR_SIZE_DIM_0(IN_DATA_TENSOR_SIZE_DIM_1), + .DATA_OUT_0_TENSOR_SIZE_DIM_1(IN_DATA_TENSOR_SIZE_DIM_1), + .DATA_OUT_0_PARALLELISM_DIM_0(IN_DATA_PARALLELISM_DIM_1), + .DATA_OUT_0_PARALLELISM_DIM_1(IN_DATA_PARALLELISM_DIM_1) + + ) fixed_softermax_i ( + .clk, + .rst, + + .data_in_0 (query_key_transpose), + .data_in_0_valid(query_key_transpose_valid), + .data_in_0_ready(query_key_transpose_ready), + + .data_out_0 (attention_scores), + .data_out_0_valid(attention_scores_valid), + .data_out_0_ready(attention_scores_ready) + ); + + // * Output: Attention scores x Value + + matmul #( + .A_TOTAL_DIM0(IN_DATA_TENSOR_SIZE_DIM_1), + .A_TOTAL_DIM1(IN_DATA_TENSOR_SIZE_DIM_1), + + .B_TOTAL_DIM0(IN_DATA_TENSOR_SIZE_DIM_0), + .B_TOTAL_DIM1(IN_DATA_TENSOR_SIZE_DIM_1), + + .A_COMPUTE_DIM0(IN_DATA_PARALLELISM_DIM_1), + .A_COMPUTE_DIM1(IN_DATA_PARALLELISM_DIM_1), + .B_COMPUTE_DIM0(IN_DATA_PARALLELISM_DIM_0), + .B_COMPUTE_DIM1(IN_DATA_PARALLELISM_DIM_1), + + .A_WIDTH (OUT_DATA_PRECISION_0), + .A_FRAC_WIDTH(OUT_DATA_PRECISION_1), + + .B_WIDTH (IN_DATA_PRECISION_0), + .B_FRAC_WIDTH(IN_DATA_PRECISION_1), + + .OUT_WIDTH (OUT_DATA_PRECISION_0), + .OUT_FRAC_WIDTH(OUT_DATA_PRECISION_1) + + ) attention_scores_values_matmul_i ( + .clk, + .rst, + + .a_data (attention_scores), + .a_valid(attention_scores_valid), + .a_ready(attention_scores_ready), + + .b_data (value), + .b_valid(value_valid), + .b_ready(value_ready), + + .out_data (out_pre_cast), + .out_valid(out_cast_valid), + .out_ready(out_cast_ready) + ); + + // * Output cast + + fixed_rounding #( + .IN_SIZE(OUT_DATA_PARALLELISM_DIM_0 * OUT_DATA_PARALLELISM_DIM_1), + + .IN_WIDTH (OUT_DATA_PRECISION_0), + .IN_FRAC_WIDTH(OUT_DATA_PRECISION_1), + + .OUT_WIDTH (OUT_DATA_PRECISION_0), + .OUT_FRAC_WIDTH(OUT_DATA_PRECISION_1) + ) data_out_cast ( + .data_in (out_pre_cast), + .data_out(out_casted) + ); + + unpacked_register_slice #( + .DATA_WIDTH(OUT_DATA_PRECISION_0), + .IN_SIZE (OUT_DATA_PARALLELISM_DIM_0 * OUT_DATA_PARALLELISM_DIM_1) + ) out_cast_register_slice_i ( + .clk(clk), + .rst(rst), + + .data_in (out_casted), + .data_in_valid(out_cast_valid), + .data_in_ready(out_cast_ready), + + .data_out (out), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + +endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_attention_input_block_batched.sv b/src/mase_components/attention/rtl/fixed_self_attention_input_block_batched.sv new file mode 100644 index 000000000..eb1feb07d --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_self_attention_input_block_batched.sv @@ -0,0 +1,294 @@ +`timescale 1ns / 1ps +module fixed_self_attention_input_block_batched #( + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 768, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 20, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + + parameter WEIGHTS_PRE_TRANSPOSED = 0, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 768, + parameter WEIGHT_TENSOR_SIZE_DIM_1 = 768, + parameter WEIGHT_PARALLELISM_DIM_0 = 4, + parameter WEIGHT_PARALLELISM_DIM_1 = 4, + parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_1 = 3, + + parameter HAS_BIAS = 1, + parameter BIAS_TENSOR_SIZE_DIM_0 = 64, + parameter BIAS_TENSOR_SIZE_DIM_1 = 20, + parameter BIAS_PARALLELISM_DIM_0 = 4, + parameter BIAS_PARALLELISM_DIM_1 = 4, + parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_1 = 3, + + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = WEIGHT_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = WEIGHT_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PRECISION_0 = 16, + parameter DATA_OUT_0_PRECISION_1 = 3 + +) ( + input logic clk, + input logic rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + // Query weights + input logic [WEIGHT_PRECISION_0-1:0] weight_query [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_query_valid, + output logic weight_query_ready, + + // Query bias + input logic [BIAS_PRECISION_0-1:0] bias_query [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_query_valid, + output logic bias_query_ready, + + // Key weights + input logic [WEIGHT_PRECISION_0-1:0] weight_key [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_key_valid, + output logic weight_key_ready, + + // Key bias + input logic [BIAS_PRECISION_0-1:0] bias_key [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_key_valid, + output logic bias_key_ready, + + // Value weights + input logic [WEIGHT_PRECISION_0-1:0] weight_value [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_value_valid, + output logic weight_value_ready, + + // Value bias + input logic [BIAS_PRECISION_0-1:0] bias_value [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_value_valid, + output logic bias_value_ready, + + // Query + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_query [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0], + output logic data_out_query_valid, + input logic data_out_query_ready, + + // Key + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_key [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0], + output logic data_out_key_valid, + input logic data_out_key_ready, + + // Value + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_value [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0], + output logic data_out_value_valid, + input logic data_out_value_ready +); + + // ! TO DO: add assertions about bias parallelism matching weight parallelism + + // * Inferred parameters + parameter DATA_IN_0_DEPTH_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1; + parameter WEIGHT_DEPTH_DIM_0 = WEIGHT_TENSOR_SIZE_DIM_0 / WEIGHT_PARALLELISM_DIM_0; + + // * Declarations + // * ================================================================= + + logic query_data_in_valid, query_data_in_ready; + logic key_data_in_valid, key_data_in_ready; + logic value_data_in_valid, value_data_in_ready; + + logic [DATA_OUT_0_PRECISION_0-1:0] query_buffer [DATA_IN_0_PARALLELISM_DIM_1 * WEIGHT_PARALLELISM_DIM_0-1:0]; + logic query_buffer_valid; + logic query_buffer_ready; + + // * Instances + // * ================================================================= + + // * Split the incoming data over the QKV projections + split_n #( + .N(3) + ) split_i ( + .data_in_valid (data_in_0_valid), + .data_in_ready (data_in_0_ready), + .data_out_valid({query_data_in_valid, key_data_in_valid, value_data_in_valid}), + .data_out_ready({query_data_in_ready, key_data_in_ready, value_data_in_ready}) + ); + + // * Query linear + + fixed_linear #( + .HAS_BIAS (HAS_BIAS), + .WEIGHTS_PRE_TRANSPOSED(WEIGHTS_PRE_TRANSPOSED), + + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + .DATA_IN_0_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .DATA_IN_0_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0(DATA_IN_0_PARALLELISM_DIM_0), + .DATA_IN_0_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1), + + .WEIGHT_PRECISION_0 (WEIGHT_PRECISION_0), + .WEIGHT_PRECISION_1 (WEIGHT_PRECISION_1), + .WEIGHT_TENSOR_SIZE_DIM_0(WEIGHT_TENSOR_SIZE_DIM_0), + .WEIGHT_TENSOR_SIZE_DIM_1(WEIGHT_TENSOR_SIZE_DIM_1), + .WEIGHT_PARALLELISM_DIM_0(WEIGHT_PARALLELISM_DIM_0), + .WEIGHT_PARALLELISM_DIM_1(WEIGHT_PARALLELISM_DIM_1), + + .BIAS_PRECISION_0 (BIAS_PRECISION_0), + .BIAS_PRECISION_1 (BIAS_PRECISION_1), + .BIAS_TENSOR_SIZE_DIM_0(BIAS_TENSOR_SIZE_DIM_0), + .BIAS_TENSOR_SIZE_DIM_1(BIAS_TENSOR_SIZE_DIM_1), + .BIAS_PARALLELISM_DIM_0(BIAS_PARALLELISM_DIM_0), + .BIAS_PARALLELISM_DIM_1(BIAS_PARALLELISM_DIM_1), + + .DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1) + + ) fixed_linear_query ( + .clk, + .rst, + + // input port for data_inivations + .data_in_0 (data_in_0), + .data_in_0_valid(query_data_in_valid), + .data_in_0_ready(query_data_in_ready), + + // input port for weight + .weight (weight_query), + .weight_valid(weight_query_valid), + .weight_ready(weight_query_ready), + + .bias (bias_query), + .bias_valid(bias_query_valid), + .bias_ready(bias_query_ready), + + .data_out_0 (query_buffer), + .data_out_0_valid(query_buffer_valid), + .data_out_0_ready(query_buffer_ready) + ); + + // * We must buffer the queries to latency match the key transpose path + // * since the matmul for QK^T buffers K^T but streams Q + matrix_fifo #( + .DATA_WIDTH(DATA_OUT_0_PRECISION_0), + .DIM0 (WEIGHT_PARALLELISM_DIM_0), + .DIM1 (DATA_IN_0_PARALLELISM_DIM_1), + .FIFO_SIZE (DATA_IN_0_DEPTH_DIM_1 * WEIGHT_DEPTH_DIM_0) + ) query_buffer_i ( + .clk, + .rst, + .in_data (query_buffer), + .in_valid (query_buffer_valid), + .in_ready (query_buffer_ready), + .out_data (data_out_query), + .out_valid(data_out_query_valid), + .out_ready(data_out_query_ready) + ); + + // * Key linear + + fixed_linear #( + .HAS_BIAS (HAS_BIAS), + .WEIGHTS_PRE_TRANSPOSED(WEIGHTS_PRE_TRANSPOSED), + + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + .DATA_IN_0_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .DATA_IN_0_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0(DATA_IN_0_PARALLELISM_DIM_0), + .DATA_IN_0_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1), + + .WEIGHT_PRECISION_0 (WEIGHT_PRECISION_0), + .WEIGHT_PRECISION_1 (WEIGHT_PRECISION_1), + .WEIGHT_TENSOR_SIZE_DIM_0(WEIGHT_TENSOR_SIZE_DIM_0), + .WEIGHT_TENSOR_SIZE_DIM_1(WEIGHT_TENSOR_SIZE_DIM_1), + .WEIGHT_PARALLELISM_DIM_0(WEIGHT_PARALLELISM_DIM_0), + .WEIGHT_PARALLELISM_DIM_1(WEIGHT_PARALLELISM_DIM_1), + + .BIAS_PRECISION_0 (BIAS_PRECISION_0), + .BIAS_PRECISION_1 (BIAS_PRECISION_1), + .BIAS_TENSOR_SIZE_DIM_0(BIAS_TENSOR_SIZE_DIM_0), + .BIAS_TENSOR_SIZE_DIM_1(BIAS_TENSOR_SIZE_DIM_1), + .BIAS_PARALLELISM_DIM_0(BIAS_PARALLELISM_DIM_0), + .BIAS_PARALLELISM_DIM_1(BIAS_PARALLELISM_DIM_1), + + .DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1) + + ) fixed_linear_key ( + .clk, + .rst, + + // input port for data_inivations + .data_in_0 (data_in_0), + .data_in_0_valid(key_data_in_valid), + .data_in_0_ready(key_data_in_ready), + + // input port for weight + .weight (weight_key), + .weight_valid(weight_key_valid), + .weight_ready(weight_key_ready), + + .bias (bias_key), + .bias_valid(bias_key_valid), + .bias_ready(bias_key_ready), + + .data_out_0 (data_out_key), + .data_out_0_valid(data_out_key_valid), + .data_out_0_ready(data_out_key_ready) + ); + + // * Value linear + + fixed_linear #( + .HAS_BIAS (HAS_BIAS), + .WEIGHTS_PRE_TRANSPOSED(WEIGHTS_PRE_TRANSPOSED), + + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + .DATA_IN_0_TENSOR_SIZE_DIM_0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .DATA_IN_0_TENSOR_SIZE_DIM_1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0(DATA_IN_0_PARALLELISM_DIM_0), + .DATA_IN_0_PARALLELISM_DIM_1(DATA_IN_0_PARALLELISM_DIM_1), + + .WEIGHT_PRECISION_0 (WEIGHT_PRECISION_0), + .WEIGHT_PRECISION_1 (WEIGHT_PRECISION_1), + .WEIGHT_TENSOR_SIZE_DIM_0(WEIGHT_TENSOR_SIZE_DIM_0), + .WEIGHT_TENSOR_SIZE_DIM_1(WEIGHT_TENSOR_SIZE_DIM_1), + .WEIGHT_PARALLELISM_DIM_0(WEIGHT_PARALLELISM_DIM_0), + .WEIGHT_PARALLELISM_DIM_1(WEIGHT_PARALLELISM_DIM_1), + + .BIAS_PRECISION_0 (BIAS_PRECISION_0), + .BIAS_PRECISION_1 (BIAS_PRECISION_1), + .BIAS_TENSOR_SIZE_DIM_0(BIAS_TENSOR_SIZE_DIM_0), + .BIAS_TENSOR_SIZE_DIM_1(BIAS_TENSOR_SIZE_DIM_1), + .BIAS_PARALLELISM_DIM_0(BIAS_PARALLELISM_DIM_0), + .BIAS_PARALLELISM_DIM_1(BIAS_PARALLELISM_DIM_1), + + .DATA_OUT_0_PRECISION_0(DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1(DATA_OUT_0_PRECISION_1) + + ) fixed_linear_value ( + .clk, + .rst, + + // input port for data_inivations + .data_in_0 (data_in_0), + .data_in_0_valid(value_data_in_valid), + .data_in_0_ready(value_data_in_ready), + + // input port for weight + .weight (weight_value), + .weight_valid(weight_value_valid), + .weight_ready(weight_value_ready), + + .bias (bias_value), + .bias_valid(bias_value_valid), + .bias_ready(bias_value_ready), + + .data_out_0 (data_out_value), + .data_out_0_valid(data_out_value_valid), + .data_out_0_ready(data_out_value_ready) + ); + +endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_attention_single_precision_wrapper.sv b/src/mase_components/attention/rtl/fixed_self_attention_single_precision_wrapper.sv new file mode 100644 index 000000000..abcf07c85 --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_self_attention_single_precision_wrapper.sv @@ -0,0 +1,209 @@ +`timescale 1ns / 1ps + +/* + * This is a workaround to use attention in single precision + * in emitted verilog, where separate precision parameters are + * emitted for each model submodule. + */ + +module fixed_self_attention_single_precision_wrapper #( + parameter NUM_HEADS = 12, + parameter ACTIVATION = 0, + parameter CHOSEN_PRECISION = "QUERY", + parameter OUTPUT_ATTENTIONS = 0, + + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 768, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 128, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + + parameter QUERY_WEIGHTS_PRE_TRANSPOSED = 0, + parameter QUERY_WEIGHT_TENSOR_SIZE_DIM_0 = 768, + parameter QUERY_WEIGHT_TENSOR_SIZE_DIM_1 = 768, + parameter QUERY_WEIGHT_PARALLELISM_DIM_0 = 4, + parameter QUERY_WEIGHT_PARALLELISM_DIM_1 = 4, + parameter QUERY_WEIGHT_PRECISION_0 = 16, + parameter QUERY_WEIGHT_PRECISION_1 = 3, + + parameter KEY_WEIGHTS_PRE_TRANSPOSED = 0, + parameter KEY_WEIGHT_TENSOR_SIZE_DIM_0 = 768, + parameter KEY_WEIGHT_TENSOR_SIZE_DIM_1 = 768, + parameter KEY_WEIGHT_PARALLELISM_DIM_0 = 4, + parameter KEY_WEIGHT_PARALLELISM_DIM_1 = 4, + parameter KEY_WEIGHT_PRECISION_0 = 16, + parameter KEY_WEIGHT_PRECISION_1 = 3, + + parameter VALUE_WEIGHTS_PRE_TRANSPOSED = 0, + parameter VALUE_WEIGHT_TENSOR_SIZE_DIM_0 = 768, + parameter VALUE_WEIGHT_TENSOR_SIZE_DIM_1 = 768, + parameter VALUE_WEIGHT_PARALLELISM_DIM_0 = 4, + parameter VALUE_WEIGHT_PARALLELISM_DIM_1 = 4, + parameter VALUE_WEIGHT_PRECISION_0 = 16, + parameter VALUE_WEIGHT_PRECISION_1 = 3, + + parameter QUERY_HAS_BIAS = 0, + parameter QUERY_BIAS_TENSOR_SIZE_DIM_0 = 64, + parameter QUERY_BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter QUERY_BIAS_PARALLELISM_DIM_0 = 4, + parameter QUERY_BIAS_PARALLELISM_DIM_1 = 4, + parameter QUERY_BIAS_PRECISION_0 = 16, + parameter QUERY_BIAS_PRECISION_1 = 3, + + parameter KEY_HAS_BIAS = 0, + parameter KEY_BIAS_TENSOR_SIZE_DIM_0 = 64, + parameter KEY_BIAS_TENSOR_SIZE_DIM_1 = 20, + parameter KEY_BIAS_PARALLELISM_DIM_0 = 4, + parameter KEY_BIAS_PARALLELISM_DIM_1 = 4, + parameter KEY_BIAS_PRECISION_0 = 16, + parameter KEY_BIAS_PRECISION_1 = 3, + + parameter VALUE_HAS_BIAS = 0, + parameter VALUE_BIAS_TENSOR_SIZE_DIM_0 = 64, + parameter VALUE_BIAS_TENSOR_SIZE_DIM_1 = 20, + parameter VALUE_BIAS_PARALLELISM_DIM_0 = 4, + parameter VALUE_BIAS_PARALLELISM_DIM_1 = 4, + parameter VALUE_BIAS_PRECISION_0 = 16, + parameter VALUE_BIAS_PRECISION_1 = 3, + + parameter CHOSEN_WEIGHTS_PRE_TRANSPOSED = QUERY_WEIGHTS_PRE_TRANSPOSED, + parameter CHOSEN_WEIGHT_TENSOR_SIZE_DIM_0 = QUERY_WEIGHT_TENSOR_SIZE_DIM_0, + parameter CHOSEN_WEIGHT_TENSOR_SIZE_DIM_1 = QUERY_WEIGHT_TENSOR_SIZE_DIM_1, + parameter CHOSEN_WEIGHT_PARALLELISM_DIM_0 = QUERY_WEIGHT_PARALLELISM_DIM_0, + parameter CHOSEN_WEIGHT_PARALLELISM_DIM_1 = QUERY_WEIGHT_PARALLELISM_DIM_1, + parameter CHOSEN_WEIGHT_PRECISION_0 = QUERY_WEIGHT_PRECISION_0, + parameter CHOSEN_WEIGHT_PRECISION_1 = QUERY_WEIGHT_PRECISION_1, + parameter CHOSEN_HAS_BIAS = QUERY_HAS_BIAS, + parameter CHOSEN_BIAS_TENSOR_SIZE_DIM_0 = QUERY_BIAS_TENSOR_SIZE_DIM_0, + parameter CHOSEN_BIAS_TENSOR_SIZE_DIM_1 = QUERY_BIAS_TENSOR_SIZE_DIM_1, + parameter CHOSEN_BIAS_PARALLELISM_DIM_0 = QUERY_BIAS_PARALLELISM_DIM_0, + parameter CHOSEN_BIAS_PARALLELISM_DIM_1 = QUERY_BIAS_PARALLELISM_DIM_1, + parameter CHOSEN_BIAS_PRECISION_0 = QUERY_BIAS_PRECISION_0, + parameter CHOSEN_BIAS_PRECISION_1 = QUERY_BIAS_PRECISION_1, + + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = CHOSEN_WEIGHT_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = CHOSEN_WEIGHT_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2, + parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0, + parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + +) ( + input logic clk, + input logic rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + // Query weights + input logic [QUERY_WEIGHT_PRECISION_0-1:0] query_weight [QUERY_WEIGHT_PARALLELISM_DIM_0 * QUERY_WEIGHT_PARALLELISM_DIM_1-1:0], + input logic query_weight_valid, + output logic query_weight_ready, + + // Query bias + input logic [QUERY_BIAS_PRECISION_0-1:0] query_bias [QUERY_BIAS_PARALLELISM_DIM_0 * QUERY_BIAS_PARALLELISM_DIM_1 -1:0], + input logic query_bias_valid, + output logic query_bias_ready, + + // Key weights + input logic [KEY_WEIGHT_PRECISION_0-1:0] key_weight [KEY_WEIGHT_PARALLELISM_DIM_0 * KEY_WEIGHT_PARALLELISM_DIM_1-1:0], + input logic key_weight_valid, + output logic key_weight_ready, + + // Key bias + input logic [KEY_BIAS_PRECISION_0-1:0] key_bias [KEY_BIAS_PARALLELISM_DIM_0 * KEY_BIAS_PARALLELISM_DIM_1 -1:0], + input logic key_bias_valid, + output logic key_bias_ready, + + // Value weights + input logic [VALUE_WEIGHT_PRECISION_0-1:0] value_weight [VALUE_WEIGHT_PARALLELISM_DIM_0 * VALUE_WEIGHT_PARALLELISM_DIM_1-1:0], + input logic value_weight_valid, + output logic value_weight_ready, + + // Value bias + input logic [VALUE_BIAS_PRECISION_0-1:0] value_bias [VALUE_BIAS_PARALLELISM_DIM_0 * VALUE_BIAS_PARALLELISM_DIM_1 -1:0], + input logic value_bias_valid, + output logic value_bias_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + +fixed_self_attention #( + .NUM_HEADS (NUM_HEADS), + .ACTIVATION (ACTIVATION), + + .DATA_IN_0_TENSOR_SIZE_DIM_0 (DATA_IN_0_TENSOR_SIZE_DIM_0), + .DATA_IN_0_TENSOR_SIZE_DIM_1 (DATA_IN_0_TENSOR_SIZE_DIM_1), + .DATA_IN_0_PARALLELISM_DIM_0 (DATA_IN_0_PARALLELISM_DIM_0), + .DATA_IN_0_PARALLELISM_DIM_1 (DATA_IN_0_PARALLELISM_DIM_1), + .DATA_IN_0_PRECISION_0 (DATA_IN_0_PRECISION_0), + .DATA_IN_0_PRECISION_1 (DATA_IN_0_PRECISION_1), + + .WEIGHTS_PRE_TRANSPOSED (CHOSEN_WEIGHTS_PRE_TRANSPOSED), + .WEIGHT_TENSOR_SIZE_DIM_0 (CHOSEN_WEIGHT_TENSOR_SIZE_DIM_0), + .WEIGHT_TENSOR_SIZE_DIM_1 (CHOSEN_WEIGHT_TENSOR_SIZE_DIM_1), + .WEIGHT_PARALLELISM_DIM_0 (CHOSEN_WEIGHT_PARALLELISM_DIM_0), + .WEIGHT_PARALLELISM_DIM_1 (CHOSEN_WEIGHT_PARALLELISM_DIM_1), + .WEIGHT_PRECISION_0 (CHOSEN_WEIGHT_PRECISION_0), + .WEIGHT_PRECISION_1 (CHOSEN_WEIGHT_PRECISION_1), + + .HAS_BIAS (CHOSEN_HAS_BIAS), + .BIAS_TENSOR_SIZE_DIM_0 (CHOSEN_BIAS_TENSOR_SIZE_DIM_0), + .BIAS_TENSOR_SIZE_DIM_1 (CHOSEN_BIAS_TENSOR_SIZE_DIM_1), + .BIAS_PARALLELISM_DIM_0 (CHOSEN_BIAS_PARALLELISM_DIM_0), + .BIAS_PARALLELISM_DIM_1 (CHOSEN_BIAS_PARALLELISM_DIM_1), + .BIAS_PRECISION_0 (CHOSEN_BIAS_PRECISION_0), + .BIAS_PRECISION_1 (CHOSEN_BIAS_PRECISION_1), + + .DATA_OUT_0_TENSOR_SIZE_DIM_0 (DATA_OUT_0_TENSOR_SIZE_DIM_0), + .DATA_OUT_0_TENSOR_SIZE_DIM_1 (DATA_OUT_0_TENSOR_SIZE_DIM_1), + .DATA_OUT_0_PARALLELISM_DIM_0 (DATA_OUT_0_PARALLELISM_DIM_0), + .DATA_OUT_0_PARALLELISM_DIM_1 (DATA_OUT_0_PARALLELISM_DIM_1), + .DATA_OUT_0_PRECISION_0 (DATA_OUT_0_PRECISION_0), + .DATA_OUT_0_PRECISION_1 (DATA_OUT_0_PRECISION_1) +) encoder_layer_0_attention_self_inst ( + .clk(clk), + .rst(rst), + + .data_in_0 (data_in_0), + .data_in_0_valid (data_in_0_valid), + .data_in_0_ready (data_in_0_ready), + + .weight_query (query_weight), + .weight_query_valid (query_weight_valid), + .weight_query_ready (query_weight_ready), + + .bias_query (query_bias), + .bias_query_valid (query_bias_valid), + .bias_query_ready (query_bias_ready), + + .weight_key (key_weight), + .weight_key_valid (key_weight_valid), + .weight_key_ready (key_weight_ready), + + .bias_key (key_bias), + .bias_key_valid (key_bias_valid), + .bias_key_ready (key_bias_ready), + + .weight_value (value_weight), + .weight_value_valid (value_weight_valid), + .weight_value_ready (value_weight_ready), + + .bias_value (value_bias), + .bias_value_valid (value_bias_valid), + .bias_value_ready (value_bias_ready), + + .data_out_0 (data_out_0), + .data_out_0_valid (data_out_0_valid), + .data_out_0_ready (data_out_0_ready) +); + +endmodule diff --git a/src/mase_components/attention/rtl/fixed_self_gqa_group.sv b/src/mase_components/attention/rtl/fixed_self_gqa_group.sv new file mode 100644 index 000000000..0f7030d3f --- /dev/null +++ b/src/mase_components/attention/rtl/fixed_self_gqa_group.sv @@ -0,0 +1,255 @@ +/* +Module : fixed_self_gqa_group +Description : Implements a single group in grouped query self-attention (GQA). +*/ + +`timescale 1ns / 1ps + +module fixed_self_gqa_group #( + // GQA Parameters + parameter GROUP_SIZE = 4, + + // Dimensions + parameter TOTAL_EMBEDDING_DIM = 32, + parameter TOTAL_SEQUENCE_DIM = 16, + parameter COMPUTE_EMBEDDING_DIM = 4, + parameter COMPUTE_SEQUENCE_DIM = 4, + + // Input Port Widths + parameter ACT_WIDTH = 8, + parameter ACT_FRAC_WIDTH = 2, + parameter Q_WEIGHT_WIDTH = 8, + parameter Q_WEIGHT_FRAC_WIDTH = 2, + parameter K_WEIGHT_WIDTH = 8, + parameter K_WEIGHT_FRAC_WIDTH = 2, + parameter V_WEIGHT_WIDTH = 8, + parameter V_WEIGHT_FRAC_WIDTH = 2, + + // Output Port Widths + parameter OUT_ACT_WIDTH = 8, + parameter OUT_ACT_FRAC_WIDTH = 2, + + // Intermediate widths + parameter Q_OUT_WIDTH = 16, + parameter Q_OUT_FRAC_WIDTH = 4, + parameter K_OUT_WIDTH = 16, + parameter K_OUT_FRAC_WIDTH = 4, + parameter V_OUT_WIDTH = 16, + parameter V_OUT_FRAC_WIDTH = 4, + parameter QK_OUT_WIDTH = 16, + parameter QK_OUT_FRAC_WIDTH = 4, + parameter SOFTERMAX_POW2_WIDTH = 16, + parameter SOFTERMAX_OUT_WIDTH = 16, + parameter SOFTERMAX_OUT_FRAC_WIDTH = 4, + + localparam TOTAL_HEAD_DIM = TOTAL_EMBEDDING_DIM / GROUP_SIZE, + localparam COMPUTE_HEAD_DIM = COMPUTE_EMBEDDING_DIM / GROUP_SIZE +) ( + input logic clk, + input logic rst, + + // Input activations + input logic [ACT_WIDTH-1:0] act_data [COMPUTE_SEQUENCE_DIM*COMPUTE_EMBEDDING_DIM-1:0], + input logic act_valid, + output logic act_ready, + + // GROUP_SIZE Channels of Query Weights + input logic [Q_WEIGHT_WIDTH-1:0] q_weight_data [GROUP_SIZE-1:0] [COMPUTE_EMBEDDING_DIM*COMPUTE_HEAD_DIM-1:0], + input logic q_weight_valid, + output logic q_weight_ready, + + // Single Channel Key Weights + input logic [K_WEIGHT_WIDTH-1:0] k_weight_data [COMPUTE_EMBEDDING_DIM*COMPUTE_HEAD_DIM-1:0], + input logic k_weight_valid, + output logic k_weight_ready, + + // Single Channel Value Weights + input logic [V_WEIGHT_WIDTH-1:0] v_weight_data [COMPUTE_EMBEDDING_DIM*COMPUTE_HEAD_DIM-1:0], + input logic v_weight_valid, + output logic v_weight_ready, + + // Output Activation Matrix + output logic [OUT_ACT_WIDTH-1:0] out_act_data [GROUP_SIZE-1:0] [COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0], + output logic out_act_valid, + input logic out_act_ready +); + + + initial begin + // Check divisibility + assert (TOTAL_HEAD_DIM * GROUP_SIZE == TOTAL_EMBEDDING_DIM); + assert (COMPUTE_HEAD_DIM * GROUP_SIZE == COMPUTE_EMBEDDING_DIM); + end + + // ----- + // Wires + // ----- + + logic [K_OUT_WIDTH-1:0] k_out_data[COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0]; + logic k_out_valid, k_out_ready; + + logic [V_OUT_WIDTH-1:0] v_out_data[COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0]; + logic v_out_valid, v_out_ready; + + logic [K_OUT_WIDTH-1:0] k_transpose_data[COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0]; + logic k_transpose_valid; + logic k_transpose_ready[GROUP_SIZE-1:0]; + + logic [OUT_ACT_WIDTH-1:0] head_act_data [GROUP_SIZE-1:0] [COMPUTE_SEQUENCE_DIM*COMPUTE_HEAD_DIM-1:0]; + logic head_act_valid[GROUP_SIZE-1:0]; + logic head_act_ready; + + + // ----- + // Modules + // ----- + + matmul #( + // Activations + .A_TOTAL_DIM0 (TOTAL_EMBEDDING_DIM), + .A_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .A_COMPUTE_DIM0(COMPUTE_EMBEDDING_DIM), + .A_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .A_WIDTH (ACT_WIDTH), + .A_FRAC_WIDTH (ACT_FRAC_WIDTH), + // Weights + .B_TOTAL_DIM0 (TOTAL_HEAD_DIM), + .B_TOTAL_DIM1 (TOTAL_EMBEDDING_DIM), + .B_COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .B_COMPUTE_DIM1(COMPUTE_EMBEDDING_DIM), + .B_WIDTH (K_WEIGHT_WIDTH), + .B_FRAC_WIDTH (K_WEIGHT_FRAC_WIDTH), + // Output + .OUT_WIDTH (K_OUT_WIDTH), + .OUT_FRAC_WIDTH(K_OUT_FRAC_WIDTH), + .OUT_SYMMETRIC (0) + ) k_matmul ( + .clk (clk), + .rst (rst), + .a_data (act_data), + .a_valid (act_valid), + .a_ready (act_ready), + .b_data (k_weight_data), + .b_valid (k_weight_valid), + .b_ready (k_weight_ready), + .out_data (k_out_data), + .out_valid(k_out_valid), + .out_ready(k_out_ready) + ); + + matmul #( + // Activations + .A_TOTAL_DIM0 (TOTAL_EMBEDDING_DIM), + .A_TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .A_COMPUTE_DIM0(COMPUTE_EMBEDDING_DIM), + .A_COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .A_WIDTH (ACT_WIDTH), + .A_FRAC_WIDTH (ACT_FRAC_WIDTH), + // Weights + .B_TOTAL_DIM0 (TOTAL_HEAD_DIM), + .B_TOTAL_DIM1 (TOTAL_EMBEDDING_DIM), + .B_COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .B_COMPUTE_DIM1(COMPUTE_EMBEDDING_DIM), + .B_WIDTH (V_WEIGHT_WIDTH), + .B_FRAC_WIDTH (V_WEIGHT_FRAC_WIDTH), + // Output + .OUT_WIDTH (V_OUT_WIDTH), + .OUT_FRAC_WIDTH(V_OUT_FRAC_WIDTH), + .OUT_SYMMETRIC (0) + ) v_matmul ( + .clk (clk), + .rst (rst), + .a_data (act_data), + .a_valid (act_valid), + .a_ready (act_ready), + .b_data (v_weight_data), + .b_valid (v_weight_valid), + .b_ready (v_weight_ready), + .out_data (v_out_data), + .out_valid(v_out_valid), + .out_ready(v_out_ready) + ); + + matrix_stream_transpose #( + .TOTAL_DIM0 (TOTAL_HEAD_DIM), + .TOTAL_DIM1 (TOTAL_SEQUENCE_DIM), + .COMPUTE_DIM0(COMPUTE_HEAD_DIM), + .COMPUTE_DIM1(COMPUTE_SEQUENCE_DIM), + .DATA_WIDTH (K_OUT_WIDTH) + ) k_transpose ( + .clk (clk), + .rst (rst), + .in_data (k_out_data), + .in_valid (k_out_valid), + .in_ready (k_out_ready), + .out_data (k_transpose_data), + .out_valid(k_transpose_valid), + .out_ready(k_transpose_ready[0]) + ); + + + for (genvar head = 0; head < GROUP_SIZE; head++) begin : gqa_heads + + fixed_gqa_head #( + // Dimensions + .TOTAL_EMBEDDING_DIM (TOTAL_EMBEDDING_DIM), + .TOTAL_HEAD_DIM (TOTAL_HEAD_DIM), + .TOTAL_SEQUENCE_DIM (TOTAL_SEQUENCE_DIM), + .COMPUTE_EMBEDDING_DIM (COMPUTE_EMBEDDING_DIM), + .COMPUTE_HEAD_DIM (COMPUTE_HEAD_DIM), + .COMPUTE_SEQUENCE_DIM (COMPUTE_SEQUENCE_DIM), + // Q Activation & Weight Widths + .Q_ACT_WIDTH (ACT_WIDTH), + .Q_ACT_FRAC_WIDTH (ACT_FRAC_WIDTH), + .Q_WEIGHT_WIDTH (Q_WEIGHT_WIDTH), + .Q_WEIGHT_FRAC_WIDTH (Q_WEIGHT_FRAC_WIDTH), + // K Activation Width + .K_ACT_WIDTH (K_OUT_WIDTH), + .K_ACT_FRAC_WIDTH (K_OUT_FRAC_WIDTH), + // V Activation Width + .V_ACT_WIDTH (V_OUT_WIDTH), + .V_ACT_FRAC_WIDTH (V_OUT_FRAC_WIDTH), + // Output Activation Width + .OUT_ACT_WIDTH (OUT_ACT_WIDTH), + .OUT_ACT_FRAC_WIDTH (OUT_ACT_FRAC_WIDTH), + // Intermediate Q Matrix Mult Widths + .Q_OUT_WIDTH (Q_OUT_WIDTH), + .Q_OUT_FRAC_WIDTH (Q_OUT_FRAC_WIDTH), + // Intermediate QK^T Matrix Mult Widths + .QK_OUT_WIDTH (QK_OUT_WIDTH), + .QK_OUT_FRAC_WIDTH (QK_OUT_FRAC_WIDTH), + // Intermediate Softermax Widths + .SOFTERMAX_POW2_WIDTH (SOFTERMAX_POW2_WIDTH), + .SOFTERMAX_OUT_WIDTH (SOFTERMAX_OUT_WIDTH), + .SOFTERMAX_OUT_FRAC_WIDTH(SOFTERMAX_OUT_FRAC_WIDTH) + ) gqa_head_inst ( + .clk (clk), + .rst (rst), + // Q Activation & Weights in + .q_act_data (act_data), + .q_act_valid (act_valid), + .q_act_ready (act_ready), + .q_weight_data (q_weight_data[head]), + .q_weight_valid (q_weight_valid), + .q_weight_ready (q_weight_ready), + // Shared K^T Data + .k_transposed_act_data (k_transpose_data), + .k_transposed_act_valid(k_transpose_valid), + .k_transposed_act_ready(k_transpose_ready[head]), + // Shared V + .v_act_data (v_out_data), + .v_act_valid (v_out_valid), + .v_act_ready (v_out_ready), + .out_act_data (head_act_data[head]), + .out_act_valid (head_act_valid[head]), + .out_act_ready (head_act_ready) + ); + + end + + assign out_act_data = head_act_data; + assign out_act_valid = head_act_valid[0]; + assign head_act_ready = out_act_ready; + + +endmodule diff --git a/src/mase_components/attention/rtl/self_attention_head_gather.sv b/src/mase_components/attention/rtl/self_attention_head_gather.sv new file mode 100644 index 000000000..1432d4aa9 --- /dev/null +++ b/src/mase_components/attention/rtl/self_attention_head_gather.sv @@ -0,0 +1,81 @@ +`timescale 1ns / 1ps +module self_attention_head_gather #( + parameter NUM_HEADS = 12, + + // * Queries, keys and values are assumed to have the same + // * precision, dimensions and parallelism + parameter IN_DATA_TENSOR_SIZE_DIM_0 = 64, + parameter IN_DATA_TENSOR_SIZE_DIM_1 = 32, + parameter IN_DATA_PARALLELISM_DIM_0 = 4, + parameter IN_DATA_PARALLELISM_DIM_1 = 4, + parameter IN_DATA_PRECISION_0 = 16, + parameter IN_DATA_PRECISION_1 = 3 + +) ( + input logic clk, + input logic rst, + + input logic [IN_DATA_PRECISION_0-1:0] split_head_out [NUM_HEADS-1:0] [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0], + input logic [NUM_HEADS-1:0] split_head_out_valid, + output logic [NUM_HEADS-1:0] split_head_out_ready, + + output logic [IN_DATA_PRECISION_0-1:0] updated_tokens [IN_DATA_PARALLELISM_DIM_0*IN_DATA_PARALLELISM_DIM_1-1:0], + output logic updated_tokens_valid, + input logic updated_tokens_ready +); + + parameter IN_DATA_DEPTH = IN_DATA_TENSOR_SIZE_DIM_0 / IN_DATA_PARALLELISM_DIM_0; + parameter BLOCKS_PER_HEAD = IN_DATA_DEPTH / NUM_HEADS; + + // Block counters + logic [NUM_HEADS-1:0][$clog2(BLOCKS_PER_HEAD):0] block_counter; + logic [NUM_HEADS-1:0] heads_flushed; + logic [$clog2(NUM_HEADS)-1:0] head_flushing_idx; + + // * Count the number of blocks received for each head + // * Create head_done mask e.g. 00000111111 (heads that have finished flushing contents) + // * Invert head_done mask e.g. 11111000000 (heads that haven't yet flushed contents) + // * Find first index gives the select signal to drive the output interface + + for (genvar head = 0; head < NUM_HEADS; head++) begin + always_ff @(posedge clk) begin + if (rst) begin + block_counter[head] <= '0; + + // * Increment block counter when accepting a block for a given head + // * But saturate at BLOCKS_PER_HEAD + end else if (split_head_out_valid[head] & split_head_out_ready[head]) begin + block_counter [head] <= (block_counter == BLOCKS_PER_HEAD - 1) ? BLOCKS_PER_HEAD : block_counter[head] + 1'b1; + + // * Reset counter when all heads done + end else if (heads_flushed == '1) begin + block_counter[head] <= '0; + end + end + + // * Create mask of heads with block count saturated at BLOCKS_PER_HEAD + // * (i.e. finished heads) + assign heads_flushed[head] = (block_counter[head] == BLOCKS_PER_HEAD); + end + + // * Find index of first (least significant) head that hasn't yet + // * finished dumping all its blocks + find_first_arbiter #( + .NUM_REQUESTERS(NUM_HEADS) + ) ff_arb_i ( + .request (~heads_flushed), + .grant_oh (), + .grant_bin(head_flushing_idx) + ); + + // * Drive output handshake interface + + assign updated_tokens = split_head_out[head_flushing_idx]; + assign updated_tokens_valid = split_head_out_valid[head_flushing_idx]; + for (genvar head = 0; head < NUM_HEADS; head++) begin + always_comb begin + split_head_out_ready[head] = updated_tokens_ready && (head_flushing_idx == head); + end + end + +endmodule diff --git a/src/mase_components/attention/rtl/self_attention_head_scatter.sv b/src/mase_components/attention/rtl/self_attention_head_scatter.sv new file mode 100644 index 000000000..e2c4bd141 --- /dev/null +++ b/src/mase_components/attention/rtl/self_attention_head_scatter.sv @@ -0,0 +1,107 @@ +`timescale 1ns / 1ps +module self_attention_head_scatter #( + parameter NUM_HEADS = 12, + + // * Queries, keys and values are assumed to have the same + // * precision, dimensions and parallelism + parameter IN_DATA_TENSOR_SIZE_DIM_0 = 64, + parameter IN_DATA_TENSOR_SIZE_DIM_1 = 32, + parameter IN_DATA_PARALLELISM_DIM_0 = 4, + parameter IN_DATA_PARALLELISM_DIM_1 = 4 +) ( + input logic clk, + input logic rst, + + input logic query_valid, + output logic query_ready, + + input logic key_valid, + output logic key_ready, + + input logic value_valid, + output logic value_ready, + + output logic [NUM_HEADS-1:0] split_query_valid, + input logic [NUM_HEADS-1:0] split_query_ready, + + output logic [NUM_HEADS-1:0] split_key_valid, + input logic [NUM_HEADS-1:0] split_key_ready, + + output logic [NUM_HEADS-1:0] split_value_valid, + input logic [NUM_HEADS-1:0] split_value_ready +); + + parameter IN_DATA_DEPTH = IN_DATA_TENSOR_SIZE_DIM_0 / IN_DATA_PARALLELISM_DIM_0; + parameter BLOCKS_PER_HEAD = IN_DATA_DEPTH / NUM_HEADS; + + // Block counters + logic [$clog2(BLOCKS_PER_HEAD):0] query_block_cnt; + logic [$clog2(BLOCKS_PER_HEAD):0] key_block_cnt; + logic [$clog2(BLOCKS_PER_HEAD):0] value_block_cnt; + + // Head counters + logic [$clog2(NUM_HEADS):0] query_head_cnt; + logic [$clog2(NUM_HEADS):0] key_head_cnt; + logic [$clog2(NUM_HEADS):0] value_head_cnt; + + // * Increment block and head counters + + always_ff @(posedge clk) begin + if (rst) begin + query_block_cnt <= '0; + key_block_cnt <= '0; + value_block_cnt <= '0; + + query_head_cnt <= '0; + key_head_cnt <= '0; + value_head_cnt <= '0; + end else begin + // Increment query counter + if (query_valid && query_ready) begin + query_block_cnt <= (query_block_cnt == BLOCKS_PER_HEAD - 1) ? '0 : query_block_cnt + 1'b1; + + if (query_block_cnt == BLOCKS_PER_HEAD - 1) begin + query_head_cnt <= (query_head_cnt == NUM_HEADS - 1) ? '0 : query_head_cnt + 1'b1; + end + end + + // Increment key counter + if (key_valid && key_ready) begin + key_block_cnt <= (key_block_cnt == BLOCKS_PER_HEAD - 1) ? '0 : key_block_cnt + 1'b1; + + if (key_block_cnt == BLOCKS_PER_HEAD - 1) begin + key_head_cnt <= (key_head_cnt == NUM_HEADS - 1) ? '0 : key_head_cnt + 1'b1; + end + end + + // Increment query counter + if (value_valid && value_ready) begin + value_block_cnt <= (value_block_cnt == BLOCKS_PER_HEAD - 1) ? '0 : value_block_cnt + 1'b1; + + if (value_block_cnt == BLOCKS_PER_HEAD - 1) begin + value_head_cnt <= (value_head_cnt == NUM_HEADS - 1) ? '0 : value_head_cnt + 1'b1; + end + end + + end + end + + // * Drive split QKV handshake interface + + for (genvar head = 0; head < NUM_HEADS; head++) begin + always_comb begin + split_query_valid[head] = query_valid && (query_head_cnt == head); + split_key_valid[head] = key_valid && (key_head_cnt == head); + split_value_valid[head] = value_valid && (value_head_cnt == head); + end + end + + always_comb begin + query_ready = split_query_ready[query_head_cnt]; + key_ready = split_key_ready[key_head_cnt]; + value_ready = split_value_ready[value_head_cnt]; + end + + + +endmodule diff --git a/src/mase_components/attention/test/fixed_gqa_head_tb.py b/src/mase_components/attention/test/fixed_gqa_head_tb.py new file mode 100644 index 000000000..cefe4a6de --- /dev/null +++ b/src/mase_components/attention/test/fixed_gqa_head_tb.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python3 + +import logging +from math import ceil + +import torch +import numpy as np + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.matrix_tools import ( + gen_random_matrix_input, + rebuild_matrix, + split_matrix, +) +from mase_cocotb.utils import bit_driver, sign_extend_t, batched, signed_to_unsigned +from mase_cocotb.interfaces.streaming import StreamDriver, ErrorThresholdStreamMonitor + +import cocotb +from cocotb.triggers import * + +from chop.nn.quantized.functional import fixed_softermax + +from chop.passes.graph.transforms.quantize.quantizers.integer import ( + integer_floor_quantizer, +) +from chop.passes.graph.transforms.quantize.quantizers.quantizers_for_hw import ( + unsigned_integer_quantizer_for_hw, +) + +logger = logging.getLogger("testbench") +logger.setLevel("INFO") + + +class FixedGQAHeadTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + [ + "TOTAL_EMBEDDING_DIM", + "TOTAL_HEAD_DIM", + "TOTAL_SEQUENCE_DIM", + "COMPUTE_EMBEDDING_DIM", + "COMPUTE_HEAD_DIM", + "COMPUTE_SEQUENCE_DIM", + "Q_ACT_WIDTH", + "Q_ACT_FRAC_WIDTH", + "Q_WEIGHT_WIDTH", + "Q_WEIGHT_FRAC_WIDTH", + "K_ACT_WIDTH", + "K_ACT_FRAC_WIDTH", + "V_ACT_WIDTH", + "V_ACT_FRAC_WIDTH", + "OUT_ACT_WIDTH", + "OUT_ACT_FRAC_WIDTH", + "Q_OUT_WIDTH", + "Q_OUT_FRAC_WIDTH", + "QK_OUT_WIDTH", + "QK_OUT_FRAC_WIDTH", + "SOFTERMAX_POW2_WIDTH", + "SOFTERMAX_OUT_WIDTH", + "SOFTERMAX_OUT_FRAC_WIDTH", + "EMBEDDING_DEPTH", + "HEAD_DEPTH", + "SEQUENCE_DEPTH", + ] + ) + + # Additional Params + self.q_act_num_iters = self.EMBEDDING_DEPTH * self.SEQUENCE_DEPTH + self.q_weight_num_iters = self.HEAD_DEPTH * self.EMBEDDING_DEPTH + self.k_transpose_num_iters = self.SEQUENCE_DEPTH * self.HEAD_DEPTH + self.v_act_num_iters = self.HEAD_DEPTH * self.SEQUENCE_DEPTH + + self.q_act_dims = dict( + total_dim0=self.TOTAL_EMBEDDING_DIM, + total_dim1=self.TOTAL_SEQUENCE_DIM, + compute_dim0=self.COMPUTE_EMBEDDING_DIM, + compute_dim1=self.COMPUTE_SEQUENCE_DIM, + ) + self.q_weight_dims = dict( + total_dim0=self.TOTAL_HEAD_DIM, + total_dim1=self.TOTAL_EMBEDDING_DIM, + compute_dim0=self.COMPUTE_HEAD_DIM, + compute_dim1=self.COMPUTE_EMBEDDING_DIM, + ) + self.k_transpose_act_dims = dict( + total_dim0=self.TOTAL_SEQUENCE_DIM, + total_dim1=self.TOTAL_HEAD_DIM, + compute_dim0=self.COMPUTE_SEQUENCE_DIM, + compute_dim1=self.COMPUTE_HEAD_DIM, + ) + self.v_act_dims = dict( + total_dim0=self.TOTAL_HEAD_DIM, + total_dim1=self.TOTAL_SEQUENCE_DIM, + compute_dim0=self.COMPUTE_HEAD_DIM, + compute_dim1=self.COMPUTE_SEQUENCE_DIM, + ) + self.out_act_dims = dict( + total_dim0=self.TOTAL_HEAD_DIM, + total_dim1=self.TOTAL_SEQUENCE_DIM, + compute_dim0=self.COMPUTE_HEAD_DIM, + compute_dim1=self.COMPUTE_SEQUENCE_DIM, + ) + + self.q_act_widths = dict( + width=self.Q_ACT_WIDTH, + frac_width=self.Q_ACT_FRAC_WIDTH, + ) + self.q_weight_widths = dict( + width=self.Q_WEIGHT_WIDTH, + frac_width=self.Q_WEIGHT_FRAC_WIDTH, + ) + self.k_transpose_act_widths = dict( + width=self.K_ACT_WIDTH, + frac_width=self.K_ACT_FRAC_WIDTH, + ) + self.v_act_widths = dict( + width=self.V_ACT_WIDTH, + frac_width=self.V_ACT_FRAC_WIDTH, + ) + + # Driver/Monitors + self.q_act_driver = StreamDriver( + dut.clk, + dut.q_act_data, + dut.q_act_valid, + dut.q_act_ready, + ) + self.q_weight_driver = StreamDriver( + dut.clk, + dut.q_weight_data, + dut.q_weight_valid, + dut.q_weight_ready, + ) + self.k_tranposed_act_driver = StreamDriver( + dut.clk, + dut.k_transposed_act_data, + dut.k_transposed_act_valid, + dut.k_transposed_act_ready, + ) + self.v_act_driver = StreamDriver( + dut.clk, + dut.v_act_data, + dut.v_act_valid, + dut.v_act_ready, + ) + + # Specify Error Threshold + self.percentage_error = 0.05 # 5% + self.error_threshold_bits = ceil( + self.percentage_error * (2**self.OUT_ACT_WIDTH) + ) + + self.output_monitor = ErrorThresholdStreamMonitor( + dut.clk, + dut.out_act_data, + dut.out_act_valid, + dut.out_act_ready, + width=self.OUT_ACT_WIDTH, + signed=True, + error_bits=1, # 1 bit rounding error + log_error=True, + check=True, + ) + + def generate_inputs(self, batches=10): + q_act = [] + q_weight = [] + k_transpose_act = [] + v_act = [] + + for _ in range(batches): + q_act.extend( + gen_random_matrix_input(**self.q_act_dims, **self.q_act_widths) + ) + q_weight.extend( + gen_random_matrix_input(**self.q_weight_dims, **self.q_weight_widths) + ) + k_transpose_act.extend( + gen_random_matrix_input( + **self.k_transpose_act_dims, **self.k_transpose_act_widths + ) + ) + v_act.extend( + gen_random_matrix_input(**self.v_act_dims, **self.v_act_widths) + ) + + return { + "q_act": q_act, + "q_weight": q_weight, + "k_transpose_act": k_transpose_act, + "v_act": v_act, + } + + def model(self, inputs: dict[str, list]): + + def _reconstruct( + input_list, + num_iters, + total_dim0, + total_dim1, + compute_dim0, + compute_dim1, + width, + frac_width, + ): + matrix_list = [] + for mat in batched(input_list, n=num_iters): + matrix_list.append( + rebuild_matrix( + x=mat, + total_dim0=total_dim0, + total_dim1=total_dim1, + compute_dim0=compute_dim0, + compute_dim1=compute_dim1, + ) + ) + matrix_t = torch.stack(matrix_list) + signed_matrix = sign_extend_t(matrix_t, bits=width) + scaled_matrix = signed_matrix.float() / (2**frac_width) + return scaled_matrix + + q_act = _reconstruct( + input_list=inputs["q_act"], + num_iters=self.q_act_num_iters, + **self.q_act_dims, + **self.q_act_widths, + ) + q_weight = _reconstruct( + input_list=inputs["q_weight"], + num_iters=self.q_weight_num_iters, + **self.q_weight_dims, + **self.q_weight_widths, + ) + k_transpose_act = _reconstruct( + input_list=inputs["k_transpose_act"], + num_iters=self.k_transpose_num_iters, + **self.k_transpose_act_dims, + **self.k_transpose_act_widths, + ) + v_act = _reconstruct( + input_list=inputs["v_act"], + num_iters=self.v_act_num_iters, + **self.v_act_dims, + **self.v_act_widths, + ) + + logger.debug("q_act: %s" % q_act) + logger.debug("q_weight: %s" % q_weight) + logger.debug("k_transpose_act: %s" % k_transpose_act) + logger.debug("v_act: %s" % v_act) + + q_out = torch.matmul(q_act, q_weight) + q_out = integer_floor_quantizer( + x=q_out, + width=self.Q_OUT_WIDTH, + frac_width=self.Q_OUT_FRAC_WIDTH, + is_signed=True, + ) + + qk_out = torch.matmul(q_out, k_transpose_act) + qk_out = integer_floor_quantizer( + x=qk_out, + width=self.QK_OUT_WIDTH, + frac_width=self.QK_OUT_FRAC_WIDTH, + is_signed=True, + ) + + softermax_out = fixed_softermax( + input=qk_out, + q_config={ + "width": self.QK_OUT_WIDTH, + "frac_width": self.QK_OUT_FRAC_WIDTH, + }, + dim=2, + ) + softermax_out = integer_floor_quantizer( + x=softermax_out, + width=self.SOFTERMAX_OUT_WIDTH, + frac_width=self.SOFTERMAX_OUT_FRAC_WIDTH, + is_signed=False, + ) + + attention_out = torch.matmul(softermax_out, v_act) + attention_out = integer_floor_quantizer( + x=attention_out, + width=self.OUT_ACT_WIDTH, + frac_width=self.OUT_ACT_FRAC_WIDTH, + is_signed=True, + ) + + logger.debug("q_out: %s" % q_out) + logger.debug("qk_out: %s" % qk_out) + logger.debug("softermax_out: %s" % softermax_out) + logger.debug("attention_out: %s" % attention_out) + + # Process output + rounded_atten = integer_floor_quantizer( + x=attention_out, + width=self.OUT_ACT_WIDTH, + frac_width=self.OUT_ACT_FRAC_WIDTH, + is_signed=True, + ) + atten_int = (rounded_atten * (2**self.OUT_ACT_FRAC_WIDTH)).int() + atten_uint = signed_to_unsigned(atten_int, bits=self.OUT_ACT_WIDTH) + logger.debug("rounded_atten: %s" % rounded_atten) + logger.debug("atten_int: %s" % atten_int) + logger.debug("atten_uint: %s" % atten_uint) + + exp_out = [] + for output_matrix in atten_uint: + exp_out.extend(split_matrix(output_matrix, **self.out_act_dims)) + return exp_out + + async def run_test(self, batches, us): + inputs = self.generate_inputs(batches) + # Load Drivers + self.q_act_driver.load_driver(inputs["q_act"]) + self.q_weight_driver.load_driver(inputs["q_weight"]) + self.k_tranposed_act_driver.load_driver(inputs["k_transpose_act"]) + self.v_act_driver.load_driver(inputs["v_act"]) + # Get expectation from model + exp_out = self.model(inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, "us") + assert self.output_monitor.recv_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = FixedGQAHeadTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=3, us=10) + + +@cocotb.test() +async def stream(dut): + tb = FixedGQAHeadTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=200, us=2000) + + +@cocotb.test() +async def backpressure(dut): + tb = FixedGQAHeadTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + await tb.reset() + await tb.run_test(batches=200, us=2000) + + +@cocotb.test() +async def valid(dut): + tb = FixedGQAHeadTB(dut) + tb.output_monitor.ready.value = 1 + tb.q_act_driver.set_valid_prob(0.5) + tb.q_weight_driver.set_valid_prob(0.5) + tb.k_tranposed_act_driver.set_valid_prob(0.5) + tb.v_act_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=200, us=2000) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = FixedGQAHeadTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.q_act_driver.set_valid_prob(0.5) + tb.q_weight_driver.set_valid_prob(0.5) + tb.k_tranposed_act_driver.set_valid_prob(0.5) + tb.v_act_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=200, us=2000) + + +if __name__ == "__main__": + + def width_cfgs(prefix: str, cfgs: list[dict]): + new_cfgs = [] + for cfg in cfgs: + new_cfgs.append({**cfg, f"{prefix}_WIDTH": 8, f"{prefix}_FRAC_WIDTH": 4}) + new_cfgs.append({**cfg, f"{prefix}_WIDTH": 16, f"{prefix}_FRAC_WIDTH": 8}) + return new_cfgs + + def dimension_cfgs(cfgs: list[dict]): + new_cfgs = [] + for cfg in cfgs: + new_cfgs.append( + { + **cfg, + "TOTAL_EMBEDDING_DIM": 64, + "TOTAL_HEAD_DIM": 16, + "TOTAL_SEQUENCE_DIM": 8, + } + ) + new_cfgs.append( + { + **cfg, + "TOTAL_EMBEDDING_DIM": 16, + "TOTAL_HEAD_DIM": 8, + "TOTAL_SEQUENCE_DIM": 16, + } + ) + return new_cfgs + + def compute_dim_cfgs(cfgs: list[dict]): + new_cfgs = [] + for cfg in cfgs: + new_cfgs.append( + { + **cfg, + "COMPUTE_EMBEDDING_DIM": 2, + "COMPUTE_HEAD_DIM": 2, + "COMPUTE_SEQUENCE_DIM": 2, + } + ) + new_cfgs.append( + { + **cfg, + "COMPUTE_EMBEDDING_DIM": 4, + "COMPUTE_HEAD_DIM": 4, + "COMPUTE_SEQUENCE_DIM": 4, + } + ) + return new_cfgs + + DEFAULT = { + # Dimensions + "TOTAL_EMBEDDING_DIM": 16, + "TOTAL_HEAD_DIM": 4, + "TOTAL_SEQUENCE_DIM": 4, # Number of tokens + "COMPUTE_EMBEDDING_DIM": 2, + "COMPUTE_HEAD_DIM": 2, + "COMPUTE_SEQUENCE_DIM": 2, + # Input Widths + "Q_ACT_WIDTH": 8, + "Q_ACT_FRAC_WIDTH": 4, + "Q_WEIGHT_WIDTH": 8, + "Q_WEIGHT_FRAC_WIDTH": 4, + "K_ACT_WIDTH": 8, + "K_ACT_FRAC_WIDTH": 2, + "V_ACT_WIDTH": 8, + "V_ACT_FRAC_WIDTH": 4, + # Output Widths + "OUT_ACT_WIDTH": 8, + "OUT_ACT_FRAC_WIDTH": 2, + # Intermediate Widths + "Q_OUT_WIDTH": 16, + "Q_OUT_FRAC_WIDTH": 4, + "QK_OUT_WIDTH": 16, + "QK_OUT_FRAC_WIDTH": 4, + "SOFTERMAX_POW2_WIDTH": 16, + "SOFTERMAX_OUT_WIDTH": 16, + "SOFTERMAX_OUT_FRAC_WIDTH": 15, + } + + cfgs = [DEFAULT] + cfgs = dimension_cfgs(cfgs) + cfgs = compute_dim_cfgs(cfgs) + for prefix in ["Q_ACT", "Q_WEIGHT", "K_ACT", "V_ACT", "OUT_ACT"]: + cfgs = width_cfgs(prefix, cfgs) + + cfgs = [ + { + "TOTAL_EMBEDDING_DIM": 64, + "TOTAL_HEAD_DIM": 16, + "TOTAL_SEQUENCE_DIM": 8, + "COMPUTE_EMBEDDING_DIM": 2, + "COMPUTE_HEAD_DIM": 2, + "COMPUTE_SEQUENCE_DIM": 2, + "Q_ACT_WIDTH": 8, + "Q_ACT_FRAC_WIDTH": 4, + "Q_WEIGHT_WIDTH": 8, + "Q_WEIGHT_FRAC_WIDTH": 4, + "K_ACT_WIDTH": 8, + "K_ACT_FRAC_WIDTH": 4, + "V_ACT_WIDTH": 8, + "V_ACT_FRAC_WIDTH": 4, + "OUT_ACT_WIDTH": 8, + "OUT_ACT_FRAC_WIDTH": 4, + "Q_OUT_WIDTH": 16, + "Q_OUT_FRAC_WIDTH": 4, + "QK_OUT_WIDTH": 16, + "QK_OUT_FRAC_WIDTH": 4, + "SOFTERMAX_POW2_WIDTH": 16, + "SOFTERMAX_OUT_WIDTH": 16, + "SOFTERMAX_OUT_FRAC_WIDTH": 15, + } + ] + print(f"Running Tests on {len(cfgs)} Configs...") + + mase_runner( + module_param_list=cfgs, + # trace=True, + jobs=12, + ) diff --git a/src/mase_components/attention/test/fixed_self_att_tb.py b/src/mase_components/attention/test/fixed_self_att_tb.py deleted file mode 100644 index 0005fefb4..000000000 --- a/src/mase_components/attention/test/fixed_self_att_tb.py +++ /dev/null @@ -1,517 +0,0 @@ -#!/usr/bin/env python3 - -import os, logging -import torch - -# from torchsummary import summary -from einops import rearrange - -import cocotb -from cocotb.triggers import Timer -from cocotb.triggers import FallingEdge -from cocotb.clock import Clock - -from mase_components.ViT.test.helpers.ha_softmax import ( - generate_table_hardware, - generate_table_div_hardware, -) -from mase_components.ViT.test.helpers.pvt_quant import QuantizedAttention - -from mase_cocotb.runner import mase_runner -from mase_cocotb.random_test import RandomSource, RandomSink, check_results -from mase_cocotb.z_qlayers import quantize_to_int as q2i - -debug = True - -logger = logging.getLogger("tb_signals") -if debug: - logger.setLevel(logging.DEBUG) - - -# DUT test specifications -class VerificationCase: - def __init__(self, samples=1): - # width config - self.samples = samples - self.data_in_width = 8 - self.data_in_frac_width = 5 - self.weight_q_width = 6 - self.weight_q_frac_width = 6 - self.weight_k_width = 6 - self.weight_k_frac_width = 6 - self.weight_v_width = 6 - self.weight_v_frac_width = 6 - - self.bias_q_width = 6 - self.bias_q_frac_width = 6 - self.bias_k_width = 6 - self.bias_k_frac_width = 6 - self.bias_v_width = 6 - self.bias_v_frac_width = 6 - - self.data_q_width = 8 - self.data_q_frac_width = 6 - self.data_k_width = 8 - self.data_k_frac_width = 6 - self.data_v_width = 8 - self.data_v_frac_width = 6 - self.data_s_width = 8 - self.data_s_frac_width = 6 - self.exp_width = 8 - self.exp_frac_width = 5 - self.div_width = 10 - self.data_s_softmax_width = 8 - self.data_s_softmax_width = 7 - self.data_z_width = 8 - self.data_z_frac_width = 6 - self.w_config = { - "q_proj": { - "name": "integer", - "weight_width": self.weight_q_width, - "weight_frac_width": self.weight_q_frac_width, - "data_in_width": self.data_in_width, - "data_in_frac_width": self.data_in_frac_width, - "bias_width": self.bias_q_width, - "bias_frac_width": self.bias_q_frac_width, - }, - "kv_proj": { - "name": "integer", - "weight_width": self.weight_k_width, - "weight_frac_width": self.weight_k_frac_width, - "data_in_width": self.data_in_width, - "data_in_frac_width": self.data_in_frac_width, - "bias_width": self.bias_k_width, - "bias_frac_width": self.bias_k_frac_width, - }, - "attn_matmul": { - "name": "integer", - "data_in_width": self.data_q_width, - "data_in_frac_width": self.data_q_frac_width, - "weight_width": self.data_k_width, - "weight_frac_width": self.data_k_frac_width, - }, - "z_matmul": { - "name": "integer", - "data_in_width": self.data_s_softmax_width, - "data_in_frac_width": self.data_s_softmax_width, - "weight_width": self.data_v_width, - "weight_frac_width": self.data_v_frac_width, - }, - "softmax": { - "exp_width": self.exp_width, - "exp_frac_width": self.exp_frac_width, - "div_width": self.div_width, - "data_in_width": self.data_s_width, - "data_in_frac_width": self.data_s_frac_width, - "data_out_width": self.data_s_softmax_width, - "data_out_frac_width": self.data_s_softmax_width, - }, - } - - self.in_parallelism = 1 - self.in_num_parallelism = 2 - - self.in_size = 4 - self.in_depth = 2 - - self.w_parallelism = 4 - self.w_num_parallelism = 2 - ( - test_in, - test_wq, - test_wk, - test_wv, - test_bq, - test_bk, - test_bv, - ) = self.att_data_generate() - self.soft_max_data_generate(self.att.scale) - self.data_in = RandomSource( - name="data_in", - samples=samples * self.in_depth * self.in_num_parallelism, - num=self.in_parallelism * self.in_size, - max_stalls=2 * samples * self.in_depth * self.in_num_parallelism, - data_specify=test_in, - debug=debug, - ) - self.weight_q = RandomSource( - name="weight_q", - samples=samples * self.in_depth * self.w_num_parallelism, - num=self.w_parallelism * self.in_size, - max_stalls=2 * samples * self.in_depth * self.w_num_parallelism, - data_specify=test_wq, - debug=debug, - ) - self.weight_k = RandomSource( - name="weight_k", - samples=samples * self.in_depth * self.w_num_parallelism, - num=self.w_parallelism * self.in_size, - max_stalls=2 * samples * self.in_depth * self.w_num_parallelism, - data_specify=test_wk, - debug=debug, - ) - self.weight_v = RandomSource( - name="weight_v", - samples=samples * self.in_depth * self.w_num_parallelism, - num=self.w_parallelism * self.in_size, - max_stalls=2 * samples * self.in_depth * self.w_num_parallelism, - data_specify=test_wv, - debug=debug, - ) - self.bias_q = RandomSource( - name="bias_q", - samples=samples * self.w_num_parallelism, - num=self.w_parallelism, - max_stalls=2 * samples, - data_specify=test_bq, - debug=debug, - ) - self.bias_k = RandomSource( - name="bias_k", - samples=samples * self.w_num_parallelism, - num=self.w_parallelism, - max_stalls=2 * samples, - data_specify=test_bk, - debug=debug, - ) - self.bias_v = RandomSource( - name="bias_v", - samples=samples * self.w_num_parallelism, - num=self.w_parallelism, - max_stalls=2 * samples, - data_specify=test_bv, - debug=debug, - ) - - ## remain modification - self.outputs = RandomSink( - samples=samples * self.in_num_parallelism * self.w_num_parallelism, - max_stalls=2 * samples * self.in_num_parallelism * self.w_num_parallelism, - debug=debug, - ) - self.samples = samples - self.ref = self.sw_compute() - - def get_dut_parameters(self): - return { - "DATA_WIDTH": self.data_in_width, - "DATA_FRAC_WIDTH": self.data_in_frac_width, - "WQ_WIDTH": self.weight_q_width, - "WQ_FRAC_WIDTH": self.weight_q_frac_width, - "WK_WIDTH": self.weight_k_width, - "WK_FRAC_WIDTH": self.weight_k_frac_width, - "WV_WIDTH": self.weight_v_width, - "WV_FRAC_WIDTH": self.weight_v_frac_width, - "BQ_WIDTH": self.bias_q_width, - "BQ_FRAC_WIDTH": self.bias_q_frac_width, - "BK_WIDTH": self.bias_k_width, - "BK_FRAC_WIDTH": self.bias_k_frac_width, - "BV_WIDTH": self.bias_v_width, - "BV_FRAC_WIDTH": self.bias_v_frac_width, - "DQ_WIDTH": self.data_q_width, - "DQ_FRAC_WIDTH": self.data_q_frac_width, - "DK_WIDTH": self.data_k_width, - "DK_FRAC_WIDTH": self.data_k_frac_width, - "DV_WIDTH": self.data_v_width, - "DV_FRAC_WIDTH": self.data_v_frac_width, - "DS_WIDTH": self.w_config["softmax"]["data_in_width"], - "DS_FRAC_WIDTH": self.w_config["softmax"]["data_in_frac_width"], - "EXP_WIDTH": self.w_config["softmax"]["exp_width"], - "EXP_FRAC_WIDTH": self.w_config["softmax"]["exp_frac_width"], - "DIV_WIDTH": self.w_config["softmax"]["div_width"], - "DS_SOFTMAX_WIDTH": self.w_config["softmax"]["data_out_width"], - "DS_SOFTMAX_FRAC_WIDTH": self.w_config["softmax"]["data_out_frac_width"], - "DZ_WIDTH": self.data_z_width, - "DZ_FRAC_WIDTH": self.data_z_frac_width, - "IN_PARALLELISM": self.in_parallelism, - "IN_NUM_PARALLELISM": self.in_num_parallelism, - "W_PARALLELISM": self.w_parallelism, - "W_NUM_PARALLELISM": self.w_num_parallelism, - "IN_SIZE": self.in_size, - "IN_DEPTH": self.in_depth, - } - - def sw_compute(self): - # get the matrix out result - # from M[num_parallelism][depth], - # and the element in M is m[parallelism][size] - # to M_out[in1_num_parallelism][in2_num_parallelism] - # the element in M_out is m_out[in1_parallelism][in2_parallelism] - - # collect all the input - # breakpoint() - data_out = self.att(self.x) - output = self.data_pack( - q2i(data_out, self.data_z_width, self.data_z_frac_width), - self.in_num_parallelism, - self.w_num_parallelism, - self.in_parallelism, - self.w_parallelism, - ) - return output - - def att_data_generate(self): - samples = self.samples - # torch.manual_seed(0) - in_y = self.in_num_parallelism * self.in_parallelism - in_x = self.in_size * self.in_depth - w_y = self.w_num_parallelism * self.w_parallelism - self.x = torch.randn((samples, in_y, in_x)) - self.att = QuantizedAttention( - dim=in_x, - num_heads=1, - qkv_bias=True, - attn_drop=0.0, - proj_drop=0.0, - config=self.w_config, - ) - - input_tensor = q2i(self.x, self.data_in_width, self.data_in_frac_width) - wq = q2i( - self.att.q.weight, self.weight_q_width, self.weight_q_frac_width - ).repeat(samples, 1, 1) - wkv = self.att.kv.weight.reshape(2, w_y, w_y) - wk, wv = wkv[0], wkv[1] - wk = q2i(wk, self.weight_k_width, self.weight_k_frac_width).repeat( - samples, 1, 1 - ) - wv = q2i(wv, self.weight_v_width, self.weight_v_frac_width).repeat( - samples, 1, 1 - ) - - bq = q2i(self.att.q.bias, self.weight_q_width, self.weight_q_frac_width).repeat( - samples, 1 - ) - bkv = self.att.kv.bias.reshape(2, w_y) - bk, bv = bkv[0], bkv[1] - bk = q2i(bk, self.bias_k_width, self.bias_k_frac_width).repeat(samples, 1) - bv = q2i(bv, self.bias_v_width, self.bias_v_frac_width).repeat(samples, 1) - - in_num_parallelism = self.in_num_parallelism - in_depth = self.in_depth - in_parallelism = self.in_parallelism - in_size = self.in_size - w_parallelism = self.w_parallelism - w_num_parallelism = self.w_num_parallelism - - data_in = self.data_pack( - input_tensor, in_num_parallelism, in_depth, in_parallelism, in_size - ) - wq_in = self.data_pack(wq, w_num_parallelism, in_depth, w_parallelism, in_size) - wk_in = self.data_pack(wk, w_num_parallelism, in_depth, w_parallelism, in_size) - wv_in = self.data_pack(wv, w_num_parallelism, in_depth, w_parallelism, in_size) - - bq_in = self.data_pack(bq, 1, w_num_parallelism, 1, w_parallelism) - bk_in = self.data_pack(bk, 1, w_num_parallelism, 1, w_parallelism) - bv_in = self.data_pack(bv, 1, w_num_parallelism, 1, w_parallelism) - data_in.reverse() - wq_in.reverse() - wk_in.reverse() - wv_in.reverse() - bq_in.reverse() - bk_in.reverse() - bv_in.reverse() - return ( - data_in, - wq_in, - wk_in, - wv_in, - bq_in, - bk_in, - bv_in, - ) - - def soft_max_data_generate(self, scale): - # generate mem_init - exp_table = generate_table_hardware( - scale, - self.w_config["softmax"]["data_in_width"], - self.w_config["softmax"]["data_in_frac_width"], - self.w_config["softmax"]["exp_width"], - self.w_config["softmax"]["exp_frac_width"], - ).tolist() - div_table = generate_table_div_hardware( - self.w_config["softmax"]["div_width"], - self.w_config["softmax"]["data_out_width"], - self.w_config["softmax"]["data_out_frac_width"], - ).tolist() - with open(r"exp_init.mem", "w") as fp: - for item in exp_table: - # write each item on a new lineformat(addr[i] ,f'0{width}b' - fp.write( - "%s\n" - % format(item, f'0{self.w_config["softmax"]["exp_width"]//4}x') - ) - with open(r"div_init.mem", "w") as fp: - for item in div_table: - # write each item on a new line - fp.write( - "%s\n" - % format(item, f'0{self.w_config["softmax"]["data_out_width"]//4}x') - ) - - def data_pack(self, in_temp, np, d, p, s): - # assum in_temp.shape = (samples, batch = 1, N,dim) - in_temp = in_temp.to(torch.int).reshape(self.samples, np * p, d * s) - ref = [] - for i in range(self.samples): - re_tensor = rearrange( - in_temp[i], "(np p) (d s) -> np (p d) s", np=np, d=d, p=p, s=s - ) - ex_tensor = torch.zeros(np, d * p, s, dtype=int) - for b in range(np): - for i in range(d): - for j in range(p): - ex_tensor[b][i * p + j] = re_tensor[b][j * d + i] - - output_tensor = rearrange( - ex_tensor, "np (d p) s -> (np d) (p s)", np=np, d=d, p=p, s=s - ) - output = output_tensor.tolist() - ref = ref + output - return ref - - -def debug_state(dut, state): - logger.debug( - "{} State: (wq_ready,wq_valid,wk_ready,wk_valid,wv_ready,wv_valid,in_ready,in_valid,data_out_ready,data_out_valid) = ({},{},{},{},{},{},{},{},{},{})".format( - state, - dut.weight_q_ready.value, - dut.weight_q_valid.value, - dut.weight_k_ready.value, - dut.weight_k_valid.value, - dut.weight_v_ready.value, - dut.weight_v_valid.value, - dut.data_in_ready.value, - dut.data_in_valid.value, - dut.data_out_ready.value, - dut.data_out_valid.value, - ) - ) - - -@cocotb.test() -async def cocotb_test_att(dut): - """Test integer based vector mult""" - samples = 30 - test_case = VerificationCase(samples=samples) - # Reset cycle - await Timer(20, units="ns") - dut.rst.value = 1 - await Timer(100, units="ns") - dut.rst.value = 0 - - # Create a 10ns-period clock on port clk - clock = Clock(dut.clk, 10, units="ns") - # Start the clock - cocotb.start_soon(clock.start()) - await Timer(500, units="ns") - - # Synchronize with the clock - dut.weight_q_valid.value = 0 - dut.weight_k_valid.value = 0 - dut.weight_v_valid.value = 0 - dut.bias_q_valid.value = 0 - dut.bias_k_valid.value = 0 - dut.bias_v_valid.value = 0 - dut.data_in_valid.value = 0 - dut.data_out_ready.value = 1 - debug_state(dut, "Pre-clk") - await FallingEdge(dut.clk) - debug_state(dut, "Post-clk") - debug_state(dut, "Pre-clk") - await FallingEdge(dut.clk) - debug_state(dut, "Post-clk") - done = False - # Set a timeout to avoid deadlock - for i in range(samples * 100): - await FallingEdge(dut.clk) - # breakpoint() - dut.weight_q_valid.value = test_case.weight_q.pre_compute() - dut.weight_k_valid.value = test_case.weight_k.pre_compute() - dut.weight_v_valid.value = test_case.weight_v.pre_compute() - dut.bias_q_valid.value = test_case.bias_q.pre_compute() - dut.bias_k_valid.value = test_case.bias_k.pre_compute() - dut.bias_v_valid.value = test_case.bias_v.pre_compute() - dut.data_in_valid.value = test_case.data_in.pre_compute() - await Timer(1, units="ns") - dut.data_out_ready.value = test_case.outputs.pre_compute( - dut.data_out_valid.value - ) - await Timer(1, units="ns") - debug_state(dut, "in compute") - dut.weight_q_valid.value, dut.weight_q.value = test_case.weight_q.compute( - dut.weight_q_ready.value - ) - dut.weight_k_valid.value, dut.weight_k.value = test_case.weight_k.compute( - dut.weight_k_ready.value - ) - dut.weight_v_valid.value, dut.weight_v.value = test_case.weight_v.compute( - dut.weight_v_ready.value - ) - - dut.bias_q_valid.value, dut.bias_q.value = test_case.bias_q.compute( - dut.bias_q_ready.value - ) - dut.bias_k_valid.value, dut.bias_k.value = test_case.bias_k.compute( - dut.bias_k_ready.value - ) - dut.bias_v_valid.value, dut.bias_v.value = test_case.bias_v.compute( - dut.bias_v_ready.value - ) - - dut.data_in_valid.value, dut.data_in.value = test_case.data_in.compute( - dut.data_in_ready.value - ) - await Timer(1, units="ns") - dut.data_out_ready.value = test_case.outputs.compute( - dut.data_out_valid.value, dut.data_out.value - ) - await Timer(1, units="ns") - # wave_check(dut) - if ( - test_case.weight_q.is_empty() - and test_case.weight_k.is_empty() - and test_case.weight_v.is_empty() - and test_case.bias_q.is_empty() - and test_case.bias_k.is_empty() - and test_case.bias_v.is_empty() - and test_case.data_in.is_empty() - and test_case.outputs.is_full() - ): - done = True - break - assert ( - done - ), "Deadlock detected or the simulation reaches the maximum cycle limit (fixed it by adjusting the loop trip count)" - - check_results(test_case.outputs.data, test_case.ref) - - -def wave_check(dut): - logger.debug( - "wave of in_out:\n\ - {},{},data_in = {} \n\ - {},{},data_out = {}\n\ - ".format( - dut.data_in_valid.value, - dut.data_in_valid.value, - [int(i) for i in dut.data_in.value], - dut.data_out_valid.value, - dut.data_out_valid.value, - [int(i) for i in dut.data_out.value], - ) - ) - - -import pytest - - -@pytest.mark.skip(reason="Needs to be fixed.") -def test_fixed_self_att(): - tb = VerificationCase() - mase_runner(module_param_list=[tb.get_dut_parameters()]) - - -if __name__ == "__main__": - test_fixed_self_att() diff --git a/src/mase_components/attention/test/fixed_self_attention_head_tb.py b/src/mase_components/attention/test/fixed_self_attention_head_tb.py new file mode 100644 index 000000000..6c1223d67 --- /dev/null +++ b/src/mase_components/attention/test/fixed_self_attention_head_tb.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 + +import os + +import torch +import logging +from functools import partial + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import Timer + +from transformers.models.bert.configuration_bert import BertConfig + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.runner import mase_runner + +# from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner +from chop.nn.quantized import BertSelfAttentionHeadInteger +from chop.passes.graph.transforms.quantize.quantizers import integer_quantizer + +from mase_components.activations.test import generate_memory + + +class FixedSelfAttentionHeadTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.DEBUG) + + # * QKV drivers + self.query_driver = StreamDriver( + dut.clk, dut.query, dut.query_valid, dut.query_ready + ) + self.key_driver = StreamDriver(dut.clk, dut.key, dut.key_valid, dut.key_ready) + self.value_driver = StreamDriver( + dut.clk, dut.value, dut.value_valid, dut.value_ready + ) + + self.out_monitor = StreamMonitor( + dut.clk, + dut.out, + dut.out_valid, + dut.out_ready, + check=False, + ) + + # Model + self.config = BertConfig() + self.head_size = self.config.hidden_size // self.config.num_attention_heads + + self.q_config = { + "width": self.get_parameter("IN_DATA_PRECISION_0"), + "frac_width": self.get_parameter("IN_DATA_PRECISION_1"), + } + self.model = BertSelfAttentionHeadInteger( + config=self.config, + q_config=self.q_config, + ) + + # Set verbosity of driver and monitor loggers to debug + # self.query_driver.log.setLevel(logging.DEBUG) + # self.key_driver.log.setLevel(logging.DEBUG) + # self.value_driver.log.setLevel(logging.DEBUG) + # self.out_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self, seq_len=20): + return { + "query_layer": torch.randn((seq_len, self.head_size)), + "key_layer": torch.randn((seq_len, self.head_size)), + "value_layer": torch.randn((seq_len, self.head_size)), + } + + def preprocess_tensor(self, tensor, config, parallelism): + if len(tensor.shape) == 1: + tensor = tensor.unsqueeze(0) + + # Quantize + quantizer = partial(integer_quantizer, **config) + q_tensor = quantizer(tensor) + self.log.debug(f"Quantized tensor: {q_tensor}") + + # Convert to integer format + q_tensor = (q_tensor * 2 ** config["frac_width"]).int() + self.log.debug(f"Tensor in integer format: {q_tensor}") + + # Split into chunks according to parallelism in each dimension + # parallelism[0]: along rows, parallelism[1]: along columns + dim_0_split = q_tensor.split(parallelism[0], dim=0) + dim_1_split = [x.split(parallelism[1], dim=1) for x in dim_0_split] + blocks = [] + # Flatten the list of blocks + for i in range(len(dim_1_split)): + for j in range(len(dim_1_split[i])): + blocks.append(dim_1_split[i][j].flatten().tolist()) + return blocks + + async def run_test(self): + await self.reset() + self.log.info(f"Reset finished") + self.out_monitor.ready.value = 1 + + inputs = self.generate_inputs( + seq_len=self.get_parameter("IN_DATA_TENSOR_SIZE_DIM_1") + ) + exp_out = self.model(**inputs) + + parallelism = [ + self.get_parameter("IN_DATA_PARALLELISM_DIM_1"), + self.get_parameter("IN_DATA_PARALLELISM_DIM_0"), + ] + + # * Load the query driver + self.log.info(f"Processing query inputs: {inputs['query_layer']}") + query_inputs = self.preprocess_tensor( + tensor=inputs["query_layer"], + config=self.q_config, + parallelism=parallelism, + ) + self.query_driver.load_driver(query_inputs) + + # * Load the key driver + self.log.info(f"Processing key inputs: {inputs['key_layer']}") + key_inputs = self.preprocess_tensor( + tensor=inputs["key_layer"], + config=self.q_config, + parallelism=parallelism, + ) + self.key_driver.load_driver(key_inputs) + + # * Load the value driver + self.log.info(f"Processing value inputs: {inputs['value_layer']}") + value_inputs = self.preprocess_tensor( + tensor=inputs["value_layer"], + config=self.q_config, + parallelism=parallelism, + ) + self.value_driver.load_driver(value_inputs) + + # * Load the output monitor + self.log.info(f"Processing outputs: {exp_out}") + outs = self.preprocess_tensor( + tensor=exp_out, + config={ + "width": self.get_parameter("OUT_DATA_PRECISION_0"), + "frac_width": self.get_parameter("OUT_DATA_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("OUT_DATA_PARALLELISM_DIM_1"), + self.get_parameter("OUT_DATA_PARALLELISM_DIM_0"), + ], + ) + self.out_monitor.load_monitor(outs) + + await Timer(1, units="ms") + if not self.out_monitor.exp_queue.empty(): + raise RuntimeError( + "Reached the end of the test, but the output monitor is not empty." + ) + + +@cocotb.test() +async def cocotb_test(dut): + tb = FixedSelfAttentionHeadTB(dut) + await tb.run_test() + + +def get_fixed_self_attention_head_config(kwargs={}): + config = { + "IN_DATA_TENSOR_SIZE_DIM_0": 64, + "IN_DATA_TENSOR_SIZE_DIM_1": 32, + "IN_DATA_PARALLELISM_DIM_0": 2, + "IN_DATA_PARALLELISM_DIM_1": 2, + "IN_DATA_PRECISION_0": 16, + "IN_DATA_PRECISION_1": 3, + "OUT_DATA_TENSOR_SIZE_DIM_0": 64, + "OUT_DATA_TENSOR_SIZE_DIM_1": 32, + "OUT_DATA_PARALLELISM_DIM_0": 2, + "OUT_DATA_PARALLELISM_DIM_1": 2, + "OUT_DATA_PRECISION_0": 16, + "OUT_DATA_PRECISION_1": 3, + } + config.update(kwargs) + return config + + +def test_fixed_self_attention_head_smoke(): + """ + Some quick tests to check if the module is working. + """ + + # * Generate exponential LUT for softmax + generate_memory.generate_sv_lut( + "exp", + 16, + 3, + 16, + 3, + ) + mase_runner( + trace=True, + module_param_list=[ + get_fixed_self_attention_head_config(), + ], + skip_build=False, + ) + + +if __name__ == "__main__": + test_fixed_self_attention_head_smoke() diff --git a/src/mase_components/attention/test/fixed_self_attention_tb.py b/src/mase_components/attention/test/fixed_self_attention_tb.py new file mode 100644 index 000000000..4df1c87b3 --- /dev/null +++ b/src/mase_components/attention/test/fixed_self_attention_tb.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 + +import os + +import torch +import logging +from functools import partial + +import cocotb +from cocotb.log import SimLog +from cocotb.triggers import Timer + +from transformers.models.bert.configuration_bert import BertConfig + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.runner import mase_runner + +# from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner +from chop.nn.quantized import BertSelfAttentionInteger, fixed_softermax +from chop.passes.graph.transforms.quantize.quantized_funcs import matmul_integer + +from mase_cocotb.utils import fixed_preprocess_tensor + + +class FixedSelfAttentionTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + + if not hasattr(self, "log"): + self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.DEBUG) + + self.data_in_0_driver = StreamDriver( + dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready + ) + + # * Weight drivers + self.weight_query_driver = StreamDriver( + dut.clk, dut.weight_query, dut.weight_query_valid, dut.weight_query_ready + ) + self.weight_key_driver = StreamDriver( + dut.clk, dut.weight_key, dut.weight_key_valid, dut.weight_key_ready + ) + self.weight_value_driver = StreamDriver( + dut.clk, dut.weight_value, dut.weight_value_valid, dut.weight_value_ready + ) + + if self.get_parameter("HAS_BIAS") == 1: + self.bias_query_driver = StreamDriver( + dut.clk, dut.biasquery_, dut.bias_query_valid, dut.bias_query_ready + ) + self.bias_key_driver = StreamDriver( + dut.clk, dut.bias_key, dut.bias_key_valid, dut.bias_key_ready + ) + self.bias_value_driver = StreamDriver( + dut.clk, dut.bias_value, dut.bias_value_valid, dut.bias_value_ready + ) + self.bias_query_driver.log.setLevel(logging.DEBUG) + self.bias_key_driver.log.setLevel(logging.DEBUG) + self.bias_value_driver.log.setLevel(logging.DEBUG) + + self.data_out_0_monitor = StreamMonitor( + dut.clk, + dut.data_out_0, + dut.data_out_0_valid, + dut.data_out_0_ready, + check=False, + ) + + # Model + self.config = BertConfig() + self.config.hidden_size = self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0") + self.config.num_attention_heads = self.get_parameter("NUM_HEADS") + self.q_config = { + "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "data_in_frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + "weight_width": self.get_parameter("WEIGHT_PRECISION_0"), + "weight_frac_width": self.get_parameter("WEIGHT_PRECISION_1"), + "bias_width": self.get_parameter("BIAS_PRECISION_0"), + "bias_frac_width": self.get_parameter("BIAS_PRECISION_1"), + } + self.out_q_config = { + "data_out_width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "data_out_frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + } + self.model = BertSelfAttentionInteger( + config=self.config, + q_config=self.q_config, + out_q_config=self.out_q_config, + bias=self.get_parameter("HAS_BIAS"), + floor=True, + ) + # * Replace softmax with fixed softermax + if self.get_parameter("ACTIVATION") == 0: + self.model.softmax = partial( + fixed_softermax, + q_config={ + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + }, + out_q_config={ + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + }, + ) + + # Set verbosity of driver and monitor loggers to debug + self.data_in_0_driver.log.setLevel(logging.DEBUG) + self.weight_query_driver.log.setLevel(logging.DEBUG) + self.weight_key_driver.log.setLevel(logging.DEBUG) + self.weight_value_driver.log.setLevel(logging.DEBUG) + self.data_out_0_monitor.log.setLevel(logging.DEBUG) + + def generate_inputs(self, batch_size=1): + return torch.randn( + ( + batch_size, + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_1"), + self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"), + ) + ) + + async def run_test(self): + await self.reset() + self.log.info(f"Reset finished") + self.data_out_0_monitor.ready.value = 1 + + inputs = self.generate_inputs() + exp_out = self.model(inputs)[0] + + # * Load the inputs driver + self.log.info(f"Processing inputs: {inputs}") + inputs = fixed_preprocess_tensor( + tensor=inputs, + q_config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], + ) + self.data_in_0_driver.load_driver(inputs) + + # * Load the weights driver + for projection in ["query", "key", "value"]: + + if self.get_parameter("WEIGHTS_PRE_TRANSPOSED") == 1: + weights = getattr(self.model, projection).weight.transpose(0, 1) + else: + weights = getattr(self.model, projection).weight + + self.log.info(f"Processing {projection} weights: {weights}") + weights = fixed_preprocess_tensor( + tensor=weights, + q_config={ + "width": self.get_parameter("WEIGHT_PRECISION_0"), + "frac_width": self.get_parameter("WEIGHT_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("WEIGHT_PARALLELISM_DIM_1"), + self.get_parameter("WEIGHT_PARALLELISM_DIM_0"), + ], + ) + getattr(self, f"weight_{projection}_driver").load_driver(weights) + + # * Load the bias driver + if self.get_parameter("HAS_BIAS") == 1: + bias = getattr(self.model, projection).bias + self.log.info(f"Processing {projection} bias: {bias}") + bias = fixed_preprocess_tensor( + tensor=bias, + q_config={ + "width": self.get_parameter("BIAS_PRECISION_0"), + "frac_width": self.get_parameter("BIAS_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("BIAS_PARALLELISM_DIM_1"), + self.get_parameter("BIAS_PARALLELISM_DIM_0"), + ], + ) + getattr(self, f"bias_{projection}_driver").load_driver(bias) + + # * Load the output monitor + self.log.info(f"Processing outputs: {exp_out}") + outs = fixed_preprocess_tensor( + tensor=exp_out, + q_config={ + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), + ], + ) + self.data_out_0_monitor.load_monitor(outs) + + await Timer(1, units="ms") + assert self.data_out_0_monitor.exp_queue.empty() + + +@cocotb.test() +async def cocotb_test(dut): + tb = FixedSelfAttentionTB(dut) + await tb.run_test() + + +def get_config(kwargs={}): + config = { + "NUM_HEADS": 1, + "ACTIVATION": 0, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 4, + "DATA_IN_0_TENSOR_SIZE_DIM_1": 4, + "DATA_IN_0_PARALLELISM_DIM_0": 2, + "DATA_IN_0_PARALLELISM_DIM_1": 2, + "DATA_IN_0_PRECISION_0": 16, + "DATA_IN_0_PRECISION_1": 8, + "WEIGHTS_PRE_TRANSPOSED": 1, + "WEIGHT_TENSOR_SIZE_DIM_0": 4, + "WEIGHT_TENSOR_SIZE_DIM_1": 4, + "WEIGHT_PARALLELISM_DIM_0": 2, + "WEIGHT_PARALLELISM_DIM_1": 2, + "WEIGHT_PRECISION_0": 16, + "WEIGHT_PRECISION_1": 8, + "HAS_BIAS": 0, + "BIAS_TENSOR_SIZE_DIM_0": 4, + "BIAS_TENSOR_SIZE_DIM_1": 4, + "BIAS_PARALLELISM_DIM_0": 2, + "BIAS_PARALLELISM_DIM_1": 2, + "BIAS_PRECISION_0": 16, + "BIAS_PRECISION_1": 8, + "DATA_OUT_0_TENSOR_SIZE_DIM_0": 4, + "DATA_OUT_0_TENSOR_SIZE_DIM_1": 4, + "DATA_OUT_0_PARALLELISM_DIM_0": 2, + "DATA_OUT_0_PARALLELISM_DIM_1": 2, + "DATA_OUT_0_PRECISION_0": 16, + "DATA_OUT_0_PRECISION_1": 8, + } + config.update(kwargs) + return config + + +def test_fixed_linear_smoke(): + """ + Some quick tests to check if the module is working. + """ + mase_runner(trace=True, module_param_list=[get_config()], skip_build=True) + + +if __name__ == "__main__": + test_fixed_linear_smoke() diff --git a/src/mase_components/attention/test/test_lint_attention.py b/src/mase_components/attention/test/test_lint_attention.py index bcb51ba47..a9f039e8b 100644 --- a/src/mase_components/attention/test/test_lint_attention.py +++ b/src/mase_components/attention/test/test_lint_attention.py @@ -3,7 +3,6 @@ import pytest -@pytest.mark.skip(reason="Needs to be fixed.") def test_lint_attention(): run_lint("attention") diff --git a/src/mase_components/attention/test/test_synth_attention.py b/src/mase_components/attention/test/test_synth_attention.py new file mode 100644 index 000000000..bc70065f2 --- /dev/null +++ b/src/mase_components/attention/test/test_synth_attention.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_attention(): + run_synth("attention") + + +if __name__ == "__main__": + test_synth_attention() diff --git a/src/mase_components/axi/test/test_synth_axi.py b/src/mase_components/axi/test/test_synth_axi.py new file mode 100644 index 000000000..6eb5763ea --- /dev/null +++ b/src/mase_components/axi/test/test_synth_axi.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_axi(): + run_synth("axi") + + +if __name__ == "__main__": + test_synth_axi() diff --git a/src/mase_components/binary_arith/rtl/binary_activation_binary_adder_tree.sv b/src/mase_components/binary_arith/rtl/binary_activation_binary_adder_tree.sv index 1a08e4065..08af47697 100644 --- a/src/mase_components/binary_arith/rtl/binary_activation_binary_adder_tree.sv +++ b/src/mase_components/binary_arith/rtl/binary_activation_binary_adder_tree.sv @@ -79,16 +79,16 @@ module binary_activation_binary_adder_tree #( assign cast_sum[(IN_WIDTH+i+1)*j+(IN_WIDTH+i):(IN_WIDTH+i+1)*j] = sum[j]; register_slice #( - .IN_WIDTH($bits(sum)) + .DATA_WIDTH($bits(sum)) ) register_slice ( .clk (clk), .rst (rst), .data_in_valid (vars[i].valid), .data_in_ready (vars[i].ready), - .data_in_data (cast_sum), + .data_in (cast_sum), .data_out_valid(vars[i+1].valid), .data_out_ready(vars[i+1].ready), - .data_out_data (cast_data) + .data_out (cast_data) ); // Casting array for vars[i+1].data diff --git a/src/mase_components/binary_arith/rtl/binary_activation_binary_vector_mult.sv b/src/mase_components/binary_arith/rtl/binary_activation_binary_vector_mult.sv index f3ccad12d..1c6cae30e 100644 --- a/src/mase_components/binary_arith/rtl/binary_activation_binary_vector_mult.sv +++ b/src/mase_components/binary_arith/rtl/binary_activation_binary_vector_mult.sv @@ -59,16 +59,16 @@ module binary_activation_binary_vector_mult #( assign product_data_in[PRODUCT_WIDTH*i+PRODUCT_WIDTH-1:PRODUCT_WIDTH*i] = product_vector[i]; register_slice #( - .IN_WIDTH($bits(product_vector)) + .DATA_WIDTH($bits(product_vector)) ) register_slice ( .clk (clk), .rst (rst), .data_in_valid (product_data_in_valid), .data_in_ready (product_data_in_ready), - .data_in_data (product_data_in), + .data_in (product_data_in), .data_out_valid(product_data_out_valid), .data_out_ready(product_data_out_ready), - .data_out_data (product_data_out) + .data_out (product_data_out) ); // Casting array for product vector diff --git a/src/mase_components/binary_arith/rtl/fixed_activation_binary_vector_mult.sv b/src/mase_components/binary_arith/rtl/fixed_activation_binary_vector_mult.sv index 1938dec90..58df4d5e4 100644 --- a/src/mase_components/binary_arith/rtl/fixed_activation_binary_vector_mult.sv +++ b/src/mase_components/binary_arith/rtl/fixed_activation_binary_vector_mult.sv @@ -61,16 +61,16 @@ module fixed_activation_binary_vector_mult #( assign product_data_in[PRODUCT_WIDTH*i+PRODUCT_WIDTH-1:PRODUCT_WIDTH*i] = product_vector[i]; register_slice #( - .IN_WIDTH($bits(product_vector)) + .DATA_WIDTH($bits(product_vector)) ) register_slice ( .clk (clk), .rst (rst), .data_in_valid (product_data_in_valid), .data_in_ready (product_data_in_ready), - .data_in_data (product_data_in), + .data_in (product_data_in), .data_out_valid(product_data_out_valid), .data_out_ready(product_data_out_ready), - .data_out_data (product_data_out) + .data_out (product_data_out) ); // Casting array for product vector diff --git a/src/mase_components/binary_arith/test/test_synth_binary_arith.py b/src/mase_components/binary_arith/test/test_synth_binary_arith.py new file mode 100644 index 000000000..c5cce0022 --- /dev/null +++ b/src/mase_components/binary_arith/test/test_synth_binary_arith.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_binary_arith(): + run_synth("binary_arith") + + +if __name__ == "__main__": + test_synth_binary_arith() diff --git a/src/mase_components/buffers/test/test_synth_buffers.py b/src/mase_components/buffers/test/test_synth_buffers.py new file mode 100644 index 000000000..cb815e214 --- /dev/null +++ b/src/mase_components/buffers/test/test_synth_buffers.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_buffers(): + run_synth("buffers") + + +if __name__ == "__main__": + test_synth_buffers() diff --git a/src/mase_components/cast/rtl/fixed_round.sv b/src/mase_components/cast/rtl/fixed_round.sv index 779ab0867..27be929b5 100644 --- a/src/mase_components/cast/rtl/fixed_round.sv +++ b/src/mase_components/cast/rtl/fixed_round.sv @@ -28,7 +28,8 @@ module fixed_round #( always_comb begin lsb_below[2] = (IN_FRAC_WIDTH >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH] : 0; lsb_below[1] = (IN_FRAC_WIDTH-1 >= OUT_FRAC_WIDTH) ? input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-1] : 0; - lsb_below[0] = (IN_FRAC_WIDTH-2 >= OUT_FRAC_WIDTH) ? |(input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-2:0]): 0; + // lsb_below[0] = (IN_FRAC_WIDTH-2 >= OUT_FRAC_WIDTH) ? |(input_data[IN_FRAC_WIDTH-OUT_FRAC_WIDTH-2:0]): 0; + lsb_below[0] = '0; // to do: fix end always_comb begin if ((IN_FRAC_WIDTH - OUT_FRAC_WIDTH) >= 0) diff --git a/src/mase_components/cast/rtl/fixed_unsigned_cast.sv b/src/mase_components/cast/rtl/fixed_unsigned_cast.sv new file mode 100644 index 000000000..774fd7976 --- /dev/null +++ b/src/mase_components/cast/rtl/fixed_unsigned_cast.sv @@ -0,0 +1,67 @@ +/* +Module : fixed_unsigned_cast +Description : Cast a fixed point unsigned number into another. + + Types of rounding when OUT_FRAC_WIDTH < IN_FRAC_WIDTH: + - Floor +*/ + +`timescale 1ns / 1ps + +module fixed_unsigned_cast #( + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 4, + + // Rounding types for when OUT_FRAC_WIDTH < IN_FRAC_WIDTH + // One of these needs to be set to 1 + parameter ROUND_FLOOR = 0, + parameter ROUND_TRUNCATE = 0, + parameter ROUND_NEAREST_INT_HALF_EVEN = 0 +) ( + input logic signed [ IN_WIDTH-1:0] in_data, + output logic signed [OUT_WIDTH-1:0] out_data +); + + initial begin + assert (IN_WIDTH > 0); + assert (OUT_WIDTH > 0); + assert (IN_FRAC_WIDTH <= IN_WIDTH); + assert (IN_FRAC_WIDTH >= 0); + assert (OUT_FRAC_WIDTH <= OUT_WIDTH); + assert (OUT_FRAC_WIDTH >= 0); + assert (ROUND_FLOOR + ROUND_TRUNCATE + ROUND_NEAREST_INT_HALF_EVEN == 1); + + // Currently only supports floor rounding + assert (ROUND_FLOOR == 1); + end + + localparam MAX_WIDTH = IN_WIDTH > OUT_WIDTH ? IN_WIDTH : OUT_WIDTH; + + localparam ROUND_OUT_WIDTH = (OUT_FRAC_WIDTH > IN_FRAC_WIDTH) ? + MAX_WIDTH + (OUT_FRAC_WIDTH - IN_FRAC_WIDTH) : + MAX_WIDTH; + + logic [ROUND_OUT_WIDTH-1:0] round_out; + + floor_round #( + .IN_WIDTH(IN_WIDTH), + .OUT_WIDTH(ROUND_OUT_WIDTH), + .IN_FRAC_WIDTH(IN_FRAC_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) floor_round_inst ( + .in_data (in_data), + .out_data(round_out) + ); + + // Unsigned clamp + always_comb begin + if (&round_out[ROUND_OUT_WIDTH-1:ROUND_OUT_WIDTH-OUT_WIDTH]) begin + out_data = '1; + end else begin + out_data = round_out[OUT_WIDTH-1:0]; + end + end + +endmodule diff --git a/src/mase_components/cast/rtl/floor_round.sv b/src/mase_components/cast/rtl/floor_round.sv index 37e575b62..e7912ccd4 100644 --- a/src/mase_components/cast/rtl/floor_round.sv +++ b/src/mase_components/cast/rtl/floor_round.sv @@ -22,11 +22,11 @@ module floor_round #( end generate - if (OUT_FRAC_WIDTH > IN_FRAC_WIDTH) begin + if (OUT_FRAC_WIDTH > IN_FRAC_WIDTH) begin : gen_out_frac_larger assign out_data = in_data <<< (OUT_FRAC_WIDTH - IN_FRAC_WIDTH); - end else if (OUT_FRAC_WIDTH == IN_FRAC_WIDTH) begin + end else if (OUT_FRAC_WIDTH == IN_FRAC_WIDTH) begin : gen_out_frac_same assign out_data = in_data; - end else begin // OUT_FRAC_WIDTH < IN_FRAC_WIDTH + end else begin : gen_out_frac_smaller // OUT_FRAC_WIDTH < IN_FRAC_WIDTH assign out_data = in_data >>> (IN_FRAC_WIDTH - OUT_FRAC_WIDTH); end endgenerate diff --git a/src/mase_components/cast/test/fixed_unsigned_cast_tb.py b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py new file mode 100644 index 000000000..cfeca8318 --- /dev/null +++ b/src/mase_components/cast/test/fixed_unsigned_cast_tb.py @@ -0,0 +1,188 @@ +import logging +from math import trunc, floor + +import torch +import cocotb +from cocotb.triggers import * + +from mase_cocotb.testbench import Testbench +from mase_cocotb.runner import mase_runner + + +# def _fixed_signed_cast_model( +# float_input, out_width, out_frac_width, symmetric, rounding_mode +# ): +# scaled_float = float_input * (2**out_frac_width) +# if rounding_mode == "floor": +# out_int = my_floor(scaled_float) +# elif rounding_mode == "round_nearest_half_even": +# out_int = my_round(scaled_float) +# else: +# raise Exception("Rounding mode not recognised.") +# out_int = my_clamp( +# out_int, +# -(2 ** (out_width - 1)) + 1 if symmetric else -(2 ** (out_width - 1)), +# (2 ** (out_width - 1)) - 1, +# ) +# out_float = out_int / (2**out_frac_width) +# # out_uint is a non-differentiable path +# out_uint = signed_to_unsigned(out_int.int(), out_width) +# return out_float, out_uint + + +logger = logging.getLogger("testbench") +logger.setLevel(logging.INFO) + + +class FixedUnsignedCastTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut) + + self.assign_self_params( + [ + "IN_WIDTH", + "IN_FRAC_WIDTH", + "OUT_WIDTH", + "OUT_FRAC_WIDTH", + "ROUND_FLOOR", + "ROUND_TRUNCATE", + "ROUND_NEAREST_INT_HALF_EVEN", + ] + ) + + def generate_inputs(self): + return torch.arange(2**self.IN_WIDTH) + + def rounding_mode(self): + if self.ROUND_FLOOR: + return "floor" + elif self.ROUND_TRUNCATE: + return "trunc" + elif self.ROUND_NEAREST_INT_HALF_EVEN: + return "round_nearest_half_even" + else: + raise Exception("Rounding mode not recognised.") + + def model(self, inputs): + float_input = inputs / (2**self.IN_FRAC_WIDTH) + scaled_float = float_input * (2**self.OUT_FRAC_WIDTH) + rounded = torch.floor(scaled_float) + model_out = torch.clamp(rounded, 0, (2**self.OUT_WIDTH - 1)) + return model_out + + async def run_test(self): + inputs = self.generate_inputs() + exp_output = self.model(inputs) + + logger.info(inputs) + logger.info(exp_output) + + for i in range(inputs.shape[0]): + x = inputs[i].item() + exp_y = exp_output[i].item() + + self.dut.in_data.value = x + await Timer(10, "ns") + got_y = int(self.dut.out_data.value) + + assert ( + got_y == exp_output[i] + ), f"Output did not match! Got {got_y}, Exp {exp_y}" + + +@cocotb.test() +async def sweep(dut): + tb = FixedUnsignedCastTB(dut) + await tb.run_test() + + +if __name__ == "__main__": + DEFAULT_CONFIG = { + "IN_WIDTH": 8, + "IN_FRAC_WIDTH": 2, + "OUT_WIDTH": 8, + "OUT_FRAC_WIDTH": 2, + "ROUND_FLOOR": 1, + "ROUND_TRUNCATE": 0, + "ROUND_NEAREST_INT_HALF_EVEN": 0, + } + + def gen_width_change_configs(cfg_list): + l = list() + for cfg in cfg_list: + l.extend( + [ + {**cfg, "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] + 1}, + {**cfg, "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] - 1}, + {**cfg, "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] + 1}, + {**cfg, "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] - 1}, + { + **cfg, + "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] + 1, + "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] - 2, + }, + { + **cfg, + "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] + 1, + "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] + 2, + }, + { + **cfg, + "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] - 1, + "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] - 2, + }, + { + **cfg, + "OUT_WIDTH": DEFAULT_CONFIG["OUT_WIDTH"] - 1, + "OUT_FRAC_WIDTH": DEFAULT_CONFIG["OUT_FRAC_WIDTH"] + 2, + }, + ] + ) + return l + + def gen_rounding(cfg_list): + l = list() + for cfg in cfg_list: + l.extend( + [ + { + **cfg, + "ROUND_FLOOR": 1, + "ROUND_TRUNCATE": 0, + "ROUND_NEAREST_INT_HALF_EVEN": 0, + }, + { + **cfg, + "ROUND_FLOOR": 0, + "ROUND_TRUNCATE": 1, + "ROUND_NEAREST_INT_HALF_EVEN": 0, + }, + { + **cfg, + "ROUND_FLOOR": 0, + "ROUND_TRUNCATE": 0, + "ROUND_NEAREST_INT_HALF_EVEN": 1, + }, + ] + ) + return l + + cfg_list = [DEFAULT_CONFIG] + cfg_list = gen_width_change_configs(cfg_list) + # Other rounding modes not supported yet + # cfg_list = gen_rounding(cfg_list) + + mase_runner( + module_param_list=[ + # DEFAULT_CONFIG, + *cfg_list, + # { + # **DEFAULT_CONFIG, + # "IN_WIDTH": 10, + # "IN_FRAC_WIDTH": 2, + # "OUT_WIDTH": 8, + # "OUT_FRAC_WIDTH": 1, + # }, + ], + trace=True, + ) diff --git a/src/mase_components/cast/test/test_synth_cast.py b/src/mase_components/cast/test/test_synth_cast.py new file mode 100644 index 000000000..5ce963c92 --- /dev/null +++ b/src/mase_components/cast/test/test_synth_cast.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_cast(): + run_synth("cast") + + +if __name__ == "__main__": + test_synth_cast() diff --git a/src/mase_components/common/rtl/buffer.sv b/src/mase_components/common/rtl/buffer.sv new file mode 100644 index 000000000..85dae0c79 --- /dev/null +++ b/src/mase_components/common/rtl/buffer.sv @@ -0,0 +1,83 @@ +`timescale 1ns / 1ps + +module buffer #( + /* verilator lint_off UNUSEDPARAM */ + parameter SELECT = 0, + + // Input 1 + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, // must equal WEIGHT_PARALLELISM_DIM_1 + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + parameter IN_0_DEPTH_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + parameter IN_0_DEPTH_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1, + + // Input 2 + parameter DATA_IN_1_PRECISION_0 = 16, + parameter DATA_IN_1_PRECISION_1 = 3, + parameter DATA_IN_1_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_1_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_1_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_0 = 4, // must equal WEIGHT_PARALLELISM_DIM_1 + parameter DATA_IN_1_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_2 = 1, + parameter IN_1_DEPTH_DIM_0 = DATA_IN_1_TENSOR_SIZE_DIM_0 / DATA_IN_1_PARALLELISM_DIM_0, + parameter IN_1_DEPTH_DIM_1 = DATA_IN_1_TENSOR_SIZE_DIM_1 / DATA_IN_1_PARALLELISM_DIM_1, + + // Output 1 + parameter DATA_OUT_0_PRECISION_0 = 16, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2, + + // Chosen input + localparam CHOSEN_DATA_IN_PRECISION_0 = (SELECT == 0) ? DATA_IN_0_PRECISION_0 : DATA_IN_1_PRECISION_0, + localparam CHOSEN_DATA_IN_PRECISION_1 = (SELECT == 0) ? DATA_IN_0_PRECISION_1 : DATA_IN_1_PRECISION_1, + localparam CHOSEN_DATA_IN_TENSOR_SIZE_DIM_0 = (SELECT == 0) ? DATA_IN_0_TENSOR_SIZE_DIM_0 : DATA_IN_1_TENSOR_SIZE_DIM_0, + localparam CHOSEN_DATA_IN_TENSOR_SIZE_DIM_1 = (SELECT == 0) ? DATA_IN_0_TENSOR_SIZE_DIM_1 : DATA_IN_1_TENSOR_SIZE_DIM_1, + localparam CHOSEN_DATA_IN_TENSOR_SIZE_DIM_2 = (SELECT == 0) ? DATA_IN_0_TENSOR_SIZE_DIM_2 : DATA_IN_1_TENSOR_SIZE_DIM_2, + localparam CHOSEN_DATA_IN_PARALLELISM_DIM_0 = (SELECT == 0) ? DATA_IN_0_PARALLELISM_DIM_0 : DATA_IN_1_PARALLELISM_DIM_0, + localparam CHOSEN_DATA_IN_PARALLELISM_DIM_1 = (SELECT == 0) ? DATA_IN_0_PARALLELISM_DIM_1 : DATA_IN_1_PARALLELISM_DIM_1, + localparam CHOSEN_DATA_IN_PARALLELISM_DIM_2 = (SELECT == 0) ? DATA_IN_0_PARALLELISM_DIM_2 : DATA_IN_1_PARALLELISM_DIM_2, + parameter CHOSEN_DEPTH_DIM_0 = (SELECT == 0) ? IN_0_DEPTH_DIM_0 : IN_1_DEPTH_DIM_0, + parameter CHOSEN_DEPTH_DIM_1 = (SELECT == 0) ? IN_0_DEPTH_DIM_1 : IN_1_DEPTH_DIM_1 +) ( + input logic clk, + input logic rst, + + // Input 0 + input logic [CHOSEN_DATA_IN_PRECISION_0-1:0] data_in_0 [CHOSEN_DATA_IN_PARALLELISM_DIM_0*CHOSEN_DATA_IN_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + +unpacked_fifo #( + .DEPTH (CHOSEN_DEPTH_DIM_0 * CHOSEN_DEPTH_DIM_1), + .DATA_WIDTH (CHOSEN_DATA_IN_PRECISION_0), + .IN_NUM (CHOSEN_DATA_IN_PARALLELISM_DIM_0 * CHOSEN_DATA_IN_PARALLELISM_DIM_1) +) buffer_i ( + .clk, + .rst, + + .data_in (data_in_0), + .data_in_valid (data_in_0_valid), + .data_in_ready (data_in_0_ready), + + .data_out (data_out_0), + .data_out_valid (data_out_0_valid), + .data_out_ready (data_out_0_ready) +); + +endmodule \ No newline at end of file diff --git a/src/mase_components/common/rtl/comparator_accumulator.sv b/src/mase_components/common/rtl/comparator_accumulator.sv new file mode 100644 index 000000000..b0e5decd4 --- /dev/null +++ b/src/mase_components/common/rtl/comparator_accumulator.sv @@ -0,0 +1,124 @@ +/* +Module : comparator_accumulator +Description : This module implements an comparator accumulation. + + Can do signed/unsigned max/min comparisons. +*/ + +`timescale 1ns / 1ps + +module comparator_accumulator #( + parameter DATA_WIDTH = 8, + parameter DEPTH = 8, + parameter MAX1_MIN0 = 1, // MAX = 1, MIN = 0 + parameter SIGNED = 0 +) ( + input logic clk, + input logic rst, + + input logic [DATA_WIDTH-1:0] in_data, + input logic in_valid, + output logic in_ready, + + output logic [DATA_WIDTH-1:0] out_data, + output logic out_valid, + input logic out_ready +); + + localparam COUNTER_WIDTH = $clog2(DEPTH) + 1; + localparam RESET_VAL = SIGNED ? (1 << (DATA_WIDTH - 1)) : 0; + + struct { + logic [COUNTER_WIDTH-1:0] count; + logic [DATA_WIDTH-1:0] data; + } + self, next_self; + + logic [DATA_WIDTH-1:0] left, right, result; + + logic [DATA_WIDTH-1:0] output_data; + logic output_valid, output_ready; + + // Comparator instance + generate + if (MAX1_MIN0) begin + if (SIGNED) begin + assign result = $signed(left) > $signed(right) ? left : right; + end else begin + assign result = left > right ? left : right; + end + end else begin + if (SIGNED) begin + assign result = $signed(left) < $signed(right) ? left : right; + end else begin + assign result = left < right ? left : right; + end + end + endgenerate + + // Output Register Instance + skid_buffer #( + .DATA_WIDTH(DATA_WIDTH) + ) out_reg ( + .clk(clk), + .rst(rst), + .data_in(output_data), + .data_in_valid(output_valid), + .data_in_ready(output_ready), + .data_out(out_data), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + + + always_comb begin + next_self = self; + + in_ready = (self.count != DEPTH) && !((self.count == DEPTH - 1) && !output_ready); + output_data = self.data; + + left = in_data; + right = self.data; + + if (self.count == DEPTH) begin + output_valid = 1; + if (output_ready) begin + output_data = result; + next_self.data = RESET_VAL; + next_self.count = 0; + end + end else if (self.count == DEPTH - 1) begin + if (in_valid && in_ready) begin + output_valid = 1; + if (output_ready) begin + // Redirect accumulation into outreg + output_data = result; + // Reset + next_self.data = RESET_VAL; + next_self.count = 0; + end else begin + next_self.count = self.count + 1; + end + end else begin + output_valid = 0; + end + end else begin + output_valid = 0; + if (in_valid && in_ready) begin + next_self.data = result; + next_self.count = self.count + 1; + end + end + + end + + always_ff @(posedge clk) begin + if (rst) begin + self <= '{0, RESET_VAL}; + end else begin + self <= next_self; + end + end + + +endmodule diff --git a/src/mase_components/common/rtl/comparator_tree.sv b/src/mase_components/common/rtl/comparator_tree.sv new file mode 100644 index 000000000..fc1e15a7e --- /dev/null +++ b/src/mase_components/common/rtl/comparator_tree.sv @@ -0,0 +1,115 @@ +/* +Module : comparator_tree +Description : This module implements a maximum number comparator tree. + + Can do signed/unsigned max/min reductions. +*/ + +`timescale 1ns / 1ps + +module comparator_tree #( + parameter SIZE = 8, // Only supports powers of 2 + parameter DATA_WIDTH = 8, + parameter MAX1_MIN0 = 1, // MAX = 1, MIN = 0 + parameter SIGNED = 0 +) ( + input logic clk, + input logic rst, + + input logic [DATA_WIDTH-1:0] in_data [SIZE-1:0], + input logic in_valid, + output logic in_ready, + + output logic [DATA_WIDTH-1:0] out_data, + output logic out_valid, + input logic out_ready +); + + localparam LEVELS = $clog2(SIZE); + + initial begin + assert (2 ** LEVELS == SIZE); // Only support power of 2 + end + + for (genvar level = 0; level <= LEVELS; level++) begin : vars + logic [DATA_WIDTH-1:0] data[(2**(LEVELS-level))-1:0]; + logic valid; + logic ready; + end + + + for (genvar level = 0; level < LEVELS; level++) begin : element_handshake + logic [2 ** (LEVELS - level - 1) - 1 : 0] element_input_valid; + logic [2 ** (LEVELS - level - 1) - 1 : 0] element_input_ready; + + logic [2 ** (LEVELS - level - 1) - 1 : 0] element_output_valid; + logic [2 ** (LEVELS - level - 1) - 1 : 0] element_output_ready; + end : element_handshake + + for (genvar i = 0; i < LEVELS; i++) begin : level + + for (genvar c = 0; c < 2 ** (LEVELS - i - 1); c++) begin : comparator + + logic [DATA_WIDTH-1:0] left, right, result; + assign left = vars[i].data[2*c]; + assign right = vars[i].data[2*c+1]; + + if (MAX1_MIN0) begin + if (SIGNED) begin + assign result = $signed(left) > $signed(right) ? left : right; + end else begin + assign result = left > right ? left : right; + end + end else begin + if (SIGNED) begin + assign result = $signed(left) < $signed(right) ? left : right; + end else begin + assign result = left < right ? left : right; + end + end + + skid_buffer #( + .DATA_WIDTH(DATA_WIDTH) + ) max_reg ( + .clk(clk), + .rst(rst), + .data_in(result), + .data_in_valid(element_handshake[i].element_input_valid[c]), + .data_in_ready(element_handshake[i].element_input_ready[c]), + .data_out(vars[i+1].data[c]), + .data_out_valid(element_handshake[i].element_output_valid[c]), + .data_out_ready(element_handshake[i].element_output_ready[c]) + ); + end + + // Join handshake signals from each skid buffer into a single + // handshake interface to drive the next level + split_n #( + .N (2 ** (LEVELS - i - 1)) + ) handshake_split ( + .data_in_valid (vars[i].valid), + .data_in_ready (vars[i].ready), + .data_out_valid (element_handshake[i].element_input_valid), + .data_out_ready (element_handshake[i].element_input_ready) + ); + + join_n #( + .NUM_HANDSHAKES (2 ** (LEVELS - i - 1)) + ) handshake_join ( + .data_in_valid (element_handshake[i].element_output_valid), + .data_in_ready (element_handshake[i].element_output_ready), + .data_out_valid (vars[i+1].valid), + .data_out_ready (vars[i+1].ready) + ); + end + + // Connect up first and last layer wires + assign vars[0].data = in_data; + assign vars[0].valid = in_valid; + assign in_ready = vars[0].ready; + + assign out_data = vars[LEVELS].data[0]; + assign out_valid = vars[LEVELS].valid; + assign vars[LEVELS].ready = out_ready; + +endmodule diff --git a/src/mase_components/common/rtl/convert_parallelism.sv b/src/mase_components/common/rtl/convert_parallelism.sv index fa7f94099..966c8fd15 100644 --- a/src/mase_components/common/rtl/convert_parallelism.sv +++ b/src/mase_components/common/rtl/convert_parallelism.sv @@ -8,12 +8,12 @@ module convert_parallelism #( input rst, input logic [DATA_WIDTH-1:0] data_in [DATA_IN_PARALLELISM-1:0], - input data_in_valid, - output data_in_ready, + input logic data_in_valid, + output logic data_in_ready, output logic [DATA_WIDTH-1:0] data_out [DATA_OUT_PARALLELISM-1:0], - output data_out_valid, - input data_out_ready + output logic data_out_valid, + input logic data_out_ready ); // if (DATA_OUT_PARALLELISM == DATA_IN_PARALLELISM) begin // always_comb begin diff --git a/src/mase_components/common/rtl/df_split.sv b/src/mase_components/common/rtl/df_split.sv new file mode 100644 index 000000000..6504acb9c --- /dev/null +++ b/src/mase_components/common/rtl/df_split.sv @@ -0,0 +1,67 @@ +`timescale 1ns / 1ps + +module df_split #( + /* verilator lint_off UNUSEDPARAM */ + + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, // must equal WEIGHT_PARALLELISM_DIM_1 + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + parameter IN_0_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + + // Output 1 + parameter DATA_OUT_0_PRECISION_0 = 16, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2, + + // Output 2 + parameter DATA_OUT_1_PRECISION_0 = 16, + parameter DATA_OUT_1_PRECISION_1 = 3, + parameter DATA_OUT_1_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_1_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_1_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2, + parameter DATA_OUT_1_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_OUT_1_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_1_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2 +) ( + input logic clk, + input logic rst, + + // input port for data_inivations + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready, + + output logic [DATA_OUT_1_PRECISION_0-1:0] data_out_1 [DATA_OUT_1_PARALLELISM_DIM_0*DATA_OUT_1_PARALLELISM_DIM_1-1:0], + output logic data_out_1_valid, + input logic data_out_1_ready +); + +split2 split_i ( + .data_in_valid (data_in_0_valid), + .data_in_ready (data_in_0_ready), + .data_out_valid ({data_out_0_valid, data_out_1_valid}), + .data_out_ready ({data_out_0_ready, data_out_1_ready}) +); + +for (genvar i=0; i 1) + else $error("Use single_element_repeat module for SIZE=1"); + end + localparam REPS_WIDTH = $clog2(REPEAT); localparam ADDR_WIDTH = SIZE == 1 ? 1 : $clog2(SIZE); localparam PTR_WIDTH = ADDR_WIDTH + 1; @@ -65,7 +70,7 @@ module repeat_circular_buffer #( next_self = self; // Input side ready - in_ready = self.size != SIZE && !(self.rep == REPEAT - 1 && self.write_ptr == self.read_ptr); + in_ready = (self.size != SIZE) && !(self.rep == REPEAT - 1 && self.write_ptr == self.read_ptr); // Pause reading when there is (no transfer on this cycle) AND the registers are full. pause_reads = !out_ready && (self.out_reg.valid || self.extra_reg.valid); diff --git a/src/mase_components/common/rtl/single_element_repeat.sv b/src/mase_components/common/rtl/single_element_repeat.sv new file mode 100644 index 000000000..7a0d14902 --- /dev/null +++ b/src/mase_components/common/rtl/single_element_repeat.sv @@ -0,0 +1,100 @@ +/* +Module : single_element_repeat +Description : This module receives data and repeats it N times. + + This module has 2 cycle latency due to output buffering. +*/ + +`timescale 1ns / 1ps + +module single_element_repeat #( + parameter DATA_WIDTH = 32, + parameter REPEAT = 2 +) ( + input logic clk, + input logic rst, + + // Input streaming port + input logic [DATA_WIDTH-1:0] in_data, + input logic in_valid, + output logic in_ready, + + // Output streaming port + output logic [DATA_WIDTH-1:0] out_data, + output logic out_valid, + input logic out_ready +); + + initial begin + assert (REPEAT > 1); + end + + localparam CTR_WIDTH = $clog2(REPEAT); + + logic [DATA_WIDTH-1:0] output_buffer_data; + logic output_buffer_valid, output_buffer_ready; + + typedef struct packed { + // Data element + logic [DATA_WIDTH-1:0] buffer_data; + logic buffer_valid; + + // Counters + logic [CTR_WIDTH-1:0] count; + } SELF_T; + + SELF_T self, next_self; + + + always_comb begin + + next_self = self; + + in_ready = !self.buffer_valid || + (self.buffer_valid && output_buffer_ready && self.count == REPEAT-1); + output_buffer_data = self.buffer_data; + output_buffer_valid = self.buffer_valid; + + if (in_valid && in_ready) begin + next_self.buffer_data = in_data; + next_self.buffer_valid = 1; + end + + if (output_buffer_valid && output_buffer_ready) begin + if (self.count == REPEAT - 1) begin + next_self.count = 0; + if (in_valid && in_ready) begin + next_self.buffer_data = in_data; + next_self.buffer_valid = 1; + end else begin + next_self.buffer_valid = 0; + end + end else begin + next_self.count = self.count + 1; + end + end + + end + + skid_buffer #( + .DATA_WIDTH(DATA_WIDTH) + ) output_buffer ( + .clk(clk), + .rst(rst), + .data_in(output_buffer_data), + .data_in_valid(output_buffer_valid), + .data_in_ready(output_buffer_ready), + .data_out(out_data), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + + always_ff @(posedge clk) begin + if (rst) begin + self <= '{default: '0}; + end else begin + self <= next_self; + end + end + +endmodule diff --git a/src/mase_components/common/rtl/skid_buffer.sv b/src/mase_components/common/rtl/skid_buffer.sv index be3950a24..ca1e76f27 100644 --- a/src/mase_components/common/rtl/skid_buffer.sv +++ b/src/mase_components/common/rtl/skid_buffer.sv @@ -19,13 +19,17 @@ module skid_buffer #( logic [DATA_WIDTH - 1:0] data_buffer_out; logic data_buffer_wren; + logic data_out_wren; + logic use_buffered_data; + logic [DATA_WIDTH - 1:0] selected_data; + logic insert, remove; + logic load, flow, fill, flush, unload; + always_ff @(posedge clk) begin if (rst) data_buffer_out <= 0; else if (data_buffer_wren) data_buffer_out <= data_in; end - logic data_out_wren; - logic use_buffered_data; - logic [DATA_WIDTH - 1:0] selected_data; + assign selected_data = (use_buffered_data) ? data_buffer_out : data_in; always_ff @(posedge clk) begin if (rst) data_out <= 0; @@ -55,13 +59,11 @@ module skid_buffer #( /* verilator lint_on WIDTH */ end end - logic insert, remove; always_comb begin insert = (data_in_valid && data_in_ready); remove = (data_out_valid && data_out_ready); end - logic load, flow, fill, flush, unload; always_comb begin load = (state == EMPTY) && ({insert, remove} == 2'b10); flow = (state == BUSY) && ({insert, remove} == 2'b11); diff --git a/src/mase_components/common/rtl/split_n.sv b/src/mase_components/common/rtl/split_n.sv new file mode 100644 index 000000000..5b5de93e2 --- /dev/null +++ b/src/mase_components/common/rtl/split_n.sv @@ -0,0 +1,45 @@ +/* +Module : splitn +Description : This module implements a 1-to-N streaming interface handshake. +*/ + +`timescale 1ns / 1ps + +module split_n #( + parameter N = 10 +) ( + input logic [0:0] data_in_valid, + output logic [0:0] data_in_ready, + output logic [N-1:0] data_out_valid, + input logic [N-1:0] data_out_ready +); + +logic [N-1:0] ready_intermediate; + +if (N == 1) begin + + assign data_out_valid = data_in_valid; + assign data_in_ready = data_out_ready; + +end else begin + + logic [N-1:0] [N-1:0] debug1; + logic [N-1:0] [N-1:0] debug2; + for (genvar i = 0; i < N; i++) begin : handshake + assign debug1 [i] = (1 << i); + assign debug2 [i] = data_out_ready | debug1[i]; + // We should wait to drive the output valid until all other ports are ready, without checking the current port + // since this leads to a combinatorial loop + assign ready_intermediate[i] = (i == 0) ? &data_out_ready[N-1:1] + : (i == N-1) ? &data_out_ready[N-2:0] + : &debug2[i]; + + assign data_out_valid[i] = data_in_valid && ready_intermediate[i]; + end + + // Apply backpressure until all output ports are ready + assign data_in_ready = &data_out_ready; + +end + +endmodule diff --git a/src/mase_components/common/rtl/unpacked_fifo.sv b/src/mase_components/common/rtl/unpacked_fifo.sv index 8f53d9011..c67fa2413 100644 --- a/src/mase_components/common/rtl/unpacked_fifo.sv +++ b/src/mase_components/common/rtl/unpacked_fifo.sv @@ -20,7 +20,7 @@ module unpacked_fifo #( assign data_in_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH] = data_in[i]; end fifo #( - .SIZE(DEPTH), + .DEPTH(DEPTH), .DATA_WIDTH(DATA_WIDTH * IN_NUM) ) ff_inst ( .clk(clk), diff --git a/src/mase_components/common/rtl/unpacked_register_slice.sv b/src/mase_components/common/rtl/unpacked_register_slice.sv new file mode 100644 index 000000000..6c5de96da --- /dev/null +++ b/src/mase_components/common/rtl/unpacked_register_slice.sv @@ -0,0 +1,36 @@ +`timescale 1ns / 1ps +module unpacked_register_slice #( + parameter DATA_WIDTH = 32, + parameter IN_SIZE = 16, + parameter type MYDATA = logic [DATA_WIDTH-1:0] +) ( + input logic clk, + input logic rst, + + input MYDATA data_in [IN_SIZE-1:0], + input logic data_in_valid, + output logic data_in_ready, + + output MYDATA data_out [IN_SIZE-1:0], + output logic data_out_valid, + input logic data_out_ready +); + logic [DATA_WIDTH * IN_SIZE - 1 : 0] data_in_flatten; + logic [DATA_WIDTH * IN_SIZE - 1 : 0] data_out_flatten; + for (genvar i = 0; i < IN_SIZE; i++) begin + assign data_in_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH] = data_in[i]; + assign data_out[i] = data_out_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH]; + end + register_slice #( + .DATA_WIDTH(DATA_WIDTH * IN_SIZE) + ) register_slice ( + .clk (clk), + .rst (rst), + .data_in_valid (data_in_valid), + .data_in_ready (data_in_ready), + .data_in (data_in_flatten), + .data_out_valid(data_out_valid), + .data_out_ready(data_out_ready), + .data_out (data_out_flatten) + ); +endmodule diff --git a/src/mase_components/common/rtl/unpacked_register_slice_quick.sv b/src/mase_components/common/rtl/unpacked_register_slice_quick.sv index 8db78a4a8..4699e716a 100644 --- a/src/mase_components/common/rtl/unpacked_register_slice_quick.sv +++ b/src/mase_components/common/rtl/unpacked_register_slice_quick.sv @@ -26,11 +26,11 @@ module unpacked_register_slice_quick #( ) register_slice ( .clk (clk), .rst (rst), - .in_valid (in_valid), - .in_ready (in_ready), - .in_data (data_in_flatten), - .out_valid(out_valid), - .out_ready(out_ready), - .out_data (data_out_flatten) + .data_in_valid (in_valid), + .data_in_ready (in_ready), + .data_in (data_in_flatten), + .data_out_valid(out_valid), + .data_out_ready(out_ready), + .data_out (data_out_flatten) ); endmodule diff --git a/src/mase_components/common/rtl/unpacked_repeat_circular_buffer.sv b/src/mase_components/common/rtl/unpacked_repeat_circular_buffer.sv new file mode 100644 index 000000000..36627eed4 --- /dev/null +++ b/src/mase_components/common/rtl/unpacked_repeat_circular_buffer.sv @@ -0,0 +1,58 @@ +/* +Module : repeat_circular_buffer +Description : This module is a repeating circular buffer. +*/ + +`timescale 1ns / 1ps + +module unpacked_repeat_circular_buffer #( + parameter DATA_WIDTH = 32, + parameter IN_NUM = 1, + parameter REPEAT = 2, + parameter SIZE = 4 +) ( + input logic clk, + input logic rst, + + // Input streaming port + input logic [DATA_WIDTH-1:0] in_data [IN_NUM-1:0], + input logic in_valid, + output logic in_ready, + + // Output streaming port + output logic [DATA_WIDTH-1:0] out_data [IN_NUM-1:0], + output logic out_valid, + input logic out_ready +); + + logic [DATA_WIDTH * IN_NUM - 1:0] data_in_flatten; + logic [DATA_WIDTH * IN_NUM - 1:0] data_out_flatten; + + for (genvar i = 0; i < IN_NUM; i++) begin : reshape + assign data_in_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH] = in_data[i]; + end + + repeat_circular_buffer #( + .DATA_WIDTH(DATA_WIDTH * IN_NUM), + .REPEAT(REPEAT), + .SIZE(SIZE) + ) buffer_inst ( + .clk(clk), + .rst(rst), + + // Input streaming port + .in_data(data_in_flatten), + .in_valid(in_valid), + .in_ready(in_ready), + + // Output streaming port + .out_data(data_out_flatten), + .out_valid(out_valid), + .out_ready(out_ready) + ); + + for (genvar i = 0; i < IN_NUM; i++) begin : unreshape + assign out_data[i] = data_out_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH]; + end + +endmodule diff --git a/src/mase_components/common/test/comparator_accumulator_tb.py b/src/mase_components/common/test/comparator_accumulator_tb.py new file mode 100644 index 000000000..fbd8af937 --- /dev/null +++ b/src/mase_components/common/test/comparator_accumulator_tb.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +import os, logging +from random import randint +from pathlib import Path + +import torch + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver, sign_extend, signed_to_unsigned, batched +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + StreamMonitor, +) + +import cocotb +from cocotb.triggers import * + +logger = logging.getLogger("testbench") +logger.setLevel("DEBUG") + + +class ComparatorAccumulatorTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params(["DATA_WIDTH", "DEPTH", "MAX1_MIN0", "SIGNED"]) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + self.output_monitor = StreamMonitor( + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, check=False + ) + + def generate_inputs(self, batches=3): + return [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.DEPTH * batches)] + + def model(self, inputs): + + batched_in = batched(inputs, self.DEPTH) + + if self.SIGNED: + batched_in = [ + [sign_extend(x, self.DATA_WIDTH) for x in l] for l in batched_in + ] + + exp_out = [] + for l in batched_in: + if self.MAX1_MIN0: + exp_out.append(max(l)) + else: + exp_out.append(min(l)) + + if self.SIGNED: + exp_out = [signed_to_unsigned(x, self.DATA_WIDTH) for x in exp_out] + + return exp_out + + async def run_test(self, batches, us): + inputs = self.generate_inputs(batches=batches) + exp_out = self.model(inputs) + + # Log the first batch + logger.info("First Batch") + logger.info("Input : %s" % inputs[: self.DEPTH]) + logger.info("Expect: %s" % exp_out[: self.DEPTH]) + + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + + await Timer(us, "us") + assert self.output_monitor.exp_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = ComparatorAccumulatorTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=10, us=10) + + +@cocotb.test() +async def stream(dut): + tb = ComparatorAccumulatorTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=500, us=500) + + +@cocotb.test() +async def backpressure(dut): + tb = ComparatorAccumulatorTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + await tb.reset() + await tb.run_test(batches=100, us=100) + + +@cocotb.test() +async def valid(dut): + tb = ComparatorAccumulatorTB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.6) + await tb.reset() + await tb.run_test(batches=100, us=100) + + +@cocotb.test() +async def backpressure_valid(dut): + tb = ComparatorAccumulatorTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=1000, us=1000) + + +if __name__ == "__main__": + + def depth_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + for d in [1, 4, 7, 13]: + out.append({**cfg, "DEPTH": d}) + return out + + def width_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + for size in [2, 4, 8, 16]: + out.append({**cfg, "DATA_WIDTH": size}) + return out + + def signed_max_min_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + out.append({**cfg, "MAX1_MIN0": 0, "SIGNED": 0}) + out.append({**cfg, "MAX1_MIN0": 0, "SIGNED": 1}) + out.append({**cfg, "MAX1_MIN0": 1, "SIGNED": 0}) + out.append({**cfg, "MAX1_MIN0": 1, "SIGNED": 1}) + return out + + DEFAULT = { + "DATA_WIDTH": 8, + "DEPTH": 8, + "MAX1_MIN0": 1, + "SIGNED": 0, + } + + cfgs = [DEFAULT] + cfgs = width_cfgs(cfgs) + cfgs = depth_cfgs(cfgs) + cfgs = signed_max_min_cfgs(cfgs) + + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=12, + ) diff --git a/src/mase_components/common/test/comparator_tree_tb.py b/src/mase_components/common/test/comparator_tree_tb.py new file mode 100644 index 000000000..5637791b3 --- /dev/null +++ b/src/mase_components/common/test/comparator_tree_tb.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +import os, logging +from random import randint +from pathlib import Path + +import torch + +from mase_cocotb.runner import mase_runner +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import bit_driver, sign_extend, signed_to_unsigned +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + StreamMonitor, +) + +import cocotb +from cocotb.triggers import * + +logger = logging.getLogger("testbench") +logger.setLevel("DEBUG") + + +class ComparatorTreeTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params(["SIZE", "DATA_WIDTH", "MAX1_MIN0", "SIGNED"]) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + self.output_monitor = StreamMonitor( + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, check=True + ) + + def generate_inputs(self, batches=3): + return [ + [randint(0, 2**self.DATA_WIDTH - 1) for _ in range(self.SIZE)] + for _ in range(batches) + ] + + def model(self, inputs): + if self.SIGNED: + inputs = [[sign_extend(x, self.DATA_WIDTH) for x in l] for l in inputs] + + exp_out = [] + for l in inputs: + if self.MAX1_MIN0: + exp_out.append(max(l)) + else: + exp_out.append(min(l)) + + if self.SIGNED: + exp_out = [signed_to_unsigned(x, self.DATA_WIDTH) for x in exp_out] + + return exp_out + + async def run_test(self, batches, us): + inputs = self.generate_inputs(batches=batches) + exp_out = self.model(inputs) + + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + + await Timer(us, "us") + assert self.output_monitor.exp_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = ComparatorTreeTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=10, us=20) + + +@cocotb.test() +async def stream(dut): + tb = ComparatorTreeTB(dut) + tb.output_monitor.ready.value = 1 + await tb.reset() + await tb.run_test(batches=500, us=100) + + +@cocotb.test() +async def backpressure(dut): + tb = ComparatorTreeTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + await tb.reset() + await tb.run_test(batches=100, us=50) + + +@cocotb.test() +async def valid(dut): + tb = ComparatorTreeTB(dut) + tb.output_monitor.ready.value = 1 + tb.in_driver.set_valid_prob(0.6) + await tb.reset() + await tb.run_test(batches=100, us=50) + + +@cocotb.test() +async def backpressure_valid(dut): + tb = ComparatorTreeTB(dut) + cocotb.start_soon(bit_driver(tb.output_monitor.ready, tb.clk, 0.5)) + tb.in_driver.set_valid_prob(0.5) + await tb.reset() + await tb.run_test(batches=2000, us=200) + + +if __name__ == "__main__": + + def size_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + for size in [2, 4, 8, 16]: + out.append({**cfg, "SIZE": size}) + return out + + def signed_max_min_cfgs(cfglist: list): + out = [] + for cfg in cfglist: + out.append({**cfg, "MAX1_MIN0": 0, "SIGNED": 0}) + out.append({**cfg, "MAX1_MIN0": 0, "SIGNED": 1}) + out.append({**cfg, "MAX1_MIN0": 1, "SIGNED": 0}) + out.append({**cfg, "MAX1_MIN0": 1, "SIGNED": 1}) + return out + + DEFAULT = { + "SIZE": 8, + "DATA_WIDTH": 8, + "MAX1_MIN0": 1, + "SIGNED": 0, + } + + cfgs = size_cfgs([DEFAULT]) + cfgs = signed_max_min_cfgs(cfgs) + + mase_runner( + module_param_list=cfgs, + trace=True, + ) diff --git a/src/mase_components/common/test/fifo_tb.py b/src/mase_components/common/test/fifo_tb.py index aba18e646..23a4d98b2 100644 --- a/src/mase_components/common/test/fifo_tb.py +++ b/src/mase_components/common/test/fifo_tb.py @@ -15,14 +15,17 @@ class FifoTB(Testbench): def __init__(self, dut) -> None: super().__init__(dut, dut.clk, dut.rst) - self.assign_self_params(["DATA_WIDTH", "SIZE", "ADDR_WIDTH"]) + self.assign_self_params(["DATA_WIDTH", "DEPTH"]) # Driver/Monitor self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) self.output_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, unsigned=True ) + # self.in_driver.log.setLevel("DEBUG") + # self.output_monitor.log.setLevel("DEBUG") + def generate_inputs(self, num=20): return [randint(0, (2**self.DATA_WIDTH) - 1) for _ in range(num)] @@ -47,11 +50,11 @@ async def cocotb_test_large_buffering(dut): await tb.reset() tb.output_monitor.ready.value = 1 - inputs = tb.generate_inputs(num=1000) + inputs = tb.generate_inputs(num=100) tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(inputs) - await Timer(40, "us") + await Timer(100, "us") assert tb.output_monitor.exp_queue.empty() @@ -66,7 +69,7 @@ async def cocotb_test_valid(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(inputs) - await Timer(20, "us") + await Timer(100, "us") assert tb.output_monitor.exp_queue.empty() @@ -80,7 +83,7 @@ async def cocotb_test_backpressure(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(inputs) - await Timer(20, "us") + await Timer(100, "us") assert tb.output_monitor.exp_queue.empty() @@ -95,7 +98,7 @@ async def cocotb_test_valid_backpressure(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(inputs) - await Timer(20, "us") + await Timer(100, "us") assert tb.output_monitor.exp_queue.empty() @@ -110,12 +113,20 @@ async def cocotb_test_soak(dut): tb.in_driver.load_driver(inputs) tb.output_monitor.load_monitor(inputs) - await Timer(1000, "us") + await Timer(2000, "us") assert tb.output_monitor.exp_queue.empty() def test_fifo(): - mase_runner(seed=0, trace=True) + mase_runner( + module_param_list=[ + {"DEPTH": 1}, + {"DEPTH": 7}, + {"DEPTH": 8}, + {"DEPTH": 81}, + ], + trace=True, + ) if __name__ == "__main__": diff --git a/src/mase_components/common/test/repeat_circular_buffer_tb.py b/src/mase_components/common/test/repeat_circular_buffer_tb.py index a6fad2cc8..78d3750a0 100644 --- a/src/mase_components/common/test/repeat_circular_buffer_tb.py +++ b/src/mase_components/common/test/repeat_circular_buffer_tb.py @@ -156,7 +156,9 @@ def test_repeat_circular_buffer(): {"DATA_WIDTH": 32, "REPEAT": 2, "SIZE": 7}, # Purely random params *[generate_random_params() for _ in range(5)], - ] + ], + trace=True, + jobs=8, ) diff --git a/src/mase_components/common/test/single_element_repeat_tb.py b/src/mase_components/common/test/single_element_repeat_tb.py new file mode 100644 index 000000000..247cc2b27 --- /dev/null +++ b/src/mase_components/common/test/single_element_repeat_tb.py @@ -0,0 +1,109 @@ +import logging +import random + +import cocotb +from cocotb.triggers import * + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.runner import mase_runner +from mase_cocotb.utils import bit_driver, batched + + +logger = logging.getLogger("testbench") +logger.setLevel(logging.INFO) + + +class SingleElementRepeatTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params(["DATA_WIDTH", "REPEAT"]) + + # Driver/Monitor + self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) + self.output_monitor = StreamMonitor( + dut.clk, dut.out_data, dut.out_valid, dut.out_ready + ) + + def generate_inputs(self, num=10): + return [random.randint(0, 2**self.DATA_WIDTH - 1) for _ in range(num)] + + def model(self, inputs): + exp_out = [] + for x in inputs: + exp_out.extend([x for _ in range(self.REPEAT)]) + return exp_out + + async def run_test(self, batches, us): + await self.reset() + inputs = self.generate_inputs(num=batches) + exp_out = self.model(inputs) + self.in_driver.load_driver(inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, units="us") + assert self.output_monitor.exp_queue.empty() + + +@cocotb.test() +async def basic(dut): + tb = SingleElementRepeatTB(dut) + tb.output_monitor.ready.value = 1 + await tb.run_test(batches=2, us=1) + + +@cocotb.test() +async def stream(dut): + tb = SingleElementRepeatTB(dut) + tb.output_monitor.ready.value = 1 + await tb.run_test(batches=4000, us=1000) + + +@cocotb.test() +async def backpressure(dut): + tb = SingleElementRepeatTB(dut) + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.5)) + await tb.run_test(batches=1000, us=1000) + + +@cocotb.test() +async def valid_backpressure(dut): + tb = SingleElementRepeatTB(dut) + tb.in_driver.set_valid_prob(0.5) + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.5)) + await tb.run_test(batches=1000, us=1500) + + +@cocotb.test() +async def valid_backpressure_more_in(dut): + tb = SingleElementRepeatTB(dut) + tb.in_driver.set_valid_prob(0.7) + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.3)) + await tb.run_test(batches=1000, us=1500) + + +@cocotb.test() +async def valid_backpressure_more_out(dut): + tb = SingleElementRepeatTB(dut) + tb.in_driver.set_valid_prob(0.3) + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.7)) + await tb.run_test(batches=1000, us=1500) + + +if __name__ == "__main__": + + def generate_random_params(): + return { + "DATA_WIDTH": random.randint(1, 32), + "REPEAT": random.randint(2, 6), + } + + cfgs = [ + {"DATA_WIDTH": 8, "REPEAT": 4}, + *[generate_random_params() for _ in range(16)], + ] + + mase_runner( + module_param_list=cfgs, + trace=True, + jobs=8, + ) diff --git a/src/mase_components/common/test/test_synth_common.py b/src/mase_components/common/test/test_synth_common.py new file mode 100644 index 000000000..2a1778a48 --- /dev/null +++ b/src/mase_components/common/test/test_synth_common.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_common(): + run_synth("common") + + +if __name__ == "__main__": + test_synth_common() diff --git a/src/mase_components/conv/test/test_synth_conv.py b/src/mase_components/conv/test/test_synth_conv.py new file mode 100644 index 000000000..c323d0ed5 --- /dev/null +++ b/src/mase_components/conv/test/test_synth_conv.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_conv(): + run_synth("conv") + + +if __name__ == "__main__": + test_synth_conv() diff --git a/src/mase_components/deps.py b/src/mase_components/deps.py index 88de9e50e..f8b6903ca 100644 --- a/src/mase_components/deps.py +++ b/src/mase_components/deps.py @@ -30,16 +30,43 @@ "conv", "activations", ], + "activations/fixed_softermax": [ + "common", + "cast", + "fixed_arithmetic", + "conv", + "matmul", + "activations", + ], # Attention - "attention/fixed_self_att": [ + "attention/fixed_self_attention": [ "attention", + "activations", + "arbiters", + "cast", + "common", + "fixed_arithmetic", + "linear", + "matmul", + ], + "attention/fixed_self_attention_head": [ + "attention", + "cast", + "common", + "fixed_arithmetic", + "linear", + "matmul", + "activations", + ], + "attention/fixed_self_attention_single_precision_wrapper": [ + "attention", + "activations", + "arbiters", "cast", "common", - "conv", "fixed_arithmetic", "linear", "matmul", - "ViT", ], "arithmetic/mac": ["fixed_arithmetic", "float_arithmetic"], # Binary arithmetic @@ -77,7 +104,6 @@ "fixed_arithmetic/fixed_lut_index": [], "fixed_arithmetic/fixed_range_augmentation": [], "fixed_arithmetic/fixed_mult": [], - "fixed_arithmetic/fixed_adder_tree_layer": [], "fixed_arithmetic/fixed_accumulator": ["common"], "fixed_arithmetic/fixed_adder_tree": ["fixed_arithmetic", "common"], "fixed_arithmetic/fixed_vector_mult": ["fixed_arithmetic", "common"], @@ -136,8 +162,9 @@ "conv/sliding_window": ["cast", "conv", "linear", "common", "fixed_arithmetic"], "conv/padding": ["cast", "conv", "linear", "common", "fixed_arithmetic"], # Matmul - "matmul/simple_matmul": ["common", "linear", "cast", "fixed_arithmetic"], + "matmul/simple_matmul": ["common", "linear", "cast", "fixed_arithmetic", "matmul"], "matmul/fixed_matmul": ["common", "linear", "cast", "fixed_arithmetic", "matmul"], + "matmul/matmul": ["common", "linear", "cast", "fixed_arithmetic", "matmul"], "matmul/test_chain_matmul": [ "common", "linear", diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_adder.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_adder.sv new file mode 100644 index 000000000..767af3d2b --- /dev/null +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_adder.sv @@ -0,0 +1,115 @@ +`timescale 1ns / 1ps + +/* + * Simple registered adder between two inputs. + * Currently doesn't support parallelism conversion. + */ + +module fixed_adder #( + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + + parameter DATA_IN_1_PRECISION_0 = 16, + parameter DATA_IN_1_PRECISION_1 = 3, + parameter DATA_IN_1_TENSOR_SIZE_DIM_0 = 4, + parameter DATA_IN_1_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_1_TENSOR_SIZE_DIM_2 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_0 = 4, + parameter DATA_IN_1_PARALLELISM_DIM_1 = 1, + parameter DATA_IN_1_PARALLELISM_DIM_2 = 1, + + parameter DATA_OUT_0_PRECISION_0 = 16, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_2, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_2 +) ( + input clk, + input rst, + + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + input logic [DATA_IN_1_PRECISION_0-1:0] data_in_1 [DATA_IN_1_PARALLELISM_DIM_0*DATA_IN_1_PARALLELISM_DIM_1-1:0], + input logic data_in_1_valid, + output logic data_in_1_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + +localparam MAX_PRECISION_0 = DATA_IN_0_PRECISION_0 > DATA_IN_1_PRECISION_0 ? DATA_IN_0_PRECISION_0 : DATA_IN_1_PRECISION_0; + +localparam SUM_PRECISION_0 = MAX_PRECISION_0 + 1; + +// ! TO DO: check if this is correct +localparam SUM_PRECISION_1 = DATA_IN_0_PRECISION_1; + +// * Declarations +// * --------------------------------------------------------------------------------------------------- + +logic joined_input_valid; +logic joined_input_ready; +logic [SUM_PRECISION_0-1:0] add_result [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0]; +logic [DATA_OUT_0_PRECISION_0-1:0] cast_out [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0]; + +// * Instances +// * --------------------------------------------------------------------------------------------------- + +// * Wait until both inputs are available +join2 join_inst ( + .data_in_valid ({data_in_0_valid, data_in_1_valid}), + .data_in_ready ({data_in_0_ready, data_in_1_ready}), + .data_out_valid(joined_input_valid), + .data_out_ready(joined_input_ready) +); + +// * Cast the sum to the requested output precision +fixed_cast #( + .IN_SIZE (DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1), + .IN_WIDTH (SUM_PRECISION_0), + .IN_FRAC_WIDTH (SUM_PRECISION_1), + .OUT_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) +) bias_cast_i ( + .data_in (add_result), + .data_out(cast_out) +); + +// * Register the output +unpacked_register_slice #( + .DATA_WIDTH(DATA_OUT_0_PRECISION_0), + .IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1) +) register_slice_i ( + .clk(clk), + .rst(rst), + + .data_in (cast_out), + .data_in_valid(joined_input_valid), + .data_in_ready(joined_input_ready), + + .data_out (data_out_0), + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) +); + +// * Logic +// * --------------------------------------------------------------------------------------------------- + +// * Do the sum +for (genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1; i++) begin + assign add_result[i] = data_in_0[i] + data_in_1[i]; +end + +endmodule \ No newline at end of file diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree.sv index 5c1a77dfd..99d6f409f 100644 --- a/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree.sv +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree.sv @@ -1,6 +1,9 @@ `timescale 1ns / 1ps + +// TODO: Add signed param. fixed_adder_tree_layer already supports signedness + module fixed_adder_tree #( - parameter IN_SIZE = 1, + parameter IN_SIZE = 2, parameter IN_WIDTH = 32, parameter OUT_WIDTH = $clog2(IN_SIZE) + IN_WIDTH ) ( @@ -18,82 +21,69 @@ module fixed_adder_tree #( localparam LEVELS = $clog2(IN_SIZE); - // Declare intermediate values at each level - for (genvar i = 0; i <= LEVELS; i++) begin : vars - // The number of inputs at each level - // level_num = ceil(num/(2^i)) - localparam LEVEL_IN_SIZE = (IN_SIZE + ((1 << i) - 1)) >> i; - // The input data array at each level - // When i = 0, data is the input of the adder tree. - // When i = level, data is the output of the adder tree. - logic [(IN_WIDTH + i)-1:0] data[LEVEL_IN_SIZE-1:0]; - // Each level has a pair of handshake signals - // When i = 0, they are the handshake logic of the input. - // When i = level, they are the handshake logic of the output. - logic valid; - logic ready; + initial begin + assert (IN_SIZE > 0); end - // Generate adder for each layer - for (genvar i = 0; i < LEVELS; i++) begin : level - // The number of inputs at each level - localparam LEVEL_IN_SIZE = (IN_SIZE + ((1 << i) - 1)) >> i; - // The number of adders needed at each level - // which is the number of the inputs at next level - localparam NEXT_LEVEL_IN_SIZE = (LEVEL_IN_SIZE + 1) / 2; - // The sum array is the output of the adders - logic [(IN_WIDTH + i):0] sum[NEXT_LEVEL_IN_SIZE-1:0]; - - // The width of the data increases by 1 for the next - // level in order to keep the carry bit from the addition - fixed_adder_tree_layer #( - .IN_SIZE (LEVEL_IN_SIZE), - .IN_WIDTH(IN_WIDTH + i) - ) layer ( - .data_in (vars[i].data), - .data_out(sum) - ); - - // Cocotb/verilator does not support array flattening, so - // we need to manually add some reshaping process. - - // Casting array for sum - logic [$bits(sum)-1:0] cast_sum; - logic [$bits(sum)-1:0] cast_data; - for (genvar j = 0; j < NEXT_LEVEL_IN_SIZE; j++) begin : reshape_in - assign cast_sum[(IN_WIDTH+i+1)*j+(IN_WIDTH+i):(IN_WIDTH+i+1)*j] = sum[j]; - end + generate + if (LEVELS == 0) begin : gen_skip_adder_tree - skid_buffer #( - .DATA_WIDTH($bits(sum)) - ) register_slice ( - .clk (clk), - .rst (rst), - .data_in (cast_sum), - .data_in_valid (vars[i].valid), - .data_in_ready (vars[i].ready), - .data_out (cast_data), - .data_out_valid(vars[i+1].valid), - .data_out_ready(vars[i+1].ready) - ); - - // Casting array for vars[i+1].data - for (genvar j = 0; j < NEXT_LEVEL_IN_SIZE; j++) begin : reshape_out - assign vars[i+1].data[j] = cast_data[(IN_WIDTH+i+1)*j+(IN_WIDTH+i):(IN_WIDTH+i+1)*j]; - end + assign data_out = data_in[0]; + assign data_out_valid = data_in_valid; + assign data_in_ready = data_out_ready; - end + end else begin : gen_adder_tree + + // data & sum wires are oversized on purpose for vivado. + logic [OUT_WIDTH*IN_SIZE-1:0] data[LEVELS:0]; + logic [OUT_WIDTH*IN_SIZE-1:0] sum[LEVELS-1:0]; + logic valid[IN_SIZE-1:0]; + logic ready[IN_SIZE-1:0]; + + // Generate adder for each layer + for (genvar i = 0; i < LEVELS; i++) begin : level + + localparam LEVEL_IN_SIZE = (IN_SIZE + ((1 << i) - 1)) >> i; + localparam LEVEL_OUT_SIZE = (LEVEL_IN_SIZE + 1) / 2; + localparam LEVEL_IN_WIDTH = IN_WIDTH + i; + localparam LEVEL_OUT_WIDTH = LEVEL_IN_WIDTH + 1; + + fixed_adder_tree_layer #( + .IN_SIZE (LEVEL_IN_SIZE), + .IN_WIDTH(LEVEL_IN_WIDTH) + ) layer ( + .data_in (data[i]), // flattened LEVEL_IN_SIZE * LEVEL_IN_WIDTH + .data_out(sum[i]) // flattened LEVEL_OUT_SIZE * LEVEL_OUT_WIDTH + ); + + skid_buffer #( + .DATA_WIDTH(LEVEL_OUT_SIZE * LEVEL_OUT_WIDTH) + ) register_slice ( + .clk (clk), + .rst (rst), + .data_in (sum[i]), + .data_in_valid (valid[i]), + .data_in_ready (ready[i]), + .data_out (data[i+1]), + .data_out_valid(valid[i+1]), + .data_out_ready(ready[i+1]) + ); + + end + + for (genvar i = 0; i < IN_SIZE; i++) begin : gen_input_assign + assign data[0][(i+1)*IN_WIDTH-1 : i*IN_WIDTH] = data_in[i]; + end + + assign valid[0] = data_in_valid; + assign data_in_ready = ready[0]; + + assign data_out = data[LEVELS][OUT_WIDTH-1:0]; + assign data_out_valid = valid[LEVELS]; + assign ready[LEVELS] = data_out_ready; + + end + endgenerate - // it will zero-extend automatically - // for (genvar j = 0; j < IN_SIZE; j++) begin : layer_0 - // assign vars[0].data[j] = data_in[j]; - // end - assign vars[0].data = data_in; - assign vars[0].valid = data_in_valid; - assign data_in_ready = vars[0].ready; - - assign data_out = vars[LEVELS].data[0]; - assign data_out_valid = vars[LEVELS].valid; - assign vars[LEVELS].ready = data_out_ready; endmodule diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree_layer.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree_layer.sv index 6bffe933d..7d2e38861 100644 --- a/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree_layer.sv +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_adder_tree_layer.sv @@ -1,20 +1,42 @@ `timescale 1ns / 1ps module fixed_adder_tree_layer #( parameter IN_SIZE = 2, - parameter IN_WIDTH = 32 + parameter IN_WIDTH = 16, + parameter SIGNED = 1, + + localparam OUT_WIDTH = IN_WIDTH + 1, + localparam OUT_SIZE = (IN_SIZE + 1) / 2 ) ( - input logic [IN_WIDTH-1:0] data_in [ IN_SIZE-1:0], - output logic [ IN_WIDTH:0] data_out[(IN_SIZE+1)/2-1:0] + input logic [ IN_SIZE*IN_WIDTH-1:0] data_in, + output logic [OUT_SIZE*OUT_WIDTH-1:0] data_out ); - generate - for (genvar i = 0; i < IN_SIZE / 2; i++) begin : pair - assign data_out[i] = {data_in[i][IN_WIDTH-1], data_in[i]} + {data_in[IN_SIZE-1-i][IN_WIDTH-1], data_in[IN_SIZE-1-i]}; + logic [ IN_WIDTH-1:0] data_in_unflat [ IN_SIZE-1:0]; + logic [OUT_WIDTH-1:0] data_out_unflat[OUT_SIZE-1:0]; + + for (genvar i = 0; i < IN_SIZE; i++) begin : in_unflat + assign data_in_unflat[i] = data_in[(i+1)*IN_WIDTH-1 : i*IN_WIDTH]; + end + + for (genvar i = 0; i < IN_SIZE / 2; i++) begin : pair + if (SIGNED) begin + assign data_out_unflat[i] = $signed(data_in_unflat[2*i]) + $signed(data_in_unflat[2*i+1]); + end else begin + assign data_out_unflat[i] = data_in_unflat[2*i] + data_in_unflat[2*i+1]; end + end - if (IN_SIZE % 2 != 0) begin : left - assign data_out[IN_SIZE/2] = {data_in[IN_SIZE/2][IN_WIDTH-1], data_in[IN_SIZE/2]}; + if (IN_SIZE % 2 != 0) begin : left + if (SIGNED) begin + assign data_out_unflat[OUT_SIZE-1] = $signed(data_in_unflat[IN_SIZE-1]); + end else begin + assign data_out_unflat[OUT_SIZE-1] = {1'b0, data_in_unflat[IN_SIZE-1]}; end - endgenerate + end + + for (genvar i = 0; i < OUT_SIZE; i++) begin : out_flat + assign data_out[(i+1)*OUT_WIDTH-1 : i*OUT_WIDTH] = data_out_unflat[i]; + end + endmodule diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_dot_product.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_dot_product.sv index a6dc0c98f..8bd12f739 100644 --- a/src/mase_components/fixed_arithmetic/rtl/fixed_dot_product.sv +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_dot_product.sv @@ -33,9 +33,14 @@ module fixed_dot_product #( localparam PRODUCT_WIDTH = IN_WIDTH + WEIGHT_WIDTH; - logic [PRODUCT_WIDTH-1:0] pv [IN_SIZE-1:0]; + logic [PRODUCT_WIDTH-1:0] pv [IN_SIZE-1:0]; logic pv_valid; logic pv_ready; + + logic [ OUT_WIDTH-1:0] sum; + logic sum_valid; + logic sum_ready; + fixed_vector_mult #( .IN_WIDTH(IN_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), @@ -56,9 +61,6 @@ module fixed_dot_product #( // sum the products - logic [OUT_WIDTH-1:0] sum; - logic sum_valid; - logic sum_ready; // sum = sum(pv) fixed_adder_tree #( .IN_SIZE (IN_SIZE), diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_isqrt.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_isqrt.sv new file mode 100644 index 000000000..052515f54 --- /dev/null +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_isqrt.sv @@ -0,0 +1,158 @@ +`timescale 1ns / 1ps + +/* verilator lint_off UNUSEDSIGNAL */ +module fixed_isqrt #( + parameter IN_WIDTH = 16, + parameter IN_FRAC_WIDTH = 7, + parameter LUT_POW = 5, + // TODO: how to use these? Will the output width not always be the same as + // the input width? + // parameter OUT_WIDTH = 16, + // parameter OUT_FRAC_WIDTH = 7, + // TODO: the design is stateless therefore no cycles needed. + // if the critical path is too large for this module then it can be + // pipelined. + // parameter PIPELINE_CYCLES = 0, + + // LUT parameters + parameter string LUT_MEMFILE = "", + + localparam MAX_NUM = (1 << IN_WIDTH) - 1, + localparam MSB_WIDTH = IN_WIDTH == 1 ? 1 : $clog2(IN_WIDTH), + localparam ONE = 1 << (IN_WIDTH - 1) // FORMAT: Q1.(WIDTH-1) +) ( + // TODO: stateless design would not need these pins. + input logic clk, + input logic rst, + + input logic [2*IN_WIDTH-1:0] in_data, + // TODO: usage of these pins depends on whether or not the design is + // pipelined whether. + input logic in_valid, + output logic in_ready, + + output logic [2*IN_WIDTH-1:0] out_data, + // TODO: usage of these pins depends on whether or not the design is + // pipelined whether. + output logic out_valid, + input logic out_ready +); + + logic [2*IN_WIDTH-1:0] x_reduced[3:0]; + logic [MSB_WIDTH-1:0] msb_index[3:0]; + logic [2*IN_WIDTH-1:0] lut_index; + logic [2*IN_WIDTH-1:0] lut_value[2:1]; + logic [2*IN_WIDTH-1:0] y[3:3]; + logic [2*IN_WIDTH-1:0] y_or_one; + logic [2*IN_WIDTH-1:0] y_aug; + + logic pipe_valid[3:1]; + logic pipe_ready[3:1]; + + logic [2*IN_WIDTH-1:0] isqrt_data_out; + + fixed_range_reduction #( + .WIDTH(IN_WIDTH) + ) fixed_range_reduction_inst ( + .data_a(in_data), + .data_out(x_reduced[0]), + .msb_index(msb_index[0]), + .not_found() + ); + + skid_buffer #( + .DATA_WIDTH(2 * IN_WIDTH + MSB_WIDTH) + ) pipe_reg_0 ( + .clk(clk), + .rst(rst), + .data_in({x_reduced[0], msb_index[0]}), + .data_in_valid(in_valid), + .data_in_ready(in_ready), + .data_out({x_reduced[1], msb_index[1]}), + .data_out_valid(pipe_valid[1]), + .data_out_ready(pipe_ready[1]) + ); + + fixed_lut_index #( + .WIDTH (IN_WIDTH), + .LUT_POW(LUT_POW) + ) fixed_lut_index_inst ( + .data_a (x_reduced[1]), + .data_b (msb_index[1]), + .data_out(lut_index) + ); + + lut #( + .DATA_WIDTH(IN_WIDTH), + .SIZE(2 ** LUT_POW), + .OUTPUT_REG(0), + .MEM_FILE(LUT_MEMFILE) + ) fixed_lut_inst ( + .clk('0), // Tie offclock + .addr(lut_index), + .data(lut_value[1]) + ); + + skid_buffer #( + .DATA_WIDTH(2 * IN_WIDTH + MSB_WIDTH + IN_WIDTH) + ) pipe_reg_1 ( + .clk(clk), + .rst(rst), + .data_in({x_reduced[1], msb_index[1], lut_value[1]}), + .data_in_valid(pipe_valid[1]), + .data_in_ready(pipe_ready[1]), + .data_out({x_reduced[2], msb_index[2], lut_value[2]}), + .data_out_valid(pipe_valid[2]), + .data_out_ready(pipe_ready[2]) + ); + + fixed_nr_stage #( + .WIDTH(IN_WIDTH), + .MSB_WIDTH(MSB_WIDTH) + ) fixed_nr_stage_inst_1 ( + .clk(clk), + .rst(rst), + .data_a(x_reduced[2]), + .data_b(lut_value[2]), + .data_in_msb(msb_index[2]), + .data_in_valid(pipe_valid[2]), + .data_in_ready(pipe_ready[2]), + .data_out(y[3]), + .data_out_x_reduced(x_reduced[3]), + .data_out_msb(msb_index[3]), + .data_out_valid(pipe_valid[3]), + .data_out_ready(pipe_ready[3]) + ); + + assign y_or_one = (x_reduced[3] == ONE) ? ONE : y[3]; + + fixed_range_augmentation #( + .WIDTH(IN_WIDTH), + .FRAC_WIDTH(IN_FRAC_WIDTH) + ) fixed_range_augmentation_inst ( + .data_a (y_or_one), + .data_b (msb_index[3]), + .data_out(y_aug) + ); + + assign isqrt_data_out = + // Fishing for 0s. + (x_reduced[3] == 0) ? MAX_NUM : ( + // Fishing for overflows. + (y_aug > MAX_NUM) ? MAX_NUM : y_aug); + + skid_buffer #( + .DATA_WIDTH(2 * IN_WIDTH) + ) output_reg ( + .clk(clk), + .rst(rst), + .data_in(isqrt_data_out), + .data_in_valid(pipe_valid[3]), + .data_in_ready(pipe_ready[3]), + .data_out(out_data), + .data_out_valid(out_valid), + .data_out_ready(out_ready) + ); + +endmodule +/* verilator lint_on UNUSEDSIGNAL */ diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_range_reduction.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_range_reduction.sv index a71f17719..7b250b350 100644 --- a/src/mase_components/fixed_arithmetic/rtl/fixed_range_reduction.sv +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_range_reduction.sv @@ -1,4 +1,11 @@ +/* +Module : fixed_range_reduction +Description : This module finds the MSB of the number. If there is no MSB, then + the "not_found" wire will be driven HIGH. +*/ + `timescale 1ns / 1ps + module fixed_range_reduction #( parameter WIDTH = 16, localparam MSB_WIDTH = $clog2(WIDTH) @@ -8,7 +15,8 @@ module fixed_range_reduction #( // Reduced x output logic [WIDTH-1:0] data_out, // FORMAT: Q1.(WIDTH-1). // msb_index - output logic [MSB_WIDTH-1:0] msb_index + output logic [MSB_WIDTH-1:0] msb_index, + output logic not_found ); // Find MSB index. Rightmost position = 0 @@ -27,6 +35,7 @@ module fixed_range_reduction #( /* verilator lint_on LATCH */ // Shift by the correct amount to set format to Q1.(WIDTH-1) - assign data_out = data_a << (WIDTH - 1 - msb_index); + assign data_out = data_a << (WIDTH - 1 - msb_index); + assign not_found = data_a == '0; endmodule diff --git a/src/mase_components/fixed_arithmetic/rtl/fixed_vector_mult.sv b/src/mase_components/fixed_arithmetic/rtl/fixed_vector_mult.sv index a93654900..2a6b0757a 100644 --- a/src/mase_components/fixed_arithmetic/rtl/fixed_vector_mult.sv +++ b/src/mase_components/fixed_arithmetic/rtl/fixed_vector_mult.sv @@ -33,6 +33,13 @@ module fixed_vector_mult #( // pv[i] = data_in[i] * w[i] logic [PRODUCT_WIDTH-1:0] product_vector[IN_SIZE-1:0]; + logic product_data_in_valid; + logic product_data_in_ready; + logic product_data_out_valid; + logic product_data_out_ready; + logic [$bits(product_vector)-1:0] product_data_in; + logic [$bits(product_vector)-1:0] product_data_out; + for (genvar i = 0; i < IN_SIZE; i = i + 1) begin : parallel_mult fixed_mult #( .IN_A_WIDTH(IN_WIDTH), @@ -44,12 +51,7 @@ module fixed_vector_mult #( ); end - logic product_data_in_valid; - logic product_data_in_ready; - logic product_data_out_valid; - logic product_data_out_ready; - logic [$bits(product_vector)-1:0] product_data_in; - logic [$bits(product_vector)-1:0] product_data_out; + join2 #() join_inst ( .data_in_ready ({weight_ready, data_in_ready}), diff --git a/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_layer_tb.py b/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_layer_tb.py index 2ff5b39d7..598668d9f 100644 --- a/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_layer_tb.py +++ b/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_layer_tb.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -# This script tests the fixed point adder tree layer +# This script is never run since the rtl has been changed. + import random, os import cocotb @@ -11,7 +12,7 @@ class VerificationCase: def __init__(self, samples=2): self.in_width = 32 - self.num = 17 # random.randint(2, 33) + self.num = 16 # random.randint(2, 33) self.inputs, self.outputs = [], [] for _ in range(samples): i, o = self.single_run() diff --git a/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_tb.py b/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_tb.py index 41eed83ce..9b3efa176 100644 --- a/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_tb.py +++ b/src/mase_components/fixed_arithmetic/test/fixed_adder_tree_tb.py @@ -4,6 +4,7 @@ import os, math, logging from mase_cocotb.random_test import RandomSource, RandomSink, check_results +from mase_cocotb.testbench import Testbench from mase_cocotb.runner import mase_runner import cocotb @@ -19,11 +20,13 @@ # DUT test specifications -class VerificationCase: - def __init__(self, samples=10): - self.data_in_width = 32 - self.num = 9 - self.data_out_width = math.ceil(math.log2(self.num)) + 32 +class VerificationCase(Testbench): + def __init__(self, dut, samples=10): + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params(["IN_SIZE", "IN_WIDTH"]) + self.data_in_width = self.IN_WIDTH + self.num = self.IN_SIZE + self.data_out_width = math.ceil(math.log2(self.num)) + self.data_in_width self.inputs = RandomSource( samples=samples, num=self.num, max_stalls=2 * samples, debug=debug ) @@ -33,11 +36,11 @@ def __init__(self, samples=10): self.samples = samples self.ref = self.sw_compute() - def get_dut_parameters(self): - return { - "IN_SIZE": self.num, - "IN_WIDTH": self.data_in_width, - } + # def get_dut_parameters(self): + # return { + # "IN_SIZE": self.num, + # "IN_WIDTH": self.data_in_width, + # } def sw_compute(self): ref = [] @@ -61,7 +64,7 @@ def is_impossible_state(data_in_ready, data_in_valid, data_out_ready, data_out_v async def cocotb_test_fixed_adder_tree(dut): """Test integer based adder tree""" samples = 20 - test_case = VerificationCase(samples=samples) + test_case = VerificationCase(dut, samples=samples) # Reset cycle await Timer(20, units="ns") @@ -151,8 +154,24 @@ async def cocotb_test_fixed_adder_tree(dut): def test_fixed_adder_tree(): - tb = VerificationCase() - mase_runner(module_param_list=[tb.get_dut_parameters()]) + mase_runner( + module_param_list=[ + # Power of 2's + {"IN_SIZE": 8, "IN_WIDTH": 32}, + {"IN_SIZE": 4, "IN_WIDTH": 32}, + {"IN_SIZE": 2, "IN_WIDTH": 32}, + {"IN_SIZE": 16, "IN_WIDTH": 64}, + {"IN_SIZE": 32, "IN_WIDTH": 7}, + # 1 size edge case + {"IN_SIZE": 1, "IN_WIDTH": 32}, + # Odd sizes + {"IN_SIZE": 3, "IN_WIDTH": 32}, + {"IN_SIZE": 9, "IN_WIDTH": 8}, + {"IN_SIZE": 7, "IN_WIDTH": 8}, + {"IN_SIZE": 5, "IN_WIDTH": 8}, + ], + trace=True, + ) if __name__ == "__main__": diff --git a/src/mase_components/fixed_arithmetic/test/fixed_isqrt_tb.py b/src/mase_components/fixed_arithmetic/test/fixed_isqrt_tb.py new file mode 100644 index 000000000..3398f3fdc --- /dev/null +++ b/src/mase_components/fixed_arithmetic/test/fixed_isqrt_tb.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +# This script tests the fixed point inverse square root. +import random, os +from pathlib import Path +from os import makedirs + +import cocotb +from cocotb.triggers import Timer +from mase_cocotb.runner import mase_runner +import math +from mase_cocotb.testbench import Testbench +from mase_cocotb.utils import verilator_str_param, bit_driver +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_components.fixed_arithmetic.test.isqrt_sw import ( + isqrt_sw2, + int_to_float, + make_lut, +) +from mase_components.common.test.lut_tb import write_memb + + +class VerificationCase(Testbench): + def __init__(self, dut): + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + [ + "IN_WIDTH", + "IN_FRAC_WIDTH", + "LUT_POW", + ] + ) + + self.input_driver = StreamDriver( + dut.clk, dut.in_data, dut.in_valid, dut.in_ready + ) + self.output_monitor = StreamMonitor( + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + name="Output ISQRT", + ) + + def generate_inputs(self, num=10000): + maxnum = (2**self.IN_WIDTH) - 1 + return [random.randint(0, maxnum) for _ in range(num)], num + + def model(self, data_in): + ref = [] + lut_size = 2**self.LUT_POW + lut = make_lut(lut_size, self.IN_WIDTH) + for x in data_in: + expected = isqrt_sw2( + x, self.IN_WIDTH, self.IN_FRAC_WIDTH, self.LUT_POW, lut + ) + ref.append(expected) + return ref + + +def debug(dut, i, f): + print( + f"X : {dut.in_data.value} {int_to_float(dut.in_data.value.integer, i, f)}" + ) + print( + f"X red : {dut.x_reduced.value} {int_to_float(dut.x_reduced.value.integer, 1, 15)}" + ) + print(f"MSB index: {dut.msb_index.value.integer}") + print(f"lut_index: {dut.lut_index.value.integer}") + print( + f"LUT value: {dut.lut_value.value} {int_to_float(dut.lut_value.value.integer, 1, 15)}" + ) + print(f"Y : {dut.y.value} {int_to_float(dut.y.value.integer, 1, 15)}") + print( + f"Y aug : {dut.y_aug.value} {int_to_float(dut.y_aug.value.integer, i, f)}" + ) + + +CLK_NS = 25 + + +@cocotb.test() +async def sweep(dut): + """Test for inverse square root""" + tb = VerificationCase(dut) + await tb.reset() + tb.output_monitor.ready.value = 1 + inputs, samples = tb.generate_inputs() + exp_out = tb.model(inputs) + tb.input_driver.load_driver(inputs) + tb.output_monitor.load_monitor(exp_out) + await Timer(1000, "us") + assert tb.output_monitor.exp_queue.empty() + + +@cocotb.test() +async def backpressure(dut): + """Test for inverse square root""" + tb = VerificationCase(dut) + await tb.reset() + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.5)) + inputs, samples = tb.generate_inputs() + exp_out = tb.model(inputs) + tb.input_driver.load_driver(inputs) + tb.output_monitor.load_monitor(exp_out) + await Timer(1000, "us") + assert tb.output_monitor.exp_queue.empty() + + +@cocotb.test() +async def valid_backpressure(dut): + """Test for inverse square root""" + tb = VerificationCase(dut) + await tb.reset() + tb.input_driver.set_valid_prob(0.5) + cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.5)) + inputs, samples = tb.generate_inputs() + exp_out = tb.model(inputs) + tb.input_driver.load_driver(inputs) + tb.output_monitor.load_monitor(exp_out) + await Timer(1000, "us") + assert tb.output_monitor.exp_queue.empty() + + +if __name__ == "__main__": + mem_dir = Path(__file__).parent / "build" / "fixed_isqrt" / "mem" + makedirs(mem_dir, exist_ok=True) + + def single_cfg(width, frac_width, lut_pow, str_id): + lut_size = 2**lut_pow + lut = make_lut(lut_size, width) + mem_path = mem_dir / f"lutmem-{str_id}.mem" + write_memb(mem_path, lut, width) + return { + "IN_WIDTH": width, + "IN_FRAC_WIDTH": frac_width, + "LUT_POW": lut_pow, + "LUT_MEMFILE": verilator_str_param(str(mem_path)), + } + + def full_sweep(): + parameter_list = [] + lut_pow = 5 + for int_width in range(1, 9): + for frac_width in range(0, 9): + width = int_width + frac_width + parameters = single_cfg( + width, frac_width, lut_pow, str_id=f"{int_width}-{frac_width}" + ) + parameter_list.append(parameters) + return parameter_list + + parameter_list = [ + # A use case in group_norm + *full_sweep(), + # single_cfg(35, 14, 7, 0) + ] + mase_runner(module_param_list=parameter_list, trace=True) diff --git a/src/mase_components/fixed_arithmetic/test/fixed_nr_stage_tb.py b/src/mase_components/fixed_arithmetic/test/fixed_nr_stage_tb.py new file mode 100644 index 000000000..8c0b70ed1 --- /dev/null +++ b/src/mase_components/fixed_arithmetic/test/fixed_nr_stage_tb.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +# This script tests the fixed point inverse square root. +import random, os + +import cocotb +from cocotb.triggers import Timer +from mase_cocotb.testbench import Testbench +from mase_cocotb.runner import mase_runner +import math +from .isqrt_sw import ( + nr_stage_sw, + float_to_int, + int_to_float, + fixed_lut_index_sw, + make_lut, + range_reduction_sw, +) + + +class VerificationCase(Testbench): + def __init__(self, dut): + super().__init__(dut) + self.assign_self_params(["WIDTH"]) + + def generate_inputs(self, lut_pow): + samples = 2**self.WIDTH + int_width = 1 + frac_width = self.WIDTH - 1 + data_x = [] + initial_guesses = [] + lut_size = 2**lut_pow + lut = make_lut(lut_size, self.WIDTH) + # NOTE: since negative values are not supported by fixed formats since + # isqrt only outputs positive results we cannot test every single com- + # bination of x and initial guesses. + for x in range(samples): + # Create inputs. + x_red = range_reduction_sw(x, self.WIDTH) + lut_index = fixed_lut_index_sw(x_red, self.WIDTH, lut_pow) + lut_value = lut[lut_index] + # Add inputs. + data_x.append(x_red) + initial_guesses.append(lut_value) + return data_x, initial_guesses, samples + + def model(self, data_x, initial_guesses): + ref = [] + for x, lut in zip(data_x, initial_guesses): + expected = nr_stage_sw(x, self.WIDTH, lut) + ref.append(expected) + return ref + + +@cocotb.test() +async def test_fixed_nr_stage(dut): + """Test for the Newton Raphson stage for isqrt""" + testcase = VerificationCase(dut) + lut_pow = 5 + data_x, initial_guesses, samples = testcase.generate_inputs(lut_pow) + ref = testcase.model(data_x, initial_guesses) + width = testcase.WIDTH + + for i in range(samples): + # Set up module data. + data_a = data_x[i] + data_b = initial_guesses[i] + + # Force module data. + dut.data_a.value = data_a + dut.data_b.value = data_b + # Wait for processing. + await Timer(10, units="ns") + + # Exepected result. + expected = ref[i] + + # Error + error = abs( + int_to_float(dut.data_out.value.integer, 1, width) + - int_to_float(expected, 1, width) + ) + + # Check the output. + assert ( + error == 0 + ), f""" + <<< --- Test failed --- >>> + Input: + X : {int_to_float(data_a, 1, width-1)} + LUT: {int_to_float(data_b, 1, width-1)} + + Output: + Out: {int_to_float(dut.data_out.value.integer, 1, width-1)} + + Expected: + {int_to_float(expected, 1, width-1)} + + Error: + {error} + """ + + +if __name__ == "__main__": + + def full_sweep(): + parameter_list = [] + for width in range(1, 17): + parameter_list.append({"WIDTH": width}) + return parameter_list + + parameter_list = full_sweep() + mase_runner(module_param_list=parameter_list) diff --git a/src/mase_components/fixed_arithmetic/test/isqrt_sw.py b/src/mase_components/fixed_arithmetic/test/isqrt_sw.py new file mode 100644 index 000000000..c4723f971 --- /dev/null +++ b/src/mase_components/fixed_arithmetic/test/isqrt_sw.py @@ -0,0 +1,287 @@ +import math + + +def find_msb(x: int, width: int) -> int: + msb_index = width - 1 + for i in range(1, width + 1): + power = 2 ** (width - i) + if power <= x: + return width - i + return msb_index + + +def float_to_int(x: float, int_width: int, frac_width: int) -> int: + integer = int(x) + x -= integer + res = integer * (2**frac_width) + for i in range(1, frac_width + 1): + power = 2 ** (-i) + if power <= x: + x -= power + res += 2 ** (frac_width - i) + return res + + +def int_to_float(x: int, int_width: int, frac_width: int) -> float: + integer = x / (2**frac_width) + fraction = x - integer * 2**frac_width + res = integer + + for i in range(1, frac_width + 1): + power = 2 ** (frac_width - i) + if power < fraction: + res += 2 ** (-i) + fraction -= power + return res + + +def range_reduction_sw(x: int, width: int) -> int: + """model of range reduction for isqrt""" + # Find MSB + # NOTE: if the input is 0 then consider msb index as width-1. + msb_index = find_msb(x, width) + res = x + if msb_index < (width - 1): + res = x * 2 ** (width - 1 - msb_index) + return res + + +def range_augmentation_sw( + x_red: int, msb_index: int, width: int, frac_width: int +) -> int: + const_len = 16 + ISQRT2 = float_to_int(1 / math.sqrt(2), 1, const_len - 1) + SQRT2 = float_to_int(math.sqrt(2), 1, const_len - 1) + """model of range augmentation for isqrt""" + shifted_amount = frac_width - msb_index + shift_amount = None + res = None + + if shifted_amount > 0: + if shifted_amount % 2 == 0: + shift_amount = shifted_amount // 2 + res = x_red + else: + shift_amount = (shifted_amount - 1) // 2 + res = (x_red * SQRT2) >> (const_len - 1) + res = res * 2 ** (shift_amount) + elif shifted_amount < 0: + if shifted_amount % 2 == 0: + shift_amount = -shifted_amount // 2 + res = x_red + else: + shift_amount = (-shifted_amount - 1) // 2 + res = x_red * ISQRT2 // 2 ** (const_len - 1) + res = res // 2 ** (shift_amount) + else: + res = x_red + res = res >> (width - 1 - frac_width) + return res + + +def fixed_lut_index_sw(x_red: int, width: int, lut_pow: int) -> int: + """model for finding the lut index for lut isqrt value""" + if width == 1 or x_red == 0: + res = 0 + else: + res = x_red - 2 ** (width - 1) + res = res * 2**lut_pow + res = res / 2 ** (width - 1) + # FORMAT OUTPUT: Q(WIDTH).0 + return int(res) + + +def make_lut(lut_size, width): + lut_step = 1 / (lut_size + 1) + x = 1 + lut_step + lut = [] + for i in range(lut_size): + value = 1 / math.sqrt(x) + value = float_to_int(value, 1, width - 1) + lut.append(value) + x += lut_step + + return lut + + +def nr_stage_sw(x_red: int, in_width: int, initial_guess: int) -> int: + """model of newton raphson stage""" + # NOTE: if width is 1 then set output to 0 always because this part gets ignored by logic. + if in_width < 2: + threehalfs = 0 + else: + threehalfs = 3 * 2 ** (in_width - 2) + + y = initial_guess + x_red = x_red >> 1 + + yy = (y * y) >> (in_width - 1) + mult = (yy * x_red) >> (in_width - 1) + sub = threehalfs - mult + y = (y * sub) >> (in_width - 1) + + return y + + +def isqrt_sw2( + x: int, in_width: int, frac_width: int, lut_pow: int, lut: list, debug=False +) -> int: + int_width = in_width - frac_width + MAX_NUM = (1 << in_width) - 1 + + if x == 0: + return MAX_NUM + msb_index = find_msb(x, in_width) + + x_red = range_reduction_sw(x, in_width) + if debug: + print("MSB index: ", msb_index) + print("X red: ", int_to_float(x_red, 1, in_width - 1)) + + ONE = float_to_int(1, 1, in_width - 1) + if x_red == ONE: + out = range_augmentation_sw(x_red, msb_index, in_width, frac_width) + if debug: + print("OUT: ", int_to_float(out, int_width, frac_width)) + if out > MAX_NUM: + if debug: + print("MAX NUM") + return MAX_NUM + return out + lut_index = fixed_lut_index_sw(x_red, in_width, lut_pow) + if lut_index > 31: + print("X: ", x) + print("MSB index: ", msb_index) + print("X red: ", int_to_float(x_red, 1, in_width - 1)) + print("INT WIDTH: ", int_width) + print("FRAC WIDTH: ", frac_width) + initial_guess = lut[lut_index] + + y = nr_stage_sw(x_red, in_width, initial_guess) + y = range_augmentation_sw(y, msb_index, in_width, frac_width) + + if debug: + print("LUT index: ", lut_index) + print("LUT value: ", int_to_float(initial_guess, 1, in_width - 1)) + print("YY : ", int_to_float(yy, 1, in_width)) + print("MULT : ", int_to_float(mult, 1, in_width)) + print("SUB : ", int_to_float(sub, 1, in_width)) + print("Result : ", int_to_float(y, int_width, frac_width)) + + if y > MAX_NUM: + return MAX_NUM + return y + + +def single_test( + val: int, verbose: bool, int_width, frac_width, lut_pow, lut, debug=False +) -> float: + val_f = int_to_float(val, int_width, frac_width) + width = int_width + frac_width + MAX_NUM_INT = (1 << width) - 1 + MAX_NUM_FLOAT = int_to_float(MAX_NUM_INT, int_width, frac_width) + expected_f = None + if val_f == 0: + expected_f = MAX_NUM_FLOAT + else: + expected_f = 1 / math.sqrt(val_f) + if expected_f > MAX_NUM_FLOAT: + expected_f = MAX_NUM_FLOAT + expected_int = float_to_int(expected_f, int_width, frac_width) + expected_f = int_to_float(expected_int, int_width, frac_width) + + output = isqrt_sw2(val, width, frac_width, lut_pow, lut, debug) + output_f = int_to_float(output, int_width, frac_width) + error = abs(expected_f - output_f) + + if verbose: + print(f"sqrt({val_f}) = {output_f} | Exp: {expected_f} | Error: {error}") + return error + + +def test_sw_model_format(num_bits, sweep, int_width, frac_width, lut_pow): + max_error = 0 + allowed_error = 2 ** (-frac_width) * num_bits + width = int_width + frac_width + lut_size = 1 << lut_pow + lut = make_lut(lut_size, width) + verbose = False + for val in range(sweep): + error = single_test(val, verbose, int_width, frac_width, lut_pow, lut) + max_error = max(error, max_error) + if error > allowed_error: + print( + f""" +FAIL + +Input: +X : {val} +INT WIDTH : {int_width} +FRAC WIDTH: {frac_width} + +Max error allowed: +{allowed_error} + +Max error observed: +{max_error} + """ + ) + print("ERROR Excedded!") + return -1 + print( + f""" +PASS +Test: Q{int_width}.{frac_width} + +Max error allowed: +{allowed_error} + +Max error observed: +{max_error} + """ + ) + + +def test_sw_model(): + num_bits = 3 # Number of error bits. + for frac_width in range(1, 9): + for int_width in range(1, 9): + width = int_width + frac_width + sweep = (1 << width) - 1 + lut_pow = 5 + error_code = test_sw_model_format( + num_bits, sweep, int_width, frac_width, lut_pow + ) + if error_code == -1: + return + + +def debug_single(): + lut_pow = 5 + lut_size = 2**lut_pow + int_width = 2 + frac_width = 1 + width = int_width + frac_width + lut = make_lut(lut_size, width) + val = 1 + verbose = True + debug = True + error = single_test(val, verbose, int_width, frac_width, lut_pow, lut, debug) + + +def lut_parameter_dict(lut_size: int, width: int, lut_prefix: str = "LUT"): + lut = make_lut(lut_size, width) + parameters = {} + for i in range(lut_size): + if i < 10: + lut_suffix = "0" + str(i) + else: + lut_suffix = str(i) + name = lut_prefix + lut_suffix + parameters |= {name: lut[i]} + return parameters + + +if __name__ == "__main__": + # debug_single() + test_sw_model() diff --git a/src/mase_components/fixed_arithmetic/test/test_synth_fixed_arithmetic.py b/src/mase_components/fixed_arithmetic/test/test_synth_fixed_arithmetic.py new file mode 100644 index 000000000..187ffd793 --- /dev/null +++ b/src/mase_components/fixed_arithmetic/test/test_synth_fixed_arithmetic.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_fixed_arithmetic(): + run_synth("fixed_arithmetic") + + +if __name__ == "__main__": + test_synth_fixed_arithmetic() diff --git a/src/mase_components/fixed_math/test/test_synth_fixed_math.py b/src/mase_components/fixed_math/test/test_synth_fixed_math.py new file mode 100644 index 000000000..900e3cc06 --- /dev/null +++ b/src/mase_components/fixed_math/test/test_synth_fixed_math.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_fixed_math(): + run_synth("fixed_math") + + +if __name__ == "__main__": + test_synth_fixed_math() diff --git a/src/mase_components/float_arithmetic/test/test_synth_floath_arithmetic.py b/src/mase_components/float_arithmetic/test/test_synth_floath_arithmetic.py new file mode 100644 index 000000000..8494be54c --- /dev/null +++ b/src/mase_components/float_arithmetic/test/test_synth_floath_arithmetic.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_float_arithmetic(): + run_synth("float_arithmetic") + + +if __name__ == "__main__": + test_synth_float_arithmetic() diff --git a/src/mase_components/hls/rtl/__init__.py b/src/mase_components/hls/rtl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mase_components/linear/rtl/binary_activation_binary_linear.sv b/src/mase_components/linear/rtl/binary_activation_binary_linear.sv index 1c797e645..20749e896 100644 --- a/src/mase_components/linear/rtl/binary_activation_binary_linear.sv +++ b/src/mase_components/linear/rtl/binary_activation_binary_linear.sv @@ -169,16 +169,16 @@ module binary_activation_binary_linear #( /* verilator lint_off UNUSEDSIGNAL */ logic dout_valid; register_slice #( - .IN_WIDTH(OUT_WIDTH) + .DATA_WIDTH(OUT_WIDTH) ) register_slice ( .clk (clk), .rst (rst), .data_in_valid (acc_join_valid), .data_in_ready (reg_ready[i]), - .data_in_data (add), + .data_in (add), .data_out_valid(dout_valid), .data_out_ready(data_out_ready), - .data_out_data (data_out[i]) + .data_out (data_out[i]) ); end assign data_out_valid = add_bias[0].dout_valid; diff --git a/src/mase_components/linear/rtl/fixed_activation_binary_linear.sv b/src/mase_components/linear/rtl/fixed_activation_binary_linear.sv index 55741ffae..3440ad1bc 100644 --- a/src/mase_components/linear/rtl/fixed_activation_binary_linear.sv +++ b/src/mase_components/linear/rtl/fixed_activation_binary_linear.sv @@ -156,16 +156,16 @@ module fixed_activation_binary_linear #( /* verilator lint_off UNUSEDSIGNAL */ logic dout_valid; register_slice #( - .IN_WIDTH(OUT_WIDTH) + .DATA_WIDTH(OUT_WIDTH) ) register_slice ( .clk (clk), .rst (rst), .data_in_valid (acc_join_valid), .data_in_ready (reg_ready[i]), - .data_in_data (add), + .data_in (add), .data_out_valid(dout_valid), .data_out_ready(data_out_ready), - .data_out_data (data_out[i]) + .data_out (data_out[i]) ); end assign data_out_valid = add_bias[0].dout_valid; diff --git a/src/mase_components/linear/rtl/fixed_linear.sv b/src/mase_components/linear/rtl/fixed_linear.sv index 6e03d47ec..d988e9599 100644 --- a/src/mase_components/linear/rtl/fixed_linear.sv +++ b/src/mase_components/linear/rtl/fixed_linear.sv @@ -1,77 +1,148 @@ `timescale 1ns / 1ps +/* + * + * The fixed_linear module implements torch.nn.functional.linear, which + * computes Y = X @ W^T + b + * + * Weight tensor is assumed to have shape (out_features, in_features) + * Data tensor is assumed to have shape (batch_size, in_features) + * Bias tensor is assumed to have shape (out_features) + * + * If WEIGHTS_PRE_TRANSPOSED is set to 0, the module will transpose the incoming + * weight matrix. Otherwise, it will assume that the incoming weight matrix is + * already transposed. + * + */ + module fixed_linear #( /* verilator lint_off UNUSEDPARAM */ - parameter HAS_BIAS = 0, + parameter HAS_BIAS = 1, + parameter WEIGHTS_PRE_TRANSPOSED = 0, parameter DATA_IN_0_PRECISION_0 = 16, parameter DATA_IN_0_PRECISION_1 = 3, - parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 4, - parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 20, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 20, + parameter DATA_IN_0_TENSOR_SIZE_DIM_2 = 1, parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, // must equal WEIGHT_PARALLELISM_DIM_1 - parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, - parameter IN_0_DEPTH = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + parameter DATA_IN_0_PARALLELISM_DIM_2 = 1, + localparam IN_0_DEPTH_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + localparam IN_0_DEPTH_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1, parameter WEIGHT_PRECISION_0 = 16, parameter WEIGHT_PRECISION_1 = 3, - parameter WEIGHT_TENSOR_SIZE_DIM_0 = 32, - parameter WEIGHT_TENSOR_SIZE_DIM_1 = 1, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 20, + parameter WEIGHT_TENSOR_SIZE_DIM_1 = 20, parameter WEIGHT_PARALLELISM_DIM_0 = 4, - parameter WEIGHT_PARALLELISM_DIM_1 = 4, // must equal DATA_IN_0_PARALLELISM_DIM_0 - - parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( - DATA_IN_0_TENSOR_SIZE_DIM_0 - ) + $clog2( - IN_0_DEPTH - ) + HAS_BIAS, - parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1, - parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 4, - parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, + parameter WEIGHT_PARALLELISM_DIM_1 = 4, + + // Inferred precision of the output data + parameter DATA_OUT_0_PRECISION_0 = 16, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = WEIGHT_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_2 = DATA_IN_0_TENSOR_SIZE_DIM_1, parameter DATA_OUT_0_PARALLELISM_DIM_0 = WEIGHT_PARALLELISM_DIM_0, - parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_2 = DATA_IN_0_PARALLELISM_DIM_1, parameter BIAS_PRECISION_0 = 16, parameter BIAS_PRECISION_1 = 3, parameter BIAS_TENSOR_SIZE_DIM_0 = DATA_OUT_0_TENSOR_SIZE_DIM_0, parameter BIAS_TENSOR_SIZE_DIM_1 = 1, - parameter BIAS_PARALLELISM_DIM_0 = 1, - parameter BIAS_PARALLELISM_DIM_1 = 1 + parameter BIAS_PARALLELISM_DIM_0 = 4, + parameter BIAS_PARALLELISM_DIM_1 = 1, + localparam BIAS_DEPTH_DIM_0 = BIAS_TENSOR_SIZE_DIM_0 / BIAS_PARALLELISM_DIM_0, + localparam BIAS_DEPTH_DIM_1 = BIAS_TENSOR_SIZE_DIM_1 / BIAS_PARALLELISM_DIM_1 + ) ( input clk, input rst, // input port for data_inivations - input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], - input data_in_0_valid, - output data_in_0_ready, + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, // input port for weight - input [WEIGHT_PRECISION_0-1:0] weight[WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], - input weight_valid, - output weight_ready, + input logic [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_valid, + output logic weight_ready, /* verilator lint_off UNUSEDSIGNAL */ - input [BIAS_PRECISION_0-1:0] bias[BIAS_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_0-1:0], - input bias_valid, + input logic [BIAS_PRECISION_0-1:0] bias[BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_valid, /* verilator lint_on UNUSEDSIGNAL */ - output bias_ready, + output logic bias_ready, - output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], - output data_out_0_valid, - input data_out_0_ready + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready ); + // The TENSOR_SIZE and PARALLELISM parameters for the weights are set by emit verilog according to the real + // tensor values. Here we account for the change when the weights are pre-transposed + localparam REAL_WEIGHT_TENSOR_SIZE_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_TENSOR_SIZE_DIM_1 : WEIGHT_TENSOR_SIZE_DIM_0; + localparam REAL_WEIGHT_TENSOR_SIZE_DIM_1 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_TENSOR_SIZE_DIM_0 : WEIGHT_TENSOR_SIZE_DIM_1; + localparam REAL_WEIGHT_PARALLELISM_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_PARALLELISM_DIM_1 : WEIGHT_PARALLELISM_DIM_0; + localparam REAL_WEIGHT_PARALLELISM_DIM_1 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_PARALLELISM_DIM_0 : WEIGHT_PARALLELISM_DIM_1; + + // * Declarations + // * --------------------------------------------------------------------------------------------------- + + logic [WEIGHT_PRECISION_0-1:0] weight_transposed [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0]; + logic weight_transposed_valid; + logic weight_transposed_ready; + + logic [DATA_OUT_0_PRECISION_0-1:0] matmul_out [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic matmul_out_valid; + logic matmul_out_ready; + + logic [DATA_OUT_0_PRECISION_0-1:0] bias_buffered [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic bias_buffered_valid, bias_buffered_ready; + + logic [DATA_OUT_0_PRECISION_0-1:0] bias_casted [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic [DATA_OUT_0_PRECISION_0-1:0] add_bias_in [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic add_bias_in_valid; + logic add_bias_in_ready; + + // * Instances + // * --------------------------------------------------------------------------------------------------- + + if (WEIGHTS_PRE_TRANSPOSED == 0) begin + matrix_stream_transpose #( + .TOTAL_DIM0 (WEIGHT_TENSOR_SIZE_DIM_0), + .TOTAL_DIM1 (WEIGHT_TENSOR_SIZE_DIM_1), + .COMPUTE_DIM0(WEIGHT_PARALLELISM_DIM_0), + .COMPUTE_DIM1(WEIGHT_PARALLELISM_DIM_1), + .DATA_WIDTH (WEIGHT_PRECISION_0) + ) weight_matrix_transpose_i ( + .clk, + .rst, + + .in_data (weight), + .in_valid(weight_valid), + .in_ready(weight_ready), + + .out_data (weight_transposed), + .out_valid(weight_transposed_valid), + .out_ready(weight_transposed_ready) + ); + end + matmul #( // Total dimensions .A_TOTAL_DIM0(DATA_IN_0_TENSOR_SIZE_DIM_0), .A_TOTAL_DIM1(DATA_IN_0_TENSOR_SIZE_DIM_1), - .B_TOTAL_DIM0(WEIGHT_TENSOR_SIZE_DIM_0), - .B_TOTAL_DIM1(WEIGHT_TENSOR_SIZE_DIM_1), + .B_TOTAL_DIM0(REAL_WEIGHT_TENSOR_SIZE_DIM_0), + .B_TOTAL_DIM1(REAL_WEIGHT_TENSOR_SIZE_DIM_1), .A_COMPUTE_DIM0(DATA_IN_0_PARALLELISM_DIM_0), .A_COMPUTE_DIM1(DATA_IN_0_PARALLELISM_DIM_1), - .B_COMPUTE_DIM0(WEIGHT_PARALLELISM_DIM_0), - .B_COMPUTE_DIM1(WEIGHT_PARALLELISM_DIM_1), + .B_COMPUTE_DIM0(REAL_WEIGHT_PARALLELISM_DIM_0), + .B_COMPUTE_DIM1(REAL_WEIGHT_PARALLELISM_DIM_1), .A_WIDTH (DATA_IN_0_PRECISION_0), .A_FRAC_WIDTH(DATA_IN_0_PRECISION_1), @@ -89,16 +160,99 @@ module fixed_linear #( .a_valid(data_in_0_valid), .a_ready(data_in_0_ready), - .b_data (weight), - .b_valid(weight_valid), - .b_ready(weight_ready), + .b_data (weight_transposed), + .b_valid(weight_transposed_valid), + .b_ready(weight_transposed_ready), - .out_data (data_out_0), - .out_valid(data_out_0_valid), - .out_ready(data_out_0_ready) + .out_data (matmul_out), + .out_valid(matmul_out_valid), + .out_ready(matmul_out_ready) ); - // ! TO DO: add bias - assign bias_ready = '0; + // Bias output + if (HAS_BIAS == 1) begin + + join2 join2_matmul_bias_i ( + .data_in_valid ({matmul_out_valid, bias_buffered_valid}), + .data_in_ready ({matmul_out_ready, bias_buffered_ready}), + .data_out_valid(add_bias_in_valid), + .data_out_ready(add_bias_in_ready) + ); + + unpacked_repeat_circular_buffer #( + .DATA_WIDTH (BIAS_PRECISION_0), + .IN_NUM (BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1), + .REPEAT (IN_0_DEPTH_DIM_1), + .SIZE (BIAS_DEPTH_DIM_0) + ) bias_buffer_inst ( + .clk, + .rst, + + // Input streaming port + .in_data (bias), + .in_valid (bias_valid), + .in_ready (bias_ready), + + // Output streaming port + .out_data (bias_buffered), + .out_valid (bias_buffered_valid), + .out_ready (bias_buffered_ready) + ); + + unpacked_register_slice #( + .DATA_WIDTH(DATA_OUT_0_PRECISION_0), + .IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1) + ) register_slice_i ( + .clk(clk), + .rst(rst), + + .data_in (add_bias_in), + .data_in_valid(add_bias_in_valid), + .data_in_ready(add_bias_in_ready), + + .data_out (data_out_0), + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) + ); + end + + // * Logic + // * --------------------------------------------------------------------------------------------------- + + if (WEIGHTS_PRE_TRANSPOSED == 1) begin + always_comb begin + weight_transposed_valid = weight_valid; + weight_ready = weight_transposed_ready; + end + + for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1; i++) begin + assign weight_transposed[i] = weight[i]; + end + end + + // * Add bias + if (HAS_BIAS == 1) begin + fixed_cast #( + .IN_SIZE (BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1), + .IN_WIDTH (BIAS_PRECISION_0), + .IN_FRAC_WIDTH (BIAS_PRECISION_1), + .OUT_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) bias_cast_i ( + .data_in (bias_buffered), + .data_out(bias_casted) + ); + + for (genvar i_0 = 0; i_0 < DATA_OUT_0_PARALLELISM_DIM_0 ; i_0++) begin + for (genvar i_1 = 0; i_1 < DATA_OUT_0_PARALLELISM_DIM_1 ; i_1++) begin + assign add_bias_in [i_1 * DATA_OUT_0_PARALLELISM_DIM_0 + i_0] = $signed(matmul_out[i_1 * DATA_OUT_0_PARALLELISM_DIM_0 + i_0]) + $signed(bias_casted[i_0]); + end + end + + end else begin + assign data_out_0 = matmul_out; + assign data_out_0_valid = matmul_out_valid; + assign matmul_out_ready = data_out_0_ready; + end endmodule diff --git a/src/mase_components/linear/test/fixed_linear_tb.py b/src/mase_components/linear/test/fixed_linear_tb.py index 0a0264567..244f6c3b2 100644 --- a/src/mase_components/linear/test/fixed_linear_tb.py +++ b/src/mase_components/linear/test/fixed_linear_tb.py @@ -1,32 +1,35 @@ #!/usr/bin/env python3 -# This script tests the fixed point linear -import os, logging +import os + +import torch +import logging +from functools import partial import cocotb from cocotb.log import SimLog -from cocotb.triggers import * +from cocotb.triggers import Timer from mase_cocotb.testbench import Testbench -from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor -from mase_cocotb.z_qlayers import quantize_to_int +from mase_cocotb.interfaces.streaming import ( + StreamDriver, + StreamMonitor, + ErrorThresholdStreamMonitor, +) from mase_cocotb.runner import mase_runner -from mase_cocotb.utils import bit_driver, sign_extend_t - -from chop.passes.graph.transforms.quantize.quantized_modules import LinearInteger - -import torch -logger = logging.getLogger("testbench") -logger.setLevel(logging.DEBUG) +# from mase_cocotb import Testbench, StreamDriver, StreamMonitor, mase_runner +from chop.nn.quantized.modules.linear import LinearInteger +from chop.nn.quantizers import integer_quantizer class LinearTB(Testbench): - def __init__(self, dut, in_features=4, out_features=4) -> None: + def __init__(self, dut) -> None: super().__init__(dut, dut.clk, dut.rst) if not hasattr(self, "log"): self.log = SimLog("%s" % (type(self).__qualname__)) + self.log.setLevel(logging.DEBUG) self.data_in_0_driver = StreamDriver( dut.clk, dut.data_in_0, dut.data_in_0_valid, dut.data_in_0_ready @@ -35,134 +38,233 @@ def __init__(self, dut, in_features=4, out_features=4) -> None: dut.clk, dut.weight, dut.weight_valid, dut.weight_ready ) - if int(dut.HAS_BIAS) == 1: + if self.get_parameter("HAS_BIAS") == 1: self.bias_driver = StreamDriver( dut.clk, dut.bias, dut.bias_valid, dut.bias_ready ) + self.bias_driver.log.setLevel(logging.DEBUG) + + # self.data_out_0_monitor = StreamMonitor( + # dut.clk, + # dut.data_out_0, + # dut.data_out_0_valid, + # dut.data_out_0_ready, + # check=True, + # ) - self.data_out_0_monitor = StreamMonitor( + self.data_out_0_monitor = ErrorThresholdStreamMonitor( dut.clk, dut.data_out_0, dut.data_out_0_valid, dut.data_out_0_ready, - check=False, + width=self.get_parameter("DATA_OUT_0_PRECISION_0"), + signed=True, + error_bits=1, + check=True, ) + # Model self.model = LinearInteger( - in_features=in_features, - out_features=out_features, - bias=False, + in_features=self.get_parameter("DATA_IN_0_TENSOR_SIZE_DIM_0"), + out_features=self.get_parameter("DATA_OUT_0_TENSOR_SIZE_DIM_0"), + bias=True if self.get_parameter("HAS_BIAS") == 1 else False, config={ - "data_in_width": 16, - "data_in_frac_width": 3, - "weight_width": 16, - "weight_frac_width": 3, - "bias_width": 16, - "bias_frac_width": 3, + "data_in_width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "data_in_frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + "weight_width": self.get_parameter("WEIGHT_PRECISION_0"), + "weight_frac_width": self.get_parameter("WEIGHT_PRECISION_1"), + "bias_width": self.get_parameter("BIAS_PRECISION_0"), + "bias_frac_width": self.get_parameter("BIAS_PRECISION_1"), }, ) + # Set verbosity of driver and monitor loggers to debug + self.data_in_0_driver.log.setLevel(logging.DEBUG) + self.weight_driver.log.setLevel(logging.DEBUG) + self.data_out_0_monitor.log.setLevel(logging.DEBUG) + def generate_inputs(self): return torch.randn((1, self.model.in_features)) - def preprocess_tensor(self, tensor, quantizer, config, parallelism): - tensor = quantizer(tensor) - tensor = (tensor * 2 ** config["frac_width"]).int() - logger.info(f"Tensor in int format: {tensor}") - tensor = tensor.reshape(-1, parallelism).tolist() - return tensor + def preprocess_tensor(self, tensor, config, parallelism): + if len(tensor.shape) == 1: + tensor = tensor.unsqueeze(0) + + # Quantize + quantizer = partial(integer_quantizer, **config) + q_tensor = quantizer(tensor) + self.log.debug(f"Quantized tensor: {q_tensor}") + + # Convert to integer format + q_tensor = (q_tensor * 2 ** config["frac_width"]).int() + self.log.debug(f"Tensor in integer format: {q_tensor}") + + # Split into chunks according to parallelism in each dimension + # parallelism[0]: along rows, parallelism[1]: along columns + dim_0_split = q_tensor.split(parallelism[0], dim=0) + dim_1_split = [x.split(parallelism[1], dim=1) for x in dim_0_split] + blocks = [] + # Flatten the list of blocks + for i in range(len(dim_1_split)): + for j in range(len(dim_1_split[i])): + blocks.append(dim_1_split[i][j].flatten().tolist()) + return blocks async def run_test(self): await self.reset() - logger.info(f"Reset finished") + self.log.info(f"Reset finished") self.data_out_0_monitor.ready.value = 1 inputs = self.generate_inputs() exp_out = self.model(inputs) - # Load the inputs driver - logger.info(f"Processing inputs") + # * Load the inputs driver + self.log.info(f"Processing inputs: {inputs}") inputs = self.preprocess_tensor( - inputs, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + tensor=inputs, + config={ + "width": self.get_parameter("DATA_IN_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_IN_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_IN_0_PARALLELISM_DIM_0"), + ], ) self.data_in_0_driver.load_driver(inputs) - # Load the weights driver - logger.info(f"Processing weights") + # * Load the weights driver + if self.get_parameter("WEIGHTS_PRE_TRANSPOSED") == 1: + weights = self.model.weight.transpose(0, 1) + else: + weights = self.model.weight + + self.log.info(f"Processing weights: {weights}") weights = self.preprocess_tensor( - self.model.weight, - self.model.w_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.WEIGHT_PARALLELISM_DIM_0) - * int(self.dut.DATA_IN_0_PARALLELISM_DIM_0), + tensor=weights, + config={ + "width": self.get_parameter("WEIGHT_PRECISION_0"), + "frac_width": self.get_parameter("WEIGHT_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("WEIGHT_PARALLELISM_DIM_1"), + self.get_parameter("WEIGHT_PARALLELISM_DIM_0"), + ], ) self.weight_driver.load_driver(weights) - # Load the output monitor - logger.info(f"Processing outputs: {exp_out}") - # To do: need to quantize output to a different precision + # * Load the bias driver + if self.get_parameter("HAS_BIAS") == 1: + bias = self.model.bias + self.log.info(f"Processing bias: {bias}") + bias = self.preprocess_tensor( + tensor=bias, + config={ + "width": self.get_parameter("BIAS_PRECISION_0"), + "frac_width": self.get_parameter("BIAS_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("BIAS_PARALLELISM_DIM_1"), + self.get_parameter("BIAS_PARALLELISM_DIM_0"), + ], + ) + self.bias_driver.load_driver(bias) + + # * Load the output monitor + self.log.info(f"Processing outputs: {exp_out}") outs = self.preprocess_tensor( - exp_out, - self.model.x_quantizer, - {"widht": 16, "frac_width": 3}, - int(self.dut.DATA_OUT_0_PARALLELISM_DIM_0), + tensor=exp_out, + config={ + "width": self.get_parameter("DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter("DATA_OUT_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"), + ], ) self.data_out_0_monitor.load_monitor(outs) - await Timer(1000, units="us") + await Timer(1, units="ms") assert self.data_out_0_monitor.exp_queue.empty() @cocotb.test() -async def cocotb_test_20x20(dut): - tb = LinearTB(dut, in_features=20, out_features=20) +async def cocotb_test(dut): + tb = LinearTB(dut) await tb.run_test() -import pytest +def get_fixed_linear_config(kwargs={}): + config = { + "HAS_BIAS": 0, + "WEIGHTS_PRE_TRANSPOSED": 1, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, + "DATA_IN_0_TENSOR_SIZE_DIM_1": 20, + "DATA_IN_0_PARALLELISM_DIM_0": 4, + "WEIGHT_TENSOR_SIZE_DIM_0": 20, + "WEIGHT_TENSOR_SIZE_DIM_1": 20, + "WEIGHT_PARALLELISM_DIM_0": 4, + "WEIGHT_PARALLELISM_DIM_1": 4, + "BIAS_TENSOR_SIZE_DIM_0": 20, + "BIAS_PARALLELISM_DIM_0": 4, + } + config.update(kwargs) + return config -@pytest.mark.skip(reason="Needs to be fixed.") -def test_fixed_linear(): +def test_fixed_linear_smoke(): + """ + Some quick tests to check if the module is working. + """ mase_runner( trace=True, module_param_list=[ - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 2, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 4, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, - { - "DATA_IN_0_TENSOR_SIZE_DIM_0": 20, - "DATA_IN_0_PARALLELISM_DIM_0": 5, - "WEIGHT_TENSOR_SIZE_DIM_0": 20, - "WEIGHT_TENSOR_SIZE_DIM_1": 20, - "WEIGHT_PARALLELISM_DIM_0": 20, - "DATA_OUT_0_TENSOR_SIZE_DIM_0": 20, - "DATA_OUT_0_PARALLELISM_DIM_0": 20, - "BIAS_TENSOR_SIZE_DIM_0": 20, - }, + get_fixed_linear_config(), + get_fixed_linear_config({"WEIGHTS_PRE_TRANSPOSED": 0}), + get_fixed_linear_config({"HAS_BIAS": 1}), + get_fixed_linear_config({"HAS_BIAS": 1, "WEIGHTS_PRE_TRANSPOSED": 0}), + ], + ) + + +def test_fixed_linear_regression(): + """ + More extensive tests to check realistic parameter sizes. + """ + mase_runner( + trace=True, + module_param_list=[ + get_fixed_linear_config( + { + "DATA_IN_0_TENSOR_SIZE_DIM_0": 768, + "DATA_IN_0_PARALLELISM_DIM_0": 32, + "WEIGHT_TENSOR_SIZE_DIM_0": 768, + "WEIGHT_TENSOR_SIZE_DIM_1": 768, + "WEIGHT_PARALLELISM_DIM_0": 32, + "WEIGHT_PARALLELISM_DIM_1": 32, + "BIAS_TENSOR_SIZE_DIM_0": 768, + "BIAS_PARALLELISM_DIM_0": 32, + } + ), + get_fixed_linear_config( + { + "HAS_BIAS": 1, + "WEIGHTS_PRE_TRANSPOSED": 0, + "DATA_IN_0_TENSOR_SIZE_DIM_0": 768, + "DATA_IN_0_PARALLELISM_DIM_0": 32, + "WEIGHT_TENSOR_SIZE_DIM_0": 768, + "WEIGHT_TENSOR_SIZE_DIM_1": 768, + "WEIGHT_PARALLELISM_DIM_0": 32, + "WEIGHT_PARALLELISM_DIM_1": 32, + "BIAS_TENSOR_SIZE_DIM_0": 768, + "BIAS_PARALLELISM_DIM_0": 32, + } + ), ], ) if __name__ == "__main__": - test_fixed_linear() + test_fixed_linear_smoke() + # test_fixed_linear_regression() diff --git a/src/mase_components/linear/test/test_synth_linear.py b/src/mase_components/linear/test/test_synth_linear.py new file mode 100644 index 000000000..be98e2046 --- /dev/null +++ b/src/mase_components/linear/test/test_synth_linear.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_linear(): + run_synth("linear") + + +if __name__ == "__main__": + test_synth_linear() diff --git a/src/mase_components/linter.py b/src/mase_components/linter.py index 428f7214e..5de03967f 100644 --- a/src/mase_components/linter.py +++ b/src/mase_components/linter.py @@ -56,6 +56,8 @@ def run_lint(group): file_path, ] + include_files + logger.info(f"Executing {cmd}") + result = subprocess.run(cmd, capture_output=True, text=True) # * Process result diff --git a/src/mase_components/llm/test/test_synth_llm.py b/src/mase_components/llm/test/test_synth_llm.py new file mode 100644 index 000000000..7103e1996 --- /dev/null +++ b/src/mase_components/llm/test/test_synth_llm.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_llm(): + run_synth("llm") + + +if __name__ == "__main__": + test_synth_llm() diff --git a/src/mase_components/matmul/rtl/matmul.sv b/src/mase_components/matmul/rtl/matmul.sv index de112e1d9..441c11bdf 100644 --- a/src/mase_components/matmul/rtl/matmul.sv +++ b/src/mase_components/matmul/rtl/matmul.sv @@ -19,7 +19,6 @@ Description : This module does a matrix multiplcation between matrices X & Y. `timescale 1ns / 1ps -// TODO: REMOVE THIS AFTER DONE /* verilator lint_off UNUSEDPARAM */ module matmul #( // Total dimensions @@ -29,10 +28,10 @@ module matmul #( parameter B_TOTAL_DIM1 = 4, // must equal A_TOTAL_DIM0 // Compute dimensions - parameter A_COMPUTE_DIM0 = 2, // 4 - parameter A_COMPUTE_DIM1 = 2, // 1 - parameter B_COMPUTE_DIM0 = 2, // 4 - parameter B_COMPUTE_DIM1 = 2, // must equal A_COMPUTE_DIM0 // 1 + parameter A_COMPUTE_DIM0 = 2, + parameter A_COMPUTE_DIM1 = 2, + parameter B_COMPUTE_DIM0 = 2, + parameter B_COMPUTE_DIM1 = 2, // must equal A_COMPUTE_DIM0 // Input fixed point widths parameter A_WIDTH = 8, @@ -98,19 +97,78 @@ module matmul #( else $fatal("B_DIM1 compute is not divisible!"); end + // ----- + // Params + // ----- + + localparam A_FLAT_WIDTH = A_WIDTH * A_COMPUTE_DIM0 * A_COMPUTE_DIM1; + localparam B_FLAT_WIDTH = B_WIDTH * B_COMPUTE_DIM0 * B_COMPUTE_DIM1; + + localparam SM_OUT_WIDTH = A_WIDTH + B_WIDTH + $clog2(A_COMPUTE_DIM0); + localparam SM_OUT_FRAC_WIDTH = A_FRAC_WIDTH + B_FRAC_WIDTH; + + localparam MAT_ACC_PTR_WIDTH = C_DEPTH_DIM0 == 1 ? 1 : $clog2(C_DEPTH_DIM0); + localparam MAT_ACC_OUT_WIDTH = $clog2(B_DEPTH_DIM1) + SM_OUT_WIDTH; + + // ----- + // Wires + // ----- + // Buffer unflatten out logic a_buffer_out_valid, a_buffer_out_ready; logic [A_WIDTH-1:0] a_buffer_out_data[A_COMPUTE_DIM0*A_COMPUTE_DIM1-1:0]; + // Repeat each submatrix in Matrix A stream B_DEPTH_DIM0 times + // Only if (B_DEPTH_DIM0 > 1) + logic [A_FLAT_WIDTH-1:0] a_data_flat; + logic [A_FLAT_WIDTH-1:0] a_buffer_out_data_flat; + + // We need to buffer the B matrix + // TODO: unless A_DEPTH_DIM1 == 1 + + logic [B_FLAT_WIDTH-1:0] b_data_flat; + + // Buffer outputs + logic [B_FLAT_WIDTH-1:0] b_buffer_out_data_flat; + logic b_buffer_out_valid, b_buffer_out_ready; + + // Matrix unflatten output + logic [B_WIDTH-1:0] b_buffer_out_data[B_COMPUTE_DIM0*B_COMPUTE_DIM1-1:0]; + + logic [SM_OUT_WIDTH-1:0] sm_out_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1]; + logic sm_out_valid, sm_out_ready; + + logic [C_DEPTH_DIM0-1:0] acc_in_valid; + logic [C_DEPTH_DIM0-1:0] acc_in_ready; + logic [C_DEPTH_DIM0-1:0] acc_out_valid; + logic [C_DEPTH_DIM0-1:0] acc_out_ready; + logic [MAT_ACC_OUT_WIDTH-1:0] acc_out_data[C_DEPTH_DIM0-1:0][C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; + + logic [MAT_ACC_OUT_WIDTH-1:0] cast_in_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; + + + // ----- + // State + // ----- + + struct { + // Points to which matrix accumulator should store the simple_matmul output + logic [MAT_ACC_PTR_WIDTH-1:0] matrix_acc_ptr; + // Points at which output accumulator should be connected to the out stream + logic [MAT_ACC_PTR_WIDTH-1:0] output_acc_ptr; + } + self, next_self; + + + // ----- + // Logic + // ----- + generate - if (B_DEPTH_DIM0 > 1) begin - // Repeat each submatrix in Matrix A stream B_DEPTH_DIM0 times - localparam A_FLAT_WIDTH = A_WIDTH * A_COMPUTE_DIM0 * A_COMPUTE_DIM1; - logic [A_FLAT_WIDTH-1:0] a_data_flat; + // B matrix Buffers - // Buffer outputs - logic [A_FLAT_WIDTH-1:0] a_buffer_out_data_flat; + if (B_DEPTH_DIM0 > 1) begin matrix_flatten #( .DATA_WIDTH(A_WIDTH), @@ -121,11 +179,10 @@ module matmul #( .data_out(a_data_flat) ); - repeat_circular_buffer #( + single_element_repeat #( .DATA_WIDTH(A_FLAT_WIDTH), // Repeat for number of rows in matrix A - .REPEAT (B_DEPTH_DIM0), - .SIZE (1) + .REPEAT (B_DEPTH_DIM0) ) input_stream_buffer ( .clk (clk), .rst (rst), @@ -151,25 +208,10 @@ module matmul #( assign a_buffer_out_valid = a_valid; assign a_ready = a_buffer_out_ready; end - endgenerate - // We need to buffer the B matrix - // TODO: unless A_DEPTH_DIM1 == 1 - - localparam B_FLAT_WIDTH = B_WIDTH * B_COMPUTE_DIM0 * B_COMPUTE_DIM1; - - // Buffer outputs - logic b_buffer_out_valid, b_buffer_out_ready; - - // Matrix unflatten output - logic [B_WIDTH-1:0] b_buffer_out_data[B_COMPUTE_DIM0*B_COMPUTE_DIM1-1:0]; - - generate + // A matrix Buffers if (A_DEPTH_DIM1 > 1) begin - logic [B_FLAT_WIDTH-1:0] b_data_flat; - logic [B_FLAT_WIDTH-1:0] b_buffer_out_data_flat; - matrix_flatten #( .DATA_WIDTH(B_WIDTH), .DIM0 (B_COMPUTE_DIM0), @@ -215,16 +257,11 @@ module matmul #( // Simple matrix multiply block's accumulator width // We do not round at simple_matmul level as we want to keep high precision // and round ourselves after the output accumulation in this matmul module. - localparam SM_OUT_WIDTH = A_WIDTH + B_WIDTH + $clog2(A_COMPUTE_DIM0); - localparam SM_OUT_FRAC_WIDTH = A_FRAC_WIDTH + B_FRAC_WIDTH; - - logic [SM_OUT_WIDTH-1:0] sm_out_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1]; - logic sm_out_valid, sm_out_ready; simple_matmul #( - .N (A_COMPUTE_DIM1), //1 - .M (A_COMPUTE_DIM0), // == B_COMPUTE_DIM1 // 4 - .K (B_COMPUTE_DIM0), // 4 + .N (A_COMPUTE_DIM1), + .M (A_COMPUTE_DIM0), // == B_COMPUTE_DIM1 + .K (B_COMPUTE_DIM0), .X_WIDTH (A_WIDTH), .X_FRAC_WIDTH (A_FRAC_WIDTH), .Y_WIDTH (B_WIDTH), @@ -247,14 +284,6 @@ module matmul #( ); // Direct the result of the simple matmul to the correct matrix_accumulator - localparam MAT_ACC_PTR_WIDTH = C_DEPTH_DIM0 == 1 ? 1 : $clog2(C_DEPTH_DIM0); - localparam MAT_ACC_OUT_WIDTH = $clog2(B_DEPTH_DIM1) + SM_OUT_WIDTH; - - logic [C_DEPTH_DIM0-1:0] acc_in_valid; - logic [C_DEPTH_DIM0-1:0] acc_in_ready; - logic [C_DEPTH_DIM0-1:0] acc_out_valid; - logic [C_DEPTH_DIM0-1:0] acc_out_ready; - logic [MAT_ACC_OUT_WIDTH-1:0] acc_out_data[C_DEPTH_DIM0-1:0][C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; for (genvar i = 0; i < C_DEPTH_DIM0; i++) begin : accumulators matrix_accumulator #( @@ -274,8 +303,6 @@ module matmul #( ); end - logic [MAT_ACC_OUT_WIDTH-1:0] cast_in_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; - for (genvar i = 0; i < C_DEPTH_DIM0; i++) begin // Change which accumulator the output of simple_matmul goes to assign acc_in_valid[i] = self.matrix_acc_ptr == i ? sm_out_valid : 0; @@ -301,15 +328,6 @@ module matmul #( end // Logic to handle accumulator selection & output selection. - - struct { - // Points to which matrix accumulator should store the simple_matmul output - logic [MAT_ACC_PTR_WIDTH-1:0] matrix_acc_ptr; - // Points at which output accumulator should be connected to the out stream - logic [MAT_ACC_PTR_WIDTH-1:0] output_acc_ptr; - } - self, next_self; - always_comb begin next_self = self; @@ -319,7 +337,7 @@ module matmul #( out_valid = acc_out_valid[self.output_acc_ptr]; // Change accumulator pointer - if (sm_out_valid) begin + if (sm_out_valid && sm_out_ready) begin if (self.matrix_acc_ptr == C_DEPTH_DIM0 - 1) begin next_self.matrix_acc_ptr = 0; end else begin diff --git a/src/mase_components/matmul/rtl/matrix_fifo.sv b/src/mase_components/matmul/rtl/matrix_fifo.sv index ab59b66bf..898a3eda1 100644 --- a/src/mase_components/matmul/rtl/matrix_fifo.sv +++ b/src/mase_components/matmul/rtl/matrix_fifo.sv @@ -40,8 +40,8 @@ module matrix_fifo #( ); fifo #( - .SIZE(FIFO_SIZE), - .DATA_WIDTH(FLAT_DATA_WIDTH) + .DATA_WIDTH(FLAT_DATA_WIDTH), + .DEPTH(FIFO_SIZE) ) input_fifo_inst ( .clk(clk), .rst(rst), diff --git a/src/mase_components/matmul/rtl/matrix_stream_transpose.sv b/src/mase_components/matmul/rtl/matrix_stream_transpose.sv index 66040f8c3..86e7a93bc 100644 --- a/src/mase_components/matmul/rtl/matrix_stream_transpose.sv +++ b/src/mase_components/matmul/rtl/matrix_stream_transpose.sv @@ -45,19 +45,27 @@ module matrix_stream_transpose #( else $fatal("DIM1 compute is not divisible!"); end + // ----- // Parameters - let max(a, b) = (a > b) ? a : b; + // ----- + // let max(a, b) = (a > b) ? a : b; localparam IN_DEPTH_DIM0 = TOTAL_DIM0 / COMPUTE_DIM0; localparam IN_DEPTH_DIM1 = TOTAL_DIM1 / COMPUTE_DIM1; localparam OUT_DEPTH_DIM0 = IN_DEPTH_DIM1; localparam OUT_DEPTH_DIM1 = IN_DEPTH_DIM0; - localparam IN_ROW_COUNTER_WIDTH = max($clog2(IN_DEPTH_DIM1), 1); - localparam IN_COL_COUNTER_WIDTH = max($clog2(IN_DEPTH_DIM0), 1); - localparam OUT_ROW_COUNTER_WIDTH = max($clog2(OUT_DEPTH_DIM1), 1); - localparam OUT_COL_COUNTER_WIDTH = max($clog2(OUT_DEPTH_DIM0), 1); + localparam IN_ROW_COUNTER_WIDTH = $clog2(IN_DEPTH_DIM1) > 1 ? $clog2(IN_DEPTH_DIM1) : 1; + localparam IN_COL_COUNTER_WIDTH = $clog2(IN_DEPTH_DIM0) > 1 ? $clog2(IN_DEPTH_DIM0) : 1; + localparam OUT_ROW_COUNTER_WIDTH = $clog2(OUT_DEPTH_DIM1) > 1 ? $clog2(OUT_DEPTH_DIM1) : 1; + localparam OUT_COL_COUNTER_WIDTH = $clog2(OUT_DEPTH_DIM0) > 1 ? $clog2(OUT_DEPTH_DIM0) : 1; + localparam FIFO_DEPTH = IN_DEPTH_DIM1; + localparam FIFO_DATA_WIDTH = DATA_WIDTH * COMPUTE_DIM0 * COMPUTE_DIM1; + + // ----- // State + // ----- + struct { // Current row & col that the window is at for the input logic [IN_ROW_COUNTER_WIDTH-1:0] in_row_count; @@ -68,16 +76,32 @@ module matrix_stream_transpose #( } self, next_self; + // ----- + // Wires + // ----- + + logic [FIFO_DATA_WIDTH-1:0] in_data_flat; + logic [FIFO_DATA_WIDTH-1:0] fifo_in_data[IN_DEPTH_DIM0-1:0]; + logic fifo_in_valid[IN_DEPTH_DIM0-1:0]; + logic fifo_in_ready[IN_DEPTH_DIM0-1:0]; + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat[IN_DEPTH_DIM0-1:0]; + logic fifo_out_valid[IN_DEPTH_DIM0-1:0]; + logic fifo_out_ready[IN_DEPTH_DIM0-1:0]; + + logic fifo_data_readys[IN_DEPTH_DIM0-1:0]; + + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_in[IN_DEPTH_DIM0-1:0]; + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_out; + logic fifo_out_valids[IN_DEPTH_DIM0-1:0]; + + logic [DATA_WIDTH-1:0] transpose_data_in[COMPUTE_DIM0*COMPUTE_DIM1-1:0]; + // FIFOs // We want to generate IN_DEPTH_DIM0 FIFOs to buffer the input chunks. // Each FIFO will need to be IN_DEPTH_DIM1 elements deep and each element will // be flattened to be size (DATA_WIDTH * COMPUTE_DIM0 * COMPUTE_DIM1) - localparam FIFO_DEPTH = IN_DEPTH_DIM1; - localparam FIFO_DATA_WIDTH = DATA_WIDTH * COMPUTE_DIM0 * COMPUTE_DIM1; - logic [FIFO_DATA_WIDTH-1:0] in_data_flat; - matrix_flatten #( .DATA_WIDTH(DATA_WIDTH), .DIM0 (COMPUTE_DIM0), @@ -88,40 +112,29 @@ module matrix_stream_transpose #( ); for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin : fifos - - // FIFO Inputs - logic [FIFO_DATA_WIDTH-1:0] fifo_in_data; - logic fifo_in_valid, fifo_in_ready; - - // FIFO Output, data needs to be unflattened - logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat; - logic fifo_out_valid, fifo_out_ready; - fifo #( - .SIZE (FIFO_DEPTH), + .DEPTH (FIFO_DEPTH), .DATA_WIDTH(FIFO_DATA_WIDTH) ) fifo_inst ( .clk (clk), .rst (rst), - .in_data (fifo_in_data), - .in_valid (fifo_in_valid), - .in_ready (fifo_in_ready), - .out_data (fifo_out_data_flat), - .out_valid(fifo_out_valid), - .out_ready(fifo_out_ready), + .in_data (fifo_in_data[i]), + .in_valid (fifo_in_valid[i]), + .in_ready (fifo_in_ready[i]), + .out_data (fifo_out_data_flat[i]), + .out_valid(fifo_out_valid[i]), + .out_ready(fifo_out_ready[i]), .empty (), .full () ); - end // Connect up wires to write to all of the fifos using in_col_count as index // The valid and ready signals will be used to select which one is written to - logic fifo_data_readys[IN_DEPTH_DIM0-1:0]; for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin - assign fifos[i].fifo_in_data = in_data_flat; - assign fifos[i].fifo_in_valid = (self.in_col_count == i) ? in_valid : 0; - assign fifo_data_readys[i] = fifos[i].fifo_in_ready; + assign fifo_in_data[i] = in_data_flat; + assign fifo_in_valid[i] = (self.in_col_count == i) ? in_valid : 0; + assign fifo_data_readys[i] = fifo_in_ready[i]; end generate @@ -141,13 +154,11 @@ module matrix_stream_transpose #( // Connect up wires to read from all of the fifos using out_row_count to index // into the column fifos which buffer the matrix - logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_in[IN_DEPTH_DIM0-1:0]; - logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_out; - logic fifo_out_valids[IN_DEPTH_DIM0-1:0]; + for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin - assign fifo_out_data_flat_mux_in[i] = fifos[i].fifo_out_data_flat; - assign fifo_out_valids[i] = fifos[i].fifo_out_valid; - assign fifos[i].fifo_out_ready = (self.out_row_count == i) ? out_ready : 0; + assign fifo_out_data_flat_mux_in[i] = fifo_out_data_flat[i]; + assign fifo_out_valids[i] = fifo_out_valid[i]; + assign fifo_out_ready[i] = (self.out_row_count == i) ? out_ready : 0; end generate @@ -175,7 +186,6 @@ module matrix_stream_transpose #( endgenerate // Unflatten FIFO data - logic [DATA_WIDTH-1:0] transpose_data_in[COMPUTE_DIM0*COMPUTE_DIM1-1:0]; matrix_unflatten #( .DATA_WIDTH(DATA_WIDTH), .DIM0 (COMPUTE_DIM0), diff --git a/src/mase_components/matmul/rtl/simple_matmul.sv b/src/mase_components/matmul/rtl/simple_matmul.sv index 1cf50b4ed..703fa3648 100644 --- a/src/mase_components/matmul/rtl/simple_matmul.sv +++ b/src/mase_components/matmul/rtl/simple_matmul.sv @@ -29,7 +29,7 @@ module simple_matmul #( // then out_width & out_frac_width must match accumulator widths parameter OUTPUT_ROUNDING = 1, parameter OUT_WIDTH = 16, - parameter OUT_FRAC_WIDTH = 0 + parameter OUT_FRAC_WIDTH = 2 ) ( input logic clk, input logic rst, @@ -50,6 +50,10 @@ module simple_matmul #( input logic out_ready ); + // ----- + // Params + // ----- + // Accumulator widths in linear layer localparam ACC_WIDTH = X_WIDTH + Y_WIDTH + $clog2(M); localparam ACC_FRAC_WIDTH = X_FRAC_WIDTH + Y_FRAC_WIDTH; @@ -63,85 +67,90 @@ module simple_matmul #( end end - logic [N*K-1:0] dot_product_ready; - logic [N*K-1:0] dot_product_valid; - assign dot_product_ready = {(N * K) {out_ready}}; - - generate - for (genvar i = 0; i < N; i++) begin : multi_row - for (genvar j = 0; j < K; j++) begin : multi_col - - // Slice a single row of x - logic [X_WIDTH-1:0] row_x[M-1:0]; - assign row_x = x_data[(i+1)*M-1 : i*M]; - - // Slice a column of y - logic [Y_WIDTH-1:0] col_y[M-1:0]; - for (genvar m = 0; m < M; m++) begin : col_assign - assign col_y[m] = y_data[m*K+j]; - end - - // Input ready signal - /* verilator lint_off UNUSEDSIGNAL */ - logic sync_ready; - /* verilator lint_on UNUSEDSIGNAL */ - - // Linear output - logic [ACC_WIDTH-1:0] dot_product_data_out; - - fixed_dot_product #( - .IN_WIDTH (X_WIDTH), - .IN_SIZE (M), - .WEIGHT_WIDTH(Y_WIDTH) - ) linear_inst ( - .clk (clk), - .rst (rst), - .data_in (row_x), - .data_in_valid (sync_valid), - .data_in_ready (sync_ready), - .weight (col_y), - .weight_valid (sync_valid), - /* verilator lint_off PINCONNECTEMPTY */ - // This pin is the same as data_in_ready pin - .weight_ready (), - /* verilator lint_on PINCONNECTEMPTY */ - .data_out (dot_product_data_out), - .data_out_valid(dot_product_valid[i*K+j]), - .data_out_ready(dot_product_ready[i*K+j]) - ); - if (OUTPUT_ROUNDING) begin : rounding - // Rounded output - logic [OUT_WIDTH-1:0] rounded_dot_product; - fixed_round #( - .IN_WIDTH (ACC_WIDTH), - .IN_FRAC_WIDTH (ACC_FRAC_WIDTH), - .OUT_WIDTH (OUT_WIDTH), - .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) - ) round_inst ( - .data_in (dot_product_data_out), - .data_out(rounded_dot_product) - ); - assign out_data[i*K+j] = rounded_dot_product; - end else begin : no_rounding - assign out_data[i*K+j] = dot_product_data_out; - end + // ----- + // Wires + // ----- + + logic [Y_WIDTH-1:0] y_data_transpose[K*M-1:0]; + logic dot_product_ready; + logic inputs_valid, inputs_ready; + + logic [N*K-1:0] dot_product_valid; + logic [N*K-1:0] sync_ready; + logic [ACC_WIDTH-1:0] dot_product_data_out[N*K-1:0]; + logic [OUT_WIDTH-1:0] rounded_dot_product[N*K-1:0]; - end - end - endgenerate - // Need to synchronise backpressure/valid signals - logic sync_valid, join_sync_ready; - assign join_sync_ready = multi_row[0].multi_col[0].sync_ready; + // ----- + // Logic + // ----- - join2 #() sync_handshake ( + // Need to synchronise x & y inputs + assign inputs_ready = sync_ready[0]; + join2 sync_handshake ( .data_in_valid ({x_valid, y_valid}), .data_in_ready ({x_ready, y_ready}), - .data_out_valid(sync_valid), - .data_out_ready(join_sync_ready) + .data_out_valid(inputs_valid), + .data_out_ready(inputs_ready) ); - assign out_valid = &dot_product_valid; + // Transpose y to make column assignment easier, this module is just a rewire + // so it shouldn't contribute anything to comb path. + transpose #( + .WIDTH(Y_WIDTH), + .DIM0 (K), + .DIM1 (M) + ) y_transpose ( + .in_data (y_data), + .out_data(y_data_transpose) + ); + + // Instantiate N-by-K number of dot products + for (genvar i = 0; i < N; i++) begin : multi_row + for (genvar j = 0; j < K; j++) begin : multi_col + + fixed_dot_product #( + .IN_WIDTH (X_WIDTH), + .IN_SIZE (M), + .WEIGHT_WIDTH(Y_WIDTH) + ) dot_product_inst ( + .clk (clk), + .rst (rst), + .data_in (x_data[((i+1)*M)-1 : i*M]), + .data_in_valid (inputs_valid), + .data_in_ready (sync_ready[i*K+j]), + .weight (y_data_transpose[((j+1)*M)-1 : j*M]), + .weight_valid (inputs_valid), + /* verilator lint_off PINCONNECTEMPTY */ + // This pin is the same as data_in_ready pin + .weight_ready (), + /* verilator lint_on PINCONNECTEMPTY */ + .data_out (dot_product_data_out[i*K+j]), + .data_out_valid(dot_product_valid[i*K+j]), + .data_out_ready(dot_product_ready) + ); + + if (OUTPUT_ROUNDING) begin : rounding + // Rounded output + fixed_round #( + .IN_WIDTH (ACC_WIDTH), + .IN_FRAC_WIDTH (ACC_FRAC_WIDTH), + .OUT_WIDTH (OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) round_inst ( + .data_in (dot_product_data_out[i*K+j]), + .data_out(rounded_dot_product[i*K+j]) + ); + assign out_data[i*K+j] = rounded_dot_product[i*K+j]; + end else begin : no_rounding + assign out_data[i*K+j] = dot_product_data_out[i*K+j]; + end + + end + end + + assign out_valid = dot_product_valid[0]; + assign dot_product_ready = out_ready; endmodule diff --git a/src/mase_components/matmul/test/matmul_tb.py b/src/mase_components/matmul/test/matmul_tb.py index 6adfb6112..6a7ec7e91 100644 --- a/src/mase_components/matmul/test/matmul_tb.py +++ b/src/mase_components/matmul/test/matmul_tb.py @@ -14,7 +14,7 @@ logger = logging.getLogger("testbench") -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) class MatmulTB(Testbench): @@ -48,7 +48,12 @@ def __init__(self, dut) -> None: self.a_driver = StreamDriver(dut.clk, dut.a_data, dut.a_valid, dut.a_ready) self.b_driver = StreamDriver(dut.clk, dut.b_data, dut.b_valid, dut.b_ready) self.output_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready, check=True + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + check=True, + unsigned=True, ) def generate_inputs(self): @@ -95,87 +100,47 @@ def model(self, A_inputs, B_inputs): B_inputs, ) + async def run_test(self, batches, us): + await self.reset() + for _ in range(batches): + A_inputs, B_inputs = self.generate_inputs() + exp_out = self.model(A_inputs, B_inputs) + # Setup drivers and monitors + self.a_driver.load_driver(A_inputs) + self.b_driver.load_driver(B_inputs) + self.output_monitor.load_monitor(exp_out) + await Timer(us, units="us") + assert self.output_monitor.exp_queue.empty() + @cocotb.test() async def single_mult(dut): tb = MatmulTB(dut) - await tb.reset() tb.output_monitor.ready.value = 1 - A_inputs, B_inputs = tb.generate_inputs() - exp_out = tb.model(A_inputs, B_inputs) - - # Setup drivers and monitors - for a in A_inputs: - tb.a_driver.append(a) - for b in B_inputs: - tb.b_driver.append(b) - for o in exp_out: - tb.output_monitor.expect(o) - - await Timer(100, units="us") - assert tb.output_monitor.exp_queue.empty() + await tb.run_test(batches=1, us=100) @cocotb.test() async def repeated_mult(dut): tb = MatmulTB(dut) - await tb.reset() tb.output_monitor.ready.value = 1 - - for _ in range(100): - A, B = tb.generate_inputs() - e_out = tb.model(A, B) - for a in A: - tb.a_driver.append(a) - for b in B: - tb.b_driver.append(b) - for o in e_out: - tb.output_monitor.expect(o) - - await Timer(100, units="us") - assert tb.output_monitor.exp_queue.empty() + await tb.run_test(batches=1000, us=2000) @cocotb.test() async def repeated_mult_backpressure(dut): tb = MatmulTB(dut) - await tb.reset() cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.6)) - - for _ in range(100): - A, B = tb.generate_inputs() - e_out = tb.model(A, B) - for a in A: - tb.a_driver.append(a) - for b in B: - tb.b_driver.append(b) - for o in e_out: - tb.output_monitor.expect(o) - - await Timer(100, units="us") - assert tb.output_monitor.exp_queue.empty() + await tb.run_test(batches=500, us=2000) @cocotb.test() async def repeated_mult_valid_backpressure(dut): tb = MatmulTB(dut) - await tb.reset() tb.a_driver.set_valid_prob(0.7) tb.b_driver.set_valid_prob(0.7) cocotb.start_soon(bit_driver(dut.out_ready, dut.clk, 0.6)) - - for _ in range(100): - A, B = tb.generate_inputs() - e_out = tb.model(A, B) - for a in A: - tb.a_driver.append(a) - for b in B: - tb.b_driver.append(b) - for o in e_out: - tb.output_monitor.expect(o) - - await Timer(100, units="us") - assert tb.output_monitor.exp_queue.empty() + await tb.run_test(batches=500, us=2000) def gen_random_dimensions(): @@ -215,7 +180,6 @@ def generate_random_dimension_cfg(cfg_list, multiple=3): import pytest -@pytest.mark.skip(reason="Needs to be fixed.") def test_matmul(): # Default is a square matrix mult # 4x4 4x4 matrix multiplication done using 2x2 window @@ -275,8 +239,8 @@ def test_matmul(): # Dimensions *generate_random_dimension_cfg([DEFAULT_CONFIG]), ], - seed=1705250706, trace=True, + jobs=12, ) diff --git a/src/mase_components/matmul/test/matrix_stream_transpose_tb.py b/src/mase_components/matmul/test/matrix_stream_transpose_tb.py index 8bcebd227..4dd7f58b1 100644 --- a/src/mase_components/matmul/test/matrix_stream_transpose_tb.py +++ b/src/mase_components/matmul/test/matrix_stream_transpose_tb.py @@ -29,7 +29,12 @@ def __init__(self, dut) -> None: self.in_driver = StreamDriver(dut.clk, dut.in_data, dut.in_valid, dut.in_ready) self.out_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready, check=True + dut.clk, + dut.out_data, + dut.out_valid, + dut.out_ready, + check=True, + unsigned=True, ) def generate_inputs(self): @@ -67,7 +72,7 @@ async def single_transpose(dut): tb.in_driver.append(x) for y in Y: tb.out_monitor.expect(y) - await Timer(1, units="us") + await Timer(100, units="us") assert tb.out_monitor.exp_queue.empty() @@ -83,7 +88,7 @@ async def multiple_transpose(dut): tb.in_driver.append(x) for y in Y: tb.out_monitor.expect(y) - await Timer(100, units="us") + await Timer(1000, units="us") assert tb.out_monitor.exp_queue.empty() @@ -99,7 +104,7 @@ async def multiple_transpose_backpressure(dut): tb.in_driver.append(x) for y in Y: tb.out_monitor.expect(y) - await Timer(100, units="us") + await Timer(1000, units="us") assert tb.out_monitor.exp_queue.empty() @@ -116,7 +121,7 @@ async def multiple_transpose_valid_backpressure(dut): tb.in_driver.append(x) for y in Y: tb.out_monitor.expect(y) - await Timer(200, units="us") + await Timer(2000, units="us") assert tb.out_monitor.exp_queue.empty() @@ -158,7 +163,8 @@ def test_matrix_stream_transpose(): }, # Random test *[gen_random_params() for _ in range(5)], - ] + ], + trace=True, ) diff --git a/src/mase_components/matmul/test/simple_matmul_tb.py b/src/mase_components/matmul/test/simple_matmul_tb.py index 75125d057..1e346359a 100644 --- a/src/mase_components/matmul/test/simple_matmul_tb.py +++ b/src/mase_components/matmul/test/simple_matmul_tb.py @@ -45,7 +45,7 @@ def __init__(self, dut) -> None: self.x_driver = StreamDriver(dut.clk, dut.x_data, dut.x_valid, dut.x_ready) self.y_driver = StreamDriver(dut.clk, dut.y_data, dut.y_valid, dut.y_ready) self.output_monitor = StreamMonitor( - dut.clk, dut.out_data, dut.out_valid, dut.out_ready + dut.clk, dut.out_data, dut.out_valid, dut.out_ready, unsigned=True ) def generate_inputs(self, random=False): diff --git a/src/mase_components/matmul/test/test_chain_matmul_tb.py b/src/mase_components/matmul/test/test_chain_matmul_tb.py new file mode 100644 index 000000000..6b4a5fa7b --- /dev/null +++ b/src/mase_components/matmul/test/test_chain_matmul_tb.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +import logging +import torch +from torch import Tensor + +import random +from random import randint +from math import ceil, log2 +from copy import copy + +import numpy as np + +import cocotb +from cocotb.triggers import * + +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor +from mase_cocotb.z_qlayers import quantize_to_int +from mase_cocotb.runner import mase_runner +from mase_cocotb.matrix_tools import gen_random_matrix_input, matrix_mult_model + +logger = logging.getLogger("testbench") +logger.setLevel(logging.DEBUG) + + +class ChainMatmulTB(Testbench): + def __init__(self, dut) -> None: + super().__init__(dut, dut.clk, dut.rst) + self.assign_self_params( + [ + "N", + "M", + "K", + "Z", + "IN_WIDTH", + "IN_FRAC_WIDTH", + "INT_WIDTH", + "INT_FRAC_WIDTH", + "OUT_WIDTH", + "OUT_FRAC_WIDTH", + "COMPUTE_DIM0", + "COMPUTE_DIM1", + "SYMMETRIC", + ] + ) + + # Drivers & Monitors + self.a_driver = StreamDriver(dut.clk, dut.a_data, dut.a_valid, dut.a_ready) + self.b_driver = StreamDriver(dut.clk, dut.b_data, dut.b_valid, dut.b_ready) + self.c_driver = StreamDriver(dut.clk, dut.c_data, dut.c_valid, dut.c_ready) + self.d_monitor = StreamMonitor( + dut.clk, dut.d_data, dut.d_valid, dut.d_ready, check=True + ) + + def cleanup(self): + self.a_driver.kill() + self.b_driver.kill() + self.c_driver.kill() + self.d_monitor.kill() + + # Dimensions for chain matmul are: (nm * mk) * kz = nz + def generate_inputs(self): + A_inputs = gen_random_matrix_input( + self.M, + self.N, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + ) + B_inputs = gen_random_matrix_input( + self.K, + self.M, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + ) + C_inputs = gen_random_matrix_input( + self.Z, + self.K, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + ) + return A_inputs, B_inputs, C_inputs + + def model(self, A_inputs, B_inputs, C_inputs): + # (nm * mk) -> nk + intermediate_matrix = matrix_mult_model( + self.M, + self.N, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.K, + self.M, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.K, + self.N, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + self.INT_WIDTH, + self.INT_FRAC_WIDTH, + self.SYMMETRIC, + A_inputs, + B_inputs, + ) + # (nk * kz) -> nz + output = matrix_mult_model( + self.K, + self.N, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.Z, + self.K, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.Z, + self.N, + self.COMPUTE_DIM0, + self.COMPUTE_DIM1, + self.INT_WIDTH, + self.INT_FRAC_WIDTH, + self.IN_WIDTH, + self.IN_FRAC_WIDTH, + self.OUT_WIDTH, + self.OUT_FRAC_WIDTH, + self.SYMMETRIC, + intermediate_matrix, + C_inputs, + ) + return output + + +@cocotb.test() +async def basic(dut): + tb = ChainMatmulTB(dut) + tb.d_monitor.ready.value = 1 + await tb.reset() + A_inputs, B_inputs, C_inputs = tb.generate_inputs() + exp_out = tb.model(A_inputs, B_inputs, C_inputs) + + # Setup drivers and monitors + for a in A_inputs: + tb.a_driver.append(a) + for b in B_inputs: + tb.b_driver.append(b) + for c in C_inputs: + tb.c_driver.append(c) + for o in exp_out: + tb.d_monitor.expect(o) + await Timer(100, units="us") + assert tb.d_monitor.exp_queue.empty() + tb.cleanup() + + +if __name__ == "__main__": + mase_runner( + module_param_list=[ + {"N": 2, "M": 2, "K": 2, "Z": 2}, + {"N": 4, "M": 2, "K": 4, "Z": 2}, + {"N": 2, "M": 4, "K": 2, "Z": 8}, + {"N": 8, "M": 2, "K": 8, "Z": 2}, + {"N": 8, "M": 4, "K": 4, "Z": 2}, + ] + ) diff --git a/src/mase_components/matmul/test/test_synth_matmul.py b/src/mase_components/matmul/test/test_synth_matmul.py new file mode 100644 index 000000000..22b68d1f3 --- /dev/null +++ b/src/mase_components/matmul/test/test_synth_matmul.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_matmul(): + run_synth("matmul") + + +if __name__ == "__main__": + test_synth_matmul() diff --git a/src/mase_components/memory/test/test_synth_memory.py b/src/mase_components/memory/test/test_synth_memory.py new file mode 100644 index 000000000..be98e2046 --- /dev/null +++ b/src/mase_components/memory/test/test_synth_memory.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_linear(): + run_synth("linear") + + +if __name__ == "__main__": + test_synth_linear() diff --git a/src/mase_components/norm/rtl/channel_selection.sv b/src/mase_components/norm/rtl/channel_selection.sv index f183e03cb..b72d633a9 100644 --- a/src/mase_components/norm/rtl/channel_selection.sv +++ b/src/mase_components/norm/rtl/channel_selection.sv @@ -10,7 +10,11 @@ Description : This module a double counter which outputs which channel the module channel_selection #( parameter NUM_CHANNELS = 2, // Number of blocks in spatial dimensions (usually = depth_dim0 * depth_dim1) - parameter NUM_SPATIAL_BLOCKS = 4 + parameter NUM_SPATIAL_BLOCKS = 4, + + // Channel and spatial state widths + localparam C_STATE_WIDTH = (NUM_CHANNELS == 1) ? 1 : $clog2(NUM_CHANNELS), + localparam S_STATE_WIDTH = (NUM_SPATIAL_BLOCKS == 1) ? 1 : $clog2(NUM_SPATIAL_BLOCKS) ) ( input logic clk, input logic rst, @@ -18,10 +22,6 @@ module channel_selection #( output logic [C_STATE_WIDTH-1:0] channel ); - // Channel and spatial state widths - localparam C_STATE_WIDTH = (NUM_CHANNELS == 1) ? 1 : $clog2(NUM_CHANNELS); - localparam S_STATE_WIDTH = (NUM_SPATIAL_BLOCKS == 1) ? 1 : $clog2(NUM_SPATIAL_BLOCKS); - generate if (NUM_CHANNELS == 1) begin assign channel = 0; diff --git a/src/mase_components/norm/rtl/group_norm_2d.sv b/src/mase_components/norm/rtl/group_norm_2d.sv index 8cbeb871d..3cbc41239 100644 --- a/src/mase_components/norm/rtl/group_norm_2d.sv +++ b/src/mase_components/norm/rtl/group_norm_2d.sv @@ -217,10 +217,9 @@ module group_norm_2d #( assign mu_acc_div = ($signed(mu_acc) * $signed({1'b0, INV_NUMVALUES_0})) >>> ACC_OUT_WIDTH; assign mu_in = mu_acc_div[IN_WIDTH-1:0]; - repeat_circular_buffer #( + single_element_repeat #( .DATA_WIDTH(IN_WIDTH), - .REPEAT(NUM_ITERS), - .SIZE(1) + .REPEAT(NUM_ITERS) ) mu_buffer ( .clk(clk), .rst(rst), @@ -358,10 +357,9 @@ module group_norm_2d #( ); - repeat_circular_buffer #( + single_element_repeat #( .DATA_WIDTH(ISQRT_WIDTH), - .REPEAT(NUM_ITERS), - .SIZE(1) + .REPEAT(NUM_ITERS) ) isqrt_var_circ_buffer ( .clk(clk), .rst(rst), diff --git a/src/mase_components/norm/rtl/rms_norm_2d.sv b/src/mase_components/norm/rtl/rms_norm_2d.sv index 3af06f4c7..40be3bed1 100644 --- a/src/mase_components/norm/rtl/rms_norm_2d.sv +++ b/src/mase_components/norm/rtl/rms_norm_2d.sv @@ -239,10 +239,9 @@ module rms_norm_2d #( .out_ready(inv_sqrt_ready) ); - repeat_circular_buffer #( + single_element_repeat #( .DATA_WIDTH(ISQRT_WIDTH), - .REPEAT(NUM_ITERS), - .SIZE(1) + .REPEAT(NUM_ITERS) ) inv_sqrt_circ_buffer ( .clk(clk), .rst(rst), diff --git a/src/mase_components/norm/test/test_synth_norm.py b/src/mase_components/norm/test/test_synth_norm.py new file mode 100644 index 000000000..77c66e9d8 --- /dev/null +++ b/src/mase_components/norm/test/test_synth_norm.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_norm(): + run_synth("norm") + + +if __name__ == "__main__": + test_synth_norm() diff --git a/src/mase_components/synth_runner.py b/src/mase_components/synth_runner.py new file mode 100644 index 000000000..bb852f78f --- /dev/null +++ b/src/mase_components/synth_runner.py @@ -0,0 +1,104 @@ +import subprocess +from pathlib import Path +import os, sys + +from chop.tools import get_logger, set_logging_verbosity +import mase_components +from mase_components.deps import MASE_HW_DEPS + +logger = get_logger(f"linter") +set_logging_verbosity("debug") + +COMPONENTS_PATH = Path(__file__).parents[0] + + +def generate_tcl_script(group, module_name, include_groups, synth_project_path): + os.makedirs(synth_project_path, exist_ok=True) + tcl_script_template = f""" +set_param board.repoPaths {{{str(Path.home())}/shared/board-files}} +create_project synth_project_{group}_{module_name} {synth_project_path} -part xcu280-fsvh2892-2L-e +set_property board_part xilinx.com:au280:part0:1.1 [current_project] +""" + for include_group in include_groups: + tcl_script_template += f"""\nadd_files {include_group}""" + + tcl_script_template += f"\n\nset_property top {module_name} [current_fileset]" + + tcl_script_template += """ +update_compile_order -fileset sources_1 +launch_runs synth_1 +wait_on_runs synth_1 +""" + + with open(f"{synth_project_path}/build.tcl", "w") as file: + file.write(tcl_script_template) + + +def run_synth(group): + comp_path = COMPONENTS_PATH / group / "rtl" + rtl_files = [ + file + for file in os.listdir(comp_path) + if file.endswith(".sv") or file.endswith(".v") + ] + + successes = [] + failures = [] + + for rtl_file in rtl_files: + file_path = comp_path / rtl_file + logger.info(f"Synthesizing {file_path}") + logger.info(f"----------------------------") + + module_name = rtl_file.replace(".sv", "") + module_path = f"{group}/{module_name}" + + if module_path not in MASE_HW_DEPS.keys(): + logger.warning( + f"Module {module_path} is not included in dependencies file." + ) + + # * List include files + include_groups = [ + f"{COMPONENTS_PATH / group / 'rtl'}" + for group in mase_components.get_modules() + if group != "vivado" + ] + + synth_project_path = ( + f"{COMPONENTS_PATH}/{group}/synth/synth_project_{group}_{module_name}" + ) + + logger.debug(f"Include files: {include_groups}") + + logger.info(f"Generating build TCL script for module: {module_path}") + generate_tcl_script(group, module_name, include_groups, synth_project_path) + + logger.info(f"Launching Vivado synthesis for module: {module_path}") + cmd = [ + "vivado", + "-mode", + "batch", + "-log", + f"{synth_project_path}/vivado.log", + "-source", + f"{synth_project_path}/build.tcl", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + + # * Process result + if result.stderr == "": + successes.append(rtl_file) + else: + logger.error(result.stderr) + failures.append(rtl_file) + + # * Print summary + logger.info(f"=========== SUMMARY ===========") + logger.info( + f"PASS: {len(successes)}/{len(rtl_files)}, FAIL: {len(failures)}/{len(rtl_files)}" + ) + + if len(failures) > 0: + logger.error(f"Failed synthesizing the following modules: {failures}") + sys.exit(1) diff --git a/src/mase_components/systolic_arrays/test/test_synth_systolic_arrays.py b/src/mase_components/systolic_arrays/test/test_synth_systolic_arrays.py new file mode 100644 index 000000000..773dfab17 --- /dev/null +++ b/src/mase_components/systolic_arrays/test/test_synth_systolic_arrays.py @@ -0,0 +1,9 @@ +from mase_components.synth_runner import run_synth + + +def test_synth_systolic_arrays(): + run_synth("systolic_arrays") + + +if __name__ == "__main__": + test_synth_systolic_arrays() diff --git a/src/mase_components/vivado/constraints.xdc b/src/mase_components/vivado/constraints.xdc new file mode 100644 index 000000000..60faaa465 --- /dev/null +++ b/src/mase_components/vivado/constraints.xdc @@ -0,0 +1 @@ +create_clock -period 20.000 -name clk -waveform {0.000 10.000} [get_ports clk] \ No newline at end of file diff --git a/test/nn/quantized/modules/attention_head.py b/test/nn/quantized/modules/attention_head.py new file mode 100644 index 000000000..e0f52a61e --- /dev/null +++ b/test/nn/quantized/modules/attention_head.py @@ -0,0 +1,22 @@ +from chop.nn.quantized.modules.attention_head import ( + BertSelfAttentionHeadInteger, +) +from transformers import AutoConfig +import torch + + +def test_quantized_attention_head(): + cf = AutoConfig.from_pretrained("bert-base-uncased") + head = BertSelfAttentionHeadInteger(cf, q_config={"width": 8, "frac_width": 4}) + + inputs = { + "query_layer": torch.randn((20, 64)), + "key_layer": torch.randn((20, 64)), + "value_layer": torch.randn((20, 64)), + } + + _ = head(**inputs) + + +if __name__ == "__main__": + test_quantized_attention_head() diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_bert.py b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py new file mode 100644 index 000000000..b0982601f --- /dev/null +++ b/test/passes/graph/analysis/autosharding/test_autosharding_bert.py @@ -0,0 +1,67 @@ +import sys, pdb, traceback + +import torch +import torch.nn as nn + +from chop.ir import MaseGraph +from chop.distributed import MaseLauncher +import chop.passes as passes +from chop.tools import get_logger + +from transformers.models.bert import BertConfig, BertModel + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +WORLD_SIZE = 8 +DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] + +def test_autosharding(): + + # Define config + config = BertConfig() + config.num_hidden_layers = 3 + config.hidden_size = 96 + config.intermediate_size = 384 + config._attn_implementation = "eager" + config_sequence_length = 4 + + # Initialize model and MaseGraph + model = BertModel(config) + mg = MaseGraph(model) + mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "bert.txt"}) + mg, _ = passes.add_common_metadata_analysis_pass( + mg, + pass_args={ + "dummy_in": { + "input_ids": torch.randint(0, 10, (1, config_sequence_length)), + }, + "add_value": False, + }, + ) + + # Run autosharding pass to decide sharding configuration + mg, module_map = passes.autosharding_analysis_pass( + mg, + pass_args = { + "mesh_shape": (2, 4), + "inter_node_bandwidth": 10e9, + "intra_node_bandwidth": 100e9 + }) + + # Insert resharding wrappers around each module to handle inter-operator communication + mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) + + # dump print model to a file + with open("model.txt", "w") as f: + print(mg.model, file=f) + + # Launch model in distributed cluster + launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) + inputs = [torch.randint(0, 10, (1, config_sequence_length))] + launcher.run(module_map, inputs) + + +if __name__ == "__main__": + test_autosharding() diff --git a/test/passes/graph/analysis/autosharding/test_autosharding_linear.py b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py new file mode 100644 index 000000000..a337942ca --- /dev/null +++ b/test/passes/graph/analysis/autosharding/test_autosharding_linear.py @@ -0,0 +1,65 @@ +import sys, pdb, traceback, os + +import torch +import torch.nn as nn + +from chop.ir import MaseGraph +from chop.distributed import MaseLauncher +import chop.passes as passes +from chop.tools import get_logger + +def excepthook(exc_type, exc_value, exc_traceback): + traceback.print_exception(exc_type, exc_value, exc_traceback) + print("\nEntering debugger...") + pdb.post_mortem(exc_traceback) + + +# Set the custom exception hook +sys.excepthook = excepthook + +logger = get_logger(__name__) +logger.setLevel("DEBUG") + +WORLD_SIZE = 8 +DEVICE_MESH = [[0, 1, 2, 3], [4, 5, 6, 7]] + +class MLP(nn.Module): + def __init__(self, in_features=64, hidden_dimension=128, out_features=64): + super().__init__() + self.l1 = nn.Linear(in_features, hidden_dimension) + self.l2 = nn.Linear(hidden_dimension, out_features) + + def forward(self, x): + out = self.l1(x) + return self.l2(out) + +def test_autosharding(): + + # Initialize model and MaseGraph + model = MLP() + mg = MaseGraph(model) + mg, _ = passes.init_metadata_analysis_pass(mg) + mg, _ = passes.add_common_metadata_analysis_pass( + mg, pass_args={"dummy_in": {"x": torch.randn((16, 64))}, "add_value": False} + ) + + # Run autosharding pass to decide sharding configuration + mg, module_map = passes.autosharding_analysis_pass( + mg, + pass_args = { + "mesh_shape": (2, 4), + "inter_node_bandwidth": 10e9, + "intra_node_bandwidth": 100e9 + }) + + # Insert resharding wrappers around each module to handle inter-operator communication + mg, _ = passes.resharding_transform_pass(mg, pass_args={"module_map": module_map, "device_mesh": DEVICE_MESH}) + + # Launch model in distributed cluster + launcher = MaseLauncher(mg, world_size=WORLD_SIZE, device_mesh=DEVICE_MESH) + inputs = [torch.randn((16, 64))] + launcher.run(module_map, inputs) + + +if __name__ == "__main__": + test_autosharding() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py new file mode 100644 index 000000000..a1e367bd5 --- /dev/null +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py @@ -0,0 +1,241 @@ +import sys, os + +import torch +import torch.nn as nn + +from transformers.activations import GELUActivation + +import chop.passes as passes +import chop.actions as actions +from chop.ir import MaseGraph +from chop.models.patched.bert import BertConfig, BertModel +from chop.models.patched.bert.modeling_bert import BertSelfAttention +from chop.passes.graph.utils import deepsetattr +from chop.nn.quantized import ( + BertSelfAttentionInteger, + LinearInteger, + LayerNormInteger, + GELUInteger, +) +from chop.tools import get_logger, set_excepthook + +from mase_components import get_module_dependencies +from mase_components.activations.test.generate_memory import generate_sv_lut + +import operator +from functools import partial + +logger = get_logger(__name__) +logger.setLevel("DEBUG") +set_excepthook() + +# * Define custom ops (leaf submodules during tracing) +# * This is useful so we can write a single optimised verilog file for self attention, +# * instead of relying on emit_verilog to instantiate each submodule +BERT_CUSTOM_OPS = { + "modules": { + BertSelfAttentionInteger: { + "args": { + "hidden_states": "data_in", + "attention_mask": None, + "head_mask": None, + "encoder_hidden_states": None, + "encoder_attention_mask": None, + "past_key_value": None, + "output_attentions": "config", + }, + "toolchain": "INTERNAL_RTL", + "module": "fixed_self_attention_single_precision_wrapper", + "dependence_files": get_module_dependencies( + "attention/fixed_self_attention_single_precision_wrapper" + ), + }, + }, + "functions": {}, +} + + +def bert_module_level_quantize(model, model_config, q_config): + for module in model.named_modules(): + if isinstance(module[1], BertSelfAttention): + new_module = BertSelfAttentionInteger( + model_config, q_config, output_tensor_only=True + ) + elif isinstance(module[1], nn.Linear): + new_module = LinearInteger( + in_features=module[1].in_features, + out_features=module[1].out_features, + bias=module[1].bias is not None, + config=q_config, + ) + elif isinstance(module[1], nn.LayerNorm): + new_module = LayerNormInteger( + normalized_shape=module[1].normalized_shape, + eps=module[1].eps, + config=q_config, + ) + elif isinstance(module[1], GELUActivation): + new_module = GELUInteger(config=q_config) + else: + continue + logger.info(f"Replacing module: {module[0]}") + deepsetattr(model, module[0], new_module) + return model + + +def bert_update_metadata(mg, q_config): + """ + The following processing is a temporary hot fix to get emit verilog working on the bert model. We + update the type and precision for the add, getitem and split (fork) nodes which are currently + inserted in the patched model code. In the (near) future, inserting forking nodes and setting their + precision correctly will be handled automatedly as a preprocessing step for the emit verilog pass, + so this function will be unnecessary. + """ + for node in mg.fx_graph.nodes: + + # Update args + if ( + node.target == operator.add + or node.target == operator.getitem + or node.meta["mase"]["common"]["mase_op"] == "df_split" + ): + node.meta["mase"]["common"]["args"]["data_in_0"]["type"] = "fixed" + node.meta["mase"]["common"]["args"]["data_in_0"]["precision"] = [ + q_config["data_in_width"], + q_config["data_in_frac_width"], + ] + if "data_in_1" in node.meta["mase"]["common"]["args"]: + node.meta["mase"]["common"]["args"]["data_in_1"]["type"] = "fixed" + node.meta["mase"]["common"]["args"]["data_in_1"]["precision"] = [ + q_config["data_in_width"], + q_config["data_in_frac_width"], + ] + + # Update results + if ( + node.target == operator.add + or node.target == operator.getitem + or node.meta["mase"]["common"]["mase_op"] == "df_split" + or node.op == "placeholder" + or node.op == "output" + ): + node.meta["mase"]["common"]["results"]["data_out_0"]["type"] = "fixed" + node.meta["mase"]["common"]["results"]["data_out_0"]["precision"] = [ + q_config["data_out_width"], + q_config["data_out_frac_width"], + ] + if "data_out_1" in node.meta["mase"]["common"]["results"]: + node.meta["mase"]["common"]["results"]["data_out_1"]["type"] = "fixed" + node.meta["mase"]["common"]["results"]["data_out_1"]["precision"] = [ + q_config["data_out_width"], + q_config["data_out_frac_width"], + ] + + # Set one of the args to none according to the select value + if node.target == operator.getitem: + select = 0 if node.args[1] == 1 else 1 + node.meta["mase"]["common"]["args"][f"data_in_{select}"] = None + + return mg, {} + + +def emit_verilog_bert( + config, q_config, config_sequence_length, wait_count=15, wait_unit="ms", max_parallelism=4 +): + # * Get model and quantize self attention, linear and layer norm layers + model = BertModel(config) + model = bert_module_level_quantize(model, config, q_config) + logger.info(f"Quantized BERT model: {model}") + + # * Trace the model + mg = MaseGraph(model, custom_ops=BERT_CUSTOM_OPS) + mg, _ = passes.init_metadata_analysis_pass(mg) + + mg, _ = passes.report_graph_analysis_pass(mg, pass_args={"file_name": "bert.txt"}) + + # * Add metadata analysis passes + mg, _ = passes.add_common_metadata_analysis_pass( + mg, + pass_args={ + "dummy_in": { + "input_ids": torch.randn( + (1, config_sequence_length, config.hidden_size) + ) + }, + "add_value": False, + }, + ) + + mg, _ = bert_update_metadata(mg, q_config) + + mg, _ = passes.add_hardware_metadata_analysis_pass( + mg, + pass_args={ + "max_parallelism": [max_parallelism] * 4, + }, + ) + + # * Save the metadata to a file for debugging + mg, _ = passes.report_node_meta_param_analysis_pass( + mg, + pass_args={ + "which": ["common", "hardware"], + "save_path": "graph_meta_params.txt", + }, + ) + + mg, _ = passes.emit_verilog_top_transform_pass(mg) + mg, _ = passes.emit_bram_transform_pass(mg) + mg, _ = passes.emit_internal_rtl_transform_pass(mg) + mg, _ = passes.emit_cocotb_transform_pass( + mg, + pass_args={ + "wait_time": wait_count, + "wait_unit": wait_unit, + }, + ) + mg, _ = passes.emit_vivado_project_transform_pass(mg) + + # Temporary: fix data coherency checks + os.environ["COCOTB_RESOLVE_X"] = "ZEROS" + + actions.simulate(skip_build=False, skip_test=False, gui=False, waves=False, simulator="questa") + + +def get_default_qconfig(): + return { + "data_in_width": 8, + "data_in_frac_width": 3, + "weight_width": 8, + "weight_frac_width": 3, + "bias_width": 8, + "bias_frac_width": 3, + "data_out_width": 8, + "data_out_frac_width": 3, + } + + +def test_emit_verilog_bert_smoke(): + config = BertConfig() + config.num_hidden_layers = 3 + config.hidden_size = 96 + config.intermediate_size = 384 + config_sequence_length = 4 + q_config = get_default_qconfig() + emit_verilog_bert(config, q_config, config_sequence_length, wait_count=10, max_parallelism=2) + + +def test_emit_verilog_bert_regression(): + config = BertConfig() + config.num_hidden_layers = 3 + config.hidden_size = 384 + config.intermediate_size = 1536 + config_sequence_length = 128 + q_config = get_default_qconfig() + emit_verilog_bert(config, q_config, config_sequence_length, wait_count=15, max_parallelism=16) + + +if __name__ == "__main__": + generate_sv_lut("gelu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) + test_emit_verilog_bert_smoke() + test_emit_verilog_bert_regression() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index 34fdf9092..3e7ee06b3 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -8,7 +8,6 @@ import chop as chop import chop.passes as passes -from chop.tools.utils import execute_cli from pathlib import Path @@ -39,10 +38,9 @@ class MLP(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.fc1 = nn.Linear(28 * 28, 28 * 28, bias=False) + self.fc1 = nn.Linear(768, 768, bias=True) def forward(self, x): - x = torch.flatten(x, start_dim=1, end_dim=-1) x = torch.nn.functional.relu(self.fc1(x)) return x @@ -52,8 +50,8 @@ def test_emit_verilog_linear(): mg = chop.MaseGraph(model=mlp) # Provide a dummy input for the graph so it can use for tracing - batch_size = 1 - x = torch.randn((batch_size, 28, 28)) + batch_size = 20 + x = torch.randn((batch_size, 768)) dummy_in = {"x": x} mg, _ = passes.init_metadata_analysis_pass(mg, None) @@ -105,15 +103,18 @@ def test_emit_verilog_linear(): 10 * torch.randn(mg.model.fc1.weight.shape) ) - mg, _ = passes.add_hardware_metadata_analysis_pass(mg) + mg, _ = passes.add_hardware_metadata_analysis_pass( + mg, pass_args={"max_parallelism": [2]*4} + ) mg, _ = passes.report_node_hardware_type_analysis_pass(mg) # pretty print mg, _ = passes.emit_verilog_top_transform_pass(mg) mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) - mg, _ = passes.emit_cocotb_transform_pass(mg) + mg, _ = passes.emit_cocotb_transform_pass(mg, pass_args={"wait_time": 100, "wait_unit": "ms", "batch_size": batch_size}) + mg, _ = passes.emit_vivado_project_transform_pass(mg) - simulate(skip_build=False, skip_test=True) + simulate(skip_build=False, skip_test=False) if __name__ == "__main__": diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py new file mode 100644 index 000000000..f5363e92a --- /dev/null +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_llama.py @@ -0,0 +1,201 @@ +import os, operator + +import torch +import torch.nn as nn + +from chop import AutoPipelineForEmitVerilog +import chop.passes as passes +import chop.actions as actions +from chop.ir import MaseGraph +from chop.passes.graph.utils import deepsetattr + +from chop.models.patched.llama import LlamaConfig, LlamaModel +from chop.models.patched.llama.modeling_llama import LlamaSdpaAttention, LlamaRMSNorm + +from chop.nn.quantized import ( + LlamaSdpaAttentionInteger, + LinearInteger, + RMSNormInteger, + SiLUInteger, +) + +from chop.tools import get_logger, set_excepthook + +from mase_components import get_module_dependencies +from mase_components.activations.test.generate_memory import generate_sv_lut + +logger = get_logger(__name__) +logger.setLevel("DEBUG") +set_excepthook() + +# Temporary: fix data coherency checks +os.environ["COCOTB_RESOLVE_X"] = "ZEROS" + +SMOKE_TEST_SCALE_FACTOR = 8 + +# * Define custom ops (leaf submodules during tracing) +# * This is useful so we can write a single optimised verilog file for self attention, +# * instead of relying on emit_verilog to instantiate each submodule +LLAMA_CUSTOM_OPS = { + "modules": { + LlamaSdpaAttention: { + "args": { + "hidden_states": "data_in", + "attention_mask": None, + "position_ids": None, + "past_key_value": None, + "output_attentions": None, + "use_cache": None, + "cache_position": None, + }, + "toolchain": "INTERNAL_RTL", + "module": "fixed_self_attention_single_precision_wrapper", + "dependence_files": get_module_dependencies( + "attention/fixed_self_attention_single_precision_wrapper" + ), + }, + RMSNormInteger: { + "args": { + "hidden_states": "data_in", + }, + "toolchain": "INTERNAL_RTL", + "module": "norm", + "dependence_files": get_module_dependencies("norm/norm"), + }, + SiLUInteger: { + "args": { + "input": "data_in", + }, + "toolchain": "INTERNAL_RTL", + "module": "silu", + "dependence_files": get_module_dependencies("silu/silu"), + }, + }, + "functions": {}, +} + + +def llama_module_level_quantize(model, model_config, q_config): + for name, module in model.named_modules(): + if isinstance(module, LlamaSdpaAttention): + new_module = LlamaSdpaAttentionInteger( + config=model_config, + q_config=q_config, + output_tensor_only=True, + ) + elif isinstance(module, nn.Linear): + new_module = LinearInteger( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + config=q_config, + ) + elif isinstance(module, LlamaRMSNorm): + new_module = RMSNormInteger( + normalized_shape=None, + eps=module.variance_epsilon, + config=q_config, + ) + elif isinstance(module, nn.SiLU): + new_module = SiLUInteger( + inplace=module.inplace, + config=q_config, + ) + else: + continue + logger.info(f"Replacing module: {name}") + deepsetattr(model, name, new_module) + return model + + +def emit_verilog_llama( + config, + q_config, + config_sequence_length, + config_batch_size, + wait_count=15, + wait_unit="ms", + max_parallelism=4, +): + # * Get model and quantize self attention, linear and layer norm layers + model = LlamaModel(config) + model = llama_module_level_quantize(model, config, q_config) + logger.info(f"Quantized Llama model: {model}") + + # * Trace the model + mg = MaseGraph(model, custom_ops=LLAMA_CUSTOM_OPS) + + pipeline = AutoPipelineForEmitVerilog() + mg = pipeline( + mg, + pass_args={ + "report_graph_analysis_pass": {"file_name": "llama.txt"}, + "add_common_metadata_analysis_pass": { + "dummy_in": { + "input_ids": torch.randn( + (config_batch_size, config_sequence_length, config.hidden_size) + ) + }, + "add_value": False, + }, + "patch_metadata_transform_pass": { + "q_config": q_config, + }, + "add_hardware_metadata_analysis_pass": { + "max_parallelism": [max_parallelism] * 4, + }, + "report_node_meta_param_analysis_pass": { + "which": ["common", "hardware"], + "save_path": "llama_graph_meta_params.txt", + }, + "emit_cocotb_transform_pass": { + "wait_time": wait_count, + "wait_unit": wait_unit, + }, + }, + ) + + actions.simulate( + skip_build=False, skip_test=False, gui=True, waves=False, simulator="questa" + ) + + +def get_default_qconfig(): + return { + "data_in_width": 8, + "data_in_frac_width": 3, + "weight_width": 8, + "weight_frac_width": 3, + "bias_width": 8, + "bias_frac_width": 3, + "data_out_width": 8, + "data_out_frac_width": 3, + } + + +def test_emit_verilog_llama_smoke(): + config = LlamaConfig() + config.num_hidden_layers = 1 + config.hidden_size //= SMOKE_TEST_SCALE_FACTOR + config.intermediate_size //= SMOKE_TEST_SCALE_FACTOR + config.max_position_embeddings = 4096 + config.rms_norm_eps = 1e-5 + + config_batch_size = 5 + config_sequence_length = 4 + + q_config = get_default_qconfig() + + emit_verilog_llama( + config, + q_config, + config_sequence_length, + config_batch_size, + wait_count=10, + max_parallelism=2, + ) + + +if __name__ == "__main__": + generate_sv_lut("silu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) + test_emit_verilog_llama_smoke() diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py new file mode 100644 index 000000000..39ec6079d --- /dev/null +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_mistral.py @@ -0,0 +1,151 @@ +import sys, os + +import torch +import torch.nn as nn + +from transformers.activations import GELUActivation + +import chop.passes as passes +import chop.actions as actions +from chop.ir import MaseGraph +from chop.models.patched.mistral import MistralConfig, MistralModel +from chop.models.patched.mistral.modeling_mistral import MistralAttention +from chop.passes.graph.utils import deepsetattr + +# from chop.nn.quantized import MistralAttentionInteger +from chop.tools import get_logger, set_excepthook + +from mase_components import get_module_dependencies +from mase_components.activations.test.generate_memory import generate_sv_lut + +import operator +from functools import partial + +logger = get_logger(__name__) +logger.setLevel("DEBUG") +set_excepthook() + +# * Define custom ops (leaf submodules during tracing) +# * This is useful so we can write a single optimised verilog file for self attention, +# * instead of relying on emit_verilog to instantiate each submodule +MISTRAL_CUSTOM_OPS = { + "modules": {}, + "functions": {}, +} + + +def mistral_module_level_quantize(model, model_config, q_config): + return model + + +def mistral_update_metadata(mg, q_config): + """ + The following processing is a temporary hot fix to get emit verilog working on the mistral model. We + update the type and precision for the add, getitem and split (fork) nodes which are currently + inserted in the patched model code. In the (near) future, inserting forking nodes and setting their + precision correctly will be handled automatedly as a preprocessing step for the emit verilog pass, + so this function will be unnecessary. + """ + return mg, {} + + +def emit_verilog_mistral( + config, + q_config, + config_sequence_length, + wait_count=15, + wait_unit="ms", + max_parallelism=4, +): + # * Get model and quantize self attention, linear and layer norm layers + model = MistralModel(config) + model = mistral_module_level_quantize(model, config, q_config) + logger.info(f"Quantized mistral model: {model}") + + # * Trace the model + mg = MaseGraph(model, custom_ops=MISTRAL_CUSTOM_OPS) + mg, _ = passes.init_metadata_analysis_pass(mg) + + mg, _ = passes.report_graph_analysis_pass( + mg, pass_args={"file_name": "mistral.txt"} + ) + + # * Add metadata analysis passes + mg, _ = passes.add_common_metadata_analysis_pass( + mg, + pass_args={ + "dummy_in": { + "input_ids": torch.randn( + (1, config_sequence_length, config.hidden_size) + ) + }, + "add_value": False, + }, + ) + + mg, _ = mistral_update_metadata(mg, q_config) + + mg, _ = passes.add_hardware_metadata_analysis_pass( + mg, + pass_args={ + "max_parallelism": [max_parallelism] * 4, + }, + ) + + # * Save the metadata to a file for debugging + mg, _ = passes.report_node_meta_param_analysis_pass( + mg, + pass_args={ + "which": ["common", "hardware"], + "save_path": "mistral_graph_meta_params.txt", + }, + ) + + mg, _ = passes.emit_verilog_top_transform_pass(mg) + mg, _ = passes.emit_bram_transform_pass(mg) + mg, _ = passes.emit_internal_rtl_transform_pass(mg) + mg, _ = passes.emit_cocotb_transform_pass( + mg, + pass_args={ + "wait_time": wait_count, + "wait_unit": wait_unit, + }, + ) + mg, _ = passes.emit_vivado_project_transform_pass(mg) + + # Temporary: fix data coherency checks + os.environ["COCOTB_RESOLVE_X"] = "ZEROS" + + actions.simulate( + skip_build=False, skip_test=False, gui=False, waves=False, simulator="questa" + ) + + +def get_default_qconfig(): + return { + "data_in_width": 8, + "data_in_frac_width": 3, + "weight_width": 8, + "weight_frac_width": 3, + "bias_width": 8, + "bias_frac_width": 3, + "data_out_width": 8, + "data_out_frac_width": 3, + } + + +def test_emit_verilog_mistral_smoke(): + config = MistralConfig() + config.num_hidden_layers = 3 + config.hidden_size = 96 + config.intermediate_size = 384 + config_sequence_length = 4 + q_config = get_default_qconfig() + emit_verilog_mistral( + config, q_config, config_sequence_length, wait_count=10, max_parallelism=2 + ) + + +if __name__ == "__main__": + generate_sv_lut("silu", 8, 3, data_width=8, f_width=3, path_with_dtype=False) + test_emit_verilog_mistral_smoke()