Skip to content

SMORMS3

Описание

Info

Родительский класс: Optimizer

Производные классы: -

SMORMS3 (squared mean over root mean squared cubed) - один из алгоритмов оптимизации, впервые упомянутый в данном источнике. Является гибридом RMSProp и метода ЛеКуна.

Введём:

r = \frac{1}{mem + 1}

После чего:

\begin{equation} m_t = r{m_{t-1}} + (1 - r)g_t \end{equation}

\begin{equation} \upsilon_t = r{\upsilon_{t-1}} + (1 - r)g_t^2 \end{equation}

Тогда процесс обновления параметров:

\theta_{t+1} = \theta_t - \frac{min(\alpha, \frac{m_t^2}{\upsilon + \epsilon})}{\sqrt{\upsilon_t^2 + \epsilon}} g_t
mem_{t+1} = 1 + mem_t(1 - \frac{m_t^2}{\upsilon + \epsilon})

Инициализация

def __init__(self, learnRate=1e-3, epsilon=1e-16, nodeinfo=None):

Параметры

Параметр Возможные типы Описание По умолчанию
learnRate float Скорость обучения 1e-3
epsilon float Сглаживающий параметр 1e-5
nodeinfo NodeInfo Объект, содержащий информацию о вычислительном узле None

Пояснения

-

Примеры


Необходимые импорты:

import numpy as np
from PuzzleLib.Optimizers import RMSProp
from PuzzleLib.Backend import gpuarray

Info

gpuarray необходим для правильного размещения тензора на GPU.

Создадим синтетическую обучающую выборку:

data = gpuarray.to_gpu(np.random.randn(16, 128).astype(np.float32))
target = gpuarray.to_gpu(np.random.randn(16, 1).astype(np.float32))

Объявляем оптимизатор:

optimizer = RMSProp(learnRate=0.001, factor=0.9)

Пусть уже есть некоторая сеть net, определённая, например, через Graph, тогда, чтобы установить оптимизатор на сеть, требуется следующее:

optimizer.setupOn(net, useGlobalState=True)

Info

Подробнее про методы оптимизаторов и их параметры читайте в описании родительского класса Optimizer

Также пусть есть некая функция ошибки loss, наследованная от Cost, рассчитывающая в т.ч. её градиент. Тогда получаем реализацию процесса оптимизации:

for i in range(100):
... predictions = net(data)
... error, grad = loss(predictions, target)

... optimizer.zeroGradParams()
... net.backward(grad)
... optimizer.update()

... if (i + 1) % 5 == 0:
...   print("Iteration #%d error: %s" % (i + 1, error))