Skip to content

Остановка и продолжение обучения сети

Запустить в Google Colab

Roundicons.com Посмотреть на GitHub

Скачать ноутбук

Введение

В этом туториале мы рассмотрим такую возможность библиотеки PuzzleLib, как продолжение обучения с некоторого чекпойнта.

Возможность останавливать обучение модели по необходимости, чтобы потом продолжить - важная, но не сложная техническая задача. Она может помочь организовывать длинные процессы обучения, сохранять чекпойнты, загружаться и продолжать обучение от любого из них и, как следствие, безопасно экспериментировать.

Перед продолжением чтения рекомендуем пройти, если вы ещё этого не сделали, туториал Обучение MNIST классификатора, так как данный материал базируется на нём и не затрагивает некоторые моменты, которые были освещены в MNIST-туториале.

Обучающая выборка

Зайдите на сайт Яна ЛеКуна (создателя датасета MNIST) и скачайте следующие файлы:

  • t10k-images.idx3-ubyte.gz
  • t10k-labels.idx1-ubyte.gz
  • train-images.idx3-ubyte.gz
  • train-labels.idx1-ubyte.gz

Поместите скачанные файлы в выбранный вами каталог, а затем распакуйте. Получившиеся файлы содержат 70000 чёрно-белых изображений размера 28 на 28 пикселей с рукописными цифрами от 0 до 9.

Реализация инструментами библиотеки

Давайте взглянем на структуру скрипта, т.к. хоть он и базируется на туториале Обучение MNIST классификатора, здесь всё же присутствуют некоторые отличия.

Импорты аналогичные:

import numpy as np
import os

from PuzzleLib.Datasets import MnistLoader

from PuzzleLib.Models.Nets.LeNet import loadLeNet

from PuzzleLib.Containers import Sequential
from PuzzleLib.Modules import Conv2D, MaxPool2D, Activation, Flatten, Linear
from PuzzleLib.Modules.Activation import relu
from PuzzleLib.Handlers import Trainer, Validator
from PuzzleLib.Optimizers import MomentumSGD
from PuzzleLib.Cost import CrossEntropy
В библиотеке уже реализована сеть Lenet и в следующей строке показано как её можно вызвать для дальнейшнего использования.
#net = loadLeNet(None, initscheme=None)

Но для лучшего понимания реализуем все слои самостоятельно. Для удобства вынесем построение архитекутры сети в отдельную функцию buildNet:

def buildNet():
    net = Sequential()
    net.append(Conv2D(1, 16, 3))
    net.append(MaxPool2D())
    net.append(Activation(relu))

    net.append(Conv2D(16, 32, 4))
    net.append(MaxPool2D())
    net.append(Activation(relu))

    net.append(Flatten())
    net.append(Linear(32 * 5 * 5, 1024))
    net.append(Activation(relu))

    net.append(Linear(1024, 10))

    return net

Также вынесен в отдельную функцию сам процесс обучения - train:

def train(net, optimizer, data, labels, epochs):
    cost = CrossEntropy(maxlabels=10)
    trainer = Trainer(net, cost, optimizer)
    validator = Validator(net, cost)

    for i in range(epochs):
        trainer.trainFromHost(data[:60000], labels[:60000], macroBatchSize=60000,
                              onMacroBatchFinish=lambda tr: print("Train error: %s" % tr.cost.getMeanError()))
        print("Accuracy: %s" % (1.0 - validator.validateFromHost(data[60000:], labels[60000:], macroBatchSize=10000)))

        optimizer.learnRate *= 0.9
        print("Reduced optimizer learn rate to %s" % optimizer.learnRate)

Аргументы функции:

  • net - объект, представляющий сеть в библиотеке (в нашем случае Sequential);
  • optimizer - оптимизатор сети, являющийся объектом класса Optimizer из семейства Optimizers;
  • data - тензор данных формата np.ndarray; в нашем случае это будут тензора размерности (N, C, H, W), где

N - общее количество картинок,

C - канальность картинок (для MNIST картинки ч/б, т.е. одноканальные),

H - высота и W - ширина картинок (28 и 28 соответственно);
- labels - вектор ярлыков (лейблов) формата np.ndarray и длины N для соответствующих картинок;
- epochs - количество эпох обучения.

Загрузим данные:

path="./"
mnist = MnistLoader()
data, labels = mnist.load(path)
data, labels = data[:], labels[:]
print("Loaded mnist")

np.random.seed(1234)

Important

Не забудьте поменять значение переменной path на тот путь, по которому вы распаковали архивы с датасетом.

Следующим шагом создаём сеть и устанавливаем на неё оптимизатор:

net = buildNet()

optimizer = MomentumSGD()
optimizer.setupOn(net, useGlobalState=True)
optimizer.learnRate = 0.1
optimizer.momRate = 0.9

Теперь нам нужно провести несколько эпох обучения - мы выбрали 10:

epochs = 10
print("Training for %s epochs ..." % epochs)
train(net, optimizer, data, labels, epochs)

По истечении 10 эпох объекты net и optimizer стали обладать некоторым специфическим внутренним состоянием (распределением значений параметров), которое мы хотим зафиксировать, чтобы использовать в дальнейшем для продолжения обучения:

print("Saving net and optimizer ...")
net.save(os.path.join(path, "net.hdf"))
optimizer.save(os.path.join(path, "optimizer.hdf"))

И наконец, в тот момент, когда мы захотим продолжить обучение, мы просто загружаем параметры из файлов (в предварительно созданные объекты соответствующих классов):

print("Reloading net and optimizer ...")
net.load(os.path.join(path, "net.hdf"))
optimizer.load(os.path.join(path, "optimizer.hdf"))

print("Continuing training for %s epochs ..." % epochs)
train(net, optimizer, data, labels, epochs)

В конце скрипта выполняется чистка от файлов, созданных за время его выполнения:

os.remove(os.path.join(path, "net.hdf"))
os.remove(os.path.join(path, "optimizer.hdf"))