summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--aten/src/ATen/core/interned_strings.h2
-rw-r--r--test/test_jit.py101
-rw-r--r--tools/build_variables.py1
-rw-r--r--torch/CMakeLists.txt1
-rw-r--r--torch/csrc/jit/graph_executor.cpp16
-rw-r--r--torch/csrc/jit/interpreter.cpp7
-rw-r--r--torch/csrc/jit/ir.cpp2
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp42
-rw-r--r--torch/csrc/jit/script/init.cpp26
-rw-r--r--torch/csrc/jit/script/logging.cpp73
-rw-r--r--torch/csrc/jit/script/logging.h90
-rw-r--r--torch/jit/_logging.py10
13 files changed, 365 insertions, 7 deletions
diff --git a/.gitignore b/.gitignore
index 66eb2df0e4..b67a756efb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -203,7 +203,6 @@ docs/dev
*.sst
*.ldb
LOCK
-LOG*
CURRENT
MANIFEST-*
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 4a6bf37bcb..25783d19af 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -86,6 +86,8 @@ namespace c10 {
_(prim, CreateObject) \
_(prim, SetAttr) \
_(prim, GetAttr) \
+ _(prim, AddStatValue) \
+ _(prim, TimePoint) \
_(aten, append) \
_(aten, item) \
_(aten, format) \
diff --git a/test/test_jit.py b/test/test_jit.py
index 6229f0ca73..a66f8479ce 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1,6 +1,7 @@
from __future__ import division
import torch
import torch.jit
+import torch.jit._logging
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel as dp
@@ -13874,6 +13875,106 @@ class TestClassType(JitTestCase):
self.assertEqual(y, f2.y)
+class TestLogging(JitTestCase):
+ def test_bump_numeric_counter(self):
+ class ModuleThatLogs(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ for i in range(x.size(0)):
+ x += 1.0
+ torch.jit._logging.add_stat_value('foo', 1)
+
+ if bool(x.sum() > 0.0):
+ torch.jit._logging.add_stat_value('positive', 1)
+ else:
+ torch.jit._logging.add_stat_value('negative', 1)
+ return x
+
+ logger = torch.jit._logging.LockingLogger()
+ old_logger = torch.jit._logging.set_logger(logger)
+ try:
+
+ mtl = ModuleThatLogs()
+ for i in range(5):
+ mtl(torch.rand(3, 4, 5))
+
+ self.assertEqual(logger.get_counter_val('foo'), 15)
+ self.assertEqual(logger.get_counter_val('positive'), 5)
+ finally:
+ torch.jit._logging.set_logger(old_logger)
+
+ def test_trace_numeric_counter(self):
+ def foo(x):
+ torch.jit._logging.add_stat_value('foo', 1)
+ return x + 1.0
+
+ traced = torch.jit.trace(foo, torch.rand(3, 4))
+ logger = torch.jit._logging.LockingLogger()
+ old_logger = torch.jit._logging.set_logger(logger)
+ try:
+ traced(torch.rand(3, 4))
+
+ self.assertEqual(logger.get_counter_val('foo'), 1)
+ finally:
+ torch.jit._logging.set_logger(old_logger)
+
+ def test_time_measurement_counter(self):
+ class ModuleThatTimes(torch.jit.ScriptModule):
+ def forward(self, x):
+ tp_start = torch.jit._logging.time_point()
+ for i in range(30):
+ x += 1.0
+ tp_end = torch.jit._logging.time_point()
+ torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
+ return x
+
+ mtm = ModuleThatTimes()
+ logger = torch.jit._logging.LockingLogger()
+ old_logger = torch.jit._logging.set_logger(logger)
+ try:
+ mtm(torch.rand(3, 4))
+ self.assertGreater(logger.get_counter_val('mytimer'), 0)
+ finally:
+ torch.jit._logging.set_logger(old_logger)
+
+ def test_time_measurement_counter_script(self):
+ class ModuleThatTimes(torch.jit.ScriptModule):
+ @torch.jit.script_method
+ def forward(self, x):
+ tp_start = torch.jit._logging.time_point()
+ for i in range(30):
+ x += 1.0
+ tp_end = torch.jit._logging.time_point()
+ torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
+ return x
+
+ mtm = ModuleThatTimes()
+ logger = torch.jit._logging.LockingLogger()
+ old_logger = torch.jit._logging.set_logger(logger)
+ try:
+ mtm(torch.rand(3, 4))
+ self.assertGreater(logger.get_counter_val('mytimer'), 0)
+ finally:
+ torch.jit._logging.set_logger(old_logger)
+
+ def test_counter_aggregation(self):
+ def foo(x):
+ for i in range(3):
+ torch.jit._logging.add_stat_value('foo', 1)
+ return x + 1.0
+
+ traced = torch.jit.trace(foo, torch.rand(3, 4))
+ logger = torch.jit._logging.LockingLogger()
+ logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
+ old_logger = torch.jit._logging.set_logger(logger)
+ try:
+ traced(torch.rand(3, 4))
+
+ self.assertEqual(logger.get_counter_val('foo'), 1)
+ finally:
+ torch.jit._logging.set_logger(old_logger)
+
+
for test in autograd_method_tests():
add_autograd_test(*test)
diff --git a/tools/build_variables.py b/tools/build_variables.py
index a293a56f8b..5718599d98 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -95,6 +95,7 @@ libtorch_sources = [
"torch/csrc/jit/scope.cpp",
"torch/csrc/jit/script/compiler.cpp",
"torch/csrc/jit/script/edit_distance.cpp",
+ "torch/csrc/jit/script/logging.cpp",
"torch/csrc/jit/script/final_returns.cpp",
"torch/csrc/jit/script/schema_type_parser.cpp",
"torch/csrc/jit/script/script_type_parser.cpp",
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index deff9030b5..9c905ff011 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -185,6 +185,7 @@ set(TORCH_SRCS
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 57768c0556..1e993dac7f 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -32,6 +32,7 @@
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/logging.h>
#include <cstdint>
#include <iterator>
@@ -362,7 +363,10 @@ struct GraphExecutorImpl {
optimize(optimize),
num_inputs(this->graph->inputs().size()),
num_flat_inputs(countFlatInputs(graph)),
- num_outputs(this->graph->outputs().size()) {}
+ num_outputs(this->graph->outputs().size()) {
+ logging::getLogger()->addStatValue(
+ logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
+ }
// entry point where execution begins
void run(Stack& stack) {
@@ -373,6 +377,9 @@ struct GraphExecutorImpl {
" inputs, but got only ",
stack.size());
+ logging::getLogger()->addStatValue(
+ logging::runtime_counters::GRAPH_EXECUTOR_INVOCATIONS, 1.0);
+
if (tracer::isTracing()) {
return runTraced(stack);
}
@@ -441,10 +448,15 @@ struct GraphExecutorImpl {
{
std::lock_guard<std::mutex> lock(compile_mutex);
auto it = plan_cache.find(spec);
- if (it != plan_cache.end())
+ if (it != plan_cache.end()) {
+ logging::getLogger()->addStatValue(
+ logging::runtime_counters::EXECUTION_PLAN_CACHE_HIT, 1.0);
return it->second;
+ }
auto plan = compileSpec(spec);
auto r = plan_cache.emplace(std::move(spec), std::move(plan));
+ logging::getLogger()->addStatValue(
+ logging::runtime_counters::EXECUTION_PLAN_CACHE_MISS, 1.0);
return r.first->second;
}
}
diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp
index 87e69b569f..eda0242be8 100644
--- a/torch/csrc/jit/interpreter.cpp
+++ b/torch/csrc/jit/interpreter.cpp
@@ -1,19 +1,20 @@
#include <torch/csrc/jit/interpreter.h>
+#include <ATen/core/ivalue.h>
+#include <c10/core/thread_pool.h>
+#include <c10/util/Exception.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
-#include <c10/util/Exception.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
-#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
-#include <c10/core/thread_pool.h>
+#include <torch/csrc/jit/script/logging.h>
#include <exception>
#include <iostream>
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 3a8e292f8d..b0ef049d69 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -845,6 +845,8 @@ bool Node::hasSideEffects() const {
case prim::RaiseException:
case prim::SetAttr:
case aten::warn:
+ case prim::AddStatValue:
+ case prim::TimePoint:
return true;
}
return false;
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 7ae656a443..9e22d2412c 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -8,6 +8,7 @@
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
@@ -887,7 +888,46 @@ RegisterOperators reg(
userObj->setSlot(slot, std::move(v));
return 0;
};
- })});
+ })
+ });
+
+RegisterOperators logging_operators({
+ Operator("prim::AddStatValue(str key, int val) -> ()", [](Stack& stack) {
+ auto val = pop(stack).toInt();
+ auto key = pop(stack).toString();
+
+ auto schema = parseSchema("prim::AddStatValue(str key, int val) -> ()");
+ // TODO: remove this custom tracing code once the custom op bugfix lands
+ if (jit::tracer::isTracing()) {
+ const auto& graph = tracer::getTracingState()->graph;
+ Node* node = graph->create(prim::AddStatValue, /*num_outputs=*/0);
+ tracer::recordSourceLocation(node);
+ node->addInput(insertConstant(*graph, key));
+ tracer::addInputs(node, "val", val);
+ graph->insertNode(node);
+ }
+ torch::jit::logging::getLogger()->addStatValue(*key, val);
+ return 0;
+ }),
+ Operator("prim::TimePoint() -> int", [](Stack& stack) {
+ auto schema = parseSchema("prim::TimePoint() -> int");
+ Node* node = nullptr;
+ // TODO: remove this custom tracing code once the custom op bugfix lands
+ if (jit::tracer::isTracing()) {
+ const auto& graph = tracer::getTracingState()->graph;
+ Node* node = graph->create(prim::TimePoint, /*num_outputs=*/0);
+ tracer::recordSourceLocation(node);
+ graph->insertNode(node);
+ }
+ auto output = autograd::profiler::getTime();
+ push(stack, output);
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, output);
+ }
+ return 0;
+ })
+});
+
// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index f4d1c89398..8175f3582d 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -16,7 +16,9 @@
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/logging.h>
#include <torch/csrc/jit/script/parser.h>
+#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/api/include/torch/ordered_dict.h>
@@ -27,6 +29,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
+#include <chrono>
#include <cstddef>
#include <memory>
#include <sstream>
@@ -1101,6 +1104,29 @@ void initJitScriptBindings(PyObject* module) {
.def("run", [](testing::FileCheck& f, const Graph& g) {
return f.run(g);
});
+
+ m.def("_logging_set_logger", [](logging::LoggerBase* logger) {
+ return logging::setLogger(logger);
+ }, py::return_value_policy::reference);
+ py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
+ m, "LoggerBase");
+ py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
+ .value("SUM", logging::LockingLogger::AggregationType::SUM)
+ .value("AVG", logging::LockingLogger::AggregationType::AVG)
+ .export_values();
+ py::class_<
+ logging::LockingLogger,
+ logging::LoggerBase,
+ std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
+ .def(py::init<>())
+ .def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
+ .def("get_counter_val", &logging::LockingLogger::getCounterValue);
+ py::class_<
+ logging::NoopLogger,
+ logging::LoggerBase,
+ std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
+ .def(py::init<>());
+
}
} // namespace script
} // namespace jit
diff --git a/torch/csrc/jit/script/logging.cpp b/torch/csrc/jit/script/logging.cpp
new file mode 100644
index 0000000000..48407cc674
--- /dev/null
+++ b/torch/csrc/jit/script/logging.cpp
@@ -0,0 +1,73 @@
+#include "torch/csrc/jit/script/logging.h"
+
+#include <atomic>
+#include <mutex>
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+namespace logging {
+
+// TODO: multi-scale histogram for this thing
+
+void LockingLogger::addStatValue(const std::string& stat_name, int64_t val) {
+ std::unique_lock<std::mutex> lk(m);
+ auto& raw_counter = raw_counters[stat_name];
+ raw_counter.sum += val;
+ raw_counter.count++;
+}
+
+TORCH_API int64_t LockingLogger::getCounterValue(const std::string& name) const {
+ std::unique_lock<std::mutex> lk(m);
+ if (!raw_counters.count(name)) {
+ return 0;
+ }
+ AggregationType type = agg_types.count(name) ? agg_types.at(name)
+ : AggregationType::SUM;
+ const auto &raw_counter = raw_counters.at(name);
+ switch (type) {
+ case AggregationType::SUM: {
+ return raw_counter.sum;
+ } break;
+ case AggregationType::AVG: {
+ return raw_counter.sum / raw_counter.count;
+ } break;
+ }
+ throw std::runtime_error("Unknown aggregation type!");
+}
+
+void LockingLogger::setAggregationType(
+ const std::string& stat_name,
+ AggregationType type) {
+ agg_types[stat_name] = type;
+}
+
+
+std::atomic<LoggerBase*> global_logger{new NoopLogger()};
+
+LoggerBase* getLogger() {
+ return global_logger.load();
+}
+
+LoggerBase *setLogger(LoggerBase* logger) {
+ LoggerBase *previous = global_logger.load();
+ while (!global_logger.compare_exchange_strong(previous, logger)) {
+ previous = global_logger.load();
+ }
+ return previous;
+}
+
+JITTimePoint timePoint() {
+ return JITTimePoint{std::chrono::high_resolution_clock::now()};
+}
+
+void recordDurationSince(const std::string& name, JITTimePoint tp) {
+ auto end = std::chrono::high_resolution_clock::now();
+ // Measurement in microseconds.
+ auto seconds = std::chrono::duration<double>(end - tp.point).count() * 1e9;
+ logging::getLogger()->addStatValue(name, seconds);
+}
+
+} // namespace logging
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/script/logging.h b/torch/csrc/jit/script/logging.h
new file mode 100644
index 0000000000..60d1bc3db6
--- /dev/null
+++ b/torch/csrc/jit/script/logging.h
@@ -0,0 +1,90 @@
+#pragma once
+
+#include <memory>
+#include <mutex>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+
+namespace torch {
+namespace jit {
+namespace logging {
+
+class LoggerBase {
+ public:
+ TORCH_API virtual void addStatValue(
+ const std::string& stat_name,
+ int64_t val) = 0;
+ virtual ~LoggerBase() {}
+};
+
+TORCH_API LoggerBase* getLogger();
+TORCH_API LoggerBase* setLogger(LoggerBase* logger);
+
+// No-op logger. This is the default and is meant to incur almost no runtime
+// overhead.
+
+class NoopLogger : public LoggerBase {
+ public:
+ void addStatValue(const std::string& stat_name, int64_t val) override {}
+ ~NoopLogger() {}
+};
+
+// Trivial locking logger. Pass in an instance of this to setLogger() to use it.
+// This keeps track of the sum of all statistics.
+//
+// NOTE: this is not written in a scalable way and should probably only be used
+// in the single-threaded case or for testing.
+class LockingLogger : public LoggerBase {
+ public:
+ TORCH_API void addStatValue(const std::string& stat_name, int64_t val) override;
+ TORCH_API virtual int64_t getCounterValue(const std::string& name) const;
+ enum class AggregationType { SUM, AVG };
+ TORCH_API void setAggregationType(
+ const std::string& stat_name,
+ AggregationType type);
+ ~LockingLogger() {}
+
+ private:
+ mutable std::mutex m;
+ struct RawCounter {
+ RawCounter() : sum(0), count(0) {}
+ int64_t sum;
+ size_t count;
+ };
+ std::unordered_map<std::string, RawCounter> raw_counters;
+ std::unordered_map<std::string, AggregationType> agg_types;
+};
+
+// Make this struct so the timer internals are opaque to the user.
+struct JITTimePoint {
+ std::chrono::time_point<std::chrono::high_resolution_clock> point;
+};
+
+TORCH_API JITTimePoint timePoint();
+TORCH_API void recordDurationSince(const std::string& name, JITTimePoint tp);
+
+namespace runtime_counters {
+constexpr const char* GRAPH_EXECUTORS_CONSTRUCTED =
+ "pytorch_runtime.graph_executors_constructed";
+constexpr const char* GRAPH_EXECUTOR_INVOCATIONS =
+ "pytorch_runtime.graph_executor_invocations";
+constexpr const char* EXECUTION_PLAN_CACHE_HIT =
+ "pytorch_runtime.execution_plan_cache_hit";
+constexpr const char* EXECUTION_PLAN_CACHE_MISS =
+ "pytorch_runtime.execution_plan_cache_miss";
+
+inline std::vector<const char*> allRuntimeCounters() {
+ return {GRAPH_EXECUTORS_CONSTRUCTED,
+ GRAPH_EXECUTOR_INVOCATIONS,
+ EXECUTION_PLAN_CACHE_HIT,
+ EXECUTION_PLAN_CACHE_MISS};
+}
+
+} // namespace runtime_counters
+
+} // namespace logging
+} // namespace jit
+} // namespace torch
diff --git a/torch/jit/_logging.py b/torch/jit/_logging.py
new file mode 100644
index 0000000000..497c34293d
--- /dev/null
+++ b/torch/jit/_logging.py
@@ -0,0 +1,10 @@
+import torch
+
+add_stat_value = torch.ops.prim.AddStatValue
+
+set_logger = torch._C._logging_set_logger
+LockingLogger = torch._C.LockingLogger
+AggregationType = torch._C.AggregationType
+NoopLogger = torch._C.NoopLogger
+
+time_point = torch.ops.prim.TimePoint