Split¶
Description¶
This module performs the operation of splitting the tensor values along a given axis.
Initializing¶
def __init__(self, axis, sections, name=None):
Parameters
Parameter | Allowed types | Description | Default |
---|---|---|---|
axis | int | Axis on which the operation is performed | - |
sections | tuple | nto what parts to break the specified axis. The number of output tensors is equal to the length of a given tuple | - |
name | str | Layer name | None |
Explanations
-
Examples¶
Necessary imports.
>>> import numpy as np
>>> from PuzzleLib.Backend import gpuarray
>>> from PuzzleLib.Modules import Split
Info
gpuarray
is required to properly place the tensor in the GPU
>>> batchsize, groups, size = 5, 3, 6
>>> data = gpuarray.to_gpu(np.random.randn(batchsize, groups, size).astype(np.float32))
>>> split = Split(axis=2, sections=(3, 2, 1))
>>> outdata = split(data)
>>> for outd in outdata:
... print(outd.shape)
(5, 3, 3)
(5, 3, 2)
(5, 3, 1)