From 6252706feb23e448bff944c3505092b18718d3ab Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Tue, 17 Apr 2018 21:23:27 -0700 Subject: [Caffe2] Workspace centric API for TensorRT transformation (#6678) * Workspace centric API for trt transformation * Merge SSA rewrite code --- caffe2/contrib/tensorrt/tensorrt_tranformer.cc | 234 ++++++++++++++++++------- caffe2/contrib/tensorrt/tensorrt_tranformer.h | 8 +- caffe2/core/workspace.h | 2 +- caffe2/onnx/backend.cc | 2 +- caffe2/onnx/onnx_exporter.cc | 33 +++- caffe2/onnx/onnx_exporter.h | 9 +- caffe2/onnx/ssa_test.cc | 6 +- caffe2/python/pybind_state.cc | 4 + caffe2/python/pybind_state.h | 3 + caffe2/python/pybind_state_gpu.cc | 22 +-- caffe2/python/trt/test_trt.py | 34 ++-- caffe2/python/trt/transform.py | 40 ++--- 12 files changed, 258 insertions(+), 139 deletions(-) diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc index 2ceb9bc2ec..5cf610e35a 100644 --- a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc +++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc @@ -1,40 +1,76 @@ #include "caffe2/contrib/tensorrt/tensorrt_tranformer.h" + +#include +#include + #include +#include #include + #include "caffe2/contrib/tensorrt/trt_utils.h" +#include "caffe2/core/context_gpu.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/onnx/onnx_exporter.h" -#include -#include -#include - namespace caffe2 { namespace { // TODO(yinghai): Remove the awkward conversion between unordered_map and map std::unordered_map InferShapes( - NetDef* init_net, + Workspace* ws, NetDef* pred_net, - const std::unordered_map& input_shape_hints) { - CaffeMap shape_hints_ordered; - for (const auto& kv : input_shape_hints) { - shape_hints_ordered.emplace(kv.first, kv.second); + CaffeMap* shape_hints_ordered) { + + // Populate shapes from workplace + const std::vector& ws_blobs = ws->Blobs(); + for (const auto& s : ws_blobs) { + shape_hints_ordered->emplace(s, GetTensorShapeOfBlob(ws->GetBlob(s))); } + std::vector nets; - nets.emplace_back(init_net); nets.emplace_back(pred_net); - InferBlobShapesAndTypes(shape_hints_ordered, nets); + InferBlobShapesAndTypes(*shape_hints_ordered, nets); std::unordered_map shape_hints; - for (const auto& kv : shape_hints_ordered) { + for (const auto& kv : *shape_hints_ordered) { shape_hints.emplace(kv.first, kv.second); } return shape_hints; } +CaffeMap SsaRewriteAndMapIO( + Workspace* ws, + NetDef* pred_net, + const std::unordered_map& input_shape_hints) { + std::unordered_map input_mapping = + onnx::SsaRewrite(nullptr, pred_net); + std::unordered_map input_reverse_mapping; + for (const auto kv : input_mapping) { + input_reverse_mapping.emplace(kv.second, kv.first); + try { + if (ws->GetBlob(kv.second)) { + ws->RenameBlob(kv.second, kv.first); + } + } catch (const EnforceNotMet& e) { + LOG(WARNING) << "Cannot rename blob " << kv.second << " to " << kv.first + << ": " << e.what(); + } + } + CaffeMap shape_hints_ordered; + for (const auto& kv : input_shape_hints) { + const auto it = input_reverse_mapping.find(kv.first); + if (it != input_reverse_mapping.end()) { + LOG(INFO) << "Adding input hint: " << it->second; + shape_hints_ordered.emplace(it->second, kv.second); + } else { + shape_hints_ordered.emplace(kv.first, kv.second); + } + } + return shape_hints_ordered; +} + // Figuring out the input the tensorrt runnable subgraph // `start` and `end` defines the continuous chunk of ops that can be readily // converted into an TensorRT op. And this function tries to figure out what's @@ -123,6 +159,94 @@ FigureOutputs(const NetDef& pred_net, int start, int end) { return all_outputs_vec; } +void CPUTensorToTensorProto( + const TensorCPU& cpu_tensor, + ::ONNX_NAMESPACE::TensorProto* t) { + const auto len = cpu_tensor.size(); + if (cpu_tensor.template IsType()) { + t->set_data_type(::ONNX_NAMESPACE::TensorProto::FLOAT); + const float* data = cpu_tensor.template data(); + for (auto i = 0; i < len; ++i) { + t->add_float_data(*data++); + } + } else if (cpu_tensor.template IsType()) { + t->set_data_type(::ONNX_NAMESPACE::TensorProto::INT64); + const int64_t* data = cpu_tensor.template data(); + for (auto i = 0; i < len; ++i) { + t->add_int64_data(*data++); + } + } else if (cpu_tensor.template IsType()) { + t->set_data_type(::ONNX_NAMESPACE::TensorProto::INT32); + const int32_t* data = cpu_tensor.template data(); + for (auto i = 0; i < len; ++i) { + t->add_int32_data(*data++); + } + } else { + CAFFE_THROW( + "Don't know how to convert workspace tensor type ", + cpu_tensor.meta().name(), + " to ONNX TensorProto"); + } +} + +void BlobToTensorProto( + const std::string& name, + Workspace* ws, + CUDAContext* context, + ::ONNX_NAMESPACE::TensorProto* t) { + // Set name + t->set_name(name); + const Blob* blob = ws->GetBlob(name); + CAFFE_ENFORCE(blob, "Blob ", name, " doesn't exist"); + + // Set dims + const auto shape = GetTensorShapeOfBlob(blob); + for (const auto i : shape.dims()) { + t->add_dims(i); + } + + // Set values + if (blob->template IsType()) { + const auto& cpu_tensor = blob->template Get(); + CPUTensorToTensorProto(cpu_tensor, t); + } else if (blob->template IsType()) { + const auto& cuda_tensor = blob->template Get(); + const auto cpu_tensor = TensorCPU(cuda_tensor, context); + context->FinishDeviceComputation(); + CPUTensorToTensorProto(cpu_tensor, t); + } else { + CAFFE_THROW( + "Initialization blob ", + name, + " needs to be either TensorCPU or TensorCUDA"); + } +} + +void BuildInitializationList( + Workspace* ws, + ::ONNX_NAMESPACE::GraphProto* g, + std::unordered_set* initialization_list) { + const std::vector& ws_blobs = ws->Blobs(); + + // Create a CUDA context and reuse it for potential tensor copies across + // devices + CUDAContext context; + + for (const auto& s : ws_blobs) { + auto it = initialization_list->find(s); + if (it != initialization_list->end()) { + auto* init_tensor = g->add_initializer(); + BlobToTensorProto(s, ws, &context, init_tensor); + initialization_list->erase(it); + } + } + CAFFE_ENFORCE( + initialization_list->empty(), "Unfulfilled initialization list"); + for (const auto& t : g->initializer()) { + VLOG(1) << "Initializer: " << t.name(); + } +} + std::vector<::ONNX_NAMESPACE::ValueInfoProto> ConvertToValueInfo( const std::vector& names, const std::unordered_map& shape_hints) { @@ -148,7 +272,7 @@ std::vector<::ONNX_NAMESPACE::ValueInfoProto> ConvertToValueInfo( return r; } -void PruneUsedWeights(const NetDef& pred_net, NetDef* init_net) { +void PruneUsedWeights(Workspace* ws, const NetDef& pred_net) { std::unordered_set used_weights; for (const auto& op : pred_net.op()) { for (const auto& i : op.input()) { @@ -156,22 +280,8 @@ void PruneUsedWeights(const NetDef& pred_net, NetDef* init_net) { } } - int last = init_net->op_size(); - for (int i = 0; i < last;) { - if (!used_weights.count(init_net->op(i).output(0))) { - if (i != last - 1) { - init_net->mutable_op()->SwapElements(i, last - 1); - } else { - ++i; - } - --last; - } else { - ++i; - } - } - - if (last < init_net->op_size()) { - init_net->mutable_op()->DeleteSubrange(last, init_net->op_size() - last); + for (const auto& w : used_weights) { + ws->RemoveBlob(w); } } @@ -264,7 +374,7 @@ OperatorDef TensorRTTransformer::BuildTrtOp( } void TensorRTTransformer::ClusterToTrtOp( - const NetDef& init_net, + Workspace* ws, const NetDef& pred_net, int start, int end, @@ -318,19 +428,22 @@ void TensorRTTransformer::ClusterToTrtOp( } // Convert weights to initializing tensors - onnx::OnnxExporter exporter; - for (const auto& op : init_net.op()) { - CAFFE_ENFORCE_EQ(op.output_size(), 1); - auto it = initialization_list.find(op.output(0)); - if (it != initialization_list.end()) { - auto* init_tensor = model->mutable_graph()->add_initializer(); - exporter.InitOpToTensorProto(op, init_tensor); - initialization_list.erase(it); + BuildInitializationList(ws, model->mutable_graph(), &initialization_list); + + if (debug_builder_) { + std::ofstream ff("trt.onnx"); + for (const auto& t : model->graph().initializer()) { + ff << "tensor: " << t.name() << std::endl; + ff << " dims: "; + for (auto i : t.dims()) { + ff << i << " "; + } + ff << std::endl; + for (auto i : t.float_data()) { + ff << " " << i << std::endl; + } } - } - CAFFE_ENFORCE(initialization_list.empty(), "Unfulfilled initialization list"); - for (const auto& t : model->graph().initializer()) { - VLOG(1) << "Initializer: " << t.name(); + ff.close(); } // Onnx model is ready. Call onnx-trt to convert to one trt c2 op @@ -345,24 +458,22 @@ void TensorRTTransformer::ClusterToTrtOp( // Cutting off the runnable part and replace with tensor ops. Asssume the nets // were topologically sorted void TensorRTTransformer::Transform( - NetDef* init_net, + Workspace* ws, NetDef* pred_net, const std::unordered_map& input_shape_hints) { - auto shape_hints = InferShapes(init_net, pred_net, input_shape_hints); + CAFFE_ENFORCE(ws); + + auto shape_hints_ordered = + SsaRewriteAndMapIO(ws, pred_net, input_shape_hints); + auto shape_hints = InferShapes(ws, pred_net, &shape_hints_ordered); std::unordered_set weights; - if (init_net) { - for (const auto& op : init_net->op()) { - CAFFE_ENFORCE_EQ(op.type().find("GivenTensor"), 0); - CAFFE_ENFORCE_EQ(op.type().rfind("Fill"), op.type().size() - 4); - CAFFE_ENFORCE_EQ(op.output_size(), 1); - for (const auto& op_output : op.output()) { - weights.emplace(op_output); - } - } + const std::vector& ws_blobs = ws->Blobs(); + for (const auto& s : ws_blobs) { + weights.emplace(s); } - CAFFE_ENFORCE(pred_net, "pred_net cannot be nullptr"); + CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr"); ::ONNX_NAMESPACE::ModelProto onnx_model; FillModelInfo(&onnx_model); @@ -372,7 +483,7 @@ void TensorRTTransformer::Transform( int op_idx = 0; int start = 0; int end = 0; - onnx::OnnxExporter exporter(true); + onnx::OnnxExporter exporter(nullptr, true); for (const OperatorDef& op : pred_net->op()) { bool support_trt = true; const OpSchema* schema = OpSchemaRegistry::Schema(op.type()); @@ -428,7 +539,7 @@ void TensorRTTransformer::Transform( } else { end = op_idx; ClusterToTrtOp( - *init_net, + ws, *pred_net, start, end, @@ -444,14 +555,7 @@ void TensorRTTransformer::Transform( if (trt_group) { end = op_idx; ClusterToTrtOp( - *init_net, - *pred_net, - start, - end, - weights, - shape_hints, - &onnx_model, - &new_ops); + ws, *pred_net, start, end, weights, shape_hints, &onnx_model, &new_ops); trt_group = false; } @@ -459,7 +563,7 @@ void TensorRTTransformer::Transform( for (const auto& op : new_ops) { pred_net->add_op()->CopyFrom(op); } - PruneUsedWeights(*pred_net, init_net); + PruneUsedWeights(ws, *pred_net); } } // namespace caffe2 diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.h b/caffe2/contrib/tensorrt/tensorrt_tranformer.h index fd69fca131..2308b4e23a 100644 --- a/caffe2/contrib/tensorrt/tensorrt_tranformer.h +++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.h @@ -2,6 +2,7 @@ #include "caffe2/core/common.h" #include "caffe2/core/operator.h" +#include "caffe2/core/workspace.h" #include "caffe2/proto/caffe2.pb.h" #include "onnx/onnx_pb.h" @@ -29,13 +30,13 @@ class TensorRTTransformer { output_size_hints); void Transform( - NetDef* init_net, + Workspace* ws, NetDef* pred_net, const std::unordered_map& shape_hints); private: void ClusterToTrtOp( - const NetDef& init_net, + Workspace* ws, const NetDef& pred_net, int start, int end, @@ -44,10 +45,11 @@ class TensorRTTransformer { ::ONNX_NAMESPACE::ModelProto* model, std::vector* new_ops); + // TensorRT params size_t max_batch_size_{50}; size_t max_workspace_size_{1024 * 1024 * 2}; int verbosity_{2}; - bool debug_builder_{true}; + bool debug_builder_{false}; }; } // namespace caffe2 diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h index c6bd638de4..19ed302435 100644 --- a/caffe2/core/workspace.h +++ b/caffe2/core/workspace.h @@ -219,7 +219,7 @@ class Workspace { /** * Renames a local workspace blob. If blob is not found in the local blob list * or if the target name is already present in local or any parent blob list - * the function will through. + * the function will throw. */ Blob* RenameBlob(const string& old_name, const string& new_name); diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 561c13d081..f835bece40 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -405,7 +405,7 @@ Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) { "' dtype is not supported"); CAFFE_ENFORCE_EQ( - c2_op.ops[0].arg().size(), + c2_op.ops.Get(0).arg().size(), 1, "Unexpected number of attributes in 'Cast'"); c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype); diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index 58f4dafe69..e1c1cc98c7 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -90,12 +90,10 @@ std::string SsaName(const std::string& n, int version) { } } // namespace -std::pair< - std::unordered_map, - std::unordered_map> -SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { +std::unordered_map SsaRewrite( + caffe2::NetDef* init_net, + caffe2::NetDef* pred_net) { std::unordered_map input_mapping; - std::unordered_map output_mapping; std::unordered_map blob_versions; #define REWRITE_EXTERNAL_IO(net, name) \ @@ -121,7 +119,6 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { blob_versions.emplace(output, 0); } REWRITE_EXTERNAL_IO(init_net, input); - REWRITE_EXTERNAL_IO(init_net, output); blob_versions.clear(); } @@ -151,11 +148,31 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) { } } } - REWRITE_EXTERNAL_IO(pred_net, output); + + // Fix the external output name back to original + std::unordered_set external_outputs; + for (const auto& output : pred_net->external_output()) { + external_outputs.emplace(output); + } + for (auto& op : *pred_net->mutable_op()) { + for (auto& output : *op.mutable_output()) { + auto pos = output.find_last_of('_'); + CAFFE_ENFORCE_NE(pos, 0); + auto basename = output.substr(0, pos); + if (!external_outputs.count(basename)) { + continue; + } + auto it = blob_versions.find(basename); + if (it != blob_versions.end() && + SsaName(basename, it->second) == output) { + output = basename; + } + } + } } #undef REWRITE_EXTERNAL_IO - return std::make_pair(std::move(input_mapping), std::move(output_mapping)); + return input_mapping; } const std::unordered_map& diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h index 8c997c065a..dd11006915 100644 --- a/caffe2/onnx/onnx_exporter.h +++ b/caffe2/onnx/onnx_exporter.h @@ -23,10 +23,11 @@ using ::ONNX_NAMESPACE::TensorProto; using ConvertedResult = std::pair, std::vector>; -std::pair< - std::unordered_map, - std::unordered_map> -SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net); +// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external +// output names for predict net. +std::unordered_map SsaRewrite( + caffe2::NetDef* init_net, + caffe2::NetDef* pred_net); class OnnxExporter { using SpecialOpConverter = ConvertedResult (OnnxExporter::*)( diff --git a/caffe2/onnx/ssa_test.cc b/caffe2/onnx/ssa_test.cc index db3ba451b8..9f83b8f486 100644 --- a/caffe2/onnx/ssa_test.cc +++ b/caffe2/onnx/ssa_test.cc @@ -22,9 +22,7 @@ TEST(SsaTest, ConvReluInplace) { net.add_external_input("X"); net.add_external_output("Y"); - std::unordered_map input_mapping; - std::unordered_map output_mapping; - std::tie(input_mapping, output_mapping) = + std::unordered_map input_mapping = caffe2::onnx::SsaRewrite(nullptr, &net); for (const auto& op : net.op()) { std::unordered_set inputs; @@ -37,5 +35,5 @@ TEST(SsaTest, ConvReluInplace) { } EXPECT_EQ(net.op(0).output(0), net.op(1).input(0)); EXPECT_EQ("X", input_mapping.at(net.external_input(0))); - EXPECT_EQ("Y", output_mapping.at(net.external_output(0))); + EXPECT_EQ("Y", net.external_output(0)); } diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 32e0a86603..7f0b0c81d8 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -55,6 +55,10 @@ CAFFE_DEFINE_TYPED_REGISTRY( REGISTER_BLOB_FETCHER((TypeMeta::Id()), TensorFetcher); REGISTER_BLOB_FEEDER(CPU, TensorFeeder); +Workspace* GetCurrentWorkspace() { + return gWorkspace; +} + class StringFetcher : public BlobFetcherBase { public: py::object Fetch(const Blob& blob) override { diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index 65882d9f15..957601fc8e 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -40,6 +40,9 @@ void addGlobalMethods(pybind11::module& m); // Expose Workspace, Net, Blob void addObjectMethods(pybind11::module& m); +// Get current workspace +Workspace* GetCurrentWorkspace(); + class BlobFetcherBase { public: struct FetchedBlob { diff --git a/caffe2/python/pybind_state_gpu.cc b/caffe2/python/pybind_state_gpu.cc index 69d69d0ed8..14c4c48af1 100644 --- a/caffe2/python/pybind_state_gpu.cc +++ b/caffe2/python/pybind_state_gpu.cc @@ -77,37 +77,29 @@ void addCUDAGlobalMethods(py::module& m) { }); m.def( "transform_trt", - [](const py::bytes& init_net_str, - const py::bytes& pred_net_str, + [](const py::bytes& pred_net_str, const std::unordered_map>& shapes, int max_batch_size, int max_workspace_size, int verbosity, - bool debug_builder) -> std::vector { + bool debug_builder) -> py::bytes { #ifdef CAFFE2_USE_TRT - caffe2::NetDef init_net; - if(!ParseProtoFromLargeString( - init_net_str.cast(), &init_net)) { - LOG(ERROR) << "broken init_net protobuf"; - } caffe2::NetDef pred_net; - if(!ParseProtoFromLargeString( - pred_net_str.cast(), &pred_net)) { + if (!ParseProtoFromLargeString( + pred_net_str.cast(), &pred_net)) { LOG(ERROR) << "broken pred_net protobuf"; } std::unordered_map tensor_shapes; - for (const auto& it: shapes) { + for (const auto& it : shapes) { tensor_shapes.emplace( it.first, CreateTensorShape(it.second, TensorProto::FLOAT)); } TensorRTTransformer ts( max_batch_size, max_workspace_size, verbosity, debug_builder); - ts.Transform(&init_net, &pred_net, tensor_shapes); - std::string init_net_str2; + ts.Transform(GetCurrentWorkspace(), &pred_net, tensor_shapes); std::string pred_net_str2; - init_net.SerializeToString(&init_net_str2); pred_net.SerializeToString(&pred_net_str2); - return {py::bytes(init_net_str2), py::bytes(pred_net_str2)}; + return py::bytes(pred_net_str2); #else CAFFE_THROW("Please build Caffe2 with USE_TENSORRT=1"); #endif // CAFFE2_USE_TRT diff --git a/caffe2/python/trt/test_trt.py b/caffe2/python/trt/test_trt.py index 7fbb2d4719..ca359f8df2 100644 --- a/caffe2/python/trt/test_trt.py +++ b/caffe2/python/trt/test_trt.py @@ -155,7 +155,7 @@ class TensorRTTransformTest(TestCase): downloadFromURLToFile(url, dest, show_progress=False) except TypeError: - # show_progress not supported prior to +# show_progress not supported prior to # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 # (Sep 17, 2017) downloadFromURLToFile(url, dest) @@ -230,44 +230,42 @@ class TensorRTTransformTest(TestCase): Y_c2 = None data = np.random.randn(*input_blob_dims).astype(np.float32) c2_time = 1 - ws = Workspace() + workspace.SwitchWorkspace("gpu_test", True) with core.DeviceScope(device_option): - ws.FeedBlob(input_name, data) - ws.RunNetOnce(init_net) - ws.CreateNet(pred_net) + workspace.FeedBlob(input_name, data) + workspace.RunNetOnce(init_net) + workspace.CreateNet(pred_net) for _ in range(warmup): - ws.RunNet(pred_net.name) + workspace.RunNet(pred_net.name) start = time.time() for _ in range(repeat): - ws.RunNet(pred_net.name) + workspace.RunNet(pred_net.name) end = time.time() c2_time = end - start - output_values = [ws.FetchBlob(name) for name in net_outputs] + output_values = [workspace.FetchBlob(name) for name in net_outputs] Y_c2 = namedtupledict('Outputs', net_outputs)(*output_values) - ws.ResetWorkspace() + workspace.ResetWorkspace() # Cut the graph - init_net_cut, pred_net_cut = transform_caffe2_net(init_net, pred_net, {input_name: input_blob_dims}) + pred_net_cut = transform_caffe2_net(device_option, init_net, pred_net, {input_name: input_blob_dims}) del init_net, pred_net - #print_net(pred_net_cut) + #_print_net(pred_net_cut) Y_trt = None input_name = pred_net_cut.external_input[0] print("C2 runtime: {}s".format(c2_time)) - ws = Workspace() with core.DeviceScope(device_option): - ws.FeedBlob(input_name, data) - ws.RunNetOnce(init_net_cut) - ws.CreateNet(pred_net_cut) + workspace.FeedBlob(input_name, data) + workspace.CreateNet(pred_net_cut) for _ in range(warmup): - ws.RunNet(pred_net_cut.name) + workspace.RunNet(pred_net_cut.name) start = time.time() for _ in range(repeat): - ws.RunNet(pred_net_cut.name) + workspace.RunNet(pred_net_cut.name) end = time.time() trt_time = end - start print("TRT runtime: {}s, improvement: {}%".format(trt_time, (c2_time-trt_time)/c2_time*100)) - output_values = [ws.FetchBlob(name) for name in net_outputs] + output_values = [workspace.FetchBlob(name) for name in net_outputs] Y_trt = namedtupledict('Outputs', net_outputs)(*output_values) np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3) diff --git a/caffe2/python/trt/transform.py b/caffe2/python/trt/transform.py index b2b3a41c5b..28e973b077 100644 --- a/caffe2/python/trt/transform.py +++ b/caffe2/python/trt/transform.py @@ -13,6 +13,7 @@ from __future__ import unicode_literals from caffe2.proto import caffe2_pb2 from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op +from caffe2.python import core, workspace import caffe2.python.onnx.frontend as c2_front import caffe2.python._import_c_extension as C import numpy as np @@ -69,10 +70,9 @@ def _infer_shapes(init_net, pred_net, inputs): return hints -def _ssa_rewrite_input(i): - return i + "_0"; - -def transform_caffe2_net(init_net, +def transform_caffe2_net( + device_option, + init_net, pred_net, input_shapes, populate_shapes = False, @@ -84,29 +84,29 @@ def transform_caffe2_net(init_net, Transfrom the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops """ check_gpu_() - c2_front.ssa_rewrite(pred_net, init_net, value_info=[]) - input_data = {} - for k,v in input_shapes.iteritems(): - input_data[_ssa_rewrite_input(k)] = np.random.randn(*v).astype(np.float32) + + # Fill the workspace with the weights + with core.DeviceScope(device_option): + workspace.RunNetOnce(init_net) # Hacky way to infer shapes as not all our operators have shape inference function. # Normally this is not needed + shape_hints = {} if populate_shapes: + input_data = {} + for k,v in input_shapes.iteritems(): + input_data[k] = np.random.randn(*v).astype(np.float32) shape_hints = _infer_shapes(init_net, pred_net, input_data) - shape_hints = {} for k,v in input_shapes.iteritems(): - shape_hints[_ssa_rewrite_input(k)] = v - init_net_str, pred_net_str = C.transform_trt(init_net.SerializeToString(), - pred_net.SerializeToString(), - shape_hints, - max_batch_size, - max_workspace_size, - verbosity, - debug_builder) - init_net_cut = caffe2_pb2.NetDef() - init_net_cut.ParseFromString(init_net_str) + shape_hints[k] = v + pred_net_str = C.transform_trt(pred_net.SerializeToString(), + shape_hints, + max_batch_size, + max_workspace_size, + verbosity, + debug_builder) pred_net_cut = caffe2_pb2.NetDef() pred_net_cut.ParseFromString(pred_net_str) - return init_net_cut, pred_net_cut + return pred_net_cut -- cgit v1.2.3