summaryrefslogtreecommitdiff
path: root/caffe2/observers
diff options
context:
space:
mode:
authorAapo Kyrola <akyrola@fb.com>2017-11-20 16:12:54 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-11-20 16:17:52 -0800
commite0c8c539e75093dd11bb8bc3f50964012194e438 (patch)
treec8d413cca5fe57847080b26177f083c5bcf43e5f /caffe2/observers
parent335c7dc6812ae8224a96d73f9f41d725be658ff0 (diff)
downloadpytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.tar.gz
pytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.tar.bz2
pytorch-e0c8c539e75093dd11bb8bc3f50964012194e438.zip
Backed out changeset 119623addbbd
Summary: Unlanding D6327460 because seems to be causing unstability. Differential Revision: D6377117 fbshipit-source-id: 4e1241fe65cd4c7a127fa6fa724f60b75965a096
Diffstat (limited to 'caffe2/observers')
-rw-r--r--caffe2/observers/runcnt_observer.cc14
-rw-r--r--caffe2/observers/runcnt_observer.h8
-rw-r--r--caffe2/observers/time_observer.cc12
-rw-r--r--caffe2/observers/time_observer.h7
4 files changed, 26 insertions, 15 deletions
diff --git a/caffe2/observers/runcnt_observer.cc b/caffe2/observers/runcnt_observer.cc
index d28a3b7d0b..d52a7d6fa3 100644
--- a/caffe2/observers/runcnt_observer.cc
+++ b/caffe2/observers/runcnt_observer.cc
@@ -18,18 +18,24 @@ std::string RunCountNetObserver::debugInfo() {
return "This operator runs " + caffe2::to_string(cnt_) + " times.";
}
-void RunCountNetObserver::Start() {
+bool RunCountNetObserver::Start() {
const auto& operators = subject_->GetOperators();
for (auto* op : operators) {
op->AttachObserver(caffe2::make_unique<RunCountOperatorObserver>(op, this));
}
+ return true;
}
-void RunCountNetObserver::Stop() {}
+bool RunCountNetObserver::Stop() {
+ return true;
+}
-void RunCountOperatorObserver::Start() {
+bool RunCountOperatorObserver::Start() {
++netObserver_->cnt_;
+ return true;
+}
+bool RunCountOperatorObserver::Stop() {
+ return true;
}
-void RunCountOperatorObserver::Stop() {}
} // namespace caffe2
diff --git a/caffe2/observers/runcnt_observer.h b/caffe2/observers/runcnt_observer.h
index a27e773a08..44b016a57b 100644
--- a/caffe2/observers/runcnt_observer.h
+++ b/caffe2/observers/runcnt_observer.h
@@ -16,8 +16,8 @@ class RunCountOperatorObserver final : public ObserverBase<OperatorBase> {
std::unique_ptr<ObserverBase<OperatorBase>> clone() override;
private:
- void Start() override;
- void Stop() override;
+ bool Start() override;
+ bool Stop() override;
private:
RunCountNetObserver* netObserver_;
@@ -34,8 +34,8 @@ class RunCountNetObserver final : public ObserverBase<NetBase> {
friend class RunCountOperatorObserver;
private:
- void Start() override;
- void Stop() override;
+ bool Start() override;
+ bool Stop() override;
protected:
std::atomic<int> cnt_;
diff --git a/caffe2/observers/time_observer.cc b/caffe2/observers/time_observer.cc
index 80f5a95944..630d632f8f 100644
--- a/caffe2/observers/time_observer.cc
+++ b/caffe2/observers/time_observer.cc
@@ -20,31 +20,35 @@
namespace caffe2 {
template <>
-void TimeObserverBase<NetBase>::Start() {
+bool TimeObserverBase<NetBase>::Start() {
CAFFE_THROW(
"This function is overridden by TimeObserver<NetBase>.\
If it was called there is an issue with compilation.");
+ return false;
}
template <>
-void TimeObserverBase<NetBase>::Stop() {
+bool TimeObserverBase<NetBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n";
+ return true;
}
template <>
-void TimeObserverBase<OperatorBase>::Start() {
+bool TimeObserverBase<OperatorBase>::Start() {
start_time_ = timer_.MilliSeconds();
++iterations_;
+ return true;
}
template <>
-void TimeObserverBase<OperatorBase>::Stop() {
+bool TimeObserverBase<OperatorBase>::Stop() {
double current_run = timer_.MilliSeconds() - start_time_;
total_time_ += current_run;
VLOG(1) << "This operator iteration took " << current_run
<< " ms to complete.\n";
+ return true;
}
} // namespace caffe2
diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h
index add88ea19f..10bf0097b0 100644
--- a/caffe2/observers/time_observer.h
+++ b/caffe2/observers/time_observer.h
@@ -36,8 +36,8 @@ class TimeObserverBase : public ObserverBase<T> {
}
~TimeObserverBase() {}
- void Start() override;
- void Stop() override;
+ bool Start() override;
+ bool Stop() override;
protected:
Timer timer_;
@@ -77,7 +77,7 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
return sum / subject_->GetOperators().size();
}
- void Start() override {
+ bool Start() override {
for (auto* op : subject_->GetOperators()) {
const auto* observer = op->AttachObserver(
caffe2::make_unique<TimeObserver<OperatorBase>>(op));
@@ -87,6 +87,7 @@ class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
}
start_time_ = timer_.MilliSeconds();
++iterations_;
+ return true;
}
private: