Skip to content

Split

Description

Info

Parent class: Module

Derived classes: -

This module performs the operation of splitting the tensor values along a given axis.

Initializing

def __init__(self, axis, sections, name=None):

Parameters

Parameter Allowed types Description Default
axis int Axis on which the operation is performed -
sections tuple nto what parts to break the specified axis. The number of output tensors is equal to the length of a given tuple -
name str Layer name None

Explanations

-

Examples

Necessary imports.

import numpy as np
from PuzzleLib.Backend import gpuarray
from PuzzleLib.Modules import Split

Info

gpuarray is required to properly place the tensor in the GPU

batchsize, groups, size = 5, 3, 6
data =  gpuarray.to_gpu(np.random.randn(batchsize, groups, size).astype(np.float32))
split = Split(axis=2, sections=(3, 2, 1))
outdata = split(data)
for outd in outdata:
...   print(outd.shape)
(5, 3, 3)
(5, 3, 2)
(5, 3, 1)