diff options
Diffstat (limited to 'caffe2/operators/onnxifi_op.h')
-rw-r--r-- | caffe2/operators/onnxifi_op.h | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/caffe2/operators/onnxifi_op.h b/caffe2/operators/onnxifi_op.h index e8994444fd..34976109c6 100644 --- a/caffe2/operators/onnxifi_op.h +++ b/caffe2/operators/onnxifi_op.h @@ -98,10 +98,22 @@ class OnnxifiOp final : public Operator<Context> { ~OnnxifiOp() { backend_graph_shared_ptr_.reset(); backend_graph_map_ptr_->remove(op_id_string_); +#ifdef ONNXIFI_ENABLE_EXT + traces_.reset(); +#endif } bool RunOnDevice() override; + void setEnableTracing(bool b) { + enable_tracing_ = b; + } + +#ifdef ONNXIFI_ENABLE_EXT + std::shared_ptr<onnxTraceEventList> traces() const { + return traces_; + } +#endif private: uint64_t SetOutputShapeAndType(int output_idx, std::vector<size_t>* dims) { uint64_t type = ONNXIFI_DATATYPE_FLOAT32; @@ -204,17 +216,29 @@ class OnnxifiOp final : public Operator<Context> { backend_ = backend_graph_shared_ptr_->backend; graph_ = backend_graph_shared_ptr_->graph; -// Set up function pointer if onnxifi_ext is enabled + getExtFunctionPointers(); + } + + /// Set up function pointer if onnxifi_ext is enabled + void getExtFunctionPointers() { #ifdef ONNXIFI_ENABLE_EXT onnxExtensionFunctionPointer p; if (lib_->onnxGetExtensionFunctionAddress( backend_id_, "onnxSetIOAndRunGraphFunction", &p) != ONNXIFI_STATUS_SUCCESS) { onnxSetIOAndRunGraphPointer_ = nullptr; - return; + } else { + onnxSetIOAndRunGraphPointer_ = + reinterpret_cast<decltype(onnxSetIOAndRunGraphPointer_)>(p); + } + if (lib_->onnxGetExtensionFunctionAddress( + backend_id_, "onnxReleaseTraceEventsFunction", &p) != + ONNXIFI_STATUS_SUCCESS) { + onnxReleaseTraceEventsPointer_ = nullptr; + } else { + onnxReleaseTraceEventsPointer_ = + reinterpret_cast<decltype(onnxReleaseTraceEventsPointer_)>(p); } - onnxSetIOAndRunGraphPointer_ = - reinterpret_cast<decltype(onnxSetIOAndRunGraphPointer_)>(p); #endif } @@ -253,6 +277,10 @@ class OnnxifiOp final : public Operator<Context> { const onnxTensorDescriptorV1*, onnxMemoryFenceV1*, onnxTraceEventList*); + + onnxStatus (*onnxReleaseTraceEventsPointer_)(onnxTraceEventList*); + + std::shared_ptr<onnxTraceEventList> traces_{nullptr}; #endif bool use_onnx_{false}; @@ -277,6 +305,8 @@ class OnnxifiOp final : public Operator<Context> { // value: position of the input where the real batch size can be extracted // from its first dimension std::unordered_map<int, int> batch_pos_map_; + // Whether we enable tracing in one run of inference + bool enable_tracing_{false}; }; } // namespace caffe2 |