summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/split/__init__.py
blob: 3a323670ef8515a8ddb07d6a5434b43bb0a7630d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn


# model
class net_split(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.split(input, 2)


_model_ = net_split()

# dummy input for onnx generation
_dummy_ = torch.randn(2, 2, 3, 3)