diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/observers/perf_observer.cc | 32 | ||||
-rw-r--r-- | modules/observers/perf_observer.h | 3 |
2 files changed, 9 insertions, 26 deletions
diff --git a/modules/observers/perf_observer.cc b/modules/observers/perf_observer.cc index f5ab3f11af..ed391a3e3f 100644 --- a/modules/observers/perf_observer.cc +++ b/modules/observers/perf_observer.cc @@ -85,15 +85,13 @@ void PerfNetObserver::Stop() { p.latency = static_cast<const PerfOperatorObserver*>(observerMap_[op]) ->getMilliseconds(); -#ifndef CAFFE2_IOS - auto cost = static_cast<const PerfOperatorObserver*>(observerMap_[op]) - ->getAnalyticalCost(); - p.flops = cost.flops; -#endif // CAFFE2_MOBILE p.engine = op->engine(); p.type = op->type(); - p.tensor_shapes = op->InputTensorShapes(); + p.tensor_shapes = + static_cast<const PerfOperatorObserver*>(observerMap_[op]) + ->getTensorShapes(); + if (op->has_debug_def()) { for (auto arg : op->debug_def().arg()) { p.args.emplace_back(arg); @@ -152,31 +150,15 @@ void PerfOperatorObserver::Stop() { /* Time from the start of the net minus the time spent on all other operators is the time spent on this operator */ milliseconds_ = netObserver_->getTimer().MilliSeconds() - milliseconds_; + tensor_shapes_ = subject_->InputTensorShapes(); } double PerfOperatorObserver::getMilliseconds() const { return milliseconds_; } -OpSchema::Cost PerfOperatorObserver::getAnalyticalCost() const { - auto* op = subject_; - auto* schema = OpSchemaRegistry::Schema(op->type()); - OpSchema::Cost cost; - if (schema && schema->HasCostInferenceFunction()) { - vector<TensorShape> shapes = op->InputTensorShapes(); - - auto all_good_shapes = std::accumulate( - shapes.begin(), - shapes.end(), - true, - [](bool acc, const TensorShape& shape) { - return acc && !shape.unknown_shape(); - }); - if (all_good_shapes) { - cost = schema->InferCost(op->debug_def(), shapes); - } - } - return cost; +std::vector<TensorShape> PerfOperatorObserver::getTensorShapes() const { + return tensor_shapes_; } } // namespace caffe2 diff --git a/modules/observers/perf_observer.h b/modules/observers/perf_observer.h index 122eca8905..6fb4063ffe 100644 --- a/modules/observers/perf_observer.h +++ b/modules/observers/perf_observer.h @@ -45,7 +45,7 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> { virtual ~PerfOperatorObserver(); double getMilliseconds() const; - OpSchema::Cost getAnalyticalCost() const; + std::vector<TensorShape> getTensorShapes() const; private: void Start() override; @@ -60,5 +60,6 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> { // costly here and a raw pointer is a cheapest sholution PerfNetObserver* netObserver_; double milliseconds_; + std::vector<TensorShape> tensor_shapes_; }; } // namespace caffe2 |