summaryrefslogtreecommitdiff
path: root/torch/functional.py
blob: bcbceffca8ba6c364f2ee56d1141aa5dc6cc9906 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from ._utils import _range


def split(tensor, split_size, dim=0):
    if dim < 0:
        dim += tensor.dim()
    dim_size = tensor.size(dim)
    num_splits = (dim_size + split_size - 1) // split_size
    last_split_size = split_size - (split_size * num_splits - dim_size)

    def get_split_size(i):
        return split_size if i < num_splits - 1 else last_split_size
    return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
                 in _range(0, num_splits))


def chunk(tensor, n_chunks, dim=0):
    if dim < 0:
        dim += tensor.dim()
    split_size = (tensor.size(dim) + n_chunks - 1) // n_chunks
    return split(tensor, split_size, dim)


def stack(sequence, dim=0):
    if len(sequence) == 0:
        raise TypeError("stack expects a non-empty sequence of tensors")
    if dim < 0:
        dim += sequence[0].dim()
    return torch.cat(list(t.unsqueeze(dim) for t in sequence), dim)