summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/sqrt/__init__.py
blob: ceba051076b40ef4443684d178705b0d66a1aa35 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn


# model
class net_sqrt(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.sqrt(input)


_model_ = net_sqrt()

# dummy input for onnx generation
_dummy_ = torch.randn(1, 2, 3, 3)