diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | aten/src/ATen/core/interned_strings.h | 2 | ||||
-rw-r--r-- | test/test_jit.py | 101 | ||||
-rw-r--r-- | tools/build_variables.py | 1 | ||||
-rw-r--r-- | torch/CMakeLists.txt | 1 | ||||
-rw-r--r-- | torch/csrc/jit/graph_executor.cpp | 16 | ||||
-rw-r--r-- | torch/csrc/jit/interpreter.cpp | 7 | ||||
-rw-r--r-- | torch/csrc/jit/ir.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/register_prim_ops.cpp | 42 | ||||
-rw-r--r-- | torch/csrc/jit/script/init.cpp | 26 | ||||
-rw-r--r-- | torch/csrc/jit/script/logging.cpp | 73 | ||||
-rw-r--r-- | torch/csrc/jit/script/logging.h | 90 | ||||
-rw-r--r-- | torch/jit/_logging.py | 10 |
13 files changed, 365 insertions, 7 deletions
diff --git a/.gitignore b/.gitignore index 66eb2df0e4..b67a756efb 100644 --- a/.gitignore +++ b/.gitignore @@ -203,7 +203,6 @@ docs/dev *.sst *.ldb LOCK -LOG* CURRENT MANIFEST-* diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 4a6bf37bcb..25783d19af 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -86,6 +86,8 @@ namespace c10 { _(prim, CreateObject) \ _(prim, SetAttr) \ _(prim, GetAttr) \ + _(prim, AddStatValue) \ + _(prim, TimePoint) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ diff --git a/test/test_jit.py b/test/test_jit.py index 6229f0ca73..a66f8479ce 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1,6 +1,7 @@ from __future__ import division import torch import torch.jit +import torch.jit._logging import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel as dp @@ -13874,6 +13875,106 @@ class TestClassType(JitTestCase): self.assertEqual(y, f2.y) +class TestLogging(JitTestCase): + def test_bump_numeric_counter(self): + class ModuleThatLogs(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + for i in range(x.size(0)): + x += 1.0 + torch.jit._logging.add_stat_value('foo', 1) + + if bool(x.sum() > 0.0): + torch.jit._logging.add_stat_value('positive', 1) + else: + torch.jit._logging.add_stat_value('negative', 1) + return x + + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + + mtl = ModuleThatLogs() + for i in range(5): + mtl(torch.rand(3, 4, 5)) + + self.assertEqual(logger.get_counter_val('foo'), 15) + self.assertEqual(logger.get_counter_val('positive'), 5) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_trace_numeric_counter(self): + def foo(x): + torch.jit._logging.add_stat_value('foo', 1) + return x + 1.0 + + traced = torch.jit.trace(foo, torch.rand(3, 4)) + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + traced(torch.rand(3, 4)) + + self.assertEqual(logger.get_counter_val('foo'), 1) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_time_measurement_counter(self): + class ModuleThatTimes(torch.jit.ScriptModule): + def forward(self, x): + tp_start = torch.jit._logging.time_point() + for i in range(30): + x += 1.0 + tp_end = torch.jit._logging.time_point() + torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) + return x + + mtm = ModuleThatTimes() + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + mtm(torch.rand(3, 4)) + self.assertGreater(logger.get_counter_val('mytimer'), 0) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_time_measurement_counter_script(self): + class ModuleThatTimes(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + tp_start = torch.jit._logging.time_point() + for i in range(30): + x += 1.0 + tp_end = torch.jit._logging.time_point() + torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) + return x + + mtm = ModuleThatTimes() + logger = torch.jit._logging.LockingLogger() + old_logger = torch.jit._logging.set_logger(logger) + try: + mtm(torch.rand(3, 4)) + self.assertGreater(logger.get_counter_val('mytimer'), 0) + finally: + torch.jit._logging.set_logger(old_logger) + + def test_counter_aggregation(self): + def foo(x): + for i in range(3): + torch.jit._logging.add_stat_value('foo', 1) + return x + 1.0 + + traced = torch.jit.trace(foo, torch.rand(3, 4)) + logger = torch.jit._logging.LockingLogger() + logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG) + old_logger = torch.jit._logging.set_logger(logger) + try: + traced(torch.rand(3, 4)) + + self.assertEqual(logger.get_counter_val('foo'), 1) + finally: + torch.jit._logging.set_logger(old_logger) + + for test in autograd_method_tests(): add_autograd_test(*test) diff --git a/tools/build_variables.py b/tools/build_variables.py index a293a56f8b..5718599d98 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -95,6 +95,7 @@ libtorch_sources = [ "torch/csrc/jit/scope.cpp", "torch/csrc/jit/script/compiler.cpp", "torch/csrc/jit/script/edit_distance.cpp", + "torch/csrc/jit/script/logging.cpp", "torch/csrc/jit/script/final_returns.cpp", "torch/csrc/jit/script/schema_type_parser.cpp", "torch/csrc/jit/script/script_type_parser.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index deff9030b5..9c905ff011 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -185,6 +185,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp + ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 57768c0556..1e993dac7f 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -32,6 +32,7 @@ #include <torch/csrc/autograd/edge.h> #include <torch/csrc/autograd/function.h> #include <torch/csrc/jit/script/compiler.h> +#include <torch/csrc/jit/script/logging.h> #include <cstdint> #include <iterator> @@ -362,7 +363,10 @@ struct GraphExecutorImpl { optimize(optimize), num_inputs(this->graph->inputs().size()), num_flat_inputs(countFlatInputs(graph)), - num_outputs(this->graph->outputs().size()) {} + num_outputs(this->graph->outputs().size()) { + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); + } // entry point where execution begins void run(Stack& stack) { @@ -373,6 +377,9 @@ struct GraphExecutorImpl { " inputs, but got only ", stack.size()); + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0); + if (tracer::isTracing()) { return runTraced(stack); } @@ -441,10 +448,15 @@ struct GraphExecutorImpl { { std::lock_guard<std::mutex> lock(compile_mutex); auto it = plan_cache.find(spec); - if (it != plan_cache.end()) + if (it != plan_cache.end()) { + logging::getLogger()->addStatValue( + logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0); return it->second; + } auto plan = compileSpec(spec); auto r = plan_cache.emplace(std::move(spec), std::move(plan)); + logging::getLogger()->addStatValue( + logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0); return r.first->second; } } diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 87e69b569f..eda0242be8 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -1,19 +1,20 @@ #include <torch/csrc/jit/interpreter.h> +#include <ATen/core/ivalue.h> +#include <c10/core/thread_pool.h> +#include <c10/util/Exception.h> #include <torch/csrc/autograd/edge.h> #include <torch/csrc/autograd/function.h> #include <torch/csrc/autograd/generated/variable_factories.h> #include <torch/csrc/autograd/grad_mode.h> #include <torch/csrc/autograd/profiler.h> #include <torch/csrc/autograd/variable.h> -#include <c10/util/Exception.h> #include <torch/csrc/jit/constants.h> #include <torch/csrc/jit/graph_executor.h> #include <torch/csrc/jit/ir.h> -#include <ATen/core/ivalue.h> #include <torch/csrc/jit/operator.h> #include <torch/csrc/jit/script/jit_exception.h> -#include <c10/core/thread_pool.h> +#include <torch/csrc/jit/script/logging.h> #include <exception> #include <iostream> diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 3a8e292f8d..b0ef049d69 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -845,6 +845,8 @@ bool Node::hasSideEffects() const { case prim::RaiseException: case prim::SetAttr: case aten::warn: + case prim::AddStatValue: + case prim::TimePoint: return true; } return false; diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 7ae656a443..9e22d2412c 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -8,6 +8,7 @@ #include <torch/csrc/jit/fuser/interface.h> #include <torch/csrc/jit/graph_executor.h> #include <torch/csrc/jit/ir.h> +#include <torch/csrc/jit/script/logging.h> #include <torch/csrc/jit/operator.h> #include <torch/csrc/jit/script/jit_exception.h> @@ -887,7 +888,46 @@ RegisterOperators reg( userObj->setSlot(slot, std::move(v)); return 0; }; - })}); + }) + }); + +RegisterOperators logging_operators({ + Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) { + auto val = pop(stack).toInt(); + auto key = pop(stack).toString(); + + auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()"); + // TODO: remove this custom tracing code once the custom op bugfix lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + node->addInput(insertConstant(*graph, key)); + tracer::addInputs(node, "val", val); + graph->insertNode(node); + } + torch::jit::logging::getLogger()->addStatValue(*key, val); + return 0; + }), + Operator("prim::TimePoint() -> int", [](Stack& stack) { + auto schema = parseSchema("prim::TimePoint() -> int"); + Node* node = nullptr; + // TODO: remove this custom tracing code once the custom op bugfix lands + if (jit::tracer::isTracing()) { + const auto& graph = tracer::getTracingState()->graph; + Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0); + tracer::recordSourceLocation(node); + graph->insertNode(node); + } + auto output = autograd::profiler::getTime(); + push(stack, output); + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, output); + } + return 0; + }) +}); + // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index f4d1c89398..8175f3582d 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -16,7 +16,9 @@ #include <torch/csrc/jit/passes/python_print.h> #include <torch/csrc/jit/pybind_utils.h> #include <torch/csrc/jit/python_tracer.h> +#include <torch/csrc/jit/script/logging.h> #include <torch/csrc/jit/script/parser.h> +#include <torch/csrc/jit/tracer.h> #include <torch/csrc/api/include/torch/ordered_dict.h> @@ -27,6 +29,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> #include <pybind11/stl_bind.h> +#include <chrono> #include <cstddef> #include <memory> #include <sstream> @@ -1101,6 +1104,29 @@ void initJitScriptBindings(PyObject* module) { .def("run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }); + + m.def("_logging_set_logger", [](logging::LoggerBase* logger) { + return logging::setLogger(logger); + }, py::return_value_policy::reference); + py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>( + m, "LoggerBase"); + py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType") + .value("SUM", logging::LockingLogger::AggregationType::SUM) + .value("AVG", logging::LockingLogger::AggregationType::AVG) + .export_values(); + py::class_< + logging::LockingLogger, + logging::LoggerBase, + std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger") + .def(py::init<>()) + .def("set_aggregation_type", &logging::LockingLogger::setAggregationType) + .def("get_counter_val", &logging::LockingLogger::getCounterValue); + py::class_< + logging::NoopLogger, + logging::LoggerBase, + std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger") + .def(py::init<>()); + } } // namespace script } // namespace jit diff --git a/torch/csrc/jit/script/logging.cpp b/torch/csrc/jit/script/logging.cpp new file mode 100644 index 0000000000..48407cc674 --- /dev/null +++ b/torch/csrc/jit/script/logging.cpp @@ -0,0 +1,73 @@ +#include "torch/csrc/jit/script/logging.h" + +#include <atomic> +#include <mutex> +#include <unordered_map> + +namespace torch { +namespace jit { +namespace logging { + +// TODO: multi-scale histogram for this thing + +void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) { + std::unique_lock<std::mutex> lk(m); + auto& raw_counter = raw_counters[stat_name]; + raw_counter.sum += val; + raw_counter.count++; +} + +TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const { + std::unique_lock<std::mutex> lk(m); + if (!raw_counters.count(name)) { + return 0; + } + AggregationType type = agg_types.count(name) ? agg_types.at(name) + : AggregationType::SUM; + const auto &raw_counter = raw_counters.at(name); + switch (type) { + case AggregationType::SUM: { + return raw_counter.sum; + } break; + case AggregationType::AVG: { + return raw_counter.sum / raw_counter.count; + } break; + } + throw std::runtime_error("Unknown aggregation type!"); +} + +void LockingLogger::setAggregationType( + const std::string& stat_name, + AggregationType type) { + agg_types[stat_name] = type; +} + + +std::atomic<LoggerBase*> global_logger{new NoopLogger()}; + +LoggerBase* getLogger() { + return global_logger.load(); +} + +LoggerBase *setLogger(LoggerBase* logger) { + LoggerBase *previous = global_logger.load(); + while (!global_logger.compare_exchange_strong(previous, logger)) { + previous = global_logger.load(); + } + return previous; +} + +JITTimePoint timePoint() { + return JITTimePoint{std::chrono::high_resolution_clock::now()}; +} + +void recordDurationSince(const std::string& name, JITTimePoint tp) { + auto end = std::chrono::high_resolution_clock::now(); + // Measurement in microseconds. + auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9; + logging::getLogger()->addStatValue(name, seconds); +} + +} // namespace logging +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/script/logging.h b/torch/csrc/jit/script/logging.h new file mode 100644 index 0000000000..60d1bc3db6 --- /dev/null +++ b/torch/csrc/jit/script/logging.h @@ -0,0 +1,90 @@ +#pragma once + +#include <memory> +#include <mutex> +#include <string> +#include <unordered_map> +#include <vector> + +#include <torch/csrc/WindowsTorchApiMacro.h> + +namespace torch { +namespace jit { +namespace logging { + +class LoggerBase { + public: + TORCH_API virtual void addStatValue( + const std::string& stat_name, + int64_t val) = 0; + virtual ~LoggerBase() {} +}; + +TORCH_API LoggerBase* getLogger(); +TORCH_API LoggerBase* setLogger(LoggerBase* logger); + +// No-op logger. This is the default and is meant to incur almost no runtime +// overhead. + +class NoopLogger : public LoggerBase { + public: + void addStatValue(const std::string& stat_name, int64_t val) override {} + ~NoopLogger() {} +}; + +// Trivial locking logger. Pass in an instance of this to setLogger() to use it. +// This keeps track of the sum of all statistics. +// +// NOTE: this is not written in a scalable way and should probably only be used +// in the single-threaded case or for testing. +class LockingLogger : public LoggerBase { + public: + TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override; + TORCH_API virtual int64_t getCounterValue(const std::string& name) const; + enum class AggregationType { SUM, AVG }; + TORCH_API void setAggregationType( + const std::string& stat_name, + AggregationType type); + ~LockingLogger() {} + + private: + mutable std::mutex m; + struct RawCounter { + RawCounter() : sum(0), count(0) {} + int64_t sum; + size_t count; + }; + std::unordered_map<std::string, RawCounter> raw_counters; + std::unordered_map<std::string, AggregationType> agg_types; +}; + +// Make this struct so the timer internals are opaque to the user. +struct JITTimePoint { + std::chrono::time_point<std::chrono::high_resolution_clock> point; +}; + +TORCH_API JITTimePoint timePoint(); +TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp); + +namespace runtime_counters { +constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED = + "pytorch_runtime.graph_executors_constructed"; +constexpr const char* GRAPH_EXECUTOR_INVOCATIONS = + "pytorch_runtime.graph_executor_invocations"; +constexpr const char* EXECUTION_PLAN_CACHE_HIT = + "pytorch_runtime.execution_plan_cache_hit"; +constexpr const char* EXECUTION_PLAN_CACHE_MISS = + "pytorch_runtime.execution_plan_cache_miss"; + +inline std::vector<const char*> allRuntimeCounters() { + return {GRAPH_EXECUTORS_CONSTRUCTED, + GRAPH_EXECUTOR_INVOCATIONS, + EXECUTION_PLAN_CACHE_HIT, + EXECUTION_PLAN_CACHE_MISS}; +} + +} // namespace runtime_counters + +} // namespace logging +} // namespace jit +} // namespace torch diff --git a/torch/jit/_logging.py b/torch/jit/_logging.py new file mode 100644 index 0000000000..497c34293d --- /dev/null +++ b/torch/jit/_logging.py @@ -0,0 +1,10 @@ +import torch + +add_stat_value = torch.ops.prim.AddStatValue + +set_logger = torch._C._logging_set_logger +LockingLogger = torch._C.LockingLogger +AggregationType = torch._C.AggregationType +NoopLogger = torch._C.NoopLogger + +time_point = torch.ops.prim.TimePoint |