summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/ConvTranspose2d/__init__.py
blob: 17b1b4214e70360257b7e1f080063296b88e9b1a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn


# model
class net_ConvTranspose2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.op = nn.ConvTranspose2d(2, 2, 1)

    def forward(self, input):
        return self.op(input)


_model_ = net_ConvTranspose2d()

# dummy input for onnx generation
_dummy_ = torch.randn(1, 2, 3, 3)