summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYinghai Lu <yinghai@fb.com>2018-04-17 21:23:27 -0700
committerGitHub <noreply@github.com>2018-04-17 21:23:27 -0700
commit6252706feb23e448bff944c3505092b18718d3ab (patch)
tree199e070bbbfc0a84287d3f4eb990dace343a8c48
parentdc94182db0f320f09894a99eb169fe83631e8440 (diff)
downloadpytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.gz
pytorch-6252706feb23e448bff944c3505092b18718d3ab.tar.bz2
pytorch-6252706feb23e448bff944c3505092b18718d3ab.zip
[Caffe2] Workspace centric API for TensorRT transformation (#6678)
* Workspace centric API for trt transformation * Merge SSA rewrite code
-rw-r--r--caffe2/contrib/tensorrt/tensorrt_tranformer.cc234
-rw-r--r--caffe2/contrib/tensorrt/tensorrt_tranformer.h8
-rw-r--r--caffe2/core/workspace.h2
-rw-r--r--caffe2/onnx/backend.cc2
-rw-r--r--caffe2/onnx/onnx_exporter.cc33
-rw-r--r--caffe2/onnx/onnx_exporter.h9
-rw-r--r--caffe2/onnx/ssa_test.cc6
-rw-r--r--caffe2/python/pybind_state.cc4
-rw-r--r--caffe2/python/pybind_state.h3
-rw-r--r--caffe2/python/pybind_state_gpu.cc22
-rw-r--r--caffe2/python/trt/test_trt.py34
-rw-r--r--caffe2/python/trt/transform.py40
12 files changed, 258 insertions, 139 deletions
diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
index 2ceb9bc2ec..5cf610e35a 100644
--- a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
+++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc
@@ -1,40 +1,76 @@
#include "caffe2/contrib/tensorrt/tensorrt_tranformer.h"
+
+#include <iostream>
+#include <unordered_set>
+
#include <NvInfer.h>
+#include <google/protobuf/text_format.h>
#include <onnx2trt.hpp>
+
#include "caffe2/contrib/tensorrt/trt_utils.h"
+#include "caffe2/core/context_gpu.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/onnx/onnx_exporter.h"
-#include <google/protobuf/text_format.h>
-#include <iostream>
-#include <unordered_set>
-
namespace caffe2 {
namespace {
// TODO(yinghai): Remove the awkward conversion between unordered_map and map
std::unordered_map<std::string, TensorShape> InferShapes(
- NetDef* init_net,
+ Workspace* ws,
NetDef* pred_net,
- const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
- CaffeMap<std::string, TensorShape> shape_hints_ordered;
- for (const auto& kv : input_shape_hints) {
- shape_hints_ordered.emplace(kv.first, kv.second);
+ CaffeMap<std::string, TensorShape>* shape_hints_ordered) {
+
+ // Populate shapes from workplace
+ const std::vector<string>& ws_blobs = ws->Blobs();
+ for (const auto& s : ws_blobs) {
+ shape_hints_ordered->emplace(s, GetTensorShapeOfBlob(ws->GetBlob(s)));
}
+
std::vector<NetDef*> nets;
- nets.emplace_back(init_net);
nets.emplace_back(pred_net);
- InferBlobShapesAndTypes(shape_hints_ordered, nets);
+ InferBlobShapesAndTypes(*shape_hints_ordered, nets);
std::unordered_map<std::string, TensorShape> shape_hints;
- for (const auto& kv : shape_hints_ordered) {
+ for (const auto& kv : *shape_hints_ordered) {
shape_hints.emplace(kv.first, kv.second);
}
return shape_hints;
}
+CaffeMap<std::string, TensorShape> SsaRewriteAndMapIO(
+ Workspace* ws,
+ NetDef* pred_net,
+ const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
+ std::unordered_map<std::string, std::string> input_mapping =
+ onnx::SsaRewrite(nullptr, pred_net);
+ std::unordered_map<std::string, std::string> input_reverse_mapping;
+ for (const auto kv : input_mapping) {
+ input_reverse_mapping.emplace(kv.second, kv.first);
+ try {
+ if (ws->GetBlob(kv.second)) {
+ ws->RenameBlob(kv.second, kv.first);
+ }
+ } catch (const EnforceNotMet& e) {
+ LOG(WARNING) << "Cannot rename blob " << kv.second << " to " << kv.first
+ << ": " << e.what();
+ }
+ }
+ CaffeMap<std::string, TensorShape> shape_hints_ordered;
+ for (const auto& kv : input_shape_hints) {
+ const auto it = input_reverse_mapping.find(kv.first);
+ if (it != input_reverse_mapping.end()) {
+ LOG(INFO) << "Adding input hint: " << it->second;
+ shape_hints_ordered.emplace(it->second, kv.second);
+ } else {
+ shape_hints_ordered.emplace(kv.first, kv.second);
+ }
+ }
+ return shape_hints_ordered;
+}
+
// Figuring out the input the tensorrt runnable subgraph
// `start` and `end` defines the continuous chunk of ops that can be readily
// converted into an TensorRT op. And this function tries to figure out what's
@@ -123,6 +159,94 @@ FigureOutputs(const NetDef& pred_net, int start, int end) {
return all_outputs_vec;
}
+void CPUTensorToTensorProto(
+ const TensorCPU& cpu_tensor,
+ ::ONNX_NAMESPACE::TensorProto* t) {
+ const auto len = cpu_tensor.size();
+ if (cpu_tensor.template IsType<float>()) {
+ t->set_data_type(::ONNX_NAMESPACE::TensorProto::FLOAT);
+ const float* data = cpu_tensor.template data<float>();
+ for (auto i = 0; i < len; ++i) {
+ t->add_float_data(*data++);
+ }
+ } else if (cpu_tensor.template IsType<int64_t>()) {
+ t->set_data_type(::ONNX_NAMESPACE::TensorProto::INT64);
+ const int64_t* data = cpu_tensor.template data<int64_t>();
+ for (auto i = 0; i < len; ++i) {
+ t->add_int64_data(*data++);
+ }
+ } else if (cpu_tensor.template IsType<int32_t>()) {
+ t->set_data_type(::ONNX_NAMESPACE::TensorProto::INT32);
+ const int32_t* data = cpu_tensor.template data<int32_t>();
+ for (auto i = 0; i < len; ++i) {
+ t->add_int32_data(*data++);
+ }
+ } else {
+ CAFFE_THROW(
+ "Don't know how to convert workspace tensor type ",
+ cpu_tensor.meta().name(),
+ " to ONNX TensorProto");
+ }
+}
+
+void BlobToTensorProto(
+ const std::string& name,
+ Workspace* ws,
+ CUDAContext* context,
+ ::ONNX_NAMESPACE::TensorProto* t) {
+ // Set name
+ t->set_name(name);
+ const Blob* blob = ws->GetBlob(name);
+ CAFFE_ENFORCE(blob, "Blob ", name, " doesn't exist");
+
+ // Set dims
+ const auto shape = GetTensorShapeOfBlob(blob);
+ for (const auto i : shape.dims()) {
+ t->add_dims(i);
+ }
+
+ // Set values
+ if (blob->template IsType<TensorCPU>()) {
+ const auto& cpu_tensor = blob->template Get<TensorCPU>();
+ CPUTensorToTensorProto(cpu_tensor, t);
+ } else if (blob->template IsType<TensorCUDA>()) {
+ const auto& cuda_tensor = blob->template Get<TensorCUDA>();
+ const auto cpu_tensor = TensorCPU(cuda_tensor, context);
+ context->FinishDeviceComputation();
+ CPUTensorToTensorProto(cpu_tensor, t);
+ } else {
+ CAFFE_THROW(
+ "Initialization blob ",
+ name,
+ " needs to be either TensorCPU or TensorCUDA");
+ }
+}
+
+void BuildInitializationList(
+ Workspace* ws,
+ ::ONNX_NAMESPACE::GraphProto* g,
+ std::unordered_set<std::string>* initialization_list) {
+ const std::vector<string>& ws_blobs = ws->Blobs();
+
+ // Create a CUDA context and reuse it for potential tensor copies across
+ // devices
+ CUDAContext context;
+
+ for (const auto& s : ws_blobs) {
+ auto it = initialization_list->find(s);
+ if (it != initialization_list->end()) {
+ auto* init_tensor = g->add_initializer();
+ BlobToTensorProto(s, ws, &context, init_tensor);
+ initialization_list->erase(it);
+ }
+ }
+ CAFFE_ENFORCE(
+ initialization_list->empty(), "Unfulfilled initialization list");
+ for (const auto& t : g->initializer()) {
+ VLOG(1) << "Initializer: " << t.name();
+ }
+}
+
std::vector<::ONNX_NAMESPACE::ValueInfoProto> ConvertToValueInfo(
const std::vector<std::string>& names,
const std::unordered_map<std::string, TensorShape>& shape_hints) {
@@ -148,7 +272,7 @@ std::vector<::ONNX_NAMESPACE::ValueInfoProto> ConvertToValueInfo(
return r;
}
-void PruneUsedWeights(const NetDef& pred_net, NetDef* init_net) {
+void PruneUsedWeights(Workspace* ws, const NetDef& pred_net) {
std::unordered_set<std::string> used_weights;
for (const auto& op : pred_net.op()) {
for (const auto& i : op.input()) {
@@ -156,22 +280,8 @@ void PruneUsedWeights(const NetDef& pred_net, NetDef* init_net) {
}
}
- int last = init_net->op_size();
- for (int i = 0; i < last;) {
- if (!used_weights.count(init_net->op(i).output(0))) {
- if (i != last - 1) {
- init_net->mutable_op()->SwapElements(i, last - 1);
- } else {
- ++i;
- }
- --last;
- } else {
- ++i;
- }
- }
-
- if (last < init_net->op_size()) {
- init_net->mutable_op()->DeleteSubrange(last, init_net->op_size() - last);
+ for (const auto& w : used_weights) {
+ ws->RemoveBlob(w);
}
}
@@ -264,7 +374,7 @@ OperatorDef TensorRTTransformer::BuildTrtOp(
}
void TensorRTTransformer::ClusterToTrtOp(
- const NetDef& init_net,
+ Workspace* ws,
const NetDef& pred_net,
int start,
int end,
@@ -318,19 +428,22 @@ void TensorRTTransformer::ClusterToTrtOp(
}
// Convert weights to initializing tensors
- onnx::OnnxExporter exporter;
- for (const auto& op : init_net.op()) {
- CAFFE_ENFORCE_EQ(op.output_size(), 1);
- auto it = initialization_list.find(op.output(0));
- if (it != initialization_list.end()) {
- auto* init_tensor = model->mutable_graph()->add_initializer();
- exporter.InitOpToTensorProto(op, init_tensor);
- initialization_list.erase(it);
+ BuildInitializationList(ws, model->mutable_graph(), &initialization_list);
+
+ if (debug_builder_) {
+ std::ofstream ff("trt.onnx");
+ for (const auto& t : model->graph().initializer()) {
+ ff << "tensor: " << t.name() << std::endl;
+ ff << " dims: ";
+ for (auto i : t.dims()) {
+ ff << i << " ";
+ }
+ ff << std::endl;
+ for (auto i : t.float_data()) {
+ ff << " " << i << std::endl;
+ }
}
- }
- CAFFE_ENFORCE(initialization_list.empty(), "Unfulfilled initialization list");
- for (const auto& t : model->graph().initializer()) {
- VLOG(1) << "Initializer: " << t.name();
+ ff.close();
}
// Onnx model is ready. Call onnx-trt to convert to one trt c2 op
@@ -345,24 +458,22 @@ void TensorRTTransformer::ClusterToTrtOp(
// Cutting off the runnable part and replace with tensor ops. Asssume the nets
// were topologically sorted
void TensorRTTransformer::Transform(
- NetDef* init_net,
+ Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& input_shape_hints) {
- auto shape_hints = InferShapes(init_net, pred_net, input_shape_hints);
+ CAFFE_ENFORCE(ws);
+
+ auto shape_hints_ordered =
+ SsaRewriteAndMapIO(ws, pred_net, input_shape_hints);
+ auto shape_hints = InferShapes(ws, pred_net, &shape_hints_ordered);
std::unordered_set<std::string> weights;
- if (init_net) {
- for (const auto& op : init_net->op()) {
- CAFFE_ENFORCE_EQ(op.type().find("GivenTensor"), 0);
- CAFFE_ENFORCE_EQ(op.type().rfind("Fill"), op.type().size() - 4);
- CAFFE_ENFORCE_EQ(op.output_size(), 1);
- for (const auto& op_output : op.output()) {
- weights.emplace(op_output);
- }
- }
+ const std::vector<string>& ws_blobs = ws->Blobs();
+ for (const auto& s : ws_blobs) {
+ weights.emplace(s);
}
- CAFFE_ENFORCE(pred_net, "pred_net cannot be nullptr");
+ CAFFE_ENFORCE(pred_net, "Predict net cannot be nullptr");
::ONNX_NAMESPACE::ModelProto onnx_model;
FillModelInfo(&onnx_model);
@@ -372,7 +483,7 @@ void TensorRTTransformer::Transform(
int op_idx = 0;
int start = 0;
int end = 0;
- onnx::OnnxExporter exporter(true);
+ onnx::OnnxExporter exporter(nullptr, true);
for (const OperatorDef& op : pred_net->op()) {
bool support_trt = true;
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
@@ -428,7 +539,7 @@ void TensorRTTransformer::Transform(
} else {
end = op_idx;
ClusterToTrtOp(
- *init_net,
+ ws,
*pred_net,
start,
end,
@@ -444,14 +555,7 @@ void TensorRTTransformer::Transform(
if (trt_group) {
end = op_idx;
ClusterToTrtOp(
- *init_net,
- *pred_net,
- start,
- end,
- weights,
- shape_hints,
- &onnx_model,
- &new_ops);
+ ws, *pred_net, start, end, weights, shape_hints, &onnx_model, &new_ops);
trt_group = false;
}
@@ -459,7 +563,7 @@ void TensorRTTransformer::Transform(
for (const auto& op : new_ops) {
pred_net->add_op()->CopyFrom(op);
}
- PruneUsedWeights(*pred_net, init_net);
+ PruneUsedWeights(ws, *pred_net);
}
} // namespace caffe2
diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.h b/caffe2/contrib/tensorrt/tensorrt_tranformer.h
index fd69fca131..2308b4e23a 100644
--- a/caffe2/contrib/tensorrt/tensorrt_tranformer.h
+++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.h
@@ -2,6 +2,7 @@
#include "caffe2/core/common.h"
#include "caffe2/core/operator.h"
+#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2.pb.h"
#include "onnx/onnx_pb.h"
@@ -29,13 +30,13 @@ class TensorRTTransformer {
output_size_hints);
void Transform(
- NetDef* init_net,
+ Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& shape_hints);
private:
void ClusterToTrtOp(
- const NetDef& init_net,
+ Workspace* ws,
const NetDef& pred_net,
int start,
int end,
@@ -44,10 +45,11 @@ class TensorRTTransformer {
::ONNX_NAMESPACE::ModelProto* model,
std::vector<OperatorDef>* new_ops);
+
// TensorRT params
size_t max_batch_size_{50};
size_t max_workspace_size_{1024 * 1024 * 2};
int verbosity_{2};
- bool debug_builder_{true};
+ bool debug_builder_{false};
};
} // namespace caffe2
diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h
index c6bd638de4..19ed302435 100644
--- a/caffe2/core/workspace.h
+++ b/caffe2/core/workspace.h
@@ -219,7 +219,7 @@ class Workspace {
/**
* Renames a local workspace blob. If blob is not found in the local blob list
* or if the target name is already present in local or any parent blob list
- * the function will through.
+ * the function will throw.
*/
Blob* RenameBlob(const string& old_name, const string& new_name);
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 561c13d081..f835bece40 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -405,7 +405,7 @@ Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) {
"' dtype is not supported");
CAFFE_ENFORCE_EQ(
- c2_op.ops[0].arg().size(),
+ c2_op.ops.Get(0).arg().size(),
1,
"Unexpected number of attributes in 'Cast'");
c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype);
diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc
index 58f4dafe69..e1c1cc98c7 100644
--- a/caffe2/onnx/onnx_exporter.cc
+++ b/caffe2/onnx/onnx_exporter.cc
@@ -90,12 +90,10 @@ std::string SsaName(const std::string& n, int version) {
}
} // namespace
-std::pair<
- std::unordered_map<std::string, std::string>,
- std::unordered_map<std::string, std::string>>
-SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
+std::unordered_map<std::string, std::string> SsaRewrite(
+ caffe2::NetDef* init_net,
+ caffe2::NetDef* pred_net) {
std::unordered_map<std::string, std::string> input_mapping;
- std::unordered_map<std::string, std::string> output_mapping;
std::unordered_map<std::string, int> blob_versions;
#define REWRITE_EXTERNAL_IO(net, name) \
@@ -121,7 +119,6 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
blob_versions.emplace(output, 0);
}
REWRITE_EXTERNAL_IO(init_net, input);
- REWRITE_EXTERNAL_IO(init_net, output);
blob_versions.clear();
}
@@ -151,11 +148,31 @@ SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net) {
}
}
}
- REWRITE_EXTERNAL_IO(pred_net, output);
+
+ // Fix the external output name back to original
+ std::unordered_set<std::string> external_outputs;
+ for (const auto& output : pred_net->external_output()) {
+ external_outputs.emplace(output);
+ }
+ for (auto& op : *pred_net->mutable_op()) {
+ for (auto& output : *op.mutable_output()) {
+ auto pos = output.find_last_of('_');
+ CAFFE_ENFORCE_NE(pos, 0);
+ auto basename = output.substr(0, pos);
+ if (!external_outputs.count(basename)) {
+ continue;
+ }
+ auto it = blob_versions.find(basename);
+ if (it != blob_versions.end() &&
+ SsaName(basename, it->second) == output) {
+ output = basename;
+ }
+ }
+ }
}
#undef REWRITE_EXTERNAL_IO
- return std::make_pair(std::move(input_mapping), std::move(output_mapping));
+ return input_mapping;
}
const std::unordered_map<std::string, std::string>&
diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h
index 8c997c065a..dd11006915 100644
--- a/caffe2/onnx/onnx_exporter.h
+++ b/caffe2/onnx/onnx_exporter.h
@@ -23,10 +23,11 @@ using ::ONNX_NAMESPACE::TensorProto;
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
-std::pair<
- std::unordered_map<std::string, std::string>,
- std::unordered_map<std::string, std::string>>
-SsaRewrite(caffe2::NetDef* init_net, caffe2::NetDef* pred_net);
+// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
+// output names for predict net.
+std::unordered_map<std::string, std::string> SsaRewrite(
+ caffe2::NetDef* init_net,
+ caffe2::NetDef* pred_net);
class OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
diff --git a/caffe2/onnx/ssa_test.cc b/caffe2/onnx/ssa_test.cc
index db3ba451b8..9f83b8f486 100644
--- a/caffe2/onnx/ssa_test.cc
+++ b/caffe2/onnx/ssa_test.cc
@@ -22,9 +22,7 @@ TEST(SsaTest, ConvReluInplace) {
net.add_external_input("X");
net.add_external_output("Y");
- std::unordered_map<std::string, std::string> input_mapping;
- std::unordered_map<std::string, std::string> output_mapping;
- std::tie(input_mapping, output_mapping) =
+ std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& op : net.op()) {
std::unordered_set<std::string> inputs;
@@ -37,5 +35,5 @@ TEST(SsaTest, ConvReluInplace) {
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
- EXPECT_EQ("Y", output_mapping.at(net.external_output(0)));
+ EXPECT_EQ("Y", net.external_output(0));
}
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 32e0a86603..7f0b0c81d8 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -55,6 +55,10 @@ CAFFE_DEFINE_TYPED_REGISTRY(
REGISTER_BLOB_FETCHER((TypeMeta::Id<TensorCPU>()), TensorFetcher<CPUContext>);
REGISTER_BLOB_FEEDER(CPU, TensorFeeder<CPUContext>);
+Workspace* GetCurrentWorkspace() {
+ return gWorkspace;
+}
+
class StringFetcher : public BlobFetcherBase {
public:
py::object Fetch(const Blob& blob) override {
diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h
index 65882d9f15..957601fc8e 100644
--- a/caffe2/python/pybind_state.h
+++ b/caffe2/python/pybind_state.h
@@ -40,6 +40,9 @@ void addGlobalMethods(pybind11::module& m);
// Expose Workspace, Net, Blob
void addObjectMethods(pybind11::module& m);
+// Get current workspace
+Workspace* GetCurrentWorkspace();
+
class BlobFetcherBase {
public:
struct FetchedBlob {
diff --git a/caffe2/python/pybind_state_gpu.cc b/caffe2/python/pybind_state_gpu.cc
index 69d69d0ed8..14c4c48af1 100644
--- a/caffe2/python/pybind_state_gpu.cc
+++ b/caffe2/python/pybind_state_gpu.cc
@@ -77,37 +77,29 @@ void addCUDAGlobalMethods(py::module& m) {
});
m.def(
"transform_trt",
- [](const py::bytes& init_net_str,
- const py::bytes& pred_net_str,
+ [](const py::bytes& pred_net_str,
const std::unordered_map<std::string, std::vector<int>>& shapes,
int max_batch_size,
int max_workspace_size,
int verbosity,
- bool debug_builder) -> std::vector<py::bytes> {
+ bool debug_builder) -> py::bytes {
#ifdef CAFFE2_USE_TRT
- caffe2::NetDef init_net;
- if(!ParseProtoFromLargeString(
- init_net_str.cast<std::string>(), &init_net)) {
- LOG(ERROR) << "broken init_net protobuf";
- }
caffe2::NetDef pred_net;
- if(!ParseProtoFromLargeString(
- pred_net_str.cast<std::string>(), &pred_net)) {
+ if (!ParseProtoFromLargeString(
+ pred_net_str.cast<std::string>(), &pred_net)) {
LOG(ERROR) << "broken pred_net protobuf";
}
std::unordered_map<std::string, TensorShape> tensor_shapes;
- for (const auto& it: shapes) {
+ for (const auto& it : shapes) {
tensor_shapes.emplace(
it.first, CreateTensorShape(it.second, TensorProto::FLOAT));
}
TensorRTTransformer ts(
max_batch_size, max_workspace_size, verbosity, debug_builder);
- ts.Transform(&init_net, &pred_net, tensor_shapes);
- std::string init_net_str2;
+ ts.Transform(GetCurrentWorkspace(), &pred_net, tensor_shapes);
std::string pred_net_str2;
- init_net.SerializeToString(&init_net_str2);
pred_net.SerializeToString(&pred_net_str2);
- return {py::bytes(init_net_str2), py::bytes(pred_net_str2)};
+ return py::bytes(pred_net_str2);
#else
CAFFE_THROW("Please build Caffe2 with USE_TENSORRT=1");
#endif // CAFFE2_USE_TRT
diff --git a/caffe2/python/trt/test_trt.py b/caffe2/python/trt/test_trt.py
index 7fbb2d4719..ca359f8df2 100644
--- a/caffe2/python/trt/test_trt.py
+++ b/caffe2/python/trt/test_trt.py
@@ -155,7 +155,7 @@ class TensorRTTransformTest(TestCase):
downloadFromURLToFile(url, dest,
show_progress=False)
except TypeError:
- # show_progress not supported prior to
+# show_progress not supported prior to
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
# (Sep 17, 2017)
downloadFromURLToFile(url, dest)
@@ -230,44 +230,42 @@ class TensorRTTransformTest(TestCase):
Y_c2 = None
data = np.random.randn(*input_blob_dims).astype(np.float32)
c2_time = 1
- ws = Workspace()
+ workspace.SwitchWorkspace("gpu_test", True)
with core.DeviceScope(device_option):
- ws.FeedBlob(input_name, data)
- ws.RunNetOnce(init_net)
- ws.CreateNet(pred_net)
+ workspace.FeedBlob(input_name, data)
+ workspace.RunNetOnce(init_net)
+ workspace.CreateNet(pred_net)
for _ in range(warmup):
- ws.RunNet(pred_net.name)
+ workspace.RunNet(pred_net.name)
start = time.time()
for _ in range(repeat):
- ws.RunNet(pred_net.name)
+ workspace.RunNet(pred_net.name)
end = time.time()
c2_time = end - start
- output_values = [ws.FetchBlob(name) for name in net_outputs]
+ output_values = [workspace.FetchBlob(name) for name in net_outputs]
Y_c2 = namedtupledict('Outputs', net_outputs)(*output_values)
- ws.ResetWorkspace()
+ workspace.ResetWorkspace()
# Cut the graph
- init_net_cut, pred_net_cut = transform_caffe2_net(init_net, pred_net, {input_name: input_blob_dims})
+ pred_net_cut = transform_caffe2_net(device_option, init_net, pred_net, {input_name: input_blob_dims})
del init_net, pred_net
- #print_net(pred_net_cut)
+ #_print_net(pred_net_cut)
Y_trt = None
input_name = pred_net_cut.external_input[0]
print("C2 runtime: {}s".format(c2_time))
- ws = Workspace()
with core.DeviceScope(device_option):
- ws.FeedBlob(input_name, data)
- ws.RunNetOnce(init_net_cut)
- ws.CreateNet(pred_net_cut)
+ workspace.FeedBlob(input_name, data)
+ workspace.CreateNet(pred_net_cut)
for _ in range(warmup):
- ws.RunNet(pred_net_cut.name)
+ workspace.RunNet(pred_net_cut.name)
start = time.time()
for _ in range(repeat):
- ws.RunNet(pred_net_cut.name)
+ workspace.RunNet(pred_net_cut.name)
end = time.time()
trt_time = end - start
print("TRT runtime: {}s, improvement: {}%".format(trt_time, (c2_time-trt_time)/c2_time*100))
- output_values = [ws.FetchBlob(name) for name in net_outputs]
+ output_values = [workspace.FetchBlob(name) for name in net_outputs]
Y_trt = namedtupledict('Outputs', net_outputs)(*output_values)
np.testing.assert_allclose(Y_c2, Y_trt, rtol=1e-3)
diff --git a/caffe2/python/trt/transform.py b/caffe2/python/trt/transform.py
index b2b3a41c5b..28e973b077 100644
--- a/caffe2/python/trt/transform.py
+++ b/caffe2/python/trt/transform.py
@@ -13,6 +13,7 @@ from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2
from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op
+from caffe2.python import core, workspace
import caffe2.python.onnx.frontend as c2_front
import caffe2.python._import_c_extension as C
import numpy as np
@@ -69,10 +70,9 @@ def _infer_shapes(init_net, pred_net, inputs):
return hints
-def _ssa_rewrite_input(i):
- return i + "_0";
-
-def transform_caffe2_net(init_net,
+def transform_caffe2_net(
+ device_option,
+ init_net,
pred_net,
input_shapes,
populate_shapes = False,
@@ -84,29 +84,29 @@ def transform_caffe2_net(init_net,
Transfrom the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops
"""
check_gpu_()
- c2_front.ssa_rewrite(pred_net, init_net, value_info=[])
- input_data = {}
- for k,v in input_shapes.iteritems():
- input_data[_ssa_rewrite_input(k)] = np.random.randn(*v).astype(np.float32)
+
+ # Fill the workspace with the weights
+ with core.DeviceScope(device_option):
+ workspace.RunNetOnce(init_net)
# Hacky way to infer shapes as not all our operators have shape inference function.
# Normally this is not needed
+ shape_hints = {}
if populate_shapes:
+ input_data = {}
+ for k,v in input_shapes.iteritems():
+ input_data[k] = np.random.randn(*v).astype(np.float32)
shape_hints = _infer_shapes(init_net, pred_net, input_data)
- shape_hints = {}
for k,v in input_shapes.iteritems():
- shape_hints[_ssa_rewrite_input(k)] = v
- init_net_str, pred_net_str = C.transform_trt(init_net.SerializeToString(),
- pred_net.SerializeToString(),
- shape_hints,
- max_batch_size,
- max_workspace_size,
- verbosity,
- debug_builder)
- init_net_cut = caffe2_pb2.NetDef()
- init_net_cut.ParseFromString(init_net_str)
+ shape_hints[k] = v
+ pred_net_str = C.transform_trt(pred_net.SerializeToString(),
+ shape_hints,
+ max_batch_size,
+ max_workspace_size,
+ verbosity,
+ debug_builder)
pred_net_cut = caffe2_pb2.NetDef()
pred_net_cut.ParseFromString(pred_net_str)
- return init_net_cut, pred_net_cut
+ return pred_net_cut