diff options
author | Aapo Kyrola <akyrola@fb.com> | 2017-11-20 16:12:54 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-11-20 16:17:52 -0800 |
commit | e0c8c539e75093dd11bb8bc3f50964012194e438 (patch) | |
tree | c8d413cca5fe57847080b26177f083c5bcf43e5f /caffe2/observers | |
parent | 335c7dc6812ae8224a96d73f9f41d725be658ff0 (diff) | |
download | pytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.tar.gz pytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.tar.bz2 pytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.zip |
Backed out changeset 119623addbbd
Summary: Unlanding D6327460 because seems to be causing unstability.
Differential Revision: D6377117
fbshipit-source-id: 4e1241fe65cd4c7a127fa6fa724f60b75965a096
Diffstat (limited to 'caffe2/observers')
-rw-r--r-- | caffe2/observers/runcnt_observer.cc | 14 | ||||
-rw-r--r-- | caffe2/observers/runcnt_observer.h | 8 | ||||
-rw-r--r-- | caffe2/observers/time_observer.cc | 12 | ||||
-rw-r--r-- | caffe2/observers/time_observer.h | 7 |
4 files changed, 26 insertions, 15 deletions
diff --git a/caffe2/observers/runcnt_observer.cc b/caffe2/observers/runcnt_observer.cc index d28a3b7d0b..d52a7d6fa3 100644 --- a/caffe2/observers/runcnt_observer.cc +++ b/caffe2/observers/runcnt_observer.cc @@ -18,18 +18,24 @@ std::string RunCountNetObserver::debugInfo() { return "This operator runs " + caffe2::to_string(cnt_) + " times."; } -void RunCountNetObserver::Start() { +bool RunCountNetObserver::Start() { const auto& operators = subject_->GetOperators(); for (auto* op : operators) { op->AttachObserver(caffe2::make_unique<RunCountOperatorObserver>(op, this)); } + return true; } -void RunCountNetObserver::Stop() {} +bool RunCountNetObserver::Stop() { + return true; +} -void RunCountOperatorObserver::Start() { +bool RunCountOperatorObserver::Start() { ++netObserver_->cnt_; + return true; +} +bool RunCountOperatorObserver::Stop() { + return true; } -void RunCountOperatorObserver::Stop() {} } // namespace caffe2 diff --git a/caffe2/observers/runcnt_observer.h b/caffe2/observers/runcnt_observer.h index a27e773a08..44b016a57b 100644 --- a/caffe2/observers/runcnt_observer.h +++ b/caffe2/observers/runcnt_observer.h @@ -16,8 +16,8 @@ class RunCountOperatorObserver final : public ObserverBase<OperatorBase> { std::unique_ptr<ObserverBase<OperatorBase>> clone() override; private: - void Start() override; - void Stop() override; + bool Start() override; + bool Stop() override; private: RunCountNetObserver* netObserver_; @@ -34,8 +34,8 @@ class RunCountNetObserver final : public ObserverBase<NetBase> { friend class RunCountOperatorObserver; private: - void Start() override; - void Stop() override; + bool Start() override; + bool Stop() override; protected: std::atomic<int> cnt_; diff --git a/caffe2/observers/time_observer.cc b/caffe2/observers/time_observer.cc index 80f5a95944..630d632f8f 100644 --- a/caffe2/observers/time_observer.cc +++ b/caffe2/observers/time_observer.cc @@ -20,31 +20,35 @@ namespace caffe2 { template <> -void TimeObserverBase<NetBase>::Start() { +bool TimeObserverBase<NetBase>::Start() { CAFFE_THROW( "This function is overridden by TimeObserver<NetBase>.\ If it was called there is an issue with compilation."); + return false; } template <> -void TimeObserverBase<NetBase>::Stop() { +bool TimeObserverBase<NetBase>::Stop() { double current_run = timer_.MilliSeconds() - start_time_; total_time_ += current_run; VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n"; + return true; } template <> -void TimeObserverBase<OperatorBase>::Start() { +bool TimeObserverBase<OperatorBase>::Start() { start_time_ = timer_.MilliSeconds(); ++iterations_; + return true; } template <> -void TimeObserverBase<OperatorBase>::Stop() { +bool TimeObserverBase<OperatorBase>::Stop() { double current_run = timer_.MilliSeconds() - start_time_; total_time_ += current_run; VLOG(1) << "This operator iteration took " << current_run << " ms to complete.\n"; + return true; } } // namespace caffe2 diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h index add88ea19f..10bf0097b0 100644 --- a/caffe2/observers/time_observer.h +++ b/caffe2/observers/time_observer.h @@ -36,8 +36,8 @@ class TimeObserverBase : public ObserverBase<T> { } ~TimeObserverBase() {} - void Start() override; - void Stop() override; + bool Start() override; + bool Stop() override; protected: Timer timer_; @@ -77,7 +77,7 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> { return sum / subject_->GetOperators().size(); } - void Start() override { + bool Start() override { for (auto* op : subject_->GetOperators()) { const auto* observer = op->AttachObserver( caffe2::make_unique<TimeObserver<OperatorBase>>(op)); @@ -87,6 +87,7 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> { } start_time_ = timer_.MilliSeconds(); ++iterations_; + return true; } private: |