-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_ex_7_knn.py
More file actions
54 lines (46 loc) · 1.66 KB
/
tf_ex_7_knn.py
File metadata and controls
54 lines (46 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# Import data
data = input_data.read_data_sets('/MNIST_data', one_hot=True)
# Split data to train and test
x_train, y_train = data.train.next_batch(5000)
x_test, y_test = data.test.next_batch(200)
# Inputs for graph
x_tr = tf.placeholder(tf.float32, [None, 784])
x_ts = tf.placeholder(tf.float32, [784])
# NN calculation using L2 distance
distance = tf.reduce_sum(tf.abs(tf.add(x_tr, tf.negative(x_ts))), reduction_indices=1)
# Prediction (get minimum distance)
prediction = tf.argmin(distance, 0)
# Define acc, init func and create a list for visualizing the results
preds = []
accuracy = 0.
init = tf.global_variables_initializer()
# Start training
with tf.Session() as sess:
sess.run(init)
for i in range(len(x_test)):
# Get nearest neighbour:
nn_index = sess.run(prediction, feed_dict={x_tr: x_train, x_ts: x_test[i, :]})
# Get nearest neighbour class label and compare
print("Test: {} Prediction: {} True class: {}".format(i, np.argmax(y_train[nn_index]), np.argmax(y_test[i])))
preds.append((nn_index, np.argmax(y_test[i])))
if np.argmax(y_train[nn_index]) == np.argmax(y_test[i]):
accuracy += 1. / len(x_test)
else:
plt.imshow(x_train[nn_index].reshape(28, 28))
plt.title(np.argmax(y_test[i]))
plt.show()
print("Accuracy:", accuracy)
# Visualize
i = 0
limit = 5
for index, predicted_label in preds:
plt.imshow(x_train[index].reshape(28, 28))
plt.title(predicted_label)
plt.show()
i += 1
if i == limit:
break