Multi-worker training with Keras#

import json
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.pop('TF_CONFIG', None)
import tensorflow as tf

Dataset and model definition#

%%writefile mnist_setup.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()

  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)

  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)

  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
      
  return model
Overwriting mnist_setup.py

Model training on a single worker#

import mnist_setup

batch_size = 64
single_worker_dataset = mnist_setup.mnist_dataset(batch_size)
single_worker_model = mnist_setup.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
2023-05-02 13:36:54.972340: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/3
70/70 [==============================] - 2s 16ms/step - loss: 2.2720 - accuracy: 0.2219
Epoch 2/3
70/70 [==============================] - 1s 15ms/step - loss: 2.2214 - accuracy: 0.4277
Epoch 3/3
70/70 [==============================] - 1s 17ms/step - loss: 2.1633 - accuracy: 0.5511
<keras.callbacks.History at 0x7f6f51b8d6a0>

Multi-worker configuration#

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}
%%writefile main.py

import os
import json

import tensorflow as tf
import mnist_setup

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)

with strategy.scope():
    multi_worker_model = mnist_setup.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
Overwriting main.py
os.environ['TF_CONFIG'] = json.dumps(tf_config)
%killbgscripts
All background processes were killed.
%%bash --bg
python main.py &> job_0.log
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)
%%bash
python main.py
2023-05-02 13:37:13.426826: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/3
70/70 [==============================] - 6s 74ms/step - loss: 2.2786 - accuracy: 0.1185
Epoch 2/3
70/70 [==============================] - 5s 71ms/step - loss: 2.2383 - accuracy: 0.2475
Epoch 3/3
70/70 [==============================] - 6s 82ms/step - loss: 2.1927 - accuracy: 0.4070
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.