summaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorHao Lu <hlu@fb.com>2018-07-06 15:07:20 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-07-06 15:15:17 -0700
commitaf107c4d162cba761de7acddb36552885bacf0a2 (patch)
tree16a8cf960382ab7385a9400a552d756c909183cb /modules
parentf87499a8f3ee05e21c2ac0e97cd3e6e1393971f0 (diff)
downloadpytorch-af107c4d162cba761de7acddb36552885bacf0a2.tar.gz
pytorch-af107c4d162cba761de7acddb36552885bacf0a2.tar.bz2
pytorch-af107c4d162cba761de7acddb36552885bacf0a2.zip
Fix shape inference bug (#9199)
Summary: Closes https://github.com/pytorch/pytorch/pull/9199 The input shapes are not logged correctly in production because `PerfNetObserver::Stop()` only gets called after the inference is done for the net and in the mobile models, it's common practice to reuse the blobs as much as possible to save memory. And the shapes of the blobs keep changing during inference. By the time you you query `InputTensorShapes()` in `PerfNetObserver::Stop()`, you only get the final shape of the blobs. To fix this bug, I moved the 'InputTensorShapes()' query from `PerfNetObserver::Stop()` to `PerfOperatorObserver::Stop()`. The latter gets called at the end of operator->run() whereas `PerfNetObserver::Stop()` gets called at the end of net->run(). Also remove `PerfOperatorObserver::getAnalyticalCost()` since it's now done on the server side and no longer needed on mobile Reviewed By: Maratyszcza Differential Revision: D8743346 fbshipit-source-id: 5d2d0132e3f5e084be7d0173863e695e62a6b4a0
Diffstat (limited to 'modules')
-rw-r--r--modules/observers/perf_observer.cc32
-rw-r--r--modules/observers/perf_observer.h3
2 files changed, 9 insertions, 26 deletions
diff --git a/modules/observers/perf_observer.cc b/modules/observers/perf_observer.cc
index f5ab3f11af..ed391a3e3f 100644
--- a/modules/observers/perf_observer.cc
+++ b/modules/observers/perf_observer.cc
@@ -85,15 +85,13 @@ void PerfNetObserver::Stop() {
p.latency = static_cast<const PerfOperatorObserver*>(observerMap_[op])
->getMilliseconds();
-#ifndef CAFFE2_IOS
- auto cost = static_cast<const PerfOperatorObserver*>(observerMap_[op])
- ->getAnalyticalCost();
- p.flops = cost.flops;
-#endif // CAFFE2_MOBILE
p.engine = op->engine();
p.type = op->type();
- p.tensor_shapes = op->InputTensorShapes();
+ p.tensor_shapes =
+ static_cast<const PerfOperatorObserver*>(observerMap_[op])
+ ->getTensorShapes();
+
if (op->has_debug_def()) {
for (auto arg : op->debug_def().arg()) {
p.args.emplace_back(arg);
@@ -152,31 +150,15 @@ void PerfOperatorObserver::Stop() {
/* Time from the start of the net minus the time spent on all other
operators is the time spent on this operator */
milliseconds_ = netObserver_->getTimer().MilliSeconds() - milliseconds_;
+ tensor_shapes_ = subject_->InputTensorShapes();
}
double PerfOperatorObserver::getMilliseconds() const {
return milliseconds_;
}
-OpSchema::Cost PerfOperatorObserver::getAnalyticalCost() const {
- auto* op = subject_;
- auto* schema = OpSchemaRegistry::Schema(op->type());
- OpSchema::Cost cost;
- if (schema && schema->HasCostInferenceFunction()) {
- vector<TensorShape> shapes = op->InputTensorShapes();
-
- auto all_good_shapes = std::accumulate(
- shapes.begin(),
- shapes.end(),
- true,
- [](bool acc, const TensorShape& shape) {
- return acc && !shape.unknown_shape();
- });
- if (all_good_shapes) {
- cost = schema->InferCost(op->debug_def(), shapes);
- }
- }
- return cost;
+std::vector<TensorShape> PerfOperatorObserver::getTensorShapes() const {
+ return tensor_shapes_;
}
} // namespace caffe2
diff --git a/modules/observers/perf_observer.h b/modules/observers/perf_observer.h
index 122eca8905..6fb4063ffe 100644
--- a/modules/observers/perf_observer.h
+++ b/modules/observers/perf_observer.h
@@ -45,7 +45,7 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> {
virtual ~PerfOperatorObserver();
double getMilliseconds() const;
- OpSchema::Cost getAnalyticalCost() const;
+ std::vector<TensorShape> getTensorShapes() const;
private:
void Start() override;
@@ -60,5 +60,6 @@ class PerfOperatorObserver : public ObserverBase<OperatorBase> {
// costly here and a raw pointer is a cheapest sholution
PerfNetObserver* netObserver_;
double milliseconds_;
+ std::vector<TensorShape> tensor_shapes_;
};
} // namespace caffe2