diff --git a/examples/mnist/README.md b/examples/mnist/README.md new file mode 100644 index 000000000..c157551fd --- /dev/null +++ b/examples/mnist/README.md @@ -0,0 +1,68 @@ +# LogicNets for MNIST Classification + +This example shows the accuracy that is attainable using the LogicNets methodology on the MNIST hand-written character classification task. + +## Prerequisites + +* LogicNets +* numpy +* torchvision + +## Installation + +If you're using the docker image, all the above prerequisites will be already installed. +Otherwise, you can install the above dependencies with pip and/or conda. + +## Download the Dataset + +The MNIST dataset will download automatically when the training script is first run. +You only need to make sure the necessary directory has been created: + +```bash +mkdir -p data +``` + +## Usage + +To train the \"MNIST-S\", \"MNIST-M\" and \"MNIST-L\" networks, +run the following: + +```bash +python train.py --arch --log-dir .// +``` + +To then generate verilog from this trained model, run the following: + +```bash +python neq2lut.py --arch --checkpoint .//best_accuracy.pth --log-dir .//verilog/ --add-registers +``` + +## Results + +Your results may vary slightly, depending on your system configuration. +The following results are attained when training on a CPU and synthesising with Vivado 2019.2: + +| Network Architecture | Test Accuracy (%) | LUTs | Flip Flops | Fmax (Mhz) | Latency (Cycles) | +| --------------------- | ----------------- | ----- | ------------- | ------------- | ----------------- | +| MNIST-S | | | | | | +| MNIST-M | | | | | | +| MNIST-L | | | | | | + +## Citation + +If you find this work useful for your research, please consider citing +our paper below: + +```bibtex +@inproceedings{umuroglu2020logicnets, + author = {Umuroglu, Yaman and Akhauri, Yash and Fraser, Nicholas J and Blott, Michaela}, + booktitle = {Proceedings of the International Conference on Field-Programmable Logic and Applications}, + title = {LogicNets: Co-Designed Neural Networks and Circuits for Extreme-Throughput Applications}, + year = {2020}, + pages = {291-297}, + publisher = {IEEE Computer Society}, + address = {Los Alamitos, CA, USA}, + month = {sep} +} +``` + diff --git a/examples/mnist/dataset_dump.py b/examples/mnist/dataset_dump.py new file mode 100644 index 000000000..5c96d7616 --- /dev/null +++ b/examples/mnist/dataset_dump.py @@ -0,0 +1,131 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from functools import reduce, partial + +import torch +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +from torchvision import transforms + +from logicnets.nn import generate_truth_tables, \ + lut_inference, \ + module_list_to_verilog_module +from logicnets.synthesis import synthesize_and_get_resource_counts + +from train import configs, model_config, other_options, test +from models import MnistNeqModel, MnistLutModel + +def dump_io(model, data_loader, input_file, output_file): + input_quant = model.module_list[0].input_quant + _, input_bitwidth = input_quant.get_scale_factor_bits() + input_bitwidth = int(input_bitwidth) + total_input_bits = model.module_list[0].in_features*input_bitwidth + input_quant.bin_output() + with open(input_file, 'w') as i_f, open(output_file, 'w') as o_f: + for data, target in data_loader: + x = input_quant(data) + indices = target + for i in range(x.shape[0]): + x_i = x[i,:] + xv_i = list(map(lambda z: input_quant.get_bin_str(z), x_i)) + xvc_i = reduce(lambda a,b: a+b, xv_i[::-1]) + i_f.write(f"{int(xvc_i,2):0{int(total_input_bits)}b}\n") + o_f.write(f"{int(indices[i])}\n") + +if __name__ == "__main__": + parser = ArgumentParser(description="Dump the train and test datasets (after input quantization) into text files") + parser.add_argument('--arch', type=str, choices=configs.keys(), default="mnist-s", + help="Specific the neural network model to use (default: %(default)s)") + parser.add_argument('--batch-size', type=int, default=None, metavar='N', + help="Batch size for evaluation (default: %(default)s)") + parser.add_argument('--input-bitwidth', type=int, default=None, + help="Bitwidth to use at the input (default: %(default)s)") + parser.add_argument('--hidden-bitwidth', type=int, default=None, + help="Bitwidth to use for activations in hidden layers (default: %(default)s)") + parser.add_argument('--output-bitwidth', type=int, default=None, + help="Bitwidth to use at the output (default: %(default)s)") + parser.add_argument('--input-fanin', type=int, default=None, + help="Fanin to use at the input (default: %(default)s)") + parser.add_argument('--hidden-fanin', type=int, default=None, + help="Fanin to use for the hidden layers (default: %(default)s)") + parser.add_argument('--output-fanin', type=int, default=None, + help="Fanin to use at the output (default: %(default)s)") + parser.add_argument('--hidden-layers', nargs='+', type=int, default=None, + help="A list of hidden layer neuron sizes (default: %(default)s)") + parser.add_argument('--log-dir', type=str, default='./log', + help="A location to store the output I/O text files (default: %(default)s)") + parser.add_argument('--checkpoint', type=str, required=True, + help="The checkpoint file which contains the model weights") + args = parser.parse_args() + defaults = configs[args.arch] + options = vars(args) + del options['arch'] + config = {} + for k in options.keys(): + config[k] = options[k] if options[k] is not None else defaults[k] # Override defaults, if specified. + + if not os.path.exists(config['log_dir']): + os.makedirs(config['log_dir']) + + # Split up configuration options to be more understandable + model_cfg = {} + for k in model_config.keys(): + model_cfg[k] = config[k] + options_cfg = {} + for k in other_options.keys(): + if k == 'cuda': + continue + options_cfg[k] = config[k] + + trans = transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(partial(torch.reshape, shape=(-1,))) + ]) + + # Fetch the datasets + dataset = {} + dataset['train'] = MNIST('./data', train=True, download=True, transform=trans) + dataset['test'] = MNIST('./data', train=False, download=True, transform=trans) + train_loader = DataLoader(dataset["train"], batch_size=config['batch_size'], shuffle=False) + test_loader = DataLoader(dataset["test"], batch_size=config['batch_size'], shuffle=False) + + # Instantiate the PyTorch model + x, y = dataset["train"][0] + model_cfg['input_length'] = len(x) + model_cfg['output_length'] = 10 + model = MnistNeqModel(model_cfg) + + # Load the model weights + checkpoint = torch.load(options_cfg['checkpoint'], map_location='cpu') + model.load_state_dict(checkpoint['model_dict']) + + # Test the PyTorch model + print("Running inference on baseline model...") + model.eval() + baseline_accuracy = test(model, test_loader, cuda=False) + print("Baseline accuracy: %f" % (baseline_accuracy)) + + # Run preprocessing on training set. + train_input_file = config['log_dir'] + "/train_input.txt" + train_output_file = config['log_dir'] + "/train_output.txt" + test_input_file = config['log_dir'] + "/test_input.txt" + test_output_file = config['log_dir'] + "/test_output.txt" + print(f"Dumping train I/O to {train_input_file} and {train_output_file}") + dump_io(model, train_loader, train_input_file, train_output_file) + print(f"Dumping test I/O to {test_input_file} and {test_output_file}") + dump_io(model, test_loader, test_input_file, test_output_file) diff --git a/examples/mnist/models.py b/examples/mnist/models.py new file mode 100644 index 000000000..fcedfd932 --- /dev/null +++ b/examples/mnist/models.py @@ -0,0 +1,145 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce +from os.path import realpath + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn import init + +from brevitas.core.quant import QuantType +from brevitas.core.scaling import ScalingImplType +from brevitas.nn import QuantHardTanh, QuantReLU + +from pyverilator import PyVerilator + +from logicnets.quant import QuantBrevitasActivation +from logicnets.nn import SparseLinearNeq, ScalarBiasScale, RandomFixedSparsityMask2D, DenseMask2D +from logicnets.init import random_restrict_fanin + +class MnistNeqModel(nn.Module): + def __init__(self, model_config): + super(MnistNeqModel, self).__init__() + self.model_config = model_config + self.num_neurons = [model_config["input_length"]] + model_config["hidden_layers"] + [model_config["output_length"]] + layer_list = [] + for i in range(1, len(self.num_neurons)): + in_features = self.num_neurons[i-1] + out_features = self.num_neurons[i] + bn = nn.BatchNorm1d(out_features) + nn.init.constant_(bn.weight.data, 1) + nn.init.constant_(bn.bias.data, 0) + if i == 1: + do_in = nn.Dropout(p=model_config["input_dropout"]) + bn_in = nn.BatchNorm1d(in_features) + nn.init.constant_(bn_in.weight.data, 1) + nn.init.constant_(bn_in.bias.data, 0) + input_bias = ScalarBiasScale(scale=False, bias_init=-0.25) + input_quant = QuantBrevitasActivation(QuantHardTanh(model_config["input_bitwidth"], max_val=1., narrow_range=False, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER), pre_transforms=[do_in, bn_in, input_bias]) + output_quant = QuantBrevitasActivation(QuantReLU(bit_width=model_config["hidden_bitwidth"], max_val=1.61, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER), pre_transforms=[bn]) + mask = RandomFixedSparsityMask2D(in_features, out_features, fan_in=model_config["input_fanin"]) + layer = SparseLinearNeq(in_features, out_features, input_quant=input_quant, output_quant=output_quant, sparse_linear_kws={'mask': mask}) + layer_list.append(layer) + elif i == len(self.num_neurons)-1: + output_bias_scale = ScalarBiasScale(bias_init=0.33) + output_quant = QuantBrevitasActivation(QuantHardTanh(bit_width=model_config["output_bitwidth"], max_val=1.33, narrow_range=False, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER), pre_transforms=[bn], post_transforms=[output_bias_scale]) + mask = RandomFixedSparsityMask2D(in_features, out_features, fan_in=model_config["output_fanin"]) + layer = SparseLinearNeq(in_features, out_features, input_quant=layer_list[-1].output_quant, output_quant=output_quant, sparse_linear_kws={'mask': mask}, apply_input_quant=False) + layer_list.append(layer) + else: + output_quant = QuantBrevitasActivation(QuantReLU(bit_width=model_config["hidden_bitwidth"], max_val=1.61, quant_type=QuantType.INT, scaling_impl_type=ScalingImplType.PARAMETER), pre_transforms=[bn]) + mask = RandomFixedSparsityMask2D(in_features, out_features, fan_in=model_config["hidden_fanin"]) + layer = SparseLinearNeq(in_features, out_features, input_quant=layer_list[-1].output_quant, output_quant=output_quant, sparse_linear_kws={'mask': mask}, apply_input_quant=False) + layer_list.append(layer) + self.module_list = nn.ModuleList(layer_list) + self.is_verilog_inference = False + self.latency = 1 + self.verilog_dir = None + self.top_module_filename = None + self.dut = None + self.logfile = None + + def verilog_inference(self, verilog_dir, top_module_filename, logfile: bool = False, add_registers: bool = False): + self.verilog_dir = realpath(verilog_dir) + self.top_module_filename = top_module_filename + self.dut = PyVerilator.build(f"{self.verilog_dir}/{self.top_module_filename}", verilog_path=[self.verilog_dir], build_dir=f"{self.verilog_dir}/verilator") + self.is_verilog_inference = True + self.logfile = logfile + if add_registers: + self.latency = len(self.num_neurons) + + def pytorch_inference(self): + self.is_verilog_inference = False + + def verilog_forward(self, x): + # Get integer output from the first layer + input_quant = self.module_list[0].input_quant + output_quant = self.module_list[-1].output_quant + _, input_bitwidth = self.module_list[0].input_quant.get_scale_factor_bits() + _, output_bitwidth = self.module_list[-1].output_quant.get_scale_factor_bits() + input_bitwidth, output_bitwidth = int(input_bitwidth), int(output_bitwidth) + total_input_bits = self.module_list[0].in_features*input_bitwidth + total_output_bits = self.module_list[-1].out_features*output_bitwidth + num_layers = len(self.module_list) + input_quant.bin_output() + self.module_list[0].apply_input_quant = False + y = torch.zeros(x.shape[0], self.module_list[-1].out_features) + x = input_quant(x) + self.dut.io.rst = 0 + self.dut.io.clk = 0 + for i in range(x.shape[0]): + x_i = x[i,:] + y_i = self.pytorch_forward(x[i:i+1,:])[0] + xv_i = list(map(lambda z: input_quant.get_bin_str(z), x_i)) + ys_i = list(map(lambda z: output_quant.get_bin_str(z), y_i)) + xvc_i = reduce(lambda a,b: a+b, xv_i[::-1]) + ysc_i = reduce(lambda a,b: a+b, ys_i[::-1]) + self.dut["M0"] = int(xvc_i, 2) + for j in range(self.latency + 1): + #print(self.dut.io.M5) + res = self.dut[f"M{num_layers}"] + result = f"{res:0{int(total_output_bits)}b}" + self.dut.io.clk = 1 + self.dut.io.clk = 0 + expected = f"{int(ysc_i,2):0{int(total_output_bits)}b}" + result = f"{res:0{int(total_output_bits)}b}" + assert(expected == result) + res_split = [result[i:i+output_bitwidth] for i in range(0, len(result), output_bitwidth)][::-1] + yv_i = torch.Tensor(list(map(lambda z: int(z, 2), res_split))) + y[i,:] = yv_i + # Dump the I/O pairs + if self.logfile is not None: + with open(self.logfile, "a") as f: + f.write(f"{int(xvc_i,2):0{int(total_input_bits)}b}{int(ysc_i,2):0{int(total_output_bits)}b}\n") + return y + + def pytorch_forward(self, x): + for l in self.module_list: + x = l(x) + return x + + def forward(self, x): + if self.is_verilog_inference: + return self.verilog_forward(x) + else: + return self.pytorch_forward(x) + +class MnistLutModel(MnistNeqModel): + pass + +class MnistVerilogModel(MnistNeqModel): + pass + diff --git a/examples/mnist/neq2lut.py b/examples/mnist/neq2lut.py new file mode 100644 index 000000000..6a3007080 --- /dev/null +++ b/examples/mnist/neq2lut.py @@ -0,0 +1,177 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from functools import partial + +import torch +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +from torchvision import transforms + +from logicnets.nn import generate_truth_tables, \ + lut_inference, \ + module_list_to_verilog_module +from logicnets.synthesis import synthesize_and_get_resource_counts +from logicnets.util import proc_postsynth_file + +from models import MnistNeqModel, MnistLutModel, MnistVerilogModel +from train import configs, model_config, test + +other_options = { + "cuda": None, + "log_dir": None, + "checkpoint": None, + "generate_bench": False, + "add_registers": False, + "simulate_pre_synthesis_verilog": False, + "simulate_post_synthesis_verilog": False, +} + +if __name__ == "__main__": + parser = ArgumentParser(description="Synthesize convert a PyTorch trained model into verilog") + parser.add_argument('--arch', type=str, choices=configs.keys(), default="mnist-s", + help="Specific the neural network model to use (default: %(default)s)") + parser.add_argument('--batch-size', type=int, default=None, metavar='N', + help="Batch size for evaluation (default: %(default)s)") + parser.add_argument('--input-bitwidth', type=int, default=None, + help="Bitwidth to use at the input (default: %(default)s)") + parser.add_argument('--hidden-bitwidth', type=int, default=None, + help="Bitwidth to use for activations in hidden layers (default: %(default)s)") + parser.add_argument('--output-bitwidth', type=int, default=None, + help="Bitwidth to use at the output (default: %(default)s)") + parser.add_argument('--input-fanin', type=int, default=None, + help="Fanin to use at the input (default: %(default)s)") + parser.add_argument('--hidden-fanin', type=int, default=None, + help="Fanin to use for the hidden layers (default: %(default)s)") + parser.add_argument('--output-fanin', type=int, default=None, + help="Fanin to use at the output (default: %(default)s)") + parser.add_argument('--hidden-layers', nargs='+', type=int, default=None, + help="A list of hidden layer neuron sizes (default: %(default)s)") + parser.add_argument('--input-dropout', type=float, default=None, + help="The amount of dropout to apply at the model input (default: %(default)s)") + parser.add_argument('--clock-period', type=float, default=1.0, + help="Target clock frequency to use during Vivado synthesis (default: %(default)s)") + parser.add_argument('--dataset-split', type=str, default='test', choices=['train', 'test'], + help="Dataset to use for evaluation (default: %(default)s)") + parser.add_argument('--log-dir', type=str, default='./log', + help="A location to store the log output of the training run and the output model (default: %(default)s)") + parser.add_argument('--checkpoint', type=str, required=True, + help="The checkpoint file which contains the model weights") + parser.add_argument('--generate-bench', action='store_true', default=False, + help="Generate the truth table in BENCH format as well as verilog (default: %(default)s)") + parser.add_argument('--dump-io', action='store_true', default=False, + help="Dump I/O to the verilog LUT to a text file in the log directory (default: %(default)s)") + parser.add_argument('--add-registers', action='store_true', default=False, + help="Add registers between each layer in generated verilog (default: %(default)s)") + parser.add_argument('--simulate-pre-synthesis-verilog', action='store_true', default=False, + help="Simulate the verilog generated by LogicNets (default: %(default)s)") + parser.add_argument('--simulate-post-synthesis-verilog', action='store_true', default=False, + help="Simulate the post-synthesis verilog produced by vivado (default: %(default)s)") + args = parser.parse_args() + defaults = configs[args.arch] + options = vars(args) + del options['arch'] + config = {} + for k in options.keys(): + config[k] = options[k] if options[k] is not None else defaults[k] # Override defaults, if specified. + + if not os.path.exists(config['log_dir']): + os.makedirs(config['log_dir']) + + # Split up configuration options to be more understandable + model_cfg = {} + for k in model_config.keys(): + model_cfg[k] = config[k] + options_cfg = {} + for k in other_options.keys(): + if k == 'cuda': + continue + options_cfg[k] = config[k] + + trans = transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(partial(torch.reshape, shape=(-1,))) + ]) + + # Fetch the test set + dataset = {} + dataset[args.dataset_split] = MNIST('./data', train=args.dataset_split == "train", download=True, transform=trans) + test_loader = DataLoader(dataset[args.dataset_split], batch_size=config['batch_size'], shuffle=False) + + # Instantiate the PyTorch model + x, y = dataset[args.dataset_split][0] + model_cfg['input_length'] = len(x) + model_cfg['output_length'] = 10 + model = MnistNeqModel(model_cfg) + + # Load the model weights + checkpoint = torch.load(options_cfg['checkpoint'], map_location='cpu') + model.load_state_dict(checkpoint['model_dict']) + + # Test the PyTorch model + print("Running inference on baseline model...") + model.eval() + baseline_accuracy = test(model, test_loader, cuda=False) + print("Baseline accuracy: %f" % (baseline_accuracy)) + + # Instantiate LUT-based model + lut_model = MnistLutModel(model_cfg) + lut_model.load_state_dict(checkpoint['model_dict']) + + # Generate the truth tables in the LUT module + print("Converting to NEQs to LUTs...") + generate_truth_tables(lut_model, verbose=True) + + # Test the LUT-based model + print("Running inference on LUT-based model...") + lut_inference(lut_model) + lut_model.eval() + lut_accuracy = test(lut_model, test_loader, cuda=False) + print("LUT-Based Model accuracy: %f" % (lut_accuracy)) + modelSave = { 'model_dict': lut_model.state_dict(), + 'test_accuracy': lut_accuracy} + + torch.save(modelSave, options_cfg["log_dir"] + "/lut_based_model.pth") + + print("Generating verilog in %s..." % (options_cfg["log_dir"])) + module_list_to_verilog_module(lut_model.module_list, "logicnet", options_cfg["log_dir"], generate_bench=options_cfg["generate_bench"], add_registers=options_cfg["add_registers"]) + print("Top level entity stored at: %s/logicnet.v ..." % (options_cfg["log_dir"])) + + if args.dump_io: + io_filename = options_cfg["log_dir"] + f"io_{args.dataset_split}.txt" + with open(io_filename, 'w') as f: + pass # Create an empty file. + print(f"Dumping verilog I/O to {io_filename}...") + else: + io_filename = None + + if args.simulate_pre_synthesis_verilog: + print("Running inference simulation of Verilog-based model...") + lut_model.verilog_inference(options_cfg["log_dir"], "logicnet.v", logfile=io_filename, add_registers=options_cfg["add_registers"]) + verilog_accuracy = test(lut_model, test_loader, cuda=False) + print("Verilog-Based Model accuracy: %f" % (verilog_accuracy)) + + print("Running out-of-context synthesis") + ret = synthesize_and_get_resource_counts(options_cfg["log_dir"], "logicnet", fpga_part="xcu280-fsvh2892-2L-e", clk_period_ns=args.clock_period, post_synthesis = 1) + + if args.simulate_post_synthesis_verilog: + print("Running post-synthesis inference simulation of Verilog-based model...") + proc_postsynth_file(options_cfg["log_dir"]) + lut_model.verilog_inference(options_cfg["log_dir"]+"/post_synth", "logicnet_post_synth.v", io_filename, add_registers=options_cfg["add_registers"]) + post_synth_accuracy = test(lut_model, test_loader, cuda=False) + print("Post-synthesis Verilog-Based Model accuracy: %f" % (post_synth_accuracy)) + diff --git a/examples/mnist/requirements.txt b/examples/mnist/requirements.txt new file mode 100644 index 000000000..ac3ab55bf --- /dev/null +++ b/examples/mnist/requirements.txt @@ -0,0 +1,17 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +numpy +tensorboard +torchvision diff --git a/examples/mnist/train.py b/examples/mnist/train.py new file mode 100644 index 000000000..1f5bf7e7f --- /dev/null +++ b/examples/mnist/train.py @@ -0,0 +1,404 @@ +# Copyright (C) 2021 Xilinx, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from functools import reduce, partial +import random + +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from torchvision.datasets import MNIST + +from models import MnistNeqModel + +# TODO: Replace default configs with YAML files. +configs = { + "mnist-xxs": { + "hidden_layers": [1024, 1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 8, + "hidden_fanin": 8, + "output_fanin": 8, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 0, + "checkpoint": None, + }, + "mnist-xs": { + "hidden_layers": [1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 8, + "hidden_fanin": 8, + "output_fanin": 8, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 0, + "checkpoint": None, + }, + "mnist-s": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 8, + "hidden_fanin": 8, + "output_fanin": 8, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 5, + "checkpoint": None, + }, + "mnist-s-1.1": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 6, + "hidden_fanin": 6, + "output_fanin": 6, + "input_dropout": 0.1, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 18, + "checkpoint": None, + }, + "mnist-m": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 10, + "hidden_fanin": 10, + "output_fanin": 10, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 2, + "checkpoint": None, + }, + "mnist-m-1.1": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 2, + "hidden_bitwidth": 2, + "output_bitwidth": 4, + "input_fanin": 5, + "hidden_fanin": 5, + "output_fanin": 5, + "input_dropout": 0.1, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 20, + "checkpoint": None, + }, + "mnist-m-1.2": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 3, + "hidden_bitwidth": 3, + "output_bitwidth": 4, + "input_fanin": 3, + "hidden_fanin": 3, + "output_fanin": 3, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 0, + "checkpoint": None, + }, + "mnist-l": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 1, + "hidden_bitwidth": 1, + "output_bitwidth": 4, + "input_fanin": 12, + "hidden_fanin": 12, + "output_fanin": 12, + "input_dropout": 0.01, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 0, + "checkpoint": None, + }, + "mnist-l-1.1": { + "hidden_layers": [1024, 1024, 1024, 1024, 1024, 128], + "input_bitwidth": 2, + "hidden_bitwidth": 2, + "output_bitwidth": 4, + "input_fanin": 6, + "hidden_fanin": 6, + "output_fanin": 6, + "input_dropout": 0.1, + "weight_decay": 1e-3, + "batch_size": 1024, + "epochs": 1000, + "learning_rate": 1e-3, + "seed": 12, + "checkpoint": None, + }, +} + +# A dictionary, so we can set some defaults if necessary +model_config = { + "hidden_layers": None, + "input_bitwidth": None, + "hidden_bitwidth": None, + "output_bitwidth": None, + "input_fanin": None, + "hidden_fanin": None, + "output_fanin": None, + "input_dropout": None, +} + +training_config = { + "weight_decay": None, + "batch_size": None, + "epochs": None, + "learning_rate": None, + "seed": None, +} + +other_options = { + "cuda": None, + "log_dir": None, + "checkpoint": None, +} + +def train(model, datasets, train_cfg, options): + # Create data loaders for training and inference: + train_loader = DataLoader(datasets["train"], batch_size=train_cfg['batch_size'], shuffle=True) + val_loader = DataLoader(datasets["valid"], batch_size=train_cfg['batch_size'], shuffle=False) + test_loader = DataLoader(datasets["test"], batch_size=train_cfg['batch_size'], shuffle=False) + + # Configure optimizer + weight_decay = train_cfg["weight_decay"] + decay_exclusions = ["bn", "bias", "learned_value"] # Make a list of parameters name fragments which will ignore weight decay TODO: make this list part of the train_cfg + decay_params = [] + no_decay_params = [] + for pname, params in model.named_parameters(): + if params.requires_grad: + if reduce(lambda a,b: a or b, map(lambda x: x in pname, decay_exclusions)): # check if the current label should be excluded from weight decay + #print("Disabling weight decay for %s" % (pname)) + no_decay_params.append(params) + else: + #print("Enabling weight decay for %s" % (pname)) + decay_params.append(params) + #else: + #print("Ignoring %s" % (pname)) + params = [{'params': decay_params, 'weight_decay': weight_decay}, + {'params': no_decay_params, 'weight_decay': 0.0}] + optimizer = optim.AdamW(params, lr=train_cfg['learning_rate'], betas=(0.5, 0.999), weight_decay=weight_decay) + + # Configure scheduler + steps = len(train_loader) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=steps*100, T_mult=1) + + # Configure criterion + criterion = nn.CrossEntropyLoss() + + # Push the model to the GPU, if necessary + if options["cuda"]: + model.cuda() + + # Setup tensorboard + writer = SummaryWriter(options["log_dir"]) + + # Main training loop + maxAcc = 0.0 + num_epochs = train_cfg["epochs"] + for epoch in range(0, num_epochs): + # Train for this epoch + model.train() + accLoss = 0.0 + correct = 0 + for batch_idx, (data, target) in enumerate(train_loader): + if options["cuda"]: + data, target = data.cuda(), target.cuda() + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + pred = output.detach().max(1, keepdim=True)[1] + target_label = target.detach().unsqueeze(1) + curCorrect = pred.eq(target_label).long().sum() + curAcc = 100.0*curCorrect / len(data) + correct += curCorrect + accLoss += loss.detach()*len(data) + loss.backward() + optimizer.step() + scheduler.step() + + # Log stats to tensorboard + #writer.add_scalar('train_loss', loss.detach().cpu().numpy(), epoch*steps + batch_idx) + #writer.add_scalar('train_accuracy', curAcc.detach().cpu().numpy(), epoch*steps + batch_idx) + #g = optimizer.param_groups[0] + #writer.add_scalar('LR', g['lr'], epoch*steps + batch_idx) + + accLoss /= len(train_loader.dataset) + accuracy = 100.0*correct / len(train_loader.dataset) + print(f"Epoch: {epoch}/{num_epochs}\tTrain Acc (%): {accuracy.detach().cpu().numpy():.2f}\tTrain Loss: {accLoss.detach().cpu().numpy():.3e}") + #for g in optimizer.param_groups: + # print("LR: {:.6f} ".format(g['lr'])) + # print("LR: {:.6f} ".format(g['weight_decay'])) + writer.add_scalar('avg_train_loss', accLoss.detach().cpu().numpy(), (epoch+1)*steps) + writer.add_scalar('avg_train_accuracy', accuracy.detach().cpu().numpy(), (epoch+1)*steps) + val_accuracy = test(model, val_loader, options["cuda"]) + test_accuracy = test(model, test_loader, options["cuda"]) + modelSave = { 'model_dict': model.state_dict(), + 'optim_dict': optimizer.state_dict(), + 'val_accuracy': val_accuracy, + 'test_accuracy': test_accuracy, + 'epoch': epoch} + torch.save(modelSave, options["log_dir"] + "/checkpoint.pth") + if(maxAcc