From 289fecb03cf57c3e49e798c5e93b6f012877c338 Mon Sep 17 00:00:00 2001 From: Cheng Zhang Date: Tue, 24 Feb 2026 13:38:04 +0000 Subject: [PATCH] remove onn from mase --- docs/source/modules/chop/passes_module.rst | 4 - docs/source/modules/chop/transform/onn.rst | 207 --- .../getting_started/Get-started-using-uv.md | 2 +- .../newcompute/onn/1-transform.ipynb | 411 ------ .../newcompute/onn/2-finetuning.ipynb | 1224 ----------------- docs/tutorials/newcompute/onn/README.md | 138 -- .../module/transforms/bitflip/__init__.py | 3 - .../transforms/bitflip/bitflip_transform.py | 90 -- .../passes/module/transforms/onn/__init__.py | 3 - .../module/transforms/onn/layers/__init__.py | 0 .../module/transforms/onn/layers/attn.py | 358 ----- .../module/transforms/onn/layers/linear.py | 30 - .../passes/module/transforms/onn/transform.py | 180 --- .../onn/test_optical_transformer.py | 121 -- 14 files changed, 1 insertion(+), 2770 deletions(-) delete mode 100644 docs/source/modules/chop/transform/onn.rst delete mode 100644 docs/tutorials/newcompute/onn/1-transform.ipynb delete mode 100644 docs/tutorials/newcompute/onn/2-finetuning.ipynb delete mode 100644 docs/tutorials/newcompute/onn/README.md delete mode 100644 src/chop/passes/module/transforms/bitflip/__init__.py delete mode 100644 src/chop/passes/module/transforms/bitflip/bitflip_transform.py delete mode 100644 src/chop/passes/module/transforms/onn/__init__.py delete mode 100644 src/chop/passes/module/transforms/onn/layers/__init__.py delete mode 100644 src/chop/passes/module/transforms/onn/layers/attn.py delete mode 100644 src/chop/passes/module/transforms/onn/layers/linear.py delete mode 100644 src/chop/passes/module/transforms/onn/transform.py delete mode 100644 test/passes/module/transforms/onn/test_optical_transformer.py diff --git a/docs/source/modules/chop/passes_module.rst b/docs/source/modules/chop/passes_module.rst index e949b72cb..821b4d9f1 100644 --- a/docs/source/modules/chop/passes_module.rst +++ b/docs/source/modules/chop/passes_module.rst @@ -35,13 +35,9 @@ Summary of Mase Module Transform Passes * - :py:meth:`~chop.passes.module.transforms.quantize.quantize_module_transform_pass` - `test_module_quantize `_ - Apply quantization transformation to the given nn.Module - * - :py:meth:`~chop.passes.module.transforms.onn.optical_transformer_module_transform_pass` - - See :doc:`transform/onn` - - Transform modules to Optical Neural Network (ONN) equivalents .. toctree:: :maxdepth: 2 :caption: Full list of module-level transform passes module_transform/quantization - transform/onn diff --git a/docs/source/modules/chop/transform/onn.rst b/docs/source/modules/chop/transform/onn.rst deleted file mode 100644 index d9f319dba..000000000 --- a/docs/source/modules/chop/transform/onn.rst +++ /dev/null @@ -1,207 +0,0 @@ -chop.passes.module.transforms.onn -================================== - -This module provides transformation passes for converting standard neural network -modules into Optical Neural Network (ONN) equivalents. The optical transformer -implementation is based on the `Optical Transformers paper `_. - -Optical neural networks leverage photonic hardware to perform matrix multiplications -with reduced power consumption. This transform simulates the quantization effects -and constraints of optical compute hardware, enabling model development and evaluation -before deployment on physical optical accelerators. - -.. note:: - - This module requires the ``mase-triton`` package to be installed. - Install via: ``pip install mase-triton`` - - -Transform Pass --------------- - -optical\_transformer\_module\_transform\_pass -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: chop.passes.module.transforms.onn.optical_transformer_module_transform_pass - - -Configuration -------------- - -The transform pass accepts configuration through the ``pass_args`` dictionary. -Layer matching can be done by exact name or regex patterns. - -Example configuration: - -.. code-block:: python - - pass_args = { - "by": "regex_name", # or "name" for exact matching - "default": { - "q_levels": 256, - "q_lut_min": 0.020040, - "q_smooth_factor": 0.9, - "q_init_seed": 0, - "q_bypass": False, - }, - # Override for specific layers using regex - ".*mlp.*": { - "q_levels": 128, - "q_bypass": False, - }, - } - - -Configuration Parameters -^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. list-table:: - :header-rows: 1 - :widths: 20 15 15 50 - - * - Parameter - - Type - - Default - - Description - * - ``q_levels`` - - int - - 256 - - Number of quantization levels for optical simulation - * - ``q_lut_min`` - - float - - 0.020040 - - Minimum value for the lookup table used in quantization - * - ``q_quantiles`` - - tuple[float, float] or None - - None - - Quantile range for min/max statistics. If None, uses absolute min/max - * - ``q_smooth_factor`` - - float - - 0.9 - - Exponential moving average factor for updating running statistics - * - ``q_init_seed`` - - int - - 0 - - Random seed for quantization noise initialization - * - ``q_bypass`` - - bool - - False - - If True, bypass optical quantization (useful for debugging) - - -Layers ------- - -OtLinear -^^^^^^^^ - -.. py:data:: chop.passes.module.transforms.onn.layers.linear.OtLinear - - Optical Transformer Linear layer. - - This is an alias to ``mase_triton.optical_compute.layers.OpticalTransformerLinear``. - It replaces standard ``torch.nn.Linear`` layers with quantized optical transformer - equivalents that simulate optical neural network hardware constraints. - - The layer applies quantization to both the input activations and weights during - matrix multiplication, and tracks running min/max statistics for calibration. - - **Class method:** - - .. py:method:: from_linear(linear, **kwargs) - :classmethod: - - Create an OtLinear from an existing ``torch.nn.Linear`` layer. - - :param linear: Source linear layer - :type linear: torch.nn.Linear - :param kwargs: Quantization parameters (q_levels, q_lut_min, q_smooth_factor, q_init_seed, q_bypass, etc.) - :return: Optical transformer linear layer with copied weights - - -OtLlamaAttention -^^^^^^^^^^^^^^^^ - -.. autoclass:: chop.passes.module.transforms.onn.layers.attn.OtLlamaAttention - :members: - :undoc-members: - :show-inheritance: - - -Functional API --------------- - -ot\_eager\_attention\_forward -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: chop.passes.module.transforms.onn.layers.attn.ot_eager_attention_forward - - -Usage Example -------------- - -Basic usage with a LLaMA model: - -.. code-block:: python - - from transformers import AutoModelForCausalLM - from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass - - # Load a pretrained model - model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - - # Define transformation configuration - pass_args = { - "by": "regex_name", - "default": { - "q_levels": 256, - "q_lut_min": 0.020040, - "q_smooth_factor": 0.9, - "q_init_seed": 0, - "q_bypass": False, - }, - } - - # Apply the optical transformer transform - model = optical_transformer_module_transform_pass(model, pass_args) - - # The model now uses OtLinear and OtLlamaAttention layers - # Continue with training or inference as usual - - -Selective Layer Transformation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Transform only specific layers using regex patterns: - -.. code-block:: python - - pass_args = { - "by": "regex_name", - # Only transform attention layers - ".*self_attn.*": { - "q_levels": 256, - "q_bypass": False, - }, - # Transform MLP with different settings - ".*mlp.*": { - "q_levels": 128, - "q_bypass": False, - }, - } - - -Bypass Mode for Debugging -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Use ``q_bypass=True`` to disable quantization while keeping the module structure: - -.. code-block:: python - - pass_args = { - "by": "regex_name", - "default": { - "q_levels": 256, - "q_bypass": True, # Disable quantization - }, - } diff --git a/docs/source/modules/documentation/getting_started/Get-started-using-uv.md b/docs/source/modules/documentation/getting_started/Get-started-using-uv.md index 81d00673b..2845975f3 100644 --- a/docs/source/modules/documentation/getting_started/Get-started-using-uv.md +++ b/docs/source/modules/documentation/getting_started/Get-started-using-uv.md @@ -49,7 +49,7 @@ In the `uv` workflow, the standard way to execute commands is via `uv run`. This 2. **Running tests**: You can run the test suite while ignoring tests that require heavy hardware dependencies (like Verilator) or platform-specific packages (like `mase-triton`): ```bash - uv run pytest test/ --ignore=test/passes/graph/transforms/verilog --ignore=test/passes/module/transforms/onn/test_optical_transformer.py + uv run pytest test/ --ignore=test/passes/graph/transforms/verilog ``` ## Test your installation diff --git a/docs/tutorials/newcompute/onn/1-transform.ipynb b/docs/tutorials/newcompute/onn/1-transform.ipynb deleted file mode 100644 index 5d04e5ca9..000000000 --- a/docs/tutorials/newcompute/onn/1-transform.ipynb +++ /dev/null @@ -1,411 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "1a79ad55", - "metadata": {}, - "source": [ - "# Optical Transformer Transform Pass\n", - "\n", - "This tutorial provides minimal documentation for the Optical Neural Network (ONN) transform pass and layer classes in MASE.\n", - "\n", - "The optical transformer implementation is based on the [Optical Transformers paper](https://arxiv.org/abs/2302.10360).\n", - "\n", - "## Overview\n", - "\n", - "The ONN transform pass replaces standard PyTorch modules with their optical transformer equivalents:\n", - "\n", - "| Original Module | Optical Equivalent |\n", - "|-----------------|--------------------|\n", - "| `torch.nn.Linear` | `OtLinear` |\n", - "| `LlamaAttention` | `OtLlamaAttention` |\n", - "\n", - "## Requirements\n", - "\n", - "The `mase-triton` package is required for ONN transforms:\n", - "\n", - "```bash\n", - "pip install mase-triton\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "984ca459", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/zz7522/miniconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig\n", - "\n", - "from chop.passes.module.transforms.onn.transform import (\n", - " OtLinear,\n", - " OtLlamaAttention,\n", - " OtTransformConfig,\n", - " optical_transformer_module_transform_pass,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0fc4e439", - "metadata": {}, - "source": [ - "## Configuration\n", - "\n", - "Use `OtTransformConfig` to configure the optical transform parameters:\n", - "\n", - "| Parameter | Type | Default | Description |\n", - "|-----------|------|---------|-------------|\n", - "| `q_levels` | int | 256 | Number of quantization levels, $2^n$ for n-bit quantization. |\n", - "| `q_lut_min` | float | 0.020040 | Minimum LUT value for quantization |\n", - "| `q_smooth_factor` | float | 0.9 | Smoothing factor for statistics updates in the training mode |\n", - "| `q_init_seed` | int | 0 | Random seed for initialization (only used in triton kernels) |\n", - "| `q_bypass` | bool | False | If True, bypass optical quantization |" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9a7c1bbd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Default ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.9, 'q_init_seed': 0, 'q_bypass': False}\n", - "Modified ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.1, 'q_init_seed': 0, 'q_bypass': False}\n" - ] - } - ], - "source": [ - "# Create default configuration\n", - "onn_config = OtTransformConfig.create_default()\n", - "print(\"Default ONN config:\", onn_config)\n", - "\n", - "# Customize configuration\n", - "onn_config[\"q_levels\"] = 256 # 8-bit quantization\n", - "onn_config[\"q_smooth_factor\"] = 0.1\n", - "print(\"Modified ONN config:\", onn_config)" - ] - }, - { - "cell_type": "markdown", - "id": "de082853", - "metadata": {}, - "source": [ - "## OtLinear: Optical Linear Layer\n", - "\n", - "`OtLinear` is the optical equivalent of `torch.nn.Linear`. It applies quantized matrix multiplication that simulates optical computing behavior." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "7e8b9c1d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Original output shape: torch.Size([2, 64])\n", - "Optical output shape: torch.Size([2, 64])\n", - "Max absolute difference: 0.035205\n" - ] - } - ], - "source": [ - "# Create a standard linear layer\n", - "linear = torch.nn.Linear(in_features=32, out_features=64)\n", - "\n", - "# Convert to optical linear layer\n", - "onn_config = OtTransformConfig.create_default()\n", - "linear_onn = OtLinear.from_linear(linear, **onn_config)\n", - "\n", - "# Compare outputs\n", - "x = torch.randn(2, 32)\n", - "y = linear(x)\n", - "y_onn = linear_onn(x)\n", - "\n", - "print(f\"Original output shape: {y.shape}\")\n", - "print(f\"Optical output shape: {y_onn.shape}\")\n", - "print(f\"Max absolute difference: {(y - y_onn).abs().max().item():.6f}\")" - ] - }, - { - "cell_type": "markdown", - "id": "5d4e368f", - "metadata": {}, - "source": [ - "## OtLlamaAttention: Optical Llama Attention\n", - "\n", - "`OtLlamaAttention` replaces the HuggingFace `LlamaAttention` with an optical-aware implementation that uses quantized scaled dot-product attention." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "6e7a6261", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Original output shape: torch.Size([1, 16, 384])\n", - "Optical output shape: torch.Size([1, 16, 384])\n" - ] - } - ], - "source": [ - "# Setup Llama configuration\n", - "model_name = \"AICrossSim/clm-60m\"\n", - "hf_config = LlamaConfig.from_pretrained(model_name)\n", - "\n", - "batch_size = 1\n", - "seq_len = 16\n", - "head_dim = hf_config.hidden_size // hf_config.num_attention_heads\n", - "\n", - "# Create standard attention layer\n", - "attn = LlamaAttention(config=hf_config, layer_idx=0)\n", - "\n", - "# Convert to optical attention\n", - "onn_config = OtTransformConfig.create_default()\n", - "onn_config[\"q_levels\"] = 512\n", - "attn_onn = OtLlamaAttention.from_pretrained(attn, layer_idx=0, **onn_config)\n", - "\n", - "# Test forward pass\n", - "pos_emb = torch.ones(batch_size, seq_len, head_dim)\n", - "x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)\n", - "\n", - "y, _ = attn(x, (pos_emb, pos_emb), None)\n", - "attn_onn.train() # Enable statistics updates\n", - "y_onn, _ = attn_onn(x, (pos_emb, pos_emb), None)\n", - "\n", - "print(f\"Original output shape: {y.shape}\")\n", - "print(f\"Optical output shape: {y_onn.shape}\")" - ] - }, - { - "cell_type": "markdown", - "id": "1ac49d15", - "metadata": {}, - "source": [ - "## Transform Pass: Network-Level Transformation\n", - "\n", - "Use `optical_transformer_module_transform_pass` to transform an entire network. The pass replaces modules based on name matching.\n", - "\n", - "### Pass Arguments\n", - "\n", - "| Key | Description |\n", - "|-----|-------------|\n", - "| `by` | Matching mode: `\"name\"` (exact) or `\"regex_name\"` (regex pattern) |\n", - "| `` | Configuration dict for layers matching the name/pattern |\n", - "| `default` | Fallback configuration if no pattern matches |" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "bdee282d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Original network:\n", - "SimpleNetwork(\n", - " (attn): LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " )\n", - " (linear): Linear(in_features=384, out_features=384, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "# Define a simple network with attention and linear layers\n", - "class SimpleNetwork(torch.nn.Module):\n", - " def __init__(self, hf_config):\n", - " super().__init__()\n", - " self.attn = LlamaAttention(config=hf_config, layer_idx=0)\n", - " self.linear = torch.nn.Linear(\n", - " in_features=hf_config.hidden_size,\n", - " out_features=hf_config.hidden_size,\n", - " )\n", - "\n", - " def forward(self, x, pos_emb):\n", - " attn_output, _ = self.attn(x, (pos_emb, pos_emb), None)\n", - " output = self.linear(attn_output)\n", - " return output\n", - "\n", - "network = SimpleNetwork(hf_config)\n", - "print(\"Original network:\")\n", - "print(network)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "ffcb8d3d", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=True) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Transformed network:\n", - "SimpleNetwork(\n", - " (attn): OtLlamaAttention(\n", - " (q_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - " (k_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - " (v_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - " (o_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - " )\n", - " (linear): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - ")\n" - ] - } - ], - "source": [ - "# Configure the transform pass with regex patterns\n", - "onn_config = OtTransformConfig.create_default()\n", - "onn_config[\"q_levels\"] = 512\n", - "\n", - "pass_args = {\n", - " \"by\": \"regex_name\", # Use regex matching\n", - " \"attn\": onn_config, # Transform the attention layer\n", - " \"linear\": onn_config, # Transform the linear layer\n", - " r\"attn\\.(q|k|v|o)_proj\": onn_config, # Transform Q/K/V/O projections inside attention\n", - "}\n", - "\n", - "# Apply the transform\n", - "network_onn = optical_transformer_module_transform_pass(network, pass_args)\n", - "\n", - "print(\"\\nTransformed network:\")\n", - "print(network_onn)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "3c3557ec", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Verification:\n", - " attn is OtLlamaAttention: True\n", - " linear is OtLinear: True\n", - " attn.q_proj is OtLinear: True\n", - " attn.k_proj is OtLinear: True\n", - " attn.v_proj is OtLinear: True\n", - " attn.o_proj is OtLinear: True\n" - ] - } - ], - "source": [ - "# Verify the transformation\n", - "print(\"Verification:\")\n", - "print(f\" attn is OtLlamaAttention: {isinstance(network_onn.attn, OtLlamaAttention)}\")\n", - "print(f\" linear is OtLinear: {isinstance(network_onn.linear, OtLinear)}\")\n", - "print(f\" attn.q_proj is OtLinear: {isinstance(network_onn.attn.q_proj, OtLinear)}\")\n", - "print(f\" attn.k_proj is OtLinear: {isinstance(network_onn.attn.k_proj, OtLinear)}\")\n", - "print(f\" attn.v_proj is OtLinear: {isinstance(network_onn.attn.v_proj, OtLinear)}\")\n", - "print(f\" attn.o_proj is OtLinear: {isinstance(network_onn.attn.o_proj, OtLinear)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "355b1f3c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Output shape: torch.Size([1, 16, 384])\n", - "Max output error: 0.029137\n", - "Output is finite: True\n" - ] - } - ], - "source": [ - "# Test the transformed network\n", - "network_onn.train() # Enable statistics updates\n", - "\n", - "pos_emb = torch.ones(batch_size, seq_len, head_dim)\n", - "x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)\n", - "\n", - "y = network(x, pos_emb)\n", - "y_onn = network_onn(x, pos_emb)\n", - "print(f\"Output shape: {y_onn.shape}\")\n", - "print(f\"Max output error: {(y - y_onn).abs().max().item():.6f}\")\n", - "print(f\"Output is finite: {y_onn.isfinite().all().item()}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mase", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/newcompute/onn/2-finetuning.ipynb b/docs/tutorials/newcompute/onn/2-finetuning.ipynb deleted file mode 100644 index 10aaf438f..000000000 --- a/docs/tutorials/newcompute/onn/2-finetuning.ipynb +++ /dev/null @@ -1,1224 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "6ba64c6d", - "metadata": {}, - "source": [ - "# Fine-tuning HuggingFace Llama with Optical Transformer Transform\n", - "\n", - "This tutorial demonstrates how to:\n", - "1. Load a pretrained HuggingFace Llama model\n", - "2. Transform it using the optical transformer pass from MASE\n", - "3. Run continual fine-tuning on the transformed model\n", - "\n", - "## Overview\n", - "\n", - "The optical transformer transform replaces standard PyTorch modules with their optical equivalents that simulate optical computing behavior. This enables:\n", - "- Quantized matrix multiplication that models optical hardware\n", - "- Noise-aware training for robust optical neural network deployment\n", - "\n", - "## Requirements\n", - "\n", - "```bash\n", - "pip install mase-triton transformers datasets accelerate\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "943f547b", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/zz7522/miniconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from transformers import (\n", - " AutoConfig,\n", - " AutoModelForCausalLM,\n", - " AutoTokenizer,\n", - " get_scheduler,\n", - " default_data_collator,\n", - ")\n", - "from datasets import load_dataset\n", - "from torch.utils.data import DataLoader\n", - "from itertools import chain\n", - "from tqdm.auto import tqdm\n", - "\n", - "from chop.passes.module.transforms.onn.transform import (\n", - " OtLinear,\n", - " OtLlamaAttention,\n", - " OtTransformConfig,\n", - " optical_transformer_module_transform_pass,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "dea4c060", - "metadata": {}, - "source": [ - "## 1. Load Pretrained HuggingFace Llama Model\n", - "\n", - "We'll use a small Llama model for demonstration. You can replace this with any Llama-based model." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "38b797b7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: cuda\n" - ] - } - ], - "source": [ - "# Model configuration\n", - "MODEL_NAME = \"AICrossSim/clm-60m\" # Small Llama model for demo\n", - "BLOCK_SIZE = 128 # Sequence length (use smaller value for demo)\n", - "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - "print(f\"Using device: {DEVICE}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8b6ed596", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model type: LlamaForCausalLM\n", - "Number of parameters: 82,101,120\n" - ] - } - ], - "source": [ - "# Load model and tokenizer\n", - "config = AutoConfig.from_pretrained(MODEL_NAME)\n", - "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " MODEL_NAME,\n", - " config=config,\n", - " attn_implementation=\"eager\", # Use eager attention for compatibility\n", - ")\n", - "\n", - "print(f\"Model type: {type(model).__name__}\")\n", - "print(f\"Number of parameters: {sum(p.numel() for p in model.parameters()):,}\")" - ] - }, - { - "cell_type": "markdown", - "id": "e99e16d4", - "metadata": {}, - "source": [ - "## 2. Configure and Apply Optical Transform\n", - "\n", - "We configure the optical transform with quantization parameters and apply it to the model." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "968be7f6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ONN Configuration:\n", - " q_levels: 256\n", - " q_lut_min: 0.02004\n", - " q_smooth_factor: 0.9\n", - " q_init_seed: 0\n", - " q_bypass: False\n" - ] - } - ], - "source": [ - "# Create ONN configuration\n", - "onn_config = OtTransformConfig.create_default()\n", - "\n", - "# Customize configuration (optional)\n", - "onn_config[\"q_levels\"] = 256 # Number of quantization levels\n", - "onn_config[\"q_lut_min\"] = 0.020040 # Minimum LUT value\n", - "onn_config[\"q_smooth_factor\"] = 0.9 # Statistics smoothing factor\n", - "onn_config[\"q_bypass\"] = False # Set to True to bypass optical quantization\n", - "\n", - "print(\"ONN Configuration:\")\n", - "for k, v in onn_config.items():\n", - " print(f\" {k}: {v}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "c4c54884", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Transform pass arguments configured\n" - ] - } - ], - "source": [ - "# Configure the transform pass\n", - "# Use regex patterns to match layer names for transformation\n", - "pass_args = {\n", - " \"by\": \"regex_name\",\n", - " # Transform all attention layers\n", - " r\"model\\.layers\\.\\d+\\.self_attn\": onn_config,\n", - " # Transform attention projections (Q, K, V, O)\n", - " r\"model\\.layers\\.\\d+\\.self_attn\\.(q|k|v|o)_proj\": onn_config,\n", - " # Transform MLP layers\n", - " r\"model\\.layers\\.\\d+\\.mlp\\.(gate|up|down)_proj\": onn_config,\n", - "}\n", - "\n", - "print(\"Transform pass arguments configured\")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "183195e2", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ") to OtLlamaAttention(\n", - " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", - " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", - " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", - ")\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Transforming model...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", - "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model transformed successfully!\n" - ] - } - ], - "source": [ - "# Apply the optical transformer transform\n", - "print(\"Transforming model...\")\n", - "model = optical_transformer_module_transform_pass(model, pass_args)\n", - "print(\"Model transformed successfully!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "57dd533d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Transformed layers:\n", - " OtLinear: 154\n", - " OtLlamaAttention: 22\n" - ] - } - ], - "source": [ - "# Verify the transformation\n", - "def count_transformed_layers(model):\n", - " ot_linear_count = 0\n", - " ot_attn_count = 0\n", - " for name, module in model.named_modules():\n", - " if isinstance(module, OtLinear):\n", - " ot_linear_count += 1\n", - " elif isinstance(module, OtLlamaAttention):\n", - " ot_attn_count += 1\n", - " return ot_linear_count, ot_attn_count\n", - "\n", - "ot_linear, ot_attn = count_transformed_layers(model)\n", - "print(f\"Transformed layers:\")\n", - "print(f\" OtLinear: {ot_linear}\")\n", - "print(f\" OtLlamaAttention: {ot_attn}\")" - ] - }, - { - "cell_type": "markdown", - "id": "066772ea", - "metadata": {}, - "source": [ - "## 3. Prepare Dataset for Fine-tuning\n", - "\n", - "We'll use a small subset of a text dataset for demonstration." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a52002e3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train samples: 36718\n", - "Validation samples: 3760\n" - ] - } - ], - "source": [ - "# Load a small dataset for demonstration\n", - "raw_datasets = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n", - "\n", - "print(f\"Train samples: {len(raw_datasets['train'])}\")\n", - "print(f\"Validation samples: {len(raw_datasets['validation'])}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "a12cb5b5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tokenization complete\n" - ] - } - ], - "source": [ - "# Tokenize the dataset\n", - "def tokenize_function(examples):\n", - " return tokenizer(examples[\"text\"])\n", - "\n", - "tokenized_datasets = raw_datasets.map(\n", - " tokenize_function,\n", - " batched=True,\n", - " remove_columns=[\"text\"],\n", - " desc=\"Tokenizing\",\n", - ")\n", - "\n", - "print(\"Tokenization complete\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "05df197a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train chunks: 19848\n", - "Validation chunks: 2074\n" - ] - } - ], - "source": [ - "# Group texts into chunks of block_size\n", - "def group_texts(examples):\n", - " # Concatenate all texts\n", - " concatenated_examples = {k: list(chain(*examples[k])) for k in examples}\n", - " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", - " # Drop the remainder\n", - " total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE\n", - " # Split into chunks\n", - " result = {\n", - " k: [t[i : i + BLOCK_SIZE] for i in range(0, total_length, BLOCK_SIZE)]\n", - " for k, t in concatenated_examples.items()\n", - " }\n", - " result[\"labels\"] = result[\"input_ids\"].copy()\n", - " return result\n", - "\n", - "lm_datasets = tokenized_datasets.map(\n", - " group_texts,\n", - " batched=True,\n", - " desc=f\"Grouping texts in chunks of {BLOCK_SIZE}\",\n", - ")\n", - "\n", - "print(f\"Train chunks: {len(lm_datasets['train'])}\")\n", - "print(f\"Validation chunks: {len(lm_datasets['validation'])}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "cee790dd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train batches: 4962\n", - "Eval batches: 519\n" - ] - } - ], - "source": [ - "# Create DataLoaders\n", - "BATCH_SIZE = 4\n", - "\n", - "train_dataloader = DataLoader(\n", - " lm_datasets[\"train\"],\n", - " shuffle=True,\n", - " collate_fn=default_data_collator,\n", - " batch_size=BATCH_SIZE,\n", - ")\n", - "\n", - "eval_dataloader = DataLoader(\n", - " lm_datasets[\"validation\"],\n", - " collate_fn=default_data_collator,\n", - " batch_size=BATCH_SIZE,\n", - ")\n", - "\n", - "print(f\"Train batches: {len(train_dataloader)}\")\n", - "print(f\"Eval batches: {len(eval_dataloader)}\")" - ] - }, - { - "cell_type": "markdown", - "id": "9796a8b2", - "metadata": {}, - "source": [ - "## 4. Setup Training\n", - "\n", - "Configure optimizer, scheduler, and training parameters." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "d69f8a60", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training configuration:\n", - " Learning rate: 0.0002\n", - " Weight decay: 0.01\n", - " Max train steps: 100\n" - ] - } - ], - "source": [ - "# Training hyperparameters\n", - "LEARNING_RATE = 2e-4\n", - "WEIGHT_DECAY = 0.01\n", - "NUM_EPOCHS = 1 # Use 1 epoch for demo\n", - "MAX_TRAIN_STEPS = 100 # Limit steps for demo\n", - "WARMUP_STEPS = 10\n", - "\n", - "print(\"Training configuration:\")\n", - "print(f\" Learning rate: {LEARNING_RATE}\")\n", - "print(f\" Weight decay: {WEIGHT_DECAY}\")\n", - "print(f\" Max train steps: {MAX_TRAIN_STEPS}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0e2230cd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Trainable parameters: 82,101,120\n" - ] - } - ], - "source": [ - "# Move model to device\n", - "model = model.to(DEVICE)\n", - "\n", - "# Set all parameters trainable\n", - "for param in model.parameters():\n", - " param.requires_grad = True\n", - "\n", - "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "print(f\"Trainable parameters: {trainable_params:,}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "e2dd98fe", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Optimizer configured\n" - ] - } - ], - "source": [ - "# Setup optimizer with weight decay\n", - "no_decay = [\"bias\", \"layer_norm.weight\"]\n", - "optimizer_grouped_parameters = [\n", - " {\n", - " \"params\": [\n", - " p for n, p in model.named_parameters()\n", - " if not any(nd in n for nd in no_decay) and p.requires_grad\n", - " ],\n", - " \"weight_decay\": WEIGHT_DECAY,\n", - " },\n", - " {\n", - " \"params\": [\n", - " p for n, p in model.named_parameters()\n", - " if any(nd in n for nd in no_decay) and p.requires_grad\n", - " ],\n", - " \"weight_decay\": 0.0,\n", - " },\n", - "]\n", - "\n", - "optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)\n", - "print(\"Optimizer configured\")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0a8803df", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "LR scheduler configured\n" - ] - } - ], - "source": [ - "# Setup learning rate scheduler\n", - "lr_scheduler = get_scheduler(\n", - " name=\"linear\",\n", - " optimizer=optimizer,\n", - " num_warmup_steps=WARMUP_STEPS,\n", - " num_training_steps=MAX_TRAIN_STEPS,\n", - ")\n", - "\n", - "print(\"LR scheduler configured\")" - ] - }, - { - "cell_type": "markdown", - "id": "dfe3d417", - "metadata": {}, - "source": [ - "## 5. Training Loop\n", - "\n", - "Run the fine-tuning loop with the transformed optical model." - ] - }, - { - "cell_type": "markdown", - "id": "f58d6dd2", - "metadata": {}, - "source": [ - "### Quantization Statistics Warmup\n", - "\n", - "**Important:** The optical transformer layers require calibration of their quantization statistics (min/max values) before they can work correctly. Without this warmup:\n", - "- The statistics are initialized to `[inf, -inf]`\n", - "- The quantized matmul operations produce NaN values\n", - "- Loss and perplexity become NaN\n", - "\n", - "We run a few forward passes in **training mode** to let the layers collect statistics from the data." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "3653ac53", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running warmup to initialize quantization statistics...\n", - " Warmup batch 1/5\n", - " Warmup batch 2/5\n", - " Warmup batch 3/5\n", - " Warmup batch 4/5\n", - " Warmup batch 5/5\n", - "Warmup complete! Quantization statistics initialized.\n" - ] - } - ], - "source": [ - "# Warmup: Run a few forward passes in training mode to initialize quantization statistics\n", - "# This is necessary because the optical transformer layers need to calibrate their\n", - "# min/max statistics before they can perform quantized operations correctly.\n", - "\n", - "print(\"Running warmup to initialize quantization statistics...\")\n", - "model.train() # Must be in training mode to update stats\n", - "num_warmup_batches = 5\n", - "\n", - "with torch.no_grad(): # No need for gradients during warmup\n", - " for i, batch in enumerate(train_dataloader):\n", - " if i >= num_warmup_batches:\n", - " break\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - " _ = model(**batch)\n", - " print(f\" Warmup batch {i+1}/{num_warmup_batches}\")\n", - "\n", - "print(\"Warmup complete! Quantization statistics initialized.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "2300b1a1", - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "\n", - "def evaluate(model, eval_dataloader, device):\n", - " \"\"\"Evaluate model and return perplexity.\"\"\"\n", - " model.eval()\n", - " losses = []\n", - " for batch in eval_dataloader:\n", - " batch = {k: v.to(device) for k, v in batch.items()}\n", - " with torch.no_grad():\n", - " outputs = model(**batch)\n", - " losses.append(outputs.loss.item())\n", - "\n", - " avg_loss = sum(losses) / len(losses)\n", - " perplexity = math.exp(avg_loss)\n", - " return avg_loss, perplexity" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "dd599352", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Evaluating before training...\n", - "Initial - Loss: 7.0303, Perplexity: 1130.35\n" - ] - } - ], - "source": [ - "# Evaluate before training\n", - "print(\"Evaluating before training...\")\n", - "eval_loss, eval_ppl = evaluate(model, eval_dataloader, DEVICE)\n", - "print(f\"Initial - Loss: {eval_loss:.4f}, Perplexity: {eval_ppl:.2f}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "c9132b0a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Starting training...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Training: 100%|██████████| 100/100 [00:24<00:00, 4.15it/s, loss=5.9711]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Training completed! Steps: 100\n" - ] - } - ], - "source": [ - "# Training loop\n", - "print(\"\\nStarting training...\")\n", - "model.train()\n", - "completed_steps = 0\n", - "train_losses = []\n", - "\n", - "progress_bar = tqdm(range(MAX_TRAIN_STEPS), desc=\"Training\")\n", - "\n", - "for epoch in range(NUM_EPOCHS):\n", - " for step, batch in enumerate(train_dataloader):\n", - " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", - "\n", - " # Forward pass\n", - " outputs = model(**batch)\n", - " loss = outputs.loss\n", - " train_losses.append(loss.item())\n", - "\n", - " # Backward pass\n", - " loss.backward()\n", - " optimizer.step()\n", - " lr_scheduler.step()\n", - " optimizer.zero_grad()\n", - "\n", - " progress_bar.update(1)\n", - " completed_steps += 1\n", - "\n", - " # Log progress\n", - " if completed_steps % 20 == 0:\n", - " avg_loss = sum(train_losses[-20:]) / min(20, len(train_losses))\n", - " progress_bar.set_postfix({\"loss\": f\"{avg_loss:.4f}\"})\n", - "\n", - " if completed_steps >= MAX_TRAIN_STEPS:\n", - " break\n", - "\n", - " if completed_steps >= MAX_TRAIN_STEPS:\n", - " break\n", - "\n", - "print(f\"\\nTraining completed! Steps: {completed_steps}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "18251766", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Evaluating after training...\n", - "Final - Loss: 5.9571, Perplexity: 386.50\n" - ] - } - ], - "source": [ - "# Evaluate after training\n", - "print(\"\\nEvaluating after training...\")\n", - "eval_loss, eval_ppl = evaluate(model, eval_dataloader, DEVICE)\n", - "print(f\"Final - Loss: {eval_loss:.4f}, Perplexity: {eval_ppl:.2f}\")" - ] - }, - { - "cell_type": "markdown", - "id": "36f3e09b", - "metadata": {}, - "source": [ - "## 6. Save the Fine-tuned Model (Optional)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "472d0991", - "metadata": {}, - "outputs": [], - "source": [ - "# Uncomment to save the model\n", - "# OUTPUT_DIR = \"./ot-llama-finetuned\"\n", - "# model.save_pretrained(OUTPUT_DIR)\n", - "# tokenizer.save_pretrained(OUTPUT_DIR)\n", - "# print(f\"Model saved to {OUTPUT_DIR}\")" - ] - }, - { - "cell_type": "markdown", - "id": "24de471e", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "This notebook demonstrated:\n", - "\n", - "1. **Loading a HuggingFace Llama model** using `AutoModelForCausalLM`\n", - "2. **Configuring the optical transform** with `OtTransformConfig`\n", - "3. **Applying the transform pass** using `optical_transformer_module_transform_pass`\n", - "4. **Preparing a dataset** for causal language modeling\n", - "5. **Running fine-tuning** with the transformed optical model\n", - "\n", - "### Key Points\n", - "\n", - "- Use `attn_implementation=\"eager\"` when loading the model for compatibility\n", - "- The transform pass uses regex patterns to match layer names\n", - "- Training mode (`model.train()`) enables statistics updates in optical layers\n", - "- The optical quantization adds noise that the model learns to be robust against\n", - "\n", - "### References\n", - "\n", - "- [Optical Transformers Paper](https://arxiv.org/abs/2302.10360)\n", - "- MASE ONN Transform: `src/chop/passes/module/transforms/onn/`" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mase", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/newcompute/onn/README.md b/docs/tutorials/newcompute/onn/README.md deleted file mode 100644 index a2d813af3..000000000 --- a/docs/tutorials/newcompute/onn/README.md +++ /dev/null @@ -1,138 +0,0 @@ -# Optical Neural Network (ONN) Transform API - -This module provides tools for transforming PyTorch neural networks to simulate optical computing behavior, based on the [Optical Transformers paper](https://arxiv.org/abs/2302.10360). - -## Installation - -```bash -pip install mase-triton -``` - -## Quick Start - -```python -from chop.passes.module.transforms.onn.transform import ( - OtTransformConfig, - optical_transformer_module_transform_pass, -) - -# Create configuration -config = OtTransformConfig.create_default() - -# Transform a model -pass_args = { - "by": "regex_name", - r"model\.layers\.\d+\.self_attn": config, - r"model\.layers\.\d+\.mlp\..*_proj": config, -} -model = optical_transformer_module_transform_pass(model, pass_args) -``` - -## API Reference - -### `OtTransformConfig` - -Configuration dictionary for optical transform parameters. - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `q_levels` | int | 256 | Number of quantization levels ($2^n$ for n-bit) | -| `q_lut_min` | float | 0.020040 | Minimum LUT value for quantization | -| `q_smooth_factor` | float | 0.9 | Smoothing factor for statistics updates | -| `q_init_seed` | int | 0 | Random seed for Triton kernels | -| `q_bypass` | bool | False | Bypass optical quantization if True | - -```python -# Create default config -config = OtTransformConfig.create_default() - -# Customize -config["q_levels"] = 512 # 9-bit quantization -config["q_smooth_factor"] = 0.1 -``` - -### `optical_transformer_module_transform_pass` - -Transform supported modules in a network to their optical equivalents. - -```python -optical_transformer_module_transform_pass(network, pass_args) -> torch.nn.Module -``` - -**Parameters:** -- `network`: The PyTorch model to transform -- `pass_args`: Configuration dictionary with: - - `by`: Matching mode - `"name"` (exact) or `"regex_name"` (regex pattern) - - Layer patterns mapped to `OtTransformConfig` dicts - - `default`: Optional fallback config - -**Supported Transformations:** - -| Original Module | Optical Equivalent | -|-----------------|-------------------| -| `torch.nn.Linear` | `OtLinear` | -| `LlamaAttention` | `OtLlamaAttention` | - -### `OtLinear` - -Optical equivalent of `torch.nn.Linear` with quantized matrix multiplication. - -```python -from chop.passes.module.transforms.onn.transform import OtLinear - -# Convert from existing linear layer -linear_onn = OtLinear.from_linear(linear, **config) -``` - -### `OtLlamaAttention` - -Optical equivalent of HuggingFace's `LlamaAttention` with quantized scaled dot-product attention. - -```python -from chop.passes.module.transforms.onn.transform import OtLlamaAttention - -# Convert from existing attention layer -attn_onn = OtLlamaAttention.from_pretrained(attn, **config) -``` - -## Important Notes - -### Quantization Statistics Warmup - -Optical layers require calibration before use. Run a few forward passes in **training mode** first: - -```python -model.train() -with torch.no_grad(): - for batch in warmup_batches: - _ = model(**batch) -``` - -Without warmup, statistics are `[inf, -inf]` and outputs will be NaN. - -### Training vs Evaluation Mode - -- **Training mode** (`model.train()`): Statistics are updated with each forward pass -- **Evaluation mode** (`model.eval()`): Statistics are frozen - -### Attention Implementation - -When loading HuggingFace models, use eager attention for compatibility: - -```python -model = AutoModelForCausalLM.from_pretrained( - model_name, - attn_implementation="eager", -) -``` - - -## Source Code - -- Transform pass: `src/chop/passes/module/transforms/onn/transform.py` -- Linear layer: `src/chop/passes/module/transforms/onn/layers/linear.py` -- Attention layer: `src/chop/passes/module/transforms/onn/layers/attn.py` - -## References - -- [Optical Transformers: End-to-end Optical Training of Transformer Models](https://arxiv.org/abs/2302.10360) diff --git a/src/chop/passes/module/transforms/bitflip/__init__.py b/src/chop/passes/module/transforms/bitflip/__init__.py deleted file mode 100644 index dd03ca3b0..000000000 --- a/src/chop/passes/module/transforms/bitflip/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .bitflip_transform import bitflip_module_transform_pass - -__all__ = ["bitflip_module_transform_pass"] diff --git a/src/chop/passes/module/transforms/bitflip/bitflip_transform.py b/src/chop/passes/module/transforms/bitflip/bitflip_transform.py deleted file mode 100644 index 61d3cb66d..000000000 --- a/src/chop/passes/module/transforms/bitflip/bitflip_transform.py +++ /dev/null @@ -1,90 +0,0 @@ -try: - from mase_triton.random_bitflip.layers import RandomBitFlipLinear - - MASE_TRITON_AVAILABLE = True -except ImportError: - MASE_TRITON_AVAILABLE = False - -import torch - -from ...module_modify_helper import replace_by_name -from ...state_dict_map import match_a_pattern - - -def get_config_by_name(config: dict, name: str): - if name in config: - return config[name] - else: - if "default" in config: - return config["default"] - else: - return None - - -def get_config_by_regex_name(config: dict, name: str): - matched_pattern = match_a_pattern(name, config.keys()) - if matched_pattern is None: - if "default" in config: - return config["default"] - else: - return None - else: - return config[matched_pattern] - - -def get_layer_config( - layer_name_to_config: dict[str, dict], use_regex: bool, layer_name: str -) -> dict | None: - if use_regex: - config = get_config_by_regex_name(layer_name_to_config, layer_name) - else: - config = get_config_by_name(layer_name_to_config, layer_name) - return config - - -if MASE_TRITON_AVAILABLE: - BITFLIP_CLS_MAP = { - torch.nn.Linear: RandomBitFlipLinear, - } - - def bitflip_module_transform_pass( - network: torch.nn.Module, pass_args: dict - ) -> torch.nn.Module: - """ - Apply bitflip module transform pass to the network. - - :param network: The network to be transformed. - :type network: torch.nn.Module - :param pass_args: The arguments for the transformation. - :type pass_args: dict - :return: The transformed network. - :rtype: torch.nn.Module - :raises AssertionError: If the `by` argument is not in ["name", "regex"]. - """ - target_classes = tuple(BITFLIP_CLS_MAP.keys()) - by = pass_args.pop("by", "regex_name") - assert by in [ - "name", - "regex_name", - ], f"by should be in ['name', 'regex_name'], but got {by}" - - for m_name, m in network.named_modules(): - if not isinstance(m, target_classes): - continue - m_config = get_layer_config( - pass_args, use_regex=by == "regex_name", layer_name=m_name - ) - if m_config is None: - continue - new_m_cls = BITFLIP_CLS_MAP[type(m)] - new_m = new_m_cls.from_linear(m, **m_config) - replace_by_name(network, name=m_name, module=new_m) - - return network - -else: - - def bitflip_module_transform_pass( - network: torch.nn.Module, pass_args: dict - ) -> torch.nn.Module: - raise RuntimeError("mase-triton is not available, please install it first.") diff --git a/src/chop/passes/module/transforms/onn/__init__.py b/src/chop/passes/module/transforms/onn/__init__.py deleted file mode 100644 index 19f18a190..000000000 --- a/src/chop/passes/module/transforms/onn/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .transform import optical_transformer_module_transform_pass - -__all__ = ["optical_transformer_module_transform_pass"] diff --git a/src/chop/passes/module/transforms/onn/layers/__init__.py b/src/chop/passes/module/transforms/onn/layers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/chop/passes/module/transforms/onn/layers/attn.py b/src/chop/passes/module/transforms/onn/layers/attn.py deleted file mode 100644 index 1f4ccc16e..000000000 --- a/src/chop/passes/module/transforms/onn/layers/attn.py +++ /dev/null @@ -1,358 +0,0 @@ -from typing import Optional - -import torch -from mase_triton.optical_compute import OpticalTransformerFunctions as OTFunctions -from mase_triton.optical_compute.layers import OpticalTransformerLinear as OTLinear -from mase_triton.optical_compute.layers import optical_transformer_update_qstats -from mase_triton.utils.torch_module import get_layer_name, set_layer_by_name -from torch import Tensor, nn -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaConfig, - LlamaDecoderLayer, - LlamaForCausalLM, - apply_rotary_pos_emb, - eager_attention_forward, - repeat_kv, -) - - -def ot_eager_attention_forward( - module: "OtLlamaAttention", - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - """ - Optical Transformer Scaled Dot-Product Attention. - - Computes scaled dot-product attention with quantized matrix multiplications - to simulate optical neural network hardware constraints. This function applies - quantization to both the query-key and attention-value matrix products. - - The quantization statistics (min/max values) are updated in-place during training - using an exponential moving average controlled by ``q_smooth_factor``. - - Args: - query (Tensor): Query tensor of shape ``(batch, heads, seq_len, head_dim)``. - key (Tensor): Key tensor of shape ``(batch, kv_heads, seq_len, head_dim)``. - value (Tensor): Value tensor of shape ``(batch, kv_heads, seq_len, head_dim)``. - attention_mask (Tensor, optional): Attention mask. Default: None. - dropout (float): Dropout probability. Default: 0.0. - scaling (float, optional): Scaling factor. If None, uses ``1/sqrt(head_dim)``. - - Returns: - Tensor: Attention output of shape ``(batch, heads, seq_len, head_dim)``. - """ - with torch.no_grad(): - query_min_max_ = optical_transformer_update_qstats( - query, - module.query_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.query_min_max.copy_(query_min_max_) - key_min_max_ = optical_transformer_update_qstats( - key, - module.key_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.key_min_max.copy_(key_min_max_) - key_states = repeat_kv(key, module.num_key_value_groups) - if not module.qk_min_max.isfinite().all(): - attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling - qk_min_max_ = optical_transformer_update_qstats( - attn_weights, - module.qk_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.qk_min_max.copy_(qk_min_max_) - - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - # attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - attn_weights, _ = OTFunctions.quantized_matmul_fn( - a=query.contiguous(), - b=key_states.transpose(2, 3).contiguous(), - a_min=module.query_min_max[0], - a_max=module.query_min_max[1], - b_min=module.key_min_max[0], - b_max=module.key_min_max[1], - b_lut_min=module.q_lut_min, - o_min=module.qk_min_max[0], - o_max=module.qk_min_max[1], - q_levels=module.q_levels, - q_seed=module.seed.item(), - skip_quantize=False, - ) - attn_weights = attn_weights * scaling - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=dropout, training=module.training - ) - # attn_output = torch.matmul(attn_weights, value_states) - - with torch.no_grad(): - attn_min_max_ = optical_transformer_update_qstats( - attn_weights, - module.attn_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.attn_min_max.copy_(attn_min_max_) - value_min_max_ = optical_transformer_update_qstats( - value_states, - module.value_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.value_min_max.copy_(value_min_max_) - attn_ = torch.matmul(attn_weights, value_states) - av_min_max_ = optical_transformer_update_qstats( - attn_, - module.av_min_max, - module.q_min_max_quantiles, - module.stat_smooth_factor, - ) - module.av_min_max.copy_(av_min_max_) - - attn_output, _ = OTFunctions.quantized_matmul_fn( - a=attn_weights.contiguous(), - b=value_states.contiguous(), - a_min=module.attn_min_max[0], - a_max=module.attn_min_max[1], - b_min=module.value_min_max[0], - b_max=module.value_min_max[1], - b_lut_min=module.q_lut_min, - o_min=module.av_min_max[0], - o_max=module.av_min_max[1], - q_levels=module.q_levels, - q_seed=module.seed.item(), - skip_quantize=module.bypass, - ) - with torch.no_grad(): - module.seed += 1 - - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class OtLlamaAttention(nn.Module): - """ - Optical Transformer attention module for LLaMA models. - - This module replaces the standard HuggingFace LlamaAttention with an optical - transformer equivalent that simulates quantized matrix multiplications as would - occur in optical neural network hardware. The implementation is based on the - `Optical Transformers paper `_. - - The attention computation uses optical transformer scaled dot-product attention - (SDPA) which applies quantization to the query-key and attention-value matrix - multiplications to simulate optical compute constraints. - - Args: - config: HuggingFace LLaMA configuration object. - layer_idx (int): Index of this attention layer in the model. - q_levels (int): Number of quantization levels for optical simulation. Default: 256. - q_lut_min (float): Minimum value for the lookup table used in quantization. Default: 0.020040. - q_quantiles (tuple[float, float], optional): Quantile range for min/max statistics. - If None, uses absolute min/max. Default: None. - q_smooth_factor (float): Exponential moving average factor for updating - running min/max statistics during training. Default: 0.9. - q_init_seed (int): Random seed for quantization noise initialization. Default: 0. - q_bypass (bool): If True, bypasses optical quantization and uses standard - PyTorch attention. Useful for debugging or comparison. Default: False. - - Attributes: - query_min_max (Tensor): Running min/max statistics for query tensors. - key_min_max (Tensor): Running min/max statistics for key tensors. - value_min_max (Tensor): Running min/max statistics for value tensors. - qk_min_max (Tensor): Running min/max statistics for query-key products. - attn_min_max (Tensor): Min/max range for attention weights (fixed at [0, 1]). - av_min_max (Tensor): Running min/max statistics for attention-value products. - seed (Tensor): Current random seed state for quantization. - - Example: - .. code-block:: python - - from chop.passes.module.transforms.onn.layers.attn import OtLlamaAttention - - # Create from existing HuggingFace attention layer - ot_attn = OtLlamaAttention.from_pretrained( - hf_attention_layer, - layer_idx=0, - q_levels=256, - q_bypass=False, - ) - """ - - def __init__( - self, - config: LlamaConfig, - layer_idx: int, - q_levels: int = 256, - q_lut_min: float = 0.020040, - q_quantiles: tuple[float, float] | None = None, - q_smooth_factor: float = 0.9, - q_init_seed: int = 0, - q_bypass: bool = False, - ): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads - ) - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, - ) - - self.q_levels = q_levels - self.q_lut_min = q_lut_min - if q_quantiles is None: - self.q_min_max_quantiles = None - else: - self.register_buffer("q_min_max_quantiles", torch.tensor(q_quantiles)) - self.register_buffer( - "query_min_max", torch.tensor([float("inf"), float("-inf")]) - ) - self.register_buffer("key_min_max", torch.tensor([float("inf"), float("-inf")])) - self.register_buffer("qk_min_max", torch.tensor([float("inf"), float("-inf")])) - self.register_buffer("attn_min_max", torch.tensor([float(0), float(1)])) - self.register_buffer( - "value_min_max", torch.tensor([float("inf"), float("-inf")]) - ) - self.register_buffer("av_min_max", torch.tensor([float("inf"), float("-inf")])) - self.register_buffer("seed", torch.tensor(q_init_seed, dtype=torch.int64)) - self.stat_smooth_factor = q_smooth_factor - self.bypass = q_bypass - - self.query_min_max: Tensor - self.key_min_max: Tensor - self.qk_min_max: Tensor - self.attn_min_max: Tensor - self.value_min_max: Tensor - self.av_min_max: Tensor - self.seed: Tensor - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value=None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - if self.bypass: - attn_output, attn_weights = eager_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - else: - attn_output, attn_weights = ot_eager_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - self.seed += 1 - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - @classmethod - def from_pretrained( - cls, - attn: LlamaAttention, - layer_idx: int, - q_levels: int = 256, - q_lut_min: float = 0.020040, - q_quantiles: tuple[float, float] | None = None, - q_smooth_factor: float = 0.9, - q_init_seed: int = 0, - q_bypass: bool = False, - ) -> "OtLlamaAttention": - assert isinstance(attn, LlamaAttention) - ot_attn = cls( - attn.config, - layer_idx, - q_levels, - q_lut_min, - q_quantiles, - q_smooth_factor, - q_init_seed, - q_bypass, - ) - ot_attn.to(attn.o_proj.weight.dtype) - ot_attn.load_state_dict(attn.state_dict(), strict=False) - return ot_attn diff --git a/src/chop/passes/module/transforms/onn/layers/linear.py b/src/chop/passes/module/transforms/onn/layers/linear.py deleted file mode 100644 index 91dcf57d0..000000000 --- a/src/chop/passes/module/transforms/onn/layers/linear.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Optical Transformer Linear Layer. - -This module provides the optical transformer linear layer implementation -by importing from the mase-triton package. -""" - -from mase_triton.optical_compute import layers as OTLayers - -#: Optical Transformer Linear layer. -#: -#: This is an alias to ``mase_triton.optical_compute.layers.OpticalTransformerLinear``. -#: It replaces standard ``torch.nn.Linear`` layers with quantized optical transformer -#: equivalents that simulate optical neural network hardware constraints. -#: -#: The layer applies quantization to both the input activations and weights during -#: matrix multiplication, and tracks running min/max statistics for calibration. -#: -#: Use the ``from_linear`` class method to convert an existing ``torch.nn.Linear``: -#: -#: .. code-block:: python -#: -#: from chop.passes.module.transforms.onn.layers.linear import OtLinear -#: -#: ot_linear = OtLinear.from_linear( -#: linear_layer, -#: q_levels=256, -#: q_lut_min=0.020040, -#: ) -OtLinear = OTLayers.OpticalTransformerLinear diff --git a/src/chop/passes/module/transforms/onn/transform.py b/src/chop/passes/module/transforms/onn/transform.py deleted file mode 100644 index edbb9eff5..000000000 --- a/src/chop/passes/module/transforms/onn/transform.py +++ /dev/null @@ -1,180 +0,0 @@ -try: - import mase_triton - - MASE_TRITON_IS_AVAILABLE = True -except ImportError: - MASE_TRITON_IS_AVAILABLE = False - -from typing import TypedDict - -import torch -from transformers.models.llama.modeling_llama import LlamaAttention as HfLlamaAttention - -from ...module_modify_helper import replace_by_name -from ...state_dict_map import match_a_pattern -from .layers.attn import OtLlamaAttention -from .layers.linear import OtLinear - - -def get_config_by_name(config: dict, name: str): - if name in config: - return config[name] - else: - if "default" in config: - return config["default"] - else: - return None - - -def get_config_by_regex_name(config: dict, name: str): - matched_pattern = match_a_pattern(name, config.keys()) - if matched_pattern is None: - if "default" in config: - return config["default"] - else: - return None - else: - return config[matched_pattern] - - -def get_layer_config( - layer_name_to_config: dict[str, dict], use_regex: bool, layer_name: str -) -> dict | None: - if use_regex: - config = get_config_by_regex_name(layer_name_to_config, layer_name) - else: - config = get_config_by_name(layer_name_to_config, layer_name) - return config - - -class OtTransformConfig(TypedDict): - q_levels: int - q_lut_min: float - q_smooth_factor: float - q_init_seed: int - q_bypass: bool - - @classmethod - def create_default(cls) -> "OtTransformConfig": - return cls( - q_levels=256, - q_lut_min=0.020040, - q_smooth_factor=0.9, - q_init_seed=0, - q_bypass=False, - ) - - -if MASE_TRITON_IS_AVAILABLE: - _SUPPORTED_MODULE_CLS = (torch.nn.Linear, HfLlamaAttention) - - def optical_transformer_module_transform_pass( - network: torch.nn.Module, pass_args: dict - ) -> torch.nn.Module: - """ - Transform a neural network by replacing supported modules with their optical transformer equivalents. - - This pass simulates optical neural network (ONN) computation by replacing standard PyTorch - modules with quantized optical transformer layers. The optical transformer model is based on - the `Optical Transformers paper `_. - - Supported module replacements: - - - ``torch.nn.Linear`` → ``OtLinear`` - - ``transformers.models.llama.modeling_llama.LlamaAttention`` → ``OtLlamaAttention`` - - Args: - network (torch.nn.Module): The input network to be transformed. - pass_args (dict): A dictionary containing transformation configurations. - - - ``by`` (str): Layer matching strategy. Either ``'name'`` for exact name matching - or ``'regex_name'`` for regex-based pattern matching. Defaults to ``'regex_name'``. - - ``default`` (dict, optional): Default configuration applied to all matching layers. - - ```` (dict): Per-layer configuration. Each layer config - can contain the following keys: - - - ``q_levels`` (int): Number of quantization levels. Default: 256. - - ``q_lut_min`` (float): Minimum value for lookup table. Default: 0.020040. - - ``q_smooth_factor`` (float): Smoothing factor for running statistics. Default: 0.9. - - ``q_init_seed`` (int): Random seed for quantization initialization. Default: 0. - - ``q_bypass`` (bool): If True, bypass optical quantization. Default: False. - - Returns: - torch.nn.Module: The transformed network with optical transformer modules. - - Raises: - RuntimeError: If ``mase-triton`` is not installed. - - Example: - .. code-block:: python - - from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass - - # Transform all linear layers with default config - pass_args = { - "by": "regex_name", - "default": { - "q_levels": 256, - "q_lut_min": 0.020040, - "q_bypass": False, - } - } - transformed_model = optical_transformer_module_transform_pass(model, pass_args) - - Note: - This pass requires the ``mase-triton`` package to be installed. - Install via ``pip install mase-triton``. - """ - by = pass_args.pop("by", "regex_name") - assert by in [ - "name", - "regex_name", - ], f"`by` can be either 'name' or 'regex_name', but got {by}" - # replace attn layers if any - for m_name, m in network.named_modules(): - if not isinstance(m, HfLlamaAttention): - continue - m_config = get_layer_config( - pass_args, use_regex=by == "regex_name", layer_name=m_name - ) - if m_config is None: - continue - if isinstance(m, HfLlamaAttention): - new_m = OtLlamaAttention.from_pretrained( - m, layer_idx=m.layer_idx, **m_config - ) - elif isinstance(m, _SUPPORTED_MODULE_CLS): - continue - else: - raise NotImplementedError( - f"ONN transform for type {type(m)} is supported" - ) - replace_by_name(network, name=m_name, module=new_m) - # replace linear layers if any - for m_name, m in network.named_modules(): - if not isinstance(m, torch.nn.Linear): - continue - m_config = get_layer_config( - pass_args, use_regex=by == "regex_name", layer_name=m_name - ) - if m_config is None: - continue - if isinstance(m, torch.nn.Linear): - new_m = OtLinear.from_linear(m, **m_config) - elif isinstance(m, _SUPPORTED_MODULE_CLS): - continue - else: - raise NotImplementedError( - f"ONN transform for type {type(m)} is supported" - ) - replace_by_name(network, name=m_name, module=new_m) - return network - -else: - - def optical_transformer_module_transform_pass( - network: torch.nn.Module, pass_args: dict - ) -> torch.nn.Module: - raise RuntimeError( - "`mase-triton` is needed for ONN transform. Install via `pip install mase-triton`." - ) diff --git a/test/passes/module/transforms/onn/test_optical_transformer.py b/test/passes/module/transforms/onn/test_optical_transformer.py deleted file mode 100644 index 6e99719d8..000000000 --- a/test/passes/module/transforms/onn/test_optical_transformer.py +++ /dev/null @@ -1,121 +0,0 @@ -import unittest - -import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig - -from chop.passes.module.transforms.onn.transform import ( - OtLinear, - OtLlamaAttention, - OtTransformConfig, - optical_transformer_module_transform_pass, -) - - -def _calculate_snr(x, noisy_x): - noise = noisy_x - x - - signal_power = torch.sum(x**2) - noise_power = torch.sum(noise**2) - - snr = signal_power / noise_power - snr_db = 10 * torch.log10(snr) - return snr_db.item() - - -class TestOnnTransform(unittest.TestCase): - def test_ot_linear_layer(self): - linear = torch.nn.Linear(in_features=32, out_features=64) - onn_cfg = OtTransformConfig.create_default() - linear_onn = OtLinear.from_linear(linear, **onn_cfg) - - x = torch.randn(2, 32) - y = linear(x) - y_onn = linear_onn(x) - - snr = _calculate_snr(y, y_onn) - assert snr > 23 - - def test_ot_llama_attn_layer(self): - onn_config = OtTransformConfig.create_default() - onn_config["q_levels"] = 512 - onn_config["q_smooth_factor"] = 0.1 - model_name = "AICrossSim/clm-60m" - hf_config = LlamaConfig.from_pretrained(model_name) - batch_size = 1 - seq_len = 16 - head_dim = hf_config.hidden_size // hf_config.num_attention_heads - - attn = LlamaAttention(config=hf_config, layer_idx=0) - - pos_emb = torch.ones(batch_size, seq_len, head_dim) - x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size) - - y, _ = attn(x, (pos_emb, pos_emb), None) - y: torch.Tensor - assert y.isfinite().all() - - attn_onn = OtLlamaAttention.from_pretrained(attn, layer_idx=0, **onn_config) - attn_onn.train() - for _ in range(3): - y_onn, _ = attn_onn(x, (pos_emb, pos_emb), None) - - snr = _calculate_snr(y, y_onn) - print(f"Attn SNR: {snr:.2f} dB") - assert snr > 1 - - def test_optical_transformer_module_transform_pass(self): - onn_config = OtTransformConfig.create_default() - onn_config["q_levels"] = 512 - onn_config["q_smooth_factor"] = 0.1 - model_name = "AICrossSim/clm-60m" - hf_config = LlamaConfig.from_pretrained(model_name) - batch_size = 1 - seq_len = 16 - head_dim = hf_config.hidden_size // hf_config.num_attention_heads - - class Network(torch.nn.Module): - def __init__(self): - super().__init__() - self.attn = LlamaAttention(config=hf_config, layer_idx=0) - self.linear = torch.nn.Linear( - in_features=hf_config.hidden_size, - out_features=hf_config.hidden_size, - ) - - def forward(self, x, pos_emb): - attn_output, _ = self.attn(x, (pos_emb, pos_emb), None) - output = self.linear(attn_output) - return output, None - - network = Network() - pos_emb = torch.ones(batch_size, seq_len, head_dim) - x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size) - - y, _ = network(x, pos_emb) - y: torch.Tensor - assert y.isfinite().all() - - pass_args = { - "by": "regex_name", - "attn": onn_config, - "linear": onn_config, - r"attn\.(q|k|v|o)_proj": onn_config, - } - - network_onn = optical_transformer_module_transform_pass(network, pass_args) - assert isinstance(network_onn.attn, OtLlamaAttention) - assert isinstance(network_onn.linear, OtLinear) - assert isinstance(network_onn.attn.q_proj, OtLinear) - assert isinstance(network_onn.attn.k_proj, OtLinear) - assert isinstance(network_onn.attn.v_proj, OtLinear) - assert isinstance(network_onn.attn.o_proj, OtLinear) - - print(network_onn) - network_onn.train() - for _ in range(3): - y_onn, _ = network_onn(x, pos_emb) - assert y_onn.isfinite().all() - - snr = _calculate_snr(y, y_onn) - assert snr > 1 - print(f"Network SNR: {snr:.2f} dB")