diff options
author | Hao Lu <hlu@fb.com> | 2018-07-06 15:07:20 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-07-06 15:15:17 -0700 |
commit | af107c4d162cba761de7acddb36552885bacf0a2 (patch) | |
tree | 16a8cf960382ab7385a9400a552d756c909183cb /modules | |
parent | f87499a8f3ee05e21c2ac0e97cd3e6e1393971f0 (diff) | |
download | pytorch-af107c4d162cba761de7acddb36552885bacf0a2.tar.gz pytorch-af107c4d162cba761de7acddb36552885bacf0a2.tar.bz2 pytorch-af107c4d162cba761de7acddb36552885bacf0a2.zip |
Fix shape inference bug (#9199)
Summary:
Closes https://github.com/pytorch/pytorch/pull/9199
The input shapes are not logged correctly in production because `PerfNetObserver::Stop()` only gets called after the inference is done for the net and in the mobile models, it's common practice to reuse the blobs as much as possible to save memory. And the shapes of the blobs keep changing during inference. By the time you you query `InputTensorShapes()` in `PerfNetObserver::Stop()`, you only get the final shape of the blobs.
To fix this bug, I moved the 'InputTensorShapes()' query from `PerfNetObserver::Stop()` to `PerfOperatorObserver::Stop()`. The latter gets called at the end of operator->run() whereas `PerfNetObserver::Stop()` gets called at the end of net->run().
Also remove `PerfOperatorObserver::getAnalyticalCost()` since it's now done on the server side and no longer needed on mobile
Reviewed By: Maratyszcza
Differential Revision: D8743346
fbshipit-source-id: 5d2d0132e3f5e084be7d0173863e695e62a6b4a0
Diffstat (limited to 'modules')
-rw-r--r-- | modules/observers/perf_observer.cc | 32 | ||||
-rw-r--r-- | modules/observers/perf_observer.h | 3 |
2 files changed, 9 insertions, 26 deletions
diff --git a/modules/observers/perf_observer.cc b/modules/observers/perf_observer.cc index f5ab3f11af..ed391a3e3f 100644 --- a/modules/observers/perf_observer.cc +++ b/modules/observers/perf_observer.cc @@ -85,15 +85,13 @@ void PerfNetObserver::Stop() { p.latency = static_cast<const PerfOperatorObserver*>(observerMap_[op]) ->getMilliseconds(); -#ifndef CAFFE2_IOS - auto cost = static_cast<const PerfOperatorObserver*>(observerMap_[op]) - ->getAnalyticalCost(); - p.flops = cost.flops; -#endif // CAFFE2_MOBILE p.engine = op->engine(); p.type = op->type(); - p.tensor_shapes = op->InputTensorShapes(); + p.tensor_shapes = + static_cast<const PerfOperatorObserver*>(observerMap_[op]) + ->getTensorShapes(); + if (op->has_debug_def()) { for (auto arg : op->debug_def().arg()) { p.args.emplace_back(arg); @@ -152,31 +150,15 @@ void PerfOperatorObserver::Stop() { /* Time from the start of the net minus the time spent on all other operators is the time spent on this operator */ milliseconds_ = netObserver_->getTimer().MilliSeconds() - milliseconds_; + tensor_shapes_ = subject_->InputTensorShapes(); } double PerfOperatorObserver::getMilliseconds() const { return milliseconds_; } -OpSchema::Cost PerfOperatorObserver::getAnalyticalCost() const { - auto* op = subject_; - auto* schema = OpSchemaRegistry::Schema(op->type()); - OpSchema::Cost cost; - if (schema && schema->HasCostInferenceFunction()) { - vector<TensorShape> shapes = op->InputTensorShapes(); - - auto all_good_shapes = std::accumulate( - shapes.begin(), - shapes.end(), - true, - [](bool acc, const TensorShape& shape) { - return acc && !shape.unknown_shape(); - }); - if (all_good_shapes) { - cost = schema->InferCost(op->debug_def(), shapes); - } - } - return cost; +std::vector<TensorShape> PerfOperatorObserver::getTensorShapes() const { + return tensor_shapes_; } } // namespace caffe2 diff --git a/modules/observers/perf_observer.h b/modules/observers/perf_observer.h index 122eca8905..6fb4063ffe 100644 --- a/modules/observers/perf_observer.h +++ b/modules/observers/perf_observer.h @@ -45,7 +45,7 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> { virtual ~PerfOperatorObserver(); double getMilliseconds() const; - OpSchema::Cost getAnalyticalCost() const; + std::vector<TensorShape> getTensorShapes() const; private: void Start() override; @@ -60,5 +60,6 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> { // costly here and a raw pointer is a cheapest sholution PerfNetObserver* netObserver_; double milliseconds_; + std::vector<TensorShape> tensor_shapes_; }; } // namespace caffe2 |