Остановка и возобновление обучения сети

Введение

В этом туториале мы рассмотрим такую возможность библиотеки PuzzleLib, как продолжение обучения с некоторого чекпойнта. Скрипт называется ResumeTrain.py, находится в папке TestLib.

Возможность останавливать обучение модели по необходимости, чтобы потом продолжить - важная, но не сложная техническая задача. Она может помочь организовывать длинные процессы обучения, сохранять чекпойнты, загружаться и продолжать обучение от любого из них и, как следствие, безопасно экспериментировать.

Перед продолжением чтения рекомендуем пройти, если вы ещё этого не сделали, туториал Обучение MNIST классификатора, так как данный материал базируется на нём и не затрагивает некоторые моменты, которые были освещены в MNIST-туториале.

Обучающая выборка

Зайдите на сайт Яна ЛеКуна (создателя датасета MNIST) и скачайте следующие файлы:

  • t10k-images.idx3-ubyte.gz
  • t10k-labels.idx1-ubyte.gz
  • train-images.idx3-ubyte.gz
  • train-labels.idx1-ubyte.gz

Поместите скачанные файлы в выбранный вами каталог, а затем распакуйте. Получившиеся файлы содержат 70000 чёрно-белых изображений размера 28 на 28 пикселей с рукописными цифрами от 0 до 9.

Проверка скрипта

Откройте скрипт ResumeTrain.py. Запустите, проверьте, что скрипт запускается и доходит до конца, не упав. Если скрипт падает, значит, вам нужно разобраться с установкой PuzzleLib (у вас может не работать Cuda или может отсутствовать какая-нибудь из Python-библиотек).

Если у вас всё получилось, переходите дальше: мы последовательно пойдём по содержимому скрипта.

Структура скрипта

Давайте взглянем на структуру скрипта, т.к. хоть он и базируется на туториале Обучение MNIST классификатора, здесь всё же присутствуют некоторые отличия.

Импорты аналогичные:

import os

import numpy as np

from PuzzleLib.Datasets import MnistLoader

from PuzzleLib.Containers import Sequential
from PuzzleLib.Modules import Conv2D, MaxPool2D, Activation, Flatten, Linear
from PuzzleLib.Modules.Activation import relu
from PuzzleLib.Handlers import Trainer, Validator
from PuzzleLib.Optimizers import MomentumSGD
from PuzzleLib.Cost import CrossEntropy

Первое значительное отличие - появилась отдельная функция для построения архитектуры сети - buildNet:

def buildNet():
    net = Sequential()
    net.append(Conv2D(1, 16, 3))
    net.append(MaxPool2D())
    net.append(Activation(relu))

    net.append(Conv2D(16, 32, 4))
    net.append(MaxPool2D())
    net.append(Activation(relu))

    net.append(Flatten())
    net.append(Linear(32 * 5 * 5, 1024))
    net.append(Activation(relu))

    net.append(Linear(1024, 10))

    return net

Также вынесен в отдельную функцию сам процесс обучения - train:

def train(net, optimizer, data, labels, epochs):
    cost = CrossEntropy(maxlabels=10)
    trainer = Trainer(net, cost, optimizer)
    validator = Validator(net, cost)

    for i in range(epochs):
        trainer.trainFromHost(data[:60000], labels[:60000], macroBatchSize=60000,
                              onMacroBatchFinish=lambda tr: print("Train error: %s" % tr.cost.getMeanError()))
        print("Accuracy: %s" % (1.0 - validator.validateFromHost(data[60000:], labels[60000:], macroBatchSize=10000)))

        optimizer.learnRate *= 0.9
        print("Reduced optimizer learn rate to %s" % optimizer.learnRate)

Аргументы функции:

  • net - объект, представляющий сеть в библиотеке (в нашем случае Sequential);
  • optimizer - оптимизатор сети, являющийся объектом класса Optimizer из семейства Optimizers;
  • data - тензор данных формата np.ndarray; в нашем случае это будут тензора размерности (N, C, H, W), где N - общее количество картинок, C - канальность картинок (для MNIST картинки ч/б, т.е. одноканальные), H - высота и W - ширина картинок (28 и 28 соответственно);
  • labels - вектор ярлыков (лейблов) формата np.ndarray и длины N для соответствующих картинок;
  • epochs - количество эпох обучения.

И последняя функция - main. Её начало похоже на аналогичную функцию из вышеупомянутого туториала:

def main():
    path = "../TestData/"
    mnist = MnistLoader()
    data, labels = mnist.load(path=path)
    data, labels = data[:], labels[:]
    print("Loaded mnist")

    np.random.seed(1234)

Important

Не забудьте поменять значение переменной path на тот путь, по которому вы распаковали архивы с датасетом.

Следующим шагом создаём сеть и устанавливаем на неё оптимизатор:

    net = buildNet()

    optimizer = MomentumSGD()
    optimizer.setupOn(net, useGlobalState=True)
    optimizer.learnRate = 0.1
    optimizer.momRate = 0.9

Теперь нам нужно провести несколько эпох обучения - мы выбрали 10:

    epochs = 10
    print("Training for %s epochs ..." % epochs)
    train(net, optimizer, data, labels, epochs)

По истечении 10 эпох объекты net и optimizer стали обладать некоторым специфическим внутренним состоянием (распределением значений параметров), которое мы хотим зафиксировать, чтобы использовать в дальнейшем для продолжения обучения:

    print("Saving net and optimizer ...")
    net.save(os.path.join(path, "net.hdf"))
    optimizer.save(os.path.join(path, "optimizer.hdf"))

И наконец, в тот момент, когда мы захотим продолжить обучение, мы просто загружаем параметры из файлов (в предварительно созданные объекты соответствующих классов):

    print("Reloading net and optimizer ...")
    net.load(os.path.join(path, "net.hdf"))
    optimizer.load(os.path.join(path, "optimizer.hdf"))

    print("Continuing training for %s epochs ..." % epochs)
    train(net, optimizer, data, labels, epochs)

В конце скрипта выполняется чистка от файлов, созданных за время его выполнения:

    os.remove(os.path.join(path, "net.hdf"))
    os.remove(os.path.join(path, "optimizer.hdf"))