C1W2: Implementing Callbacks in TensorFlow using the MNIST Dataset
C1W2: Implementing Callbacks in TensorFlow using the MNIST Dataset#
import tensorflow as tf
from tensorflow.keras import layers, losses
MNIST dataset
60,000 28x28 grayscale images of the 10 digits
tf.keras.datasets.mnist.load_data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if logs.get('accuracy') is not None and logs.get('accuracy') > 0.99:
print("\nReached 99% accuracy so cancelling training!")
self.model.stop_training = True
def train_mnist(x_train, y_train):
callbacks = myCallback()
model = tf.keras.Sequential([
layers.Rescaling(1/255, input_shape=(28, 28)),
layers.Flatten(),
layers.Dense(512, activation='relu'),
layers.Dense(10)])
model.compile(optimizer='adam',
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
return history
history = train_mnist(x_train, y_train)
Epoch 1/10
1875/1875 [==============================] - 10s 4ms/step - loss: 0.1972 - accuracy: 0.9423
Epoch 2/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0796 - accuracy: 0.9759
Epoch 3/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0510 - accuracy: 0.9843
Epoch 4/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0361 - accuracy: 0.9888
Epoch 5/10
1868/1875 [============================>.] - ETA: 0s - loss: 0.0271 - accuracy: 0.9915
Reached 99% accuracy so cancelling training!
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0270 - accuracy: 0.9915