Trainer¶
Описание¶
Обработчик, предназначенный упростить пользователю процесс обучения модели путём устранения необходимости вручную прописывать последовательность действий. Является своего рода обёрткой вокруг операций над данными.
Info
В обработчиках, в зависимости от места размещения данных, может происходить их разбиение следующими способами:
- данные размещены на диске - сначала производится разделение данных на макробатчи (macrobatch) - блоки, которые целиком размещаются на GPU, после чего идёт разбиение макробатча на более мелкие батчи (batch), которые уже подаются непосредственно на вход модели;
- данные уже размещены на GPU - происходит только разбиение на батчи, подаваемые непосредственно на вход модели.
Инициализация¶
def __init__(self, mod, cost, optimizer, onBatchFinish=None, batchsize=128):
Параметры
Параметр | Возможные типы | Описание | По умолчанию |
---|---|---|---|
mod | Module | Обучаемая нейросеть | - |
cost | Cost | Целевая функция | - |
optimizer | Optimizer | Оптимизатор модели | - |
onBatchFinish | callable | Функция, которая будет вызываться по окончании обработки блока (batch) данных | None |
batchsize | int | Размер блока данных | 128 |
Пояснения
-
Методы¶
Со всеми основными методами обработчиков можно ознакомиться в документации к родительскому классу Handler.
trainFromHost
¶
def trainFromHost(self, data, target, macroBatchSize=10000, onMacroBatchFinish=None, random=True):
Функционал
Обёртка вокруг метода handleFromHost() родительского класса Handler, учитывающая специфику процесса обучения: при вызове метода происходит сброс накопленной ошибки и модель переводится в режим обучения (некоторые слои сети могут себя вести по-разному в режимах обучения и инференса).
Параметры
Параметр | Возможные типы | Описание | По умолчанию |
---|---|---|---|
data | tensor | Тензор данных | - |
target | tensor | Тензор соответствующих данным меток | None |
macroBatchSize | int | Размер макробатча. Данные будут разбиты на макроблоки размера macroBatchSize | 10000 |
onMacroBatchFinish | callable | Функция, которая будет вызываться после обработки макробатча | None |
random | bool | Перемешивать ли блоки данных перед обработкой в случайном порядке | True |
Пояснения
-
train
¶
def train(self, data, target, random=True):
Функционал
Обёртка вокруг метода handle() родительского класса Handler, учитывающая специфику процесса обучения: при вызове метода происходит сброс накопленной ошибки и модель переводится в режим обучения (некоторые слои сети могут себя вести по-разному в режимах обучения и инференса).
Параметры
Параметр | Возможные типы | Описание | По умолчанию |
---|---|---|---|
data | GPUArray | Тензор данных, размещённый на GPU | - |
target | GPUArray | Тензор соответствующих данным меток, размещённый на GPU | None |
random | bool | Перемешивать ли блоки данных перед обработкой в случайном порядке | True |
Пояснения
-
handleBatch
¶
def handleBatch(self, batch, idx, state):
Функционал
Корневой метод обучающего обработчика. Вычисляет градиент функции ошибки на переданном батче и производит обратное распространение ошибки, после чего запускает процесс обновления весов модели оптимизатором.
Параметры
Параметр | Возможные типы | Описание | По умолчанию |
---|---|---|---|
batch | list | Лист из двух элементов: [data, target] |
- |
idx | int | Порядковый номер блока данных | - |
state | dict | Словарь, содержащий информацию о состоянии обработки данных | - |
Пояснения
-