MaxUnpool2D

Описание

Info

Родительский класс: Module

Производные классы: -

Данный модуль вычисляет частичную обратную величину двумерного максимизирующего пулинга MaxPool2D.

Операция пулинга не является полностью обратимой, так как при обработке теряются немаксимальные значения.

Данный модуль принимает в качестве входных данных выходные данные MaxPool2D, включая индексы максимальных значений, и вычисляет частичную инверсию тензора-кандидата, располагая его элементы в выходном тензоре согласно индексам максимальных значений. Все немаксимальные значения нового тензора будут нулевыми.

Инициализация

def __init__(self, maxpool2d, name=None):

Параметры

Параметр Возможные типы Описание По умолчанию
maxpool2d Module Экземпляр класса MaxPool2D -
name str Имя слоя None

Пояснения

-

Примеры

Необходимые импорты.

import numpy as np
from PuzzleLib.Backend import gpuarray
from PuzzleLib.Modules import MaxPool2D, MaxUnpool2D

Info

gpuarray необходим для правильного размещения тензора на GPU

Для упрощения размер батча и количество карт примем равными 1:

np.random.seed(123)
batchsize, maps, h, w = 1, 1, 6, 6
indata = gpuarray.to_gpu(np.random.randint(0, 9, (batchsize, maps, h, w)).astype(np.float32))
print(indata)
[[[[2. 2. 6. 1. 3. 6.]
   [1. 0. 1. 0. 0. 3.]
   [4. 0. 0. 4. 1. 7.]
   [3. 2. 4. 7. 2. 4.]
   [8. 0. 7. 3. 4. 6.]
   [1. 5. 6. 2. 1. 8.]]]]

Экземпляр класса MaxPool2D инициализируем с параметрами по умолчанию (size=2, stride=2, pad=0). Как и говорилось выше, экземпляр класса MaxUnpool2D должен принимать на вход объект MaxPool2D:

maxpool2d = MaxPool2D()
maxunpool2d = MaxUnpool2D(maxpool2d)
print(maxpool2d(indata))
[[[[2. 6. 6.]
   [4. 7. 7.]
   [8. 7. 8.]]]]

Создадим тензор-кандидат и проведём над ним операцию:

data = data = gpuarray.to_gpu(np.random.randint(0, 9, maxpool2d.data.shape).astype(np.float32))
print(data)
[[[[3. 5. 0.]
   [2. 6. 2.]
   [4. 4. 6.]]]]
print(maxunpool2d(data))
[[[[3. 0. 5. 0. 0. 0.]
   [0. 0. 0. 0. 0. 0.]
   [2. 0. 0. 0. 0. 2.]
   [0. 0. 0. 6. 0. 0.]
   [4. 0. 4. 0. 0. 0.]
   [0. 0. 0. 0. 0. 6.]]]]