Stop Training after Certain Time in TensorFlow 2

Model training can take a long time. TensorFlow 2 provides the TimeStopping callback which allows to stop training after a certain amount of time has passed.

The TimeStopping callback is provided as TensorFlow addon. It can installed using pip package manager from the command line. A package name is tensorflow-addons.

pip install tensorflow-addons

We create a model that classifies images. An instance of TimeStopping are passed to the fit method as callbacks argument. The TimeStopping constructor has seconds parameter which defines maximum amount of time before stopping training.

import tensorflow_addons as tfa
from tensorflow import keras

fashionMnist = keras.datasets.fashion_mnist
(trainImages, trainLabels), (testImages, testLabels) = fashionMnist.load_data()

trainImages = trainImages / 255

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

timeStopping = tfa.callbacks.TimeStopping(seconds=10, verbose=1)

model.fit(trainImages, trainLabels, epochs=15, callbacks=[timeStopping])

We can set verbose argument to 1 to determine the epoch on which training was stopped.

Epoch 1/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.4974
Epoch 2/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3750
Epoch 3/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3372
Epoch 4/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3131
Epoch 5/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.2959
Timed stopping at epoch 5 after training for 0:00:10

Leave a Comment

Your email address will not be published. Required fields are marked *