Upsample2D¶
Описание¶
Увеличивает размерность двумерных карт в заданное количество раз, заполняя новые ячейки значениями согласно выбранному режиму (см. параметры).
В отличие от деконволюции данный слой не является обучаемым.
Инициализация¶
def __init__(self, scale=2, mode="nearest", name=None):
Параметры
Параметр | Возможные типы | Описание | По умолчанию |
---|---|---|---|
scale | Union[int, tuple] | Масштаб: во сколько раз увеличить входной тензор | 2 |
mode | str | Режим заполнения новых ячеек | "nearest" |
name | str | Имя слоя | None |
Пояснения
scale
- возможна передача как единой величины масштаба по высоте и ширине, так и tuple вида (scale_h, scale_w)
, где scale_h
- величина масштаба для высоты карты, и scale_w
- для ширины;
mode
- возможные варианты: "nearest" (копирует значение ближайшей ячейки), "linear" (использует линейную интерполяцию по значениям близлежащих ячеек).
Примеры¶
Необходимые импорты:
import numpy as np
from PuzzleLib.Backend import gpuarray
from PuzzleLib.Modules import Upsample2D
Info
gpuarray
необходим для правильного размещения тензора на GPU
np.random.seed(123)
batchsize, maps, h, w = 1, 1, 3, 3
data = gpuarray.to_gpu(np.random.randint(0, 10, (batchsize, maps, h, w)).astype(np.float32))
print(data)
[[[[2. 2. 6.]
[1. 3. 9.]
[6. 1. 0.]]]]
upsample = Upsample2D(scale=2, mode="nearest")
print(upsample(data))
[[[[2. 2. 2. 2. 6. 6.]
[2. 2. 2. 2. 6. 6.]
[1. 1. 3. 3. 9. 9.]
[1. 1. 3. 3. 9. 9.]
[6. 6. 1. 1. 0. 0.]
[6. 6. 1. 1. 0. 0.]]]]
Как и говорилось выше, масштаб можно задать разным для высоты и ширины (в этом случае длина кортежа масштабов должна соответствовать количеству размерностей карт):
upsample = Upsample2D(scale=(2, 1), mode="nearest")
print(upsample(data))
[[[[2. 2. 6.]
[2. 2. 6.]
[1. 3. 9.]
[1. 3. 9.]
[6. 1. 0.]
[6. 1. 0.]]]]
При линейной интерполяции результаты будут отличаться:
upsample = Upsample2D(scale=2, mode="linear")
print(upsample(data))
[[[[2. 2. 2. 2.8000002 4.4 6. ]
[1.6 1.9200001 2.2400002 3.3600004 5.28 7.2000003 ]
[1.2 1.84 2.48 3.9200003 6.1600003 8.400001 ]
[2.0000002 2.2400002 2.48 3.52 5.3599997 7.2 ]
[4. 3.1200001 2.24 2.16 2.8799999 3.6 ]
[6. 4. 2. 0.79999995 0.39999998 0. ]]]]