summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/div/__init__.py
blob: b94a5d9ab43148d3be30e43e9151c0b5c099119d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn


# model
class net_div(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        return torch.div(inputs[0], inputs[1])


_model_ = net_div()

# dummy input for onnx generation
_dummy_ = [torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3)]