diff --git a/Handwritten Digit Recognition/README.md b/Handwritten Digit Recognition/README.md new file mode 100644 index 0000000..8d2410b --- /dev/null +++ b/Handwritten Digit Recognition/README.md @@ -0,0 +1,29 @@ +# Digit Recognition WebApp, PyTorch, Flask + +**Demo Site : https://digit-recog-torch.uc.r.appspot.com/** + + + + +### Model +- handwritten digit recognition using MNIST dataset +- two CNN layers, Batch Normalization, Adam Optimizer +- centering input digit for better recognition +- 99.3% of accuracy at validation + +### Web Application +- Flask for web framework +- d3.js for drawing bar graph + +## Requirements +- `pip install -r requirements.txt` +- Tested on Python 3.9.13, pytorch 2.0.0, CUDA 11.8 + +## Usage + +- #### Run WebApp + - `python3 server.py` (then access to `localhost:5000`) + +- #### Training Model + - Training on CPU: `python3 train.py` + - Training on GPU: `python3 train.py --use_gpu` (enabled when gpu and cuda is available) diff --git a/Handwritten Digit Recognition/__pycache__/model.cpython-39.pyc b/Handwritten Digit Recognition/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000..35f12f7 Binary files /dev/null and b/Handwritten Digit Recognition/__pycache__/model.cpython-39.pyc differ diff --git a/Handwritten Digit Recognition/__pycache__/train.cpython-39.pyc b/Handwritten Digit Recognition/__pycache__/train.cpython-39.pyc new file mode 100644 index 0000000..98411a2 Binary files /dev/null and b/Handwritten Digit Recognition/__pycache__/train.cpython-39.pyc differ diff --git a/Handwritten Digit Recognition/checkpoint/best_accuracy.pth b/Handwritten Digit Recognition/checkpoint/best_accuracy.pth new file mode 100644 index 0000000..70a6c89 Binary files /dev/null and b/Handwritten Digit Recognition/checkpoint/best_accuracy.pth differ diff --git a/Handwritten Digit Recognition/demo.gif b/Handwritten Digit Recognition/demo.gif new file mode 100644 index 0000000..7c06ed6 Binary files /dev/null and b/Handwritten Digit Recognition/demo.gif differ diff --git a/Handwritten Digit Recognition/model.py b/Handwritten Digit Recognition/model.py new file mode 100644 index 0000000..66f6340 --- /dev/null +++ b/Handwritten Digit Recognition/model.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(2, 2) + + self.conv1 = nn.Conv2d(1, 32, 5, 1, 2) + self.bn1 = nn.BatchNorm2d(32) + + self.conv2 = nn.Conv2d(32, 64, 5, 1, 2) + self.bn2 = nn.BatchNorm2d(64) + + self.fc1 = nn.Linear(7 * 7 * 64, 1024) + self.fc2 = nn.Linear(1024, 10) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) # 28x28->14x14 + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.maxpool(x) # 14x14->7x7 + + x = torch.flatten(x, 1) + x = self.fc1(x) + x = self.relu(x) + + x = self.fc2(x) + x = self.softmax(x) + + return x diff --git a/Handwritten Digit Recognition/requirements.txt b/Handwritten Digit Recognition/requirements.txt new file mode 100644 index 0000000..6ce1ac6 --- /dev/null +++ b/Handwritten Digit Recognition/requirements.txt @@ -0,0 +1,5 @@ +Flask +Pillow +torch +torchvision +--extra-index-url https://download.pytorch.org/whl/cu118 \ No newline at end of file diff --git a/Handwritten Digit Recognition/server.py b/Handwritten Digit Recognition/server.py new file mode 100644 index 0000000..2f11c74 --- /dev/null +++ b/Handwritten Digit Recognition/server.py @@ -0,0 +1,80 @@ +import json + +import numpy as np +import torch +from flask import Flask, render_template, request +from PIL import Image, ImageChops, ImageOps +from torchvision import transforms + +from model import Model +from train import SAVE_MODEL_PATH + +app = Flask(__name__) +predict = None + + +@app.route("/") +def index(): + return render_template("index.html") + + +@app.route("/DigitRecognition", methods=["POST"]) +def predict_digit(): + img = Image.open(request.files["img"]).convert("L") + + # predict + res_json = {"pred": "Err", "probs": []} + if predict is not None: + res = predict(img) + res_json["pred"] = str(np.argmax(res)) + res_json["probs"] = [p * 100 for p in res] + + return json.dumps(res_json) + + +class Predict(): + def __init__(self): + device = torch.device("cpu") + self.model = Model().to(device) + self.model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=device)) + self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) + + def _centering_img(self, img): + w, h = img.size[:2] + left, top, right, bottom = w, h, -1, -1 + imgpix = img.getdata() + + for y in range(h): + offset_y = y * w + for x in range(w): + if imgpix[offset_y + x] > 0: + left = min(left, x) + top = min(top, y) + right = max(right, x) + bottom = max(bottom, y) + + shift_x = (left + (right - left) // 2) - w // 2 + shift_y = (top + (bottom - top) // 2) - h // 2 + return ImageChops.offset(img, -shift_x, -shift_y) + + def __call__(self, img): + img = ImageOps.invert(img) # MNIST image is inverted + img = self._centering_img(img) + img = img.resize((28, 28), Image.BICUBIC) # resize to 28x28 + tensor = self.transform(img) + tensor = tensor.unsqueeze_(0) # 1,1,28,28 + + self.model.eval() + with torch.no_grad(): + preds = self.model(tensor) + preds = preds.detach().numpy()[0] + + return preds + + +if __name__ == "__main__": + import os + assert os.path.exists(SAVE_MODEL_PATH), "no saved model" + predict = Predict() + + app.run(host="0.0.0.0") diff --git a/Handwritten Digit Recognition/static/css/style.css b/Handwritten Digit Recognition/static/css/style.css new file mode 100644 index 0000000..98d0352 --- /dev/null +++ b/Handwritten Digit Recognition/static/css/style.css @@ -0,0 +1,33 @@ +.common { + text-align: center; +} + +.title { + font-size: x-large; +} + +.boxitem { + display: inline-block; + vertical-align: top; +} + +#inputimg { + vertical-align: top; + border: solid 1px black; +} + +#clearbtn { + width: 100%; + font-size: medium; + display: block; +} + +#pred { + width: 196px; + height: 196px; + line-height: 196px; + font-size: 160px; + border: solid 1px black; + background-color: white; + font-family: Century; +} diff --git a/Handwritten Digit Recognition/static/js/script.js b/Handwritten Digit Recognition/static/js/script.js new file mode 100644 index 0000000..945eca2 --- /dev/null +++ b/Handwritten Digit Recognition/static/js/script.js @@ -0,0 +1,161 @@ + + +let cvsIn = document.getElementById("inputimg"); +let ctxIn = cvsIn.getContext("2d"); +let divOut = document.getElementById("pred"); +let svgGraph = null; +let mouselbtn = false; + + +// initilize +window.onload = ()=>{ + + ctxIn.fillStyle = "white"; + ctxIn.fillRect(0, 0, cvsIn.width, cvsIn.height); + ctxIn.lineWidth = 7; + ctxIn.lineCap = "round"; + + initProbGraph(); +} + +function initProbGraph(){ + + const dummyData = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; // dummy data for initialize graph + const margin = { top: 10, right: 10, bottom: 10, left: 20 } + , width = 250, height = 196; + + let yScale = d3.scaleLinear() + .domain([9, 0]) + .range([height, 0]); + + svgGraph = d3.select("#probGraph") + .attr("width", width + margin.left + margin.right) + .attr("height", height + margin.top + margin.bottom) + .append("g") + .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); + + svgGraph.append("g") + .attr("class", "y axis") + .call(d3.axisLeft(yScale)); + + const barHeight = 20 + svgGraph.selectAll("svg") + .data(dummyData) + .enter() + .append("rect") + .attr("y", (d,i)=>(yScale(i) - barHeight / 2)) + .attr("height", barHeight) + .style("fill", "green") + .attr("x", 0) + .attr("width", (d)=>d * 2) + .call(d3.axisLeft(yScale)); +} + +cvsIn.addEventListener("mousedown", (e)=>{ + + if(e.button == 0){ + let rect = e.target.getBoundingClientRect(); + let x = e.clientX - rect.left; + let y = e.clientY - rect.top; + mouselbtn = true; + ctxIn.beginPath(); + ctxIn.moveTo(x, y); + } + else if(e.button == 2){ + onClear(); // clear by mouse right button + } +}); + +cvsIn.addEventListener("mouseup", (e)=>{ + if(e.button == 0){ + mouselbtn = false; + onRecognition(); + } +}); + +cvsIn.addEventListener("mousemove", (e)=>{ + let rect = e.target.getBoundingClientRect(); + let x = e.clientX - rect.left; + let y = e.clientY - rect.top; + if(mouselbtn){ + ctxIn.lineTo(x, y); + ctxIn.stroke(); + } +}); + +cvsIn.addEventListener("touchstart", (e)=>{ + // for touch device + if (e.targetTouches.length == 1) { + let rect = e.target.getBoundingClientRect(); + let touch = e.targetTouches[0]; + let x = touch.clientX - rect.left; + let y = touch.clientY - rect.top; + ctxIn.beginPath(); + ctxIn.moveTo(x, y); + } +}); + +cvsIn.addEventListener("touchmove", (e)=>{ + // for touch device + if (e.targetTouches.length == 1) { + let rect = e.target.getBoundingClientRect(); + let touch = e.targetTouches[0]; + let x = touch.clientX - rect.left; + let y = touch.clientY - rect.top; + ctxIn.lineTo(x, y); + ctxIn.stroke(); + e.preventDefault(); + } +}); + +cvsIn.addEventListener("touchend", (e)=>onRecognition()); + +cvsIn.addEventListener("contextmenu", (e)=>e.preventDefault()); + +document.getElementById("clearbtn").onclick = onClear; +function onClear(){ + mouselbtn = false; + ctxIn.fillStyle = "white"; + ctxIn.fillRect(0, 0, cvsIn.width, cvsIn.height); + ctxIn.fillStyle = "black"; +} + +// post digit to server for recognition +function onRecognition() { + console.time("time"); + + cvsIn.toBlob((blob)=>{ + let form = new FormData(); + form.append('img', blob, "dummy.png") + + $.ajax({ + url: "./DigitRecognition", + type: "POST", + data: form, + processData: false, + contentType: false, + }) + .then( + (data)=>showResult(JSON.parse(data)), + ()=>alert("error") + ) + }) + + console.timeEnd("time"); +} + + +function showResult(res){ + + divOut.textContent = res.pred; + + document.getElementById("prob").innerHTML = + "Probability : " + res.probs[res.pred].toFixed(2) + "%"; + + svgGraph.selectAll("rect") + .data(res.probs) + .transition() + .duration(300) + .style("fill", (d, i) => i == res.pred ? "blue":"green") + .attr("width", (d) => d * 2) +} diff --git a/Handwritten Digit Recognition/templates/index.html b/Handwritten Digit Recognition/templates/index.html new file mode 100644 index 0000000..9060ef0 --- /dev/null +++ b/Handwritten Digit Recognition/templates/index.html @@ -0,0 +1,35 @@ + + +
+ + + +Handwritten Digit Recognition WebApp
+