summaryrefslogtreecommitdiff
path: root/caffe2/observers
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2017-10-10 15:57:48 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-10-10 16:10:41 -0700
commit63caca89dbb4c79a966ff2b8d191bd9c768624c4 (patch)
treee8bdb1954f103f242c48344f3ba2a1f736f315ac /caffe2/observers
parentf11ff5befb0be4c6d0c282ad21b0d3de8d6f9695 (diff)
downloadpytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.tar.gz
pytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.tar.bz2
pytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.zip
expose observers to python
Summary: observer framework can now be used in python + a small writeup of how to use it Reviewed By: salexspb Differential Revision: D5905002 fbshipit-source-id: e40ec24a55e08fb73beea9b4f3b68e71fc66ffb1
Diffstat (limited to 'caffe2/observers')
-rw-r--r--caffe2/observers/CMakeLists.txt9
-rw-r--r--caffe2/observers/README.md36
-rw-r--r--caffe2/observers/time_observer.cc54
-rw-r--r--caffe2/observers/time_observer.h85
-rw-r--r--caffe2/observers/time_observer_test.cc90
5 files changed, 274 insertions, 0 deletions
diff --git a/caffe2/observers/CMakeLists.txt b/caffe2/observers/CMakeLists.txt
new file mode 100644
index 0000000000..90e9a16ee5
--- /dev/null
+++ b/caffe2/observers/CMakeLists.txt
@@ -0,0 +1,9 @@
+if(USE_OBSERVERS)
+ message(STATUS "Include Observer library")
+ set(Caffe2_CONTRIB_OBSERVERS_CPU_SRC
+ "${CMAKE_CURRENT_SOURCE_DIR}/time_observer.cc"
+ )
+
+ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${Caffe2_CONTRIB_OBSERVERS_CPU_SRC})
+ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
+endif()
diff --git a/caffe2/observers/README.md b/caffe2/observers/README.md
new file mode 100644
index 0000000000..86ef3aed9f
--- /dev/null
+++ b/caffe2/observers/README.md
@@ -0,0 +1,36 @@
+# Observers
+
+## Usage
+
+Observers are a small framework that allow users to attach code to the execution of SimpleNets and Operators.
+
+An example of an Observer is the `TimeObserver`, used as follows:
+
+### C++
+
+```
+unique_ptr<TimeObserver<NetBase>> net_ob =
+ make_unique<TimeObserver<NetBase>>(net.get());
+auto* ob = net->AddObserver(std::move(net_ob));
+net->Run();
+LOG(INFO) << "av time children: " << ob->average_time_children();
+LOG(INFO) << "av time: " << ob->average_time();
+```
+
+### Python
+
+```
+model.net.AddTimeObserver()
+ws.RunNet(model.net)
+ob = model.net.GetObserver()
+
+print("av time children:", ob.average_time_children())
+print("av time:", ob.average_time())
+```
+
+
+## Implementing An Observer
+
+To implement an observer you must inherit from `ObserverBase` and implement the `Start` and `Stop` functions.
+
+Observers are instantiated with a `subject` of a generic type, such as a `Net` or `Operator`. The observer framework is built to be generic enough to "observe" various other types, however.
diff --git a/caffe2/observers/time_observer.cc b/caffe2/observers/time_observer.cc
new file mode 100644
index 0000000000..630d632f8f
--- /dev/null
+++ b/caffe2/observers/time_observer.cc
@@ -0,0 +1,54 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "time_observer.h"
+#include "caffe2/core/logging.h"
+
+namespace caffe2 {
+
+template <>
+bool TimeObserverBase<NetBase>::Start() {
+ CAFFE_THROW(
+ "This function is overridden by TimeObserver<NetBase>.\
+ If it was called there is an issue with compilation.");
+ return false;
+}
+
+template <>
+bool TimeObserverBase<NetBase>::Stop() {
+ double current_run = timer_.MilliSeconds() - start_time_;
+ total_time_ += current_run;
+ VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n";
+ return true;
+}
+
+template <>
+bool TimeObserverBase<OperatorBase>::Start() {
+ start_time_ = timer_.MilliSeconds();
+ ++iterations_;
+ return true;
+}
+
+template <>
+bool TimeObserverBase<OperatorBase>::Stop() {
+ double current_run = timer_.MilliSeconds() - start_time_;
+ total_time_ += current_run;
+ VLOG(1) << "This operator iteration took " << current_run
+ << " ms to complete.\n";
+ return true;
+}
+
+} // namespace caffe2
diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h
new file mode 100644
index 0000000000..f3d129a1bc
--- /dev/null
+++ b/caffe2/observers/time_observer.h
@@ -0,0 +1,85 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
+#define CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
+
+#include <unordered_map>
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/net.h"
+#include "caffe2/core/observer.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/timer.h"
+
+namespace caffe2 {
+
+template <class T>
+class TimeObserverBase : public ObserverBase<T> {
+ public:
+ explicit TimeObserverBase<T>(T* subject) : ObserverBase<T>(subject) {}
+ inline float average_time() const {
+ return total_time_ / iterations_;
+ }
+ ~TimeObserverBase() {}
+
+ bool Start() override;
+ bool Stop() override;
+
+ protected:
+ Timer timer_;
+ float start_time_ = 0.0f;
+ float total_time_ = 0.0f;
+ int iterations_ = 0;
+};
+
+template <class T>
+class TimeObserver final : public TimeObserverBase<T> {
+ public:
+ explicit TimeObserver<T>(T* subject) : TimeObserverBase<T>(subject) {}
+};
+
+template <>
+class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> {
+ public:
+ explicit TimeObserver<NetBase>(NetBase* subject)
+ : TimeObserverBase<NetBase>(subject) {}
+ float average_time_children() const {
+ float sum = 0.0f;
+ for (const auto* observer : operator_observers_) {
+ sum += observer->average_time();
+ }
+ return sum / subject_->GetOperators().size();
+ }
+
+ bool Start() override {
+ for (auto* op : subject_->GetOperators()) {
+ operator_observers_.push_back(
+ dynamic_cast_if_rtti<TimeObserver<OperatorBase>*>(op->AddObserver(
+ caffe2::make_unique<TimeObserver<OperatorBase>>(op))));
+ }
+ start_time_ = timer_.MilliSeconds();
+ ++iterations_;
+ return true;
+ }
+
+ private:
+ vector<const TimeObserver<OperatorBase>*> operator_observers_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_
diff --git a/caffe2/observers/time_observer_test.cc b/caffe2/observers/time_observer_test.cc
new file mode 100644
index 0000000000..8e55e9d135
--- /dev/null
+++ b/caffe2/observers/time_observer_test.cc
@@ -0,0 +1,90 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/net.h"
+#include "caffe2/core/observer.h"
+#include "caffe2/core/operator.h"
+#include "time_observer.h"
+
+#include <google/protobuf/text_format.h>
+#include <gtest/gtest.h>
+#include <chrono>
+#include <thread>
+
+namespace caffe2 {
+
+namespace {
+
+class SleepOp final : public OperatorBase {
+ public:
+ using OperatorBase::OperatorBase;
+ bool Run(int /* unused */) override {
+ StartAllObservers();
+ std::this_thread::sleep_for(std::chrono::milliseconds(3000));
+ StopAllObservers();
+ return true;
+ }
+};
+
+REGISTER_CPU_OPERATOR(SleepOp, SleepOp);
+REGISTER_CUDA_OPERATOR(SleepOp, SleepOp);
+
+OPERATOR_SCHEMA(SleepOp)
+ .NumInputs(0, INT_MAX)
+ .NumOutputs(0, INT_MAX)
+ .AllowInplace({{0, 0}, {1, 1}});
+
+unique_ptr<NetBase> CreateNetTestHelper(Workspace* ws) {
+ NetDef net_def;
+ {
+ auto& op = *(net_def.add_op());
+ op.set_type("SleepOp");
+ op.add_input("in");
+ op.add_output("hidden");
+ }
+ {
+ auto& op = *(net_def.add_op());
+ op.set_type("SleepOp");
+ op.add_input("hidden");
+ op.add_output("out");
+ }
+ net_def.add_external_input("in");
+ net_def.add_external_output("out");
+
+ return CreateNet(net_def, ws);
+}
+}
+
+TEST(TimeObserverTest, Test3Seconds) {
+ Workspace ws;
+ ws.CreateBlob("in");
+ NetDef net_def;
+ unique_ptr<NetBase> net(CreateNetTestHelper(&ws));
+ unique_ptr<TimeObserver<NetBase>> net_ob =
+ make_unique<TimeObserver<NetBase>>(net.get());
+ auto* ob = dynamic_cast_if_rtti<TimeObserver<NetBase>*>(
+ net->AddObserver(std::move(net_ob)));
+ net->Run();
+ CAFFE_ENFORCE(ob);
+ LOG(INFO) << "av time children: " << ob->average_time_children();
+ LOG(INFO) << "av time: " << ob->average_time();
+ CAFFE_ENFORCE(ob->average_time_children() > 3000);
+ CAFFE_ENFORCE(ob->average_time_children() < 3500);
+ CAFFE_ENFORCE(ob->average_time() > 6000);
+ CAFFE_ENFORCE(ob->average_time() < 6500);
+}
+}