keras.ipynb
keras.ipynb#
Based on:
../production_ml/solutions/keras.ipynb
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras import layers, losses
import os
datasets, info = tfds.load('mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
strategy = tf.distribute.MirroredStrategy()
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
with strategy.scope():
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPool2D(),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(10)
])
model.compile(optimizer='adam',
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch < 7:
return 1e-4
else:
return 1e-5
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print(f'\nLearning rate for epoch {epoch + 1} is {model.optimizer.lr.numpy()}')
callbacks = [
tf.keras.callbacks.TensorBoard(),
tf.keras.callbacks.ModelCheckpoint(checkpoint_prefix, save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
# PrintLR()
]
model.fit(train_dataset, epochs=2, callbacks=callbacks)
Epoch 1/2
938/938 [==============================] - 18s 16ms/step - loss: 0.2096 - accuracy: 0.9391 - lr: 0.0010
Epoch 2/2
938/938 [==============================] - 15s 16ms/step - loss: 0.0736 - accuracy: 0.9786 - lr: 0.0010
<keras.callbacks.History at 0x7faf6e29b310>
!ls {checkpoint_dir}
checkpoint ckpt_1.index ckpt_2.index
ckpt_1.data-00000-of-00001 ckpt_2.data-00000-of-00001
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
def eval_model(model):
eval_loss, eval_acc = model.evaluate(eval_dataset)
print(f'Eval loss: {eval_loss}, Eval Accuracy: {eval_acc}')
eval_model(model)
157/157 [==============================] - 2s 8ms/step - loss: 0.0573 - accuracy: 0.9811
Eval loss: 0.05726996436715126, Eval Accuracy: 0.9811000227928162
path = 'saved_model/'
model.save(path, save_format='tf')
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: saved_model/assets
INFO:tensorflow:Assets written to: saved_model/assets
unreplicated_model = tf.keras.models.load_model(path)
eval_model(unreplicated_model)
157/157 [==============================] - 1s 5ms/step - loss: 0.0573 - accuracy: 0.9811
Eval loss: 0.05726996436715126, Eval Accuracy: 0.9811000227928162