1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
|
import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
import warnings
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import partial, wraps
import numpy
import math
# EDITING THIS FILE? READ THIS FIRST!
#
# - This file is ONLY for ATen operators (e.g., operators that show up in the
# trace as aten::blah). If you need to special case a primitive operator,
# look at _run_symbolic_function
# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
# tensors are always first, then non-tensor arguments.
# - Parameter names must *exactly* match the names in VariableType.cpp, because
# dispatch is done with keyword arguments.
# - Looking for inplace ops? They're detected by the trailing underscore, and
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
#
# ----------------------------------------------------------------------------------
# A note on Tensor types
# ----------------------------------------------------------------------------------
#
# In general, we should avoid depending on the type of Tensor Values contained
# within the trace graph. However, this is sometimes unavoidable (due to ONNX
# spec requirements, etc). If you are implementing a symbolic and need Tensor
# type information, note that there are several levels of Tensor types, defined
# in aten/src/ATen/core/jit_type.h:
#
# TensorType - This is a Tensor, but we don't know anything about its
# properties (e.g. scalar type, # dims, shapes).
# Appears as `Tensor` in graph print-outs.
# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
# type and number of dimensions, but not the concrete
# shapes. For example, appears as 'Float(*, *)' in
# graph print-outs. Useful accessor methods include
# dim() and scalarType()
# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying
# on the number of dimensions (DimensionedTensorType) which is better than relying on
# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
# more robust to different graphs.
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value in the trace")
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
else:
raise RuntimeError("Casting constants to `{}` is not implemented".format(desc))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == "prim::ListConstruct"
return list(list_node.inputs())
def parse_args(*arg_descriptors):
def decorator(fn):
def wrapper(g, *args):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _if_scalar_type_as(g, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
We only support implicit casting for scalars, so we never
actually need to insert an ONNX cast operator here; just
fix up the scalar.
"""
if isinstance(self, torch._C.Value):
return self
elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
ty = tensor.type().scalarType().lower()
return getattr(self, ty)()
else:
return self
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
# READ ME BEFORE EDITING _default_onnx_opset_version:
#
# The variable below controls which ONNX operator set version we are
# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
# change occurred in version 8. As long as this variable < 8, you can
# export models targeting the old behavior. However, if you bump
# this variable to 8 or later, the breaking change will take into effect:
# you MUST adjust any symbolic affected by breaking changes. The ONNX
# spec publishes a *comprehensive* list of BC-breaking changes for every
# operator revision at:
#
# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
#
# Please be sure to go through and check all of our implementations here before
# increasing this number. This includes symbolic definitions NOT in this
# file, so grep for "OpName" (with quotes)
#
# Besides, opset_version can be specified in the invocation of export()
# and export_to_pretty_string(), and _export_onnx_opset_version will be set
# and the symbolic functions should check it to determine the behavior
# of the exporter.
_default_onnx_opset_version = 9
_onnx_master_opset = 10
_onnx_stable_opsets = [9]
_export_onnx_opset_version = _default_onnx_opset_version
def _set_opset_version(opset_version):
global _export_onnx_opset_version
if opset_version == _default_onnx_opset_version:
return
if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
_export_onnx_opset_version = opset_version
return
raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
# ---------------------------------------------------------------------
# Symbolic definitions
# ---------------------------------------------------------------------
# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
# some moving parts to implementing the ONNX translation in this case:
#
# - By the time we get the scalar in a symbolic function here, it is no longer
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
# want it to be a zero dim tensor but this change has not happened yet.)
# However, the type of this scalar is *exactly* what the user wrote in
# Python, which may not match the tensor it is being added to. PyTorch
# will do implicit conversions on scalars; however, ONNX will not, so
# we must do the conversion ourselves. This is what _if_scalar_type_as
# does.
#
# - Dispatch to these functions takes advantage an outrageous coincidence
# between the tensor and scalar name. When we add two tensors together,
# you get the dispatch:
#
# add(*[self, other], **{"alpha": alpha})
#
# When you add a tensor and a scalar, you get the dispatch:
#
# add(*[self], **{"other": other, "alpha": alpha})
#
# By having the argument name line up with the name of the scalar attribute
# if it exists, we can write a single function for both overloads.
#
# used to represent "missing" optional inputs
def unused(g):
n = g.op("prim::Constant")
n.setType(OptionalType.ofTensor())
return n
def _shape_as_tensor(g, input):
return g.op('Shape', input)
def _reshape_from_tensor(g, input, shape):
return g.op('Reshape', input, shape)
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)
def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Add", self, _if_scalar_type_as(g, other, self))
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
# See Note [Pointwise by scalar]. Note that self or other may be scalars.
other = _maybe_get_scalar(other)
return g.op("Sub", self, _if_scalar_type_as(g, other, self))
def rsub(g, self, other, alpha=None):
other = _maybe_get_scalar(other)
other = _if_scalar_type_as(g, other, self)
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Mul", self, _if_scalar_type_as(g, other, self))
def div(g, self, other):
# See Note [Pointwise by scalar]
other = _maybe_get_scalar(other)
return g.op("Div", self, _if_scalar_type_as(g, other, self))
def reciprocal(g, self):
return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
@parse_args('v', 'i')
def cat(g, tensor_list, dim):
tensors = _unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args('v', 'i')
def stack(g, tensor_list, dim):
unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
ty = _try_get_scalar_type(self, other).lower()
C = g.constant(0, [1], ty)
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def bmm(g, self, other):
return g.op("MatMul", self, other)
def matmul(g, self, other):
return g.op("MatMul", self, other)
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
def neg(g, self):
return g.op("Neg", self)
def sqrt(g, self):
return g.op("Sqrt", self)
def tanh(g, self):
return g.op("Tanh", self)
def sin(g, self):
return g.op("Sin", self)
def cos(g, self):
return g.op("Cos", self)
def tan(g, self):
return g.op("Tan", self)
def asin(g, self):
return g.op("Asin", self)
def acos(g, self):
return g.op("Acos", self)
def atan(g, self):
return g.op("Atan", self)
def sigmoid(g, self):
return g.op("Sigmoid", self)
def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
# dim-reduce path
dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
return symbolic
mean = _reduce_op_symbolic('ReduceMean')
sum = _reduce_op_symbolic('ReduceSum')
prod = _reduce_op_symbolic('ReduceProd')
@parse_args('v', 'i')
def cumsum(g, input, dim):
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
def t(g, self):
return g.op("Transpose", self, perm_i=(1, 0))
def expand(g, self, size, implicit):
size = _maybe_get_const(size, 'is')
if not _is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Expand", self, size)
def expand_as(g, self, other):
shape = g.op("Shape", other)
return g.op("Expand", self, shape)
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
return g.op("Gather", weight, indices)
@parse_args('v', 'v', 'v', 'i', 'i', 'i')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse):
return g.op("ATen",
embedding_matrix,
indices,
offsets,
operator_s="embedding_bag",
outputs=4,
scale_grad_by_freq_i=scale_grad_by_freq,
mode_i=mode,
sparse_i=sparse)
def size(g, self, dim):
full_shape = g.op("Shape", self)
return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
@parse_args('v', 'i', 'i')
def transpose(g, self, dim0, dim1):
if dim0 == dim1: # micro-optimization
return self
# NB: Transpose in ONNX is actually a Permute
axes = list(range(self.type().dim()))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
@parse_args('v', 'is')
def permute(g, self, dims):
if dims == list(range(0, len(dims))):
return self
return g.op("Transpose", self, perm_i=dims)
def view(g, self, size):
size = _maybe_get_const(size, 'is')
if _is_value(size):
shape = size
else:
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
return g.op("Flatten", self, axis_i=1)
shape = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Reshape", self, shape)
def prim_ConstantSplit(g, self, split_size, dim):
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
# TODO: It would be better to export this as a chunk directly, as this is
# less sensitive to changes in input size.
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
# method, and use the desugared version
def prim_ConstantChunk(g, self, chunks, dim):
split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
return prim_ConstantSplit(g, self, split_size, dim)
@parse_args('v', 'i', 'i')
def split(g, self, split_size, dim):
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1)
@parse_args('v', 'is', 'i')
def split_with_sizes(g, self, split_sizes, dim):
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
@parse_args('v', 'i', 'v')
def select(g, self, dim, index):
if dim > 1:
# TODO: this is a temporary hack because of the implementation details
# of Gather in caffe2. We need to change this as soon as possible.
# TODO: this breaks if index == -1
index_val = _parse_arg(index, 'i')
slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index_val], ends_i=[index_val + 1])
return g.op("Squeeze", slice_node, axes_i=[dim])
else:
return g.op("Gather", self, index, axis_i=dim)
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [_get_const(dim, 'i', 'dim')]
return g.op("Squeeze", self, axes_i=dims)
def prelu(g, self, weight):
return g.op("PRelu", self, weight)
def relu(g, input):
return g.op("Relu", input)
def ceil(g, input):
return g.op("Ceil", input)
def floor(g, input):
return g.op("Floor", input)
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
if _scalar(threshold) != 0:
return _unimplemented("threshold", "non-zero threshold")
if _scalar(value) != 0:
return _unimplemented("threshold", "non-zero value")
return g.op("Relu", self)
def leaky_relu(g, input, negative_slope, inplace=False):
negative_slope = _get_const(negative_slope, 't', 'negative_slope')
# See Note [Export inplace]
# TODO: Talk to ONNX about unconditional cast of scalar to float
return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
@parse_args('v', 'i')
def glu(g, input, dim):
assert input.type().sizes()[dim] % 2 == 0
first, second = g.op('Split', input, axis_i=dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))
@parse_args('v', 'i', 'i')
def softmax(g, input, dim, dtype=None):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
# If input is a 2 x 3 tensor:
# input = [[1.0, 1.0, 1.0],
# [1.0, 1,0, 1,0]]
# with dim = 0, the result is:
# result = [[0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5]]
# with axis = 0, the result is:
# result = [[0.167, 0.167, 0.167],
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
# So use softmax when dim and axis both equal to ndim - 1
# otherwise compute softmax using a subgraph with other operators
if input.type().kind() == "CompleteTensorType" or input.type().kind() == "DimensionedTensorType":
if dim < 0:
dim = input.type().dim() + dim
if input.type().dim() == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype:
softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
return softmax
exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype:
softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
return softmax
@parse_args('v', 't', 'v')
def softplus(g, self, beta, threshold):
if beta != 1:
return _unimplemented("beta", "has to be 1")
return g.op('Softplus', self)
def get_pool_ceil_padding(input, kernel_size, stride, padding):
dim = input.type().sizes()[-len(padding):]
ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1
for i in range(0, len(padding))]
# ensure last pooling starts inside
ceiled_output_dim = [ceiled_output_dim[i] - 1
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
else ceiled_output_dim[i]
for i in range(0, len(ceiled_output_dim))]
padding_ceil = [0
if (stride[i] == 1)
else
(kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)))
for i in range(0, len(padding))]
# ensure padding is not > kernel_size
padding_ceil = [(int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1))
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
else
int(padding_ceil[i])
for i in range(0, len(padding_ceil))]
return padding_ceil
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and input.type().kind() != "CompleteTensorType":
return _unimplemented("max_pool1d_with_indices", "input size not accesible")
if set(_single(dilation)) != {1}:
return _unimplemented("max_pool1d_with_indices", "dilation")
if stride is None:
stride = kernel_size
padding = tuple(_single(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
r, indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=_single(kernel_size),
pads_i=padding,
strides_i=_single(stride))
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and substract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1],
strides_i=[1])
# convert indices to have non-flattened indices values
s = g.op("Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1])
indices = sub(g, indices, s)
return r, indices
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and input.type().kind() != "CompleteTensorType":
return _unimplemented("max_pool2d_with_indices", "input size not accesible")
if set(_pair(dilation)) != {1}:
return _unimplemented("max_pool2d_with_indices", "dilation")
if not stride:
stride = kernel_size
padding = tuple(_pair(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
r, indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=_pair(kernel_size),
pads_i=padding,
strides_i=_pair(stride))
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened
# See comment in max_pool1d_with_indices for details.
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1, 1],
strides_i=[1, 1])
# convert indices to have non-flattened indices values
s = g.op("Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1])
indices = sub(g, indices, s)
return r, indices
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and input.type().kind() != "CompleteTensorType":
return _unimplemented("max_pool3d_with_indices", "input size not accesible")
if set(_triple(dilation)) != {1}:
return _unimplemented("max_pool3d_with_indices", "dilation")
if not stride:
stride = kernel_size
padding = tuple(_triple(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
r, indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=_triple(kernel_size),
pads_i=padding,
strides_i=_triple(stride))
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened
# See comment in max_pool1d_with_indices for details.
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1, 1, 1],
strides_i=[1, 1, 1])
# convert indices to have non-flattened indices values
s = g.op("Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1])
indices = sub(g, indices, s)
return r, indices
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
if ceil_mode and input.type().kind() != "CompleteTensorType":
return _unimplemented(name, "input size not accesible")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
if count_include_pad:
input = g.op("Pad", input,
pads_i=((0,) * 2 + padding) * 2,
mode_s='constant',
value_f=0.)
padding = (0,) * len(padding)
if ceil_mode:
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
output = g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
def _adaptive_pool(name, type, tuple_fn, fn=None):
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
# by executing a GlobalPool.
# It is also supported for cases where the output size is a factor of the input size.
# For these cases the stride and kernel size are uniform along all the indices of
# the same dimension, which makes it possible to export it to ONNX.
# for MaxPool, GlobalMaxPool does not return indices,
# so we try using max_poolxd_with_indices, and if it is not possible
# (input is not CompleteTensorType or output size not factor of input size)
# then we call GlobalAveragePool and return None for the indices
if output_size == [1] * len(output_size) and type == "AveragePool":
return g.op("GlobalAveragePool", input)
if input.type().kind() != "CompleteTensorType":
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, 'input size not accesible')
dim = input.type().sizes()[2:]
# verify if output size % input size = 0 for all dim
mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
if mod != [0] * len(mod):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, 'output size that are not factor of input size')
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == "MaxPool":
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
output = g.op(type, input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(k))
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', "AveragePool", _single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', "AveragePool", _pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', "AveragePool", _triple)
adaptive_max_pool1d = _adaptive_pool('adaptive_max_pool1d', "MaxPool", _single, max_pool1d_with_indices)
adaptive_max_pool2d = _adaptive_pool('adaptive_max_pool2d', "MaxPool", _pair, max_pool2d_with_indices)
adaptive_max_pool3d = _adaptive_pool('adaptive_max_pool3d', "MaxPool", _triple, max_pool3d_with_indices)
@parse_args('v', 'is', 'f')
def constant_pad_nd(g, input, padding, value):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "constant"
paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
@parse_args('v', 'is')
def reflection_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "reflect"
paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
@parse_args('v', 'is')
def replication_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "edge"
paddings = prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad
@parse_args('v', 'is')
def upsample_nearest2d(g, input, output_size):
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
return g.op("Upsample", input, scales,
mode_s="nearest")
@parse_args('v', 'is', 'i')
def upsample_bilinear2d(g, input, output_size, align_corners):
if align_corners:
return _unimplemented("upsample_bilinear2d", "align_corners == True")
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
width_scale]))
return g.op("Upsample", input, scales,
mode_s="linear")
def wrap_logical_op_with_cast_to_uint8(func):
def wrap_with_cast(g, input, other):
return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
return wrap_with_cast
def wrap_logical_op_with_negation(func):
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
@wrap_logical_op_with_cast_to_uint8
def eq(g, self, other):
return g.op("Equal", self, other)
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def ne(g, self, other):
return g.op("Equal", self, other)
@wrap_logical_op_with_cast_to_uint8
def gt(g, input, other):
return gt_impl(g, input, other)
def gt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
def lt(g, input, other):
return lt_impl(g, input, other)
def lt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Less", input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def ge(g, input, other):
other = _maybe_get_scalar(other)
return lt_impl(g, input, _if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def le(g, input, other):
other = _maybe_get_scalar(other)
return gt_impl(g, input, _if_scalar_type_as(g, other, input))
def where(g, condition, self, other):
return g.op("Where", condition, self, other)
@parse_args('v', 'i', 'i')
def log_softmax(g, input, dim=None, dtype=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
if dim < 0:
dim = input.type().dim() + dim
if input.type().dim() != dim + 1:
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype:
return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
return return_op
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
def _convolution(g, input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
weight_size = weight.type().sizes()
args = [input, weight]
# ONNX only supports 1D bias
if not bias.node().mustBeNone() and bias.type().dim() == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
"strides_i": stride,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": padding + padding,
"dilations_i": dilation,
"group_i": groups}
if any(o != 0 for o in output_padding):
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
# output_padding is more straightforward, so we use it here.
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
assert transposed
assert len(stride) == len(output_padding)
kwargs["output_padding_i"] = output_padding
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
if not bias.node().mustBeNone() and bias.type().dim() != 1:
return g.op("Add", n, bias)
else:
return n
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
input_sizes = input.type().sizes()
if len(input_sizes) == 2:
# batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
input = g.op("Unsqueeze", input, axes_i=[2])
if weight is None or weight.node().mustBeNone():
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().mustBeNone():
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not training else 5)
if not training:
if len(input_sizes) == 2:
out = g.op("Squeeze", out, axes_i=[2])
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setUniqueName("batch_norm_dead_output-" + saved_mean.uniqueName())
saved_var.setUniqueName("batch_norm_dead_output-" + saved_var.uniqueName())
if len(input_sizes) == 2:
res = g.op("Squeeze", res, axes_i=[2])
return res
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
input_sizes = input.type().sizes()
if weight is None or weight.node().mustBeNone():
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().mustBeNone():
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
@parse_args('v', 'i', 'i', 'i')
def unfold(g, input, dimension, size, step):
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
@parse_args('v', 'v', 'i')
def _weight_norm(graph, v, g, dim):
return graph.op("ATen", v, g, dim_i=dim, operator_s="_weight_norm")
@parse_args('v', 't', 't', 't')
def elu(g, input, alpha, scale, input_scale):
if scale and scale != 1.:
return _unimplemented("scale", "does not support scale in Elu")
if input_scale and input_scale != 1.:
return _unimplemented("input_scale", "does not support input_scale in Elu")
# See Note [Export inplace]
return g.op("Elu", input, alpha_f=_scalar(alpha))
def selu(g, input):
return g.op("Selu", input)
@parse_args('v', 'i', 'v')
def index_select(g, self, dim, index):
return g.op("Gather", self, index, axis_i=dim)
def index_put(g, self, indices_list_value, values, accumulate):
indices_list = _unpack_list(indices_list_value)
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s='index_put')
def type_as(g, self, other):
if self.isCompleteTensor() and other.isCompleteTensor() and self.type().scalarType() == other.type().scalarType():
return self
if other.isCompleteTensor():
other_type_name = other.type().scalarType()
return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
else:
# We don't know the type of other, bail by emitting ATen
return g.op("ATen", self, other, operator_s="type_as")
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
return g.op("ATen", self, weight, bias, normalized_shape_i=normalized_shape,
eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
# ignore clone operators that are inserted by PyTorch autograd
def clone(g, input):
return input
def abs(g, self):
return g.op("Abs", self)
def log(g, self):
return g.op("Log", self)
def pow(g, self, exponent):
exponent = _maybe_get_scalar(exponent)
return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
def clamp(g, self, min, max):
# min or max may be None that we need to dispatch to
# Clip separately, as ONNX does not have None syntax
if min.node().mustBeNone():
return clamp_max(g, self, max)
elif max.node().mustBeNone():
return clamp_min(g, self, min)
else:
min = _parse_arg(min, 'f')
max = _parse_arg(max, 'f')
return g.op("Clip", self, min_f=min, max_f=max)
@parse_args('v', 'f')
def clamp_min(g, self, min):
return g.op("Clip", self, min_f=min)
@parse_args('v', 'f')
def clamp_max(g, self, max):
return g.op("Clip", self, max_f=max)
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
def max(g, self, dim_or_y=None, keepdim=None):
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
if keepdim is None:
return g.op("Max", self, dim_or_y)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="max",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
def min(g, self, dim_or_y=None, keepdim=None):
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
if keepdim is None:
return g.op("Min", self, dim_or_y)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="min",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
def exp(g, self):
return g.op("Exp", self)
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
if not train: # in eval mode, dropout is non-op
return input
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
return r
def _unsupported_dropout(name):
@parse_args('v', 'f', 'i')
def feature_dropout(g, input, p, train):
# NB: In inference mode, FeatureDropout is exported as an identity op.
from torch.onnx.symbolic import _unimplemented
if train:
return _unimplemented(name, "training mode")
return input
return feature_dropout
feature_dropout = _unsupported_dropout("feature_dropout")
alpha_dropout = _unsupported_dropout("alpha_dropout")
feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
# See Note [Export inplace]
dropout_ = dropout
feature_dropout_ = feature_dropout
alpha_dropout_ = alpha_dropout
feature_alpha_dropout_ = feature_alpha_dropout
@parse_args('v', 't', 'i', 'i')
def norm(g, self, p, dim, keepdim):
if p == 1:
f = _reduce_op_symbolic("ReduceL1")
elif p == 2:
f = _reduce_op_symbolic("ReduceL2")
else:
raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
return f(g, self, dim=dim, keepdim=keepdim)
@parse_args('v', 'v', 'v', 'i')
def conv_tbc(g, input, weight, bias, pad):
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
@parse_args('v', 'i', 'i')
def _unique(g, input, sorted, return_inverse):
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
return_inverse_i=return_inverse, outputs=2)
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
}
scalar_name_to_pytorch = {
'uint8_t': 'Byte',
'int8_t': 'Char',
'double': 'Double',
'float': 'Float',
'half': 'Half',
'int': 'Int',
'int64_t': 'Long',
'int16_t': 'Short',
}
# This indicates each scalar type's corresponding
# torch type. Related source:
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
scalar_type_to_pytorch_type = [
torch.uint8, # 0
torch.int8, # 1
torch.short, # 2
torch.int, # 3
torch.int64, # 4
torch.half, # 5
torch.float, # 6
torch.double, # 7
]
def _cast_func_template(to_i, g, input, non_blocking):
return g.op("Cast", input, to_i=to_i)
for k, v in cast_pytorch_to_onnx.items():
name = '_cast_{}'.format(k)
globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
scalar_type_to_onnx = [
cast_pytorch_to_onnx["Byte"],
cast_pytorch_to_onnx["Char"],
cast_pytorch_to_onnx["Short"],
cast_pytorch_to_onnx["Int"],
cast_pytorch_to_onnx["Long"],
cast_pytorch_to_onnx["Half"],
cast_pytorch_to_onnx["Float"],
cast_pytorch_to_onnx["Double"],
]
@parse_args('v', 'i', 'v', 'v', 'b')
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device and layout in ONNX, so we ignore it
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
@parse_args('v', 'i', 'v', 'v', 'b')
def zeros_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
@parse_args('v', 'i', 'v', 'v', 'b')
def ones(g, sizes, dtype, layout, device, pin_memory=False):
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
@parse_args('v', 'i', 'v', 'v', 'b')
def ones_like(g, input, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
const_value = _maybe_get_const(value, 't')
if _is_value(const_value):
tmp = zeros(sizes, dtype, layout, device, pin_memory=pin_memory)
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = _get_const(dtype, 'i', 'dtype')
pin_memory = _get_const(pin_memory, 'b', 'pin_memory')
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
@parse_args('v', 'f', 'i', 'v', 'v', 'b')
def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False):
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype], pin_memory=pin_memory))
@parse_args('v', 'v', 'v', 'v', 'i')
def slice(g, self, dim, start, end, step):
if step != 1:
_unimplemented("slice", "step!=1 is currently not supported")
if start.node().kind() != 'onnx::Constant' or \
end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant':
start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0])
end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0])
dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0])
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
else:
start = _parse_arg(start, 'i')
end = _parse_arg(end, 'i')
dim = _parse_arg(dim, 'i')
return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
@parse_args('v', 'f', 'f')
def hardtanh(g, self, min_val, max_val):
return g.op("Clip", self, min_f=min_val, max_f=max_val)
def alias(g, self):
return self
@parse_args('v', 'i')
def unsqueeze(g, self, dim):
return g.op("Unsqueeze", self, axes_i=[dim])
@parse_args('v', 'i', 'i', 'i', 'i')
def topk(g, self, k, dim, largest, sorted, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported for topk")
if not largest:
_unimplemented("TopK", "Ascending TopK is not supported")
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
def to(g, self, *args):
# ONNX doesn't have a concept of a device, so we ignore device casts
if len(args) == 3:
if args[0].type().isSubtypeOf(ListType.ofInts()):
# aten::to(Tensor, Device, bool, bool)
return self
else:
# aten::to(Tensor, ScalarType, bool, bool)
dtype = _get_const(args[0], 'i', 'dtype')
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
elif len(args) == 4:
# aten::to(Tensor, Device, ScalarType, bool, bool)
dtype = _get_const(args[1], 'i', 'dtype')
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
elif len(args) == 5:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
dtype = _get_const(args[0], 'i', 'dtype')
# Layout and device are ignored
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool) -> Tensor
dtype = _get_const(args[0], 'i', 'dtype')
# Layout and device are ignored
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::to signature")
def repeat(g, self, repeats):
if not _is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
const_repeats = _maybe_get_const(repeats, 'is')
if self.isCompleteTensor() and not _is_value(const_repeats):
sizes = self.type().sizes()
diff_dims = len(const_repeats) - len(sizes)
if diff_dims > 0:
self = view(g, self, [1] * diff_dims + sizes)
return g.op("Tile", self, repeats)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = view(g, self, [-1, upscale_factor, upscale_factor,
output_channel, dims[2], dims[3]])
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return view(g, after_transpose,
[-1, output_channel, dims[2] * upscale_factor, dims[3] *
upscale_factor])
@parse_args('v', 'i', 'v', 'v', 'f', 'i')
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
weights_per_layer = 4 if has_biases else 2
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)]
if batch_first:
return _unimplemented("RNN/GRU/LSTM", "batch_first")
if dropout and train:
return _unimplemented("RNN/GRU/LSTM", "dropout in training mode")
if variant.startswith('RNN'):
nonlinearity = variant[4:].lower()
variant = 'RNN'
w_hh = all_weights[1]
hidden_size = w_hh.type().sizes()[1]
unidirectional = not bidirectional
prev_output = input
h_outs = []
if variant == 'RNN' or variant == 'GRU':
h0 = initial_states
elif variant == 'LSTM':
h0, c0 = initial_states
c_outs = []
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
if variant == 'GRU':
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
reform_permutation = [(1, 2), (0, 1), (2, 3)]
elif variant == 'LSTM':
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
reform_permutation = [(0, 1), (3, 4), (1, 3)]
def reform_weights(g, w, n, intervals):
slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals]
return g.op('Concat', *slices, axis_i=0)
def transform_weights(layer_index):
if variant == 'RNN':
weight_ih, weight_hh, bias_ih, bias_hh = layer_weights[layer_index]
elif variant == 'GRU' or variant == 'LSTM':
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat))
def retrieve_state(x, start, end):
return x if num_layers == 1 else g.op('Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
for i in range(num_layers):
if unidirectional:
weight_ih, weight_hh, bias_concat = transform_weights(i)
state_indices = i, i + 1
else:
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
bias_concat = g.op('Concat', bias_f, bias_b, axis_i=0)
state_indices = 2 * i, 2 * i + 2
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices))
if variant == 'LSTM':
inputs.append(retrieve_state(c0, *state_indices))
extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
if variant == 'RNN':
prev_output, h_out = g.op('RNN', *inputs, outputs=2,
hidden_size_i=hidden_size,
activations_s=[nonlinearity],
**extra_kwargs)
elif variant == 'GRU':
prev_output, h_out = g.op('GRU', *inputs, outputs=2,
hidden_size_i=hidden_size,
linear_before_reset_i=1,
**extra_kwargs)
elif variant == 'LSTM':
prev_output, h_out, c_out = g.op('LSTM', *inputs, outputs=3,
hidden_size_i=hidden_size,
**extra_kwargs)
if bidirectional:
# The ONNX RNN/GRU/LSTM produce an output of dimensions
# seq_len, num_directions, batch, hidden_size
# We have to convert to match pytorch's expected
# seq_len, batch, num_directions * hidden_size
# by first moving num_directions before hidden_size with
# Transpose, and then combining it with hidden_size
# with Reshape.
prev_output = g.op('Transpose', prev_output, perm_i=[0, 2, 1, 3])
prev_output = g.op('Reshape', prev_output, g.op('Constant', value_t=torch.LongTensor([0, 0, -1])))
else:
prev_output = g.op('Squeeze', prev_output, axes_i=[1])
h_outs.append(h_out)
if variant == 'LSTM':
c_outs.append(c_out)
h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
if variant == 'RNN' or variant == 'GRU':
return prev_output, h_outs
elif variant == 'LSTM':
c_outs = c_out if num_layers == 1 else g.op('Concat', *c_outs, axis_i=0)
return prev_output, h_outs, c_outs
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def lstm(g, *args):
if _is_tensor_list(args[3]):
return _lstm_packed(g, *args)
else:
return _lstm_full(g, *args)
def _one_hidden_rnn(kind):
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
weight = _unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
weight = _unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def symbolic(g, *args):
if _is_tensor_list(args[3]):
return _rnn_packed(g, *args)
else:
return _rnn_full(g, *args)
return symbolic
gru = _one_hidden_rnn('GRU')
rnn_tanh = _one_hidden_rnn('RNN_TANH')
rnn_relu = _one_hidden_rnn('RNN_RELU')
@parse_args('v', 'i')
def _dim_arange(g, like, dim):
return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')
def detach(g, input):
# Erase aten::detach nodes because ONNX is inference only
return input
def contiguous(g, input):
return input
@parse_args('v', 'v', 'i')
def _pack_padded_sequence(g, input, lengths, batch_first):
# There currently is no PackPadded operator in ONNX. We rely on an
# optimization pass to remove this later. It is an error if all
# PackPadded operators cannot be optimized out.
if batch_first:
input = g.op('Transpose', input, perm_i=[1, 0, 2])
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
raise RuntimeError("Lengths must be a Tensor for ONNX export")
# We know it's a TensorType so this check is now safe.
# It's really only necessary because those operators expand to something that
# only works with int32 types in Caffe2...
if lengths.type().scalarType() != 'Int':
lengths = _cast_Int(g, lengths, False)
return g.op("prim::PackPadded", input, lengths, outputs=2)
@parse_args('v', 'v', 'i', 't', 'v')
def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
# It is only useful/used when training using data_parallel model, so
# It shouldn't be relevant for ONNX anyway
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
if batch_first:
data = g.op('Transpose', data, perm_i=[1, 0, 2])
return data, lengths
def randn(g, *shapes):
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomNormal', shape_i=shape)
@parse_args('v', 'f', 'f', 'i', 'none')
def rrelu(g, input, lower, upper, training, generator):
p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
return g.op('PRelu', input, p)
@parse_args('v')
def log_sigmoid(g, input):
p = g.op('Sigmoid', input)
return g.op('Log', p)
@parse_args('v')
def erf(g, input):
return g.op('Erf', input)
@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
if end_dim < 0 :
end_dim = dim + end_dim
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1 and end_dim == dim - 1 :
return g.op("Flatten", input, axis_i=start_dim)
if start_dim == 0 and end_dim == dim - 2 :
return g.op("Flatten", input, axis_i=end_dim + 1)
# use Reshape for cases where the output shape is not 2D
if input.type().kind() != "CompleteTensorType":
return _unimplemented("flatten", "input size not accesible")
input_dims = input.type().sizes()
output_dims = []
for i in range(0, dim):
if start_dim < i and end_dim >= i:
output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
else:
output_dims.append(input_dims[i])
shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
p = _reshape_from_tensor(g, input, shape)
return p
@parse_args('v')
def nonzero(g, input):
return t(g, g.op('NonZero', input))
@parse_args('v')
def isnan(g, input):
output = g.op('IsNaN', input)
output = _cast_func_template(cast_pytorch_to_onnx['Byte'], g, output, None)
return output
@parse_args('v', 'i', 'i', 'i')
def narrow(g, input, dim, start, length):
return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
def argmax(g, input, dim, keepdim):
if dim.node().mustBeNone():
flattened = reshape(g, input, (-1,))
return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False)
else:
dim = _parse_arg(dim, 'i')
keepdim = _parse_arg(keepdim, 'i')
return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
def argmin(g, input, dim, keepdim):
if dim.node().mustBeNone():
flattened = reshape(g, input, (-1,))
return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False)
else:
dim = _parse_arg(dim, 'i')
keepdim = _parse_arg(keepdim, 'i')
return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
|