diff options
author | Bram Wasti <bwasti@fb.com> | 2017-10-10 15:57:48 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-10-10 16:10:41 -0700 |
commit | 63caca89dbb4c79a966ff2b8d191bd9c768624c4 (patch) | |
tree | e8bdb1954f103f242c48344f3ba2a1f736f315ac /caffe2/observers | |
parent | f11ff5befb0be4c6d0c282ad21b0d3de8d6f9695 (diff) | |
download | pytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.tar.gz pytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.tar.bz2 pytorch-63caca89dbb4c79a966ff2b8d191bd9c768624c4.zip |
expose observers to python
Summary: observer framework can now be used in python + a small writeup of how to use it
Reviewed By: salexspb
Differential Revision: D5905002
fbshipit-source-id: e40ec24a55e08fb73beea9b4f3b68e71fc66ffb1
Diffstat (limited to 'caffe2/observers')
-rw-r--r-- | caffe2/observers/CMakeLists.txt | 9 | ||||
-rw-r--r-- | caffe2/observers/README.md | 36 | ||||
-rw-r--r-- | caffe2/observers/time_observer.cc | 54 | ||||
-rw-r--r-- | caffe2/observers/time_observer.h | 85 | ||||
-rw-r--r-- | caffe2/observers/time_observer_test.cc | 90 |
5 files changed, 274 insertions, 0 deletions
diff --git a/caffe2/observers/CMakeLists.txt b/caffe2/observers/CMakeLists.txt new file mode 100644 index 0000000000..90e9a16ee5 --- /dev/null +++ b/caffe2/observers/CMakeLists.txt @@ -0,0 +1,9 @@ +if(USE_OBSERVERS) + message(STATUS "Include Observer library") + set(Caffe2_CONTRIB_OBSERVERS_CPU_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/time_observer.cc" + ) + + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${Caffe2_CONTRIB_OBSERVERS_CPU_SRC}) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) +endif() diff --git a/caffe2/observers/README.md b/caffe2/observers/README.md new file mode 100644 index 0000000000..86ef3aed9f --- /dev/null +++ b/caffe2/observers/README.md @@ -0,0 +1,36 @@ +# Observers + +## Usage + +Observers are a small framework that allow users to attach code to the execution of SimpleNets and Operators. + +An example of an Observer is the `TimeObserver`, used as follows: + +### C++ + +``` +unique_ptr<TimeObserver<NetBase>> net_ob = + make_unique<TimeObserver<NetBase>>(net.get()); +auto* ob = net->AddObserver(std::move(net_ob)); +net->Run(); +LOG(INFO) << "av time children: " << ob->average_time_children(); +LOG(INFO) << "av time: " << ob->average_time(); +``` + +### Python + +``` +model.net.AddTimeObserver() +ws.RunNet(model.net) +ob = model.net.GetObserver() + +print("av time children:", ob.average_time_children()) +print("av time:", ob.average_time()) +``` + + +## Implementing An Observer + +To implement an observer you must inherit from `ObserverBase` and implement the `Start` and `Stop` functions. + +Observers are instantiated with a `subject` of a generic type, such as a `Net` or `Operator`. The observer framework is built to be generic enough to "observe" various other types, however. diff --git a/caffe2/observers/time_observer.cc b/caffe2/observers/time_observer.cc new file mode 100644 index 0000000000..630d632f8f --- /dev/null +++ b/caffe2/observers/time_observer.cc @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "time_observer.h" +#include "caffe2/core/logging.h" + +namespace caffe2 { + +template <> +bool TimeObserverBase<NetBase>::Start() { + CAFFE_THROW( + "This function is overridden by TimeObserver<NetBase>.\ + If it was called there is an issue with compilation."); + return false; +} + +template <> +bool TimeObserverBase<NetBase>::Stop() { + double current_run = timer_.MilliSeconds() - start_time_; + total_time_ += current_run; + VLOG(1) << "This net iteration took " << current_run << " ms to complete.\n"; + return true; +} + +template <> +bool TimeObserverBase<OperatorBase>::Start() { + start_time_ = timer_.MilliSeconds(); + ++iterations_; + return true; +} + +template <> +bool TimeObserverBase<OperatorBase>::Stop() { + double current_run = timer_.MilliSeconds() - start_time_; + total_time_ += current_run; + VLOG(1) << "This operator iteration took " << current_run + << " ms to complete.\n"; + return true; +} + +} // namespace caffe2 diff --git a/caffe2/observers/time_observer.h b/caffe2/observers/time_observer.h new file mode 100644 index 0000000000..f3d129a1bc --- /dev/null +++ b/caffe2/observers/time_observer.h @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_ +#define CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_ + +#include <unordered_map> + +#include "caffe2/core/common.h" +#include "caffe2/core/net.h" +#include "caffe2/core/observer.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/timer.h" + +namespace caffe2 { + +template <class T> +class TimeObserverBase : public ObserverBase<T> { + public: + explicit TimeObserverBase<T>(T* subject) : ObserverBase<T>(subject) {} + inline float average_time() const { + return total_time_ / iterations_; + } + ~TimeObserverBase() {} + + bool Start() override; + bool Stop() override; + + protected: + Timer timer_; + float start_time_ = 0.0f; + float total_time_ = 0.0f; + int iterations_ = 0; +}; + +template <class T> +class TimeObserver final : public TimeObserverBase<T> { + public: + explicit TimeObserver<T>(T* subject) : TimeObserverBase<T>(subject) {} +}; + +template <> +class TimeObserver<NetBase> final : public TimeObserverBase<NetBase> { + public: + explicit TimeObserver<NetBase>(NetBase* subject) + : TimeObserverBase<NetBase>(subject) {} + float average_time_children() const { + float sum = 0.0f; + for (const auto* observer : operator_observers_) { + sum += observer->average_time(); + } + return sum / subject_->GetOperators().size(); + } + + bool Start() override { + for (auto* op : subject_->GetOperators()) { + operator_observers_.push_back( + dynamic_cast_if_rtti<TimeObserver<OperatorBase>*>(op->AddObserver( + caffe2::make_unique<TimeObserver<OperatorBase>>(op)))); + } + start_time_ = timer_.MilliSeconds(); + ++iterations_; + return true; + } + + private: + vector<const TimeObserver<OperatorBase>*> operator_observers_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_CONTRIB_OBSERVERS_TIME_OBSERVER_H_ diff --git a/caffe2/observers/time_observer_test.cc b/caffe2/observers/time_observer_test.cc new file mode 100644 index 0000000000..8e55e9d135 --- /dev/null +++ b/caffe2/observers/time_observer_test.cc @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/core/common.h" +#include "caffe2/core/net.h" +#include "caffe2/core/observer.h" +#include "caffe2/core/operator.h" +#include "time_observer.h" + +#include <google/protobuf/text_format.h> +#include <gtest/gtest.h> +#include <chrono> +#include <thread> + +namespace caffe2 { + +namespace { + +class SleepOp final : public OperatorBase { + public: + using OperatorBase::OperatorBase; + bool Run(int /* unused */) override { + StartAllObservers(); + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + StopAllObservers(); + return true; + } +}; + +REGISTER_CPU_OPERATOR(SleepOp, SleepOp); +REGISTER_CUDA_OPERATOR(SleepOp, SleepOp); + +OPERATOR_SCHEMA(SleepOp) + .NumInputs(0, INT_MAX) + .NumOutputs(0, INT_MAX) + .AllowInplace({{0, 0}, {1, 1}}); + +unique_ptr<NetBase> CreateNetTestHelper(Workspace* ws) { + NetDef net_def; + { + auto& op = *(net_def.add_op()); + op.set_type("SleepOp"); + op.add_input("in"); + op.add_output("hidden"); + } + { + auto& op = *(net_def.add_op()); + op.set_type("SleepOp"); + op.add_input("hidden"); + op.add_output("out"); + } + net_def.add_external_input("in"); + net_def.add_external_output("out"); + + return CreateNet(net_def, ws); +} +} + +TEST(TimeObserverTest, Test3Seconds) { + Workspace ws; + ws.CreateBlob("in"); + NetDef net_def; + unique_ptr<NetBase> net(CreateNetTestHelper(&ws)); + unique_ptr<TimeObserver<NetBase>> net_ob = + make_unique<TimeObserver<NetBase>>(net.get()); + auto* ob = dynamic_cast_if_rtti<TimeObserver<NetBase>*>( + net->AddObserver(std::move(net_ob))); + net->Run(); + CAFFE_ENFORCE(ob); + LOG(INFO) << "av time children: " << ob->average_time_children(); + LOG(INFO) << "av time: " << ob->average_time(); + CAFFE_ENFORCE(ob->average_time_children() > 3000); + CAFFE_ENFORCE(ob->average_time_children() < 3500); + CAFFE_ENFORCE(ob->average_time() > 6000); + CAFFE_ENFORCE(ob->average_time() < 6500); +} +} |