Upsample2D

Описание

Info

Родительский класс: Module

Производные классы: -

Увеличивает размерность двумерных карт в заданное количество раз, заполняя новые ячейки значениями согласно выбранному режиму (см. параметры).

В отличие от деконволюции данный слой не является обучаемым.

Инициализация

def __init__(self, scale=2, mode="nearest", name=None):

Параметры

Параметр Возможные типы Описание По умолчанию
scale Union[int, tuple] Масштаб: во сколько раз увеличить входной тензор 2
mode str Режим заполнения новых ячеек "nearest"
name str Имя слоя None

Пояснения

scale - возможна передача как единой величины масштаба по высоте и ширине, так и tuple вида (scale_h, scale_w), где scale_h - величина масштаба для высоты карты, и scale_w - для ширины;


mode - возможные варианты: "nearest" (копирует значение ближайшей ячейки), "linear" (использует линейную интерполяцию по значениям близлежащих ячеек).

Примеры

Необходимые импорты:

import numpy as np
from PuzzleLib.Backend import gpuarray
from PuzzleLib.Modules import Upsample2D

Info

gpuarray необходим для правильного размещения тензора на GPU

np.random.seed(123)
batchsize, maps, h, w = 1, 1, 3, 3
data = gpuarray.to_gpu(np.random.randint(0, 10, (batchsize, maps, h, w)).astype(np.float32))
print(data)
[[[[2. 2. 6.]
   [1. 3. 9.]
   [6. 1. 0.]]]]
upsample = Upsample2D(scale=2, mode="nearest")
print(upsample(data))
[[[[2. 2. 2. 2. 6. 6.]
   [2. 2. 2. 2. 6. 6.]
   [1. 1. 3. 3. 9. 9.]
   [1. 1. 3. 3. 9. 9.]
   [6. 6. 1. 1. 0. 0.]
   [6. 6. 1. 1. 0. 0.]]]]

Как и говорилось выше, масштаб можно задать разным для высоты и ширины (в этом случае длина кортежа масштабов должна соответствовать количеству размерностей карт):

upsample = Upsample2D(scale=(2, 1), mode="nearest")
print(upsample(data))
[[[[2. 2. 6.]
   [2. 2. 6.]
   [1. 3. 9.]
   [1. 3. 9.]
   [6. 1. 0.]
   [6. 1. 0.]]]]

При линейной интерполяции результаты будут отличаться:

upsample = Upsample2D(scale=2, mode="linear")
print(upsample(data))
[[[[2.         2.         2.         2.8000002  4.4        6.        ]
   [1.6        1.9200001  2.2400002  3.3600004  5.28       7.2000003 ]
   [1.2        1.84       2.48       3.9200003  6.1600003  8.400001  ]
   [2.0000002  2.2400002  2.48       3.52       5.3599997  7.2       ]
   [4.         3.1200001  2.24       2.16       2.8799999  3.6       ]
   [6.         4.         2.         0.79999995 0.39999998 0.        ]]]]