keras.ipynb#

Based on:

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras import layers, losses
import os
datasets, info = tfds.load('mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    
    return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
with strategy.scope():
    model = tf.keras.Sequential([
        layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        layers.MaxPool2D(),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(10)
    ])

model.compile(optimizer='adam',
              loss=losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
def decay(epoch):
    if epoch < 3:
        return 1e-3
    elif epoch < 7:
        return 1e-4
    else:
        return 1e-5
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'\nLearning rate for epoch {epoch + 1} is {model.optimizer.lr.numpy()}')
        
callbacks = [
    tf.keras.callbacks.TensorBoard(),
    tf.keras.callbacks.ModelCheckpoint(checkpoint_prefix, save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    # PrintLR() 
]
model.fit(train_dataset, epochs=2, callbacks=callbacks)
Epoch 1/2
938/938 [==============================] - 18s 16ms/step - loss: 0.2096 - accuracy: 0.9391 - lr: 0.0010
Epoch 2/2
938/938 [==============================] - 15s 16ms/step - loss: 0.0736 - accuracy: 0.9786 - lr: 0.0010
<keras.callbacks.History at 0x7faf6e29b310>
!ls {checkpoint_dir}
checkpoint		    ckpt_1.index		ckpt_2.index
ckpt_1.data-00000-of-00001  ckpt_2.data-00000-of-00001
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

def eval_model(model):
    eval_loss, eval_acc = model.evaluate(eval_dataset)
    print(f'Eval loss: {eval_loss}, Eval Accuracy: {eval_acc}')

eval_model(model)
157/157 [==============================] - 2s 8ms/step - loss: 0.0573 - accuracy: 0.9811
Eval loss: 0.05726996436715126, Eval Accuracy: 0.9811000227928162
path = 'saved_model/'

model.save(path, save_format='tf')
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets
unreplicated_model = tf.keras.models.load_model(path)

eval_model(unreplicated_model)
157/157 [==============================] - 1s 5ms/step - loss: 0.0573 - accuracy: 0.9811
Eval loss: 0.05726996436715126, Eval Accuracy: 0.9811000227928162