summaryrefslogtreecommitdiff
path: root/caffe2/operators/onnxifi_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2/operators/onnxifi_op.h')
-rw-r--r--caffe2/operators/onnxifi_op.h38
1 files changed, 34 insertions, 4 deletions
diff --git a/caffe2/operators/onnxifi_op.h b/caffe2/operators/onnxifi_op.h
index e8994444fd..34976109c6 100644
--- a/caffe2/operators/onnxifi_op.h
+++ b/caffe2/operators/onnxifi_op.h
@@ -98,10 +98,22 @@ class OnnxifiOp final : public Operator<Context> {
~OnnxifiOp() {
backend_graph_shared_ptr_.reset();
backend_graph_map_ptr_->remove(op_id_string_);
+#ifdef ONNXIFI_ENABLE_EXT
+ traces_.reset();
+#endif
}
bool RunOnDevice() override;
+ void setEnableTracing(bool b) {
+ enable_tracing_ = b;
+ }
+
+#ifdef ONNXIFI_ENABLE_EXT
+ std::shared_ptr<onnxTraceEventList> traces() const {
+ return traces_;
+ }
+#endif
private:
uint64_t SetOutputShapeAndType(int output_idx, std::vector<size_t>* dims) {
uint64_t type = ONNXIFI_DATATYPE_FLOAT32;
@@ -204,17 +216,29 @@ class OnnxifiOp final : public Operator<Context> {
backend_ = backend_graph_shared_ptr_->backend;
graph_ = backend_graph_shared_ptr_->graph;
-// Set up function pointer if onnxifi_ext is enabled
+ getExtFunctionPointers();
+ }
+
+ /// Set up function pointer if onnxifi_ext is enabled
+ void getExtFunctionPointers() {
#ifdef ONNXIFI_ENABLE_EXT
onnxExtensionFunctionPointer p;
if (lib_->onnxGetExtensionFunctionAddress(
backend_id_, "onnxSetIOAndRunGraphFunction", &p) !=
ONNXIFI_STATUS_SUCCESS) {
onnxSetIOAndRunGraphPointer_ = nullptr;
- return;
+ } else {
+ onnxSetIOAndRunGraphPointer_ =
+ reinterpret_cast<decltype(onnxSetIOAndRunGraphPointer_)>(p);
+ }
+ if (lib_->onnxGetExtensionFunctionAddress(
+ backend_id_, "onnxReleaseTraceEventsFunction", &p) !=
+ ONNXIFI_STATUS_SUCCESS) {
+ onnxReleaseTraceEventsPointer_ = nullptr;
+ } else {
+ onnxReleaseTraceEventsPointer_ =
+ reinterpret_cast<decltype(onnxReleaseTraceEventsPointer_)>(p);
}
- onnxSetIOAndRunGraphPointer_ =
- reinterpret_cast<decltype(onnxSetIOAndRunGraphPointer_)>(p);
#endif
}
@@ -253,6 +277,10 @@ class OnnxifiOp final : public Operator<Context> {
const onnxTensorDescriptorV1*,
onnxMemoryFenceV1*,
onnxTraceEventList*);
+
+ onnxStatus (*onnxReleaseTraceEventsPointer_)(onnxTraceEventList*);
+
+ std::shared_ptr<onnxTraceEventList> traces_{nullptr};
#endif
bool use_onnx_{false};
@@ -277,6 +305,8 @@ class OnnxifiOp final : public Operator<Context> {
// value: position of the input where the real batch size can be extracted
// from its first dimension
std::unordered_map<int, int> batch_pos_map_;
+ // Whether we enable tracing in one run of inference
+ bool enable_tracing_{false};
};
} // namespace caffe2