Cost¶
Описание¶
Info
Данный модуль предназначен, в первую очередь, для разработчиков, которые хотят глубже понять устройство библиотеки, а также для тех, кто собирается реализовывать свои модули.
Costs (критерии обучения, ошибки, функции потерь) – это функции, вычисляющие ошибку сети.
Cost - родительский класс, от которого наследуются все функции потерь.
Инициализация¶
def __init__(self):
Параметры
-
Пояснения
-
Методы¶
resetAccumulator¶
def resetAccumulator(self):
Выполняет сброс глобальных параметров класса, то есть тех, которые принимают участие в вычислении на конкретной эпохе.
Сбрасывает накопленную ошибку, обнуляет количество батчей и семплов. Обёртка вокруг функции resetDeviceAccumulator.
Применяется при начале вычислений на новой эпохе.
Параметры
-
Пояснения
-
updateState¶
def updateState(self, samples):
Обновляет размер батча и количество семплов.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| samples | int | Размер батча | - | 
Пояснения
-
resetDeviceAccumulator¶
def resetDeviceAccumulator(self):
Выполняет сброс накопленной ошибки.
Параметры
-
Пояснения
-
getError¶
def getError(self):
Достает из GPU-памяти и возвращает ошибку на батче.
Параметры
-
Пояснения
-
getMeanError¶
def getMeanError(self):
Достает из GPU-памяти и возвращает среднюю ошибку на выборке.
Параметры
-
Пояснения
-
getValError¶
def getValError(self):
Возвращает валидационную ошибку. Метод должен быть вызван после вызова метода validate.
Параметры
-
Пояснения
-
__call__¶
def __call__(self, pred, target, queryError=True):
Считает градиент и/или ошибку для двух векторов.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
| queryError | bool | Стягивать или нет значение ошибки с GPU | True | 
Пояснения
queryError - Если False, то метод будет выполняться немного быстрее, так как не будет происходить процесс загрузки значения ошибки с GPU. Также, если выставлено значение False, метод вернёт только self.grad, в оборатном случае возвращаются self.error, self.grad.
calcError¶
def calcError(self, pred, target):
Абстрактный метод, который необходимо реализовать в классах-потомках. После реализации должен обновлять величину накопленной в течение данной эпохи ошибки.
Опциональное вычисление - расчёт ошибки на батч (может проводиться и в calcGrad, зависит от функции ошибки).
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
calcGrad¶
def calcGrad(self, pred, target):
Абстрактный метод, который необходимо реализовать в классах-потомках. После реализации должен считать градиент ошибки сети.
Опциональное вычисление - расчёт ошибки на батч (может проводиться и в calcError, зависит от функции ошибки).
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
calcVal¶
def calcVal(self, pred, target):
Абстрактный метод, который необходимо реализовать в классах-потомках. После реализации должен считать и возвращать ошибку для заданных векторов, не оказывая влияния на внутреннее состояние класса.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
validate¶
def validate(self, pred, target):
Обёртка вокруг функции calcVal с проверкой на корректность.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
reset¶
def reset(self):
Выполняет сброс батчевых параметров класса, то есть тех, которые принимают участие в вычислении конкретного батча.
Применяется при начале вычислений на новом батче.
Параметры
-
Пояснения
-
checkDataShape¶
def checkDataShape(self, pred, target):
Выполняет проверку корректности размерности тензоров входных данных. По умолчанию не делает ничего.
Метод может быть содержательно переопределен в классах-наследниках: для этого он должен поднимать исключение в случае, когда размерность некорректная.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
checkValDataShape¶
def checkValDataShape(self, pred, target):
Выполняет проверку корректности размерности тензоров входных данных для метода validate. По умолчанию не делает ничего.
Метод может быть содержательно переопределен в классах-наследниках: для этого он должен поднимать исключение в случае, когда размерность некорректная.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Предсказанные лейблы | - | 
| target | tensor | Истинные лейблы | - | 
Пояснения
-
getBatchsize¶
def getBatchsize(self, pred):
Считает размер батча для заданного тензора.
Параметры
| Параметр | Возможные типы | Описание | По умолчанию | 
|---|---|---|---|
| pred | tensor | Тензор лейблов | - | 
Пояснения
-