Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Handwritten Digit Recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Digit Recognition WebApp, PyTorch, Flask

**Demo Site : https://digit-recog-torch.uc.r.appspot.com/**

![Digit Recognition](./demo.gif)


### Model
- handwritten digit recognition using MNIST dataset
- two CNN layers, Batch Normalization, Adam Optimizer
- centering input digit for better recognition
- 99.3% of accuracy at validation

### Web Application
- Flask for web framework
- d3.js for drawing bar graph

## Requirements
- `pip install -r requirements.txt`
- Tested on Python 3.9.13, pytorch 2.0.0, CUDA 11.8

## Usage

- #### Run WebApp
- `python3 server.py` (then access to `localhost:5000`)

- #### Training Model
- Training on CPU: `python3 train.py`
- Training on GPU: `python3 train.py --use_gpu` (enabled when gpu and cuda is available)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added Handwritten Digit Recognition/demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 40 additions & 0 deletions Handwritten Digit Recognition/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn


class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(2, 2)

self.conv1 = nn.Conv2d(1, 32, 5, 1, 2)
self.bn1 = nn.BatchNorm2d(32)

self.conv2 = nn.Conv2d(32, 64, 5, 1, 2)
self.bn2 = nn.BatchNorm2d(64)

self.fc1 = nn.Linear(7 * 7 * 64, 1024)
self.fc2 = nn.Linear(1024, 10)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) # 28x28->14x14

x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.maxpool(x) # 14x14->7x7

x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.relu(x)

x = self.fc2(x)
x = self.softmax(x)

return x
5 changes: 5 additions & 0 deletions Handwritten Digit Recognition/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Flask
Pillow
torch
torchvision
--extra-index-url https://download.pytorch.org/whl/cu118
80 changes: 80 additions & 0 deletions Handwritten Digit Recognition/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json

import numpy as np
import torch
from flask import Flask, render_template, request
from PIL import Image, ImageChops, ImageOps
from torchvision import transforms

from model import Model
from train import SAVE_MODEL_PATH

app = Flask(__name__)
predict = None


@app.route("/")
def index():
return render_template("index.html")


@app.route("/DigitRecognition", methods=["POST"])
def predict_digit():
img = Image.open(request.files["img"]).convert("L")

# predict
res_json = {"pred": "Err", "probs": []}
if predict is not None:
res = predict(img)
res_json["pred"] = str(np.argmax(res))
res_json["probs"] = [p * 100 for p in res]

return json.dumps(res_json)


class Predict():
def __init__(self):
device = torch.device("cpu")
self.model = Model().to(device)
self.model.load_state_dict(torch.load(SAVE_MODEL_PATH, map_location=device))
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

def _centering_img(self, img):
w, h = img.size[:2]
left, top, right, bottom = w, h, -1, -1
imgpix = img.getdata()

for y in range(h):
offset_y = y * w
for x in range(w):
if imgpix[offset_y + x] > 0:
left = min(left, x)
top = min(top, y)
right = max(right, x)
bottom = max(bottom, y)

shift_x = (left + (right - left) // 2) - w // 2
shift_y = (top + (bottom - top) // 2) - h // 2
return ImageChops.offset(img, -shift_x, -shift_y)

def __call__(self, img):
img = ImageOps.invert(img) # MNIST image is inverted
img = self._centering_img(img)
img = img.resize((28, 28), Image.BICUBIC) # resize to 28x28
tensor = self.transform(img)
tensor = tensor.unsqueeze_(0) # 1,1,28,28

self.model.eval()
with torch.no_grad():
preds = self.model(tensor)
preds = preds.detach().numpy()[0]

return preds


if __name__ == "__main__":
import os
assert os.path.exists(SAVE_MODEL_PATH), "no saved model"
predict = Predict()

app.run(host="0.0.0.0")
33 changes: 33 additions & 0 deletions Handwritten Digit Recognition/static/css/style.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.common {
text-align: center;
}

.title {
font-size: x-large;
}

.boxitem {
display: inline-block;
vertical-align: top;
}

#inputimg {
vertical-align: top;
border: solid 1px black;
}

#clearbtn {
width: 100%;
font-size: medium;
display: block;
}

#pred {
width: 196px;
height: 196px;
line-height: 196px;
font-size: 160px;
border: solid 1px black;
background-color: white;
font-family: Century;
}
161 changes: 161 additions & 0 deletions Handwritten Digit Recognition/static/js/script.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@


let cvsIn = document.getElementById("inputimg");
let ctxIn = cvsIn.getContext("2d");
let divOut = document.getElementById("pred");
let svgGraph = null;
let mouselbtn = false;


// initilize
window.onload = ()=>{

ctxIn.fillStyle = "white";
ctxIn.fillRect(0, 0, cvsIn.width, cvsIn.height);
ctxIn.lineWidth = 7;
ctxIn.lineCap = "round";

initProbGraph();
}

function initProbGraph(){

const dummyData = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; // dummy data for initialize graph
const margin = { top: 10, right: 10, bottom: 10, left: 20 }
, width = 250, height = 196;

let yScale = d3.scaleLinear()
.domain([9, 0])
.range([height, 0]);

svgGraph = d3.select("#probGraph")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");

svgGraph.append("g")
.attr("class", "y axis")
.call(d3.axisLeft(yScale));

const barHeight = 20
svgGraph.selectAll("svg")
.data(dummyData)
.enter()
.append("rect")
.attr("y", (d,i)=>(yScale(i) - barHeight / 2))
.attr("height", barHeight)
.style("fill", "green")
.attr("x", 0)
.attr("width", (d)=>d * 2)
.call(d3.axisLeft(yScale));
}

cvsIn.addEventListener("mousedown", (e)=>{

if(e.button == 0){
let rect = e.target.getBoundingClientRect();
let x = e.clientX - rect.left;
let y = e.clientY - rect.top;
mouselbtn = true;
ctxIn.beginPath();
ctxIn.moveTo(x, y);
}
else if(e.button == 2){
onClear(); // clear by mouse right button
}
});

cvsIn.addEventListener("mouseup", (e)=>{
if(e.button == 0){
mouselbtn = false;
onRecognition();
}
});

cvsIn.addEventListener("mousemove", (e)=>{
let rect = e.target.getBoundingClientRect();
let x = e.clientX - rect.left;
let y = e.clientY - rect.top;
if(mouselbtn){
ctxIn.lineTo(x, y);
ctxIn.stroke();
}
});

cvsIn.addEventListener("touchstart", (e)=>{
// for touch device
if (e.targetTouches.length == 1) {
let rect = e.target.getBoundingClientRect();
let touch = e.targetTouches[0];
let x = touch.clientX - rect.left;
let y = touch.clientY - rect.top;
ctxIn.beginPath();
ctxIn.moveTo(x, y);
}
});

cvsIn.addEventListener("touchmove", (e)=>{
// for touch device
if (e.targetTouches.length == 1) {
let rect = e.target.getBoundingClientRect();
let touch = e.targetTouches[0];
let x = touch.clientX - rect.left;
let y = touch.clientY - rect.top;
ctxIn.lineTo(x, y);
ctxIn.stroke();
e.preventDefault();
}
});

cvsIn.addEventListener("touchend", (e)=>onRecognition());

cvsIn.addEventListener("contextmenu", (e)=>e.preventDefault());

document.getElementById("clearbtn").onclick = onClear;
function onClear(){
mouselbtn = false;
ctxIn.fillStyle = "white";
ctxIn.fillRect(0, 0, cvsIn.width, cvsIn.height);
ctxIn.fillStyle = "black";
}

// post digit to server for recognition
function onRecognition() {
console.time("time");

cvsIn.toBlob((blob)=>{
let form = new FormData();
form.append('img', blob, "dummy.png")

$.ajax({
url: "./DigitRecognition",
type: "POST",
data: form,
processData: false,
contentType: false,
})
.then(
(data)=>showResult(JSON.parse(data)),
()=>alert("error")
)
})

console.timeEnd("time");
}


function showResult(res){

divOut.textContent = res.pred;

document.getElementById("prob").innerHTML =
"Probability : " + res.probs[res.pred].toFixed(2) + "%";

svgGraph.selectAll("rect")
.data(res.probs)
.transition()
.duration(300)
.style("fill", (d, i) => i == res.pred ? "blue":"green")
.attr("width", (d) => d * 2)
}
35 changes: 35 additions & 0 deletions Handwritten Digit Recognition/templates/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!doctype html>
<html lang="ja">
<head>
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/style.css') }}">
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Handwritten Digit Recognition </title>
</head>

<body>
<p class="title common">Handwritten Digit Recognition WebApp</p>
<div class="common">
<!-- <a href="https://github.com/nai-kon/CNN-Digit-Recognition">https://github.com/nai-kon/CNN-Digit-Recognition</a><br/> -->
<!-- created by Katz Sasaki -->
</div>
<br/>
<div class="common">
<div class="boxitem">
<canvas id="inputimg" width="196" height="196"></canvas>
<button id="clearbtn">Clear</button>
</div>
<div class="boxitem">
<div id="pred">0</div>
<div id="prob">Probability : </div>
</div>
<div class="boxitem">
<svg id="probGraph" width="250" height="196"></svg>
</div>
</div>

<script src="{{ url_for('static', filename='js/script.js') }}"></script>
<script src="//ajax.googleapis.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
<script src="//d3js.org/d3.v5.min.js"></script>

</body>
Loading