Upsample2D

Описание

Info

Родительский класс: Module

Производные классы: -

Увеличивает размерность двумерных карт в заданное количество раз, заполняя новые ячейки значениями согласно выбранному режиму (см. параметры).

В отличие от деконволюции данный слой не является обучаемым.

Инициализация

def __init__(self, scale=2, mode="nearest", name=None):

Параметры

Параметр Возможные типы Описание По умолчанию
scale Union[int, tuple] Масштаб: во сколько раз увеличить входной тензор. 2
mode str Режим заполнения новых ячеек. "nearest"
name str Имя слоя. None

Пояснения

scale - возможна передача как единой величины масштаба по высоте и ширине, так и tuple вида (scale_h, scale_w), где scale_h - величина масштаба для высоты карты, и scale_w - для ширины;


mode - возможные варианты: "nearest" (копирует значение ближайшей ячейки), "linear" (использует линейную интерполяцию по значениям близлежащих ячеек).

Примеры

Необходимые импорты:

>>> import numpy as np
>>> from PuzzleLib.Backend import gpuarray
>>> from PuzzleLib.Modules import Upsample2D

Info

gpuarray необходим для правильного размещения тензора на GPU

>>> batchsize, maps, h, w = 1, 1, 3, 3
>>> data = gpuarray.to_gpu(np.random.randint(0, 10, (batchsize, maps, h, w)).astype(np.float32))
>>> print(data)
[[[[9. 8. 4.]
   [2. 7. 5.]
   [3. 8. 3.]]]]
>>> upsample = Upsample2D(scale=2, mode="nearest")
>>> upsample(data)
[[[[9. 9. 8. 8. 4. 4.]
   [9. 9. 8. 8. 4. 4.]
   [2. 2. 7. 7. 5. 5.]
   [2. 2. 7. 7. 5. 5.]
   [3. 3. 8. 8. 3. 3.]
   [3. 3. 8. 8. 3. 3.]]]]

Как и говорилось выше, масштаб можно задать разным для высоты и ширины (в этом случае длина кортежа масштабов должна соответствовать количеству размерностей карт):

>>> upsample = Upsample2D(scale=(2, 1), mode="nearest")
>>> upsample(data)
[[[[9. 8. 4.]
   [9. 8. 4.]
   [2. 7. 5.]
   [2. 7. 5.]
   [3. 8. 3.]
   [3. 8. 3.]]]]

При линейной интерполяции результаты будут отличаться:

>>> upsample = Upsample2D(scale=2, mode="linear")
>>> upsample(data)
[[[[9.        8.6       8.2       7.2       5.6       4.       ]
   [6.2000003 6.76      7.32      6.96      5.68      4.4      ]
   [3.3999999 4.92      6.44      6.7200003 5.7599998 4.8      ]
   [2.2       4.2       6.2       6.68      5.64      4.6      ]
   [2.6       4.6       6.6000004 6.84      5.3199997 3.8      ]
   [3.        5.        7.        7.        5.        3.       ]]]]