Skip to content

Trainer

Описание

Обработчик, предназначенный упростить пользователю процесс обучения модели путём устранения необходимости вручную прописывать последовательность действий. Является своего рода обёрткой вокруг операций над данными.

Info

В обработчиках, в зависимости от места размещения данных, может происходить их разбиение следующими способами:

  • данные размещены на диске - сначала производится разделение данных на макробатчи (macrobatch) - блоки, которые целиком размещаются на GPU, после чего идёт разбиение макробатча на более мелкие батчи (batch), которые уже подаются непосредственно на вход модели;
  • данные уже размещены на GPU - происходит только разбиение на батчи, подаваемые непосредственно на вход модели.

Инициализация

def __init__(self, mod, cost, optimizer, onBatchFinish=None, batchsize=128):

Параметры

Параметр Возможные типы Описание По умолчанию
mod Module Обучаемая нейросеть -
cost Cost Целевая функция -
optimizer Optimizer Оптимизатор модели -
onBatchFinish callable Функция, которая будет вызываться по окончании обработки блока (batch) данных None
batchsize int Размер блока данных 128

Пояснения

-

Методы

Со всеми основными методами обработчиков можно ознакомиться в документации к родительскому классу Handler.

trainFromHost

def trainFromHost(self, data, target, macroBatchSize=10000, onMacroBatchFinish=None, random=True):

Функционал

Обёртка вокруг метода handleFromHost() родительского класса Handler, учитывающая специфику процесса обучения: при вызове метода происходит сброс накопленной ошибки и модель переводится в режим обучения (некоторые слои сети могут себя вести по-разному в режимах обучения и инференса).

Параметры

Параметр Возможные типы Описание По умолчанию
data tensor Тензор данных -
target tensor Тензор соответствующих данным меток None
macroBatchSize int Размер макробатча. Данные будут разбиты на макроблоки размера macroBatchSize 10000
onMacroBatchFinish callable Функция, которая будет вызываться после обработки макробатча None
random bool Перемешивать ли блоки данных перед обработкой в случайном порядке True

Пояснения

-

train

def train(self, data, target, random=True):

Функционал

Обёртка вокруг метода handle() родительского класса Handler, учитывающая специфику процесса обучения: при вызове метода происходит сброс накопленной ошибки и модель переводится в режим обучения (некоторые слои сети могут себя вести по-разному в режимах обучения и инференса).

Параметры

Параметр Возможные типы Описание По умолчанию
data GPUArray Тензор данных, размещённый на GPU -
target GPUArray Тензор соответствующих данным меток, размещённый на GPU None
random bool Перемешивать ли блоки данных перед обработкой в случайном порядке True

Пояснения

-

handleBatch

def handleBatch(self, batch, idx, state):

Функционал

Корневой метод обучающего обработчика. Вычисляет градиент функции ошибки на переданном батче и производит обратное распространение ошибки, после чего запускает процесс обновления весов модели оптимизатором.

Параметры

Параметр Возможные типы Описание По умолчанию
batch list Лист из двух элементов: [data, target] -
idx int Порядковый номер блока данных -
state dict Словарь, содержащий информацию о состоянии обработки данных -

Пояснения

-