-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcode.py
More file actions
49 lines (36 loc) · 1.84 KB
/
code.py
File metadata and controls
49 lines (36 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
print(tf.__version__)
dataset = tf.keras.datasets.fashion_mnist
DESIRED_ACCURACY = 0.99
class myCallBack(tf.model.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if logs.get('acc')>=DESIRED_ACCURACY:
print("\nReached "+ str(DESIRED_ACCURACY)+" so cancelling training.\n")
self.model.stop_training = True
callbacks = myCallBack() #instantiating the class myCallBack
(training_data,training_labels), (test_data,test_labels) = dataset.load() #loading the data from our data set
#Reshape the data so it's in the form of a 4D list of dimensions 60000*28*28*1 (60000 images, width = 28, height= 28, colour info = 1(greyscale))
training_data = training_data.reshape(60000,28,28,1)
training_data = training_data/255.0
#Reshape the test data
test_data = test_data.reshape(10000,28,28,1)
test_data = test_data/255.0
#defining a model
model = tf.keras.model.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation = 'relu', input_shape = (28,28,1)),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64, (3,3), activation = 'relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(32, (3,3), activation = 'relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten()
tf.keras.layers.Dense(128, activation = 'relu'),
tf.keras.layers.Dense(10, activation = 'softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
#Summary provides us an overview of how the layers are manipulating the image and its parameters
model.summary()
#fitting the model with callbacks
model.fit(training_data, training_labels, epochs = 15, callbacks = [callbacks])
test_loss, test_acc = model.evaluate(test_data,test_labels)
print("The accuracy of the model on test data is: " + str(test_acc))