C1W2: Implementing Callbacks in TensorFlow using the MNIST Dataset#

import tensorflow as tf

Load and inspect the data#

MNIST dataset

  • 60,000 28x28 grayscale images of the 10 digits

tf.keras.datasets.mnist.load_data

# current_dir = os.getcwd()

# data_path = os.path.join(current_dir, "data/mnist.npz")

# (x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path=data_path)

(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()

x_train = x_train / 255.0
data_shape = x_train.shape

print(f"There are {data_shape[0]} examples with shape ({data_shape[1]}, {data_shape[2]})")
There are 60000 examples with shape (28, 28)

Defining your callback#

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy') is not None and logs.get('accuracy') > 0.99:
            print("\nReached 99% accuracy so cancelling training!") 
            self.model.stop_training = True

Create and train your model#

def train_mnist(x_train, y_train):

    callbacks = myCallback()
    
    model = tf.keras.models.Sequential([ 
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ]) 

    model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy']) 
    

    history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

    return history
hist = train_mnist(x_train, y_train)
Epoch 1/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.1992 - accuracy: 0.9411
Epoch 2/10
1875/1875 [==============================] - 12s 6ms/step - loss: 0.0786 - accuracy: 0.9758
Epoch 3/10
1875/1875 [==============================] - 12s 7ms/step - loss: 0.0522 - accuracy: 0.9835
Epoch 4/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0366 - accuracy: 0.9880
Epoch 5/10
1869/1875 [============================>.] - ETA: 0s - loss: 0.0276 - accuracy: 0.9910
Reached 99% accuracy so cancelling training!
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0277 - accuracy: 0.9910