summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/Tanh/__init__.py
blob: 76b46298a3560f15382d46fde9cf299cb75305a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn


# model
class net_Tanh(nn.Module):
    def __init__(self):
        super().__init__()
        self.op = nn.Tanh()

    def forward(self, input):
        return self.op(input)


_model_ = net_Tanh()

# dummy input for onnx generation
_dummy_ = torch.randn(1, 2, 3, 3)