Stop Training after Certain Time in TensorFlow 2

Stop Training after Certain Time in TensorFlow 2

Model training can take a long time. TensorFlow 2 provides the TimeStopping callback which allows stopping training after a certain amount of time has passed.

The TimeStopping callback is provided as TensorFlow add-on. It can be installed using pip package manager from the command line. A package name is tensorflow-addons.

pip install tensorflow-addons

We create a model that classifies images. An instance of TimeStopping are passed to the fit method as callbacks argument. The TimeStopping constructor has seconds a parameter which defines the maximum amount of time before stopping training.

import tensorflow_addons as tfa
from tensorflow import keras

fashionMnist = keras.datasets.fashion_mnist
(trainImages, trainLabels), (testImages, testLabels) = fashionMnist.load_data()

trainImages = trainImages / 255

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

timeStopping = tfa.callbacks.TimeStopping(seconds=10, verbose=1)

model.fit(trainImages, trainLabels, epochs=15, callbacks=[timeStopping])

We can set verbose argument to 1 to determine the epoch on which training was stopped.

Epoch 1/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.4974
Epoch 2/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3750
Epoch 3/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3372
Epoch 4/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.3131
Epoch 5/15
1875/1875 [==============================] - 2s 1ms/step - loss: 0.2959
Timed stopping at epoch 5 after training for 0:00:10

Leave a Comment

Cancel reply

Your email address will not be published.