1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
|
## @package onnx
#Module caffe2.python.trt.transform
"""
TensorRT related transformation
Note that ONNX-TRT enforce an NCHW input!
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2
from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op
import caffe2.python.onnx.frontend as c2_front
import caffe2.python._import_c_extension as C
import numpy as np
def _dim_values_to_list(dim_values):
return [x.dim_value for x in dim_values]
def _get_output_shapes(output_value_infos):
names = [x.name for x in output_value_infos]
shapes = [_dim_values_to_list(x.type.tensor_type.shape.dim) for x in output_value_infos]
return dict(zip(names, shapes))
def check_gpu_():
try:
C.get_cuda_version()
except Exception as _:
raise Exception("TensorRT related functions require CUDA support")
def convert_onnx_model_to_trt_op(onnx_model,
max_batch_size=50,
max_workspace_size=2*1024*1024,
verbosity=1,
debug_builder=False):
"""
Convert the whole ONNX model to a TensorRT C2 op
"""
check_gpu_()
trt_str = C.onnx_to_trt_op(onnx_model.SerializeToString(),
_get_output_shapes(onnx_model.graph.output),
max_batch_size,
max_workspace_size,
verbosity,
debug_builder)
op = caffe2_pb2.OperatorDef()
op.ParseFromString(trt_str)
return op
def _infer_shapes(init_net, pred_net, inputs):
ws, outputs = c2_native_run_net(init_net, pred_net, inputs)
hints = {}
for op in pred_net.op:
for o in op.output:
if o not in hints:
blob = ws.FetchBlob(o)
if hasattr(blob, 'shape'):
hints[o] = blob.shape
for i in op.input:
if i not in hints:
blob = ws.FetchBlob(i)
if hasattr(blob, 'shape'):
hints[i] = blob.shape
return hints
def _ssa_rewrite_input(i):
return i + "_0";
def transform_caffe2_net(init_net,
pred_net,
input_shapes,
populate_shapes = False,
max_batch_size=50,
max_workspace_size=2*1024*1024,
verbosity=1,
debug_builder=False):
"""
Transfrom the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops
"""
check_gpu_()
c2_front.ssa_rewrite(pred_net, init_net, value_info=[])
input_data = {}
for k,v in input_shapes.iteritems():
input_data[_ssa_rewrite_input(k)] = np.random.randn(*v).astype(np.float32)
# Hacky way to infer shapes as not all our operators have shape inference function.
# Normally this is not needed
if populate_shapes:
shape_hints = _infer_shapes(init_net, pred_net, input_data)
shape_hints = {}
for k,v in input_shapes.iteritems():
shape_hints[_ssa_rewrite_input(k)] = v
init_net_str, pred_net_str = C.transform_trt(init_net.SerializeToString(),
pred_net.SerializeToString(),
shape_hints,
max_batch_size,
max_workspace_size,
verbosity,
debug_builder)
init_net_cut = caffe2_pb2.NetDef()
init_net_cut.ParseFromString(init_net_str)
pred_net_cut = caffe2_pb2.NetDef()
pred_net_cut.ParseFromString(pred_net_str)
return init_net_cut, pred_net_cut
|