summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src/util
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/src/util')
-rw-r--r--runtime/onert/core/src/util/ChromeTracingEventWriter.cc195
-rw-r--r--runtime/onert/core/src/util/ConfigSource.cc36
-rw-r--r--runtime/onert/core/src/util/EventCollector.cc77
-rw-r--r--runtime/onert/core/src/util/EventCollector.h70
-rw-r--r--runtime/onert/core/src/util/EventCollectorGlobal.cc93
-rw-r--r--runtime/onert/core/src/util/EventCollectorGlobal.h155
-rw-r--r--runtime/onert/core/src/util/EventRecorder.cc532
-rw-r--r--runtime/onert/core/src/util/EventRecorder.h63
-rw-r--r--runtime/onert/core/src/util/EventWriter.cc49
-rw-r--r--runtime/onert/core/src/util/EventWriter.h144
-rw-r--r--runtime/onert/core/src/util/Index.test.cc (renamed from runtime/onert/core/src/util/GeneralConfigSource.cc)39
-rw-r--r--runtime/onert/core/src/util/MDTableEventWriter.cc365
-rw-r--r--runtime/onert/core/src/util/ObjectManager.test.cc211
-rw-r--r--runtime/onert/core/src/util/SNPEEventWriter.cc186
-rw-r--r--runtime/onert/core/src/util/ShapeInference.cc243
-rw-r--r--runtime/onert/core/src/util/ShapeInference.test.cc544
-rw-r--r--runtime/onert/core/src/util/TracingCtx.cc (renamed from runtime/onert/core/src/util/EnvConfigSource.cc)22
17 files changed, 2107 insertions, 917 deletions
diff --git a/runtime/onert/core/src/util/ChromeTracingEventWriter.cc b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
new file mode 100644
index 000000000..c3f5179df
--- /dev/null
+++ b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+// json type for ChromeTracingWriter
+namespace
+{
+
+std::string quote(const std::string &value)
+{
+ std::stringstream ss;
+ ss << '"' << value << '"';
+ return ss.str();
+}
+
+std::string field(const std::string &k, const std::string &v)
+{
+ std::stringstream ss;
+ ss << quote(k) << " : " << quote(v);
+ return ss.str();
+}
+
+struct Content // One Entry in Chrome Event Trace
+{
+ std::vector<std::pair<std::string, std::string>> flds;
+ std::vector<std::pair<std::string, std::string>> args;
+};
+
+std::string object(const Content &content)
+{
+ std::stringstream ss;
+
+ ss << "{ ";
+
+ ss << field(content.flds[0].first, content.flds[0].second);
+
+ for (uint32_t n = 1; n < content.flds.size(); ++n)
+ {
+ ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second);
+ }
+
+ if (content.args.size() > 0)
+ {
+ ss << ", " << quote("args") << " : { ";
+ ss << field(content.args.at(0).first, content.args.at(0).second);
+
+ for (uint32_t n = 1; n < content.args.size(); ++n)
+ {
+ ss << ", " << field(content.args.at(n).first, content.args.at(n).second);
+ }
+
+ ss << "}";
+ }
+
+ ss << " }";
+
+ return ss.str();
+}
+
+void fill(Content &content, const DurationEvent &evt, const std::string &name,
+ const std::string &tid)
+{
+ content.flds.emplace_back("name", name);
+ content.flds.emplace_back("pid", "0");
+ content.flds.emplace_back("tid", tid);
+ content.flds.emplace_back("ph", evt.ph);
+ content.flds.emplace_back("ts", evt.ts);
+ content.args = evt.args;
+}
+
+void fill(Content &content, const CounterEvent &evt)
+{
+ assert(evt.name != "");
+
+ content.flds.emplace_back("name", evt.name);
+ content.flds.emplace_back("pid", "0");
+ content.flds.emplace_back("tid", evt.tid);
+ content.flds.emplace_back("ph", evt.ph);
+ content.flds.emplace_back("ts", evt.ts);
+ content.args = evt.args;
+}
+
+std::string object(const DurationEvent &evt, const std::string &name, const std::string &tid)
+{
+ Content content;
+
+ fill(content, evt, name, tid);
+
+ return ::object(content);
+}
+
+std::string object(const CounterEvent &evt)
+{
+ Content content;
+
+ fill(content, evt);
+
+ for (auto it = evt.values.begin(); it != evt.values.end(); ++it)
+ {
+ content.args.emplace_back(it->first, it->second);
+ }
+
+ return ::object(content);
+}
+
+std::string getSessionLabel(const DurationEvent &evt)
+{
+ return "$" + std::to_string(evt.session_index) + " sess";
+}
+
+std::string getSubgLabel(const DurationEvent &evt)
+{
+ return "$" + std::to_string(evt.subg_index) + " subg";
+}
+
+std::string getOpLabel(const OpSeqDurationEvent &evt)
+{
+ return "@" + std::to_string(evt.op_index) + " " + evt.op_name;
+}
+
+std::string getLabel(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ return getOpLabel(*evt_ptr);
+ }
+ else // SubgDurationEvent
+ {
+ return getSubgLabel(evt);
+ }
+}
+
+std::string getTid(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ return getSessionLabel(*evt_ptr) + ", " + getSubgLabel(*evt_ptr) + ", " + evt_ptr->backend;
+ }
+ else // SubgDurationEvent
+ {
+ return getSessionLabel(evt) + ", " + getSubgLabel(evt);
+ }
+}
+
+} // namespace
+
+void ChromeTracingWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders)
+{
+ _os << "{\n";
+ _os << " " << quote("traceEvents") << ": [\n";
+
+ for (const auto &recorder : recorders)
+ {
+ flushOneRecord(*recorder);
+ }
+
+ _os << " { }\n";
+ _os << " ]\n";
+ _os << "}\n";
+}
+
+void ChromeTracingWriter::flushOneRecord(const EventRecorder &recorder)
+{
+ for (const auto &evt : recorder.duration_events())
+ {
+ const std::string name = getLabel(*evt);
+ const std::string tid = getTid(*evt);
+
+ _os << " " << object(*evt, name, tid) << ",\n";
+ }
+
+ for (const auto &evt : recorder.counter_events())
+ {
+ _os << " " << object(evt) << ",\n";
+ }
+}
diff --git a/runtime/onert/core/src/util/ConfigSource.cc b/runtime/onert/core/src/util/ConfigSource.cc
index 45cce662e..b7fcefc7a 100644
--- a/runtime/onert/core/src/util/ConfigSource.cc
+++ b/runtime/onert/core/src/util/ConfigSource.cc
@@ -15,13 +15,15 @@
*/
#include "util/ConfigSource.h"
-#include "util/GeneralConfigSource.h"
-#include "util/EnvConfigSource.h"
+#include "util/logging.h"
+
+#include <misc/EnvConfigSource.h>
+#include <misc/GeneralConfigSource.h>
+#include <misc/IConfigSource.h>
-#include <array>
#include <algorithm>
+#include <array>
#include <cassert>
-
#include <memory>
namespace onert
@@ -29,9 +31,26 @@ namespace onert
namespace util
{
+using namespace nnfw::misc;
+
static std::unique_ptr<IConfigSource> _source;
+static std::unique_ptr<IConfigSource> _source_ext;
void config_source(std::unique_ptr<IConfigSource> &&source) { _source = std::move(source); }
+void config_source_ext(std::unique_ptr<IConfigSource> &&source) { _source_ext = std::move(source); }
+
+void setConfigKeyValues(const CfgKeyValues &keyValues)
+{
+ auto configsrc = std::make_unique<GeneralConfigSource>();
+
+ for (auto it = keyValues.begin(); it != keyValues.end(); ++it)
+ {
+ VERBOSE(NNPKG_CONFIGS) << "(" << it->first << ") = (" << it->second << ")" << std::endl;
+ configsrc->set(it->first, it->second);
+ }
+
+ onert::util::config_source_ext(std::move(configsrc));
+}
static IConfigSource *config_source()
{
@@ -67,6 +86,15 @@ static std::string getConfigOrDefault(const std::string &key)
auto ret = config_source()->get(key);
if (ret.empty())
{
+ // if env is not set, search from external
+ if (_source_ext.get())
+ {
+ ret = _source_ext.get()->get(key);
+ }
+ }
+ // if not found search from defaults
+ if (ret.empty())
+ {
auto itr = defaults.find(key);
if (itr != defaults.end())
{
diff --git a/runtime/onert/core/src/util/EventCollector.cc b/runtime/onert/core/src/util/EventCollector.cc
index de37276bf..c1b9c4315 100644
--- a/runtime/onert/core/src/util/EventCollector.cc
+++ b/runtime/onert/core/src/util/EventCollector.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "util/EventCollector.h"
+#include "EventCollector.h"
// C++ standard libraries
#include <chrono>
@@ -30,24 +30,62 @@ std::string timestamp(void)
{
auto now = std::chrono::steady_clock::now();
return std::to_string(
- std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count());
+ std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count());
}
-class DurationEventBuilder
+class DurationEventBuilder : public EventCollector::EventVisitor
{
public:
DurationEventBuilder(const std::string &ts) : _ts{ts} {}
- DurationEvent build(const std::string &tid, const std::string &name, const std::string &ph) const
+ std::unique_ptr<SubgDurationEvent> build(const EventCollector::SubgEvent &evt_collected,
+ const std::string &ph) const
{
- DurationEvent evt;
+ auto dur_evt = std::make_unique<SubgDurationEvent>();
- evt.name = name;
- evt.tid = tid;
- evt.ph = ph;
- evt.ts = _ts;
+ // The following will be set by a child of EventsWriter:
+ // dur_evt.name, dur_evt.tid
+ dur_evt->ph = ph;
+ dur_evt->ts = _ts;
+ dur_evt->tracing_ctx = evt_collected.tracing_ctx;
- return evt;
+ dur_evt->session_index = evt_collected.session_index;
+ dur_evt->subg_index = evt_collected.subg_index;
+
+ dur_evt->args = evt_collected.userData;
+ {
+ dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index));
+ dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index));
+ }
+
+ return dur_evt;
+ }
+
+ std::unique_ptr<OpSeqDurationEvent> build(const EventCollector::OpSeqEvent &evt_collected,
+ const std::string &ph) const
+ {
+ auto dur_evt = std::make_unique<OpSeqDurationEvent>();
+
+ // The following will be set by a child of EventsWriter:
+ // dur_evt.name, dur_evt.tid
+ dur_evt->ph = ph;
+ dur_evt->ts = _ts;
+ dur_evt->tracing_ctx = evt_collected.tracing_ctx;
+
+ dur_evt->session_index = evt_collected.session_index;
+ dur_evt->subg_index = evt_collected.subg_index;
+
+ dur_evt->backend = evt_collected.backend;
+ dur_evt->op_index = evt_collected.op_index;
+ dur_evt->op_name = evt_collected.op_name;
+
+ dur_evt->args = evt_collected.userData;
+ {
+ dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index));
+ dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index));
+ }
+
+ return dur_evt;
}
private:
@@ -86,19 +124,26 @@ inline void emit_rusage(EventRecorder *rec, const std::string &ts)
} // namespace
-void EventCollector::onEvent(const Event &event)
+template <typename EventT> void EventCollector::onEvent(const EventT &event)
{
auto ts = timestamp();
+ DurationEventBuilder builder(ts);
+
switch (event.edge)
{
case Edge::BEGIN:
- _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "B"));
+ {
+ auto duration_evt = builder.build(event, "B");
+ _rec->emit(std::move(duration_evt));
break;
-
+ }
case Edge::END:
- _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "E"));
+ {
+ auto duration_evt = builder.build(event, "E");
+ _rec->emit(std::move(duration_evt));
break;
+ }
}
// TODO: Add resurece measurement(e.g. RSS)
@@ -107,3 +152,7 @@ void EventCollector::onEvent(const Event &event)
emit_rusage(_rec, ts);
#endif
}
+
+// template instantiation
+template void EventCollector::onEvent<EventCollector::SubgEvent>(const SubgEvent &event);
+template void EventCollector::onEvent<EventCollector::OpSeqEvent>(const OpSeqEvent &event);
diff --git a/runtime/onert/core/src/util/EventCollector.h b/runtime/onert/core/src/util/EventCollector.h
index 8154be592..effb72373 100644
--- a/runtime/onert/core/src/util/EventCollector.h
+++ b/runtime/onert/core/src/util/EventCollector.h
@@ -17,7 +17,13 @@
#ifndef __ONERT_UTIL_EVENT_COLLECTOR_H__
#define __ONERT_UTIL_EVENT_COLLECTOR_H__
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
+
+#include "util/TracingCtx.h"
+
+#include <string>
+#include <utility>
+#include <vector>
class EventCollector
{
@@ -28,11 +34,69 @@ public:
END
};
+ struct SubgEvent;
+ struct OpEvent;
+
+ class EventVisitor
+ {
+ public:
+ virtual ~EventVisitor() = default;
+
+ virtual std::unique_ptr<DurationEvent> visit(const SubgEvent &, const std::string &) const
+ {
+ throw std::runtime_error("Please implement");
+ }
+ virtual std::unique_ptr<DurationEvent> visit(const OpEvent &, const std::string &) const
+ {
+ throw std::runtime_error("Please implement");
+ }
+ };
+
struct Event
{
+ const onert::util::TracingCtx *tracing_ctx;
+
Edge edge;
+ uint32_t session_index;
+ uint32_t subg_index;
+
+ // user-defined data: pairs of (key, value)
+ std::vector<std::pair<std::string, std::string>> userData;
+
+ protected:
+ Event(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index)
+ : tracing_ctx(a_tracing_ctx), edge(a_edge), session_index(tracing_ctx->getSessionId()),
+ subg_index(a_subg_index)
+ { /* empty */
+ }
+
+ virtual ~Event() = default;
+ };
+
+ struct SubgEvent : public Event
+ {
+ // constructor for subgraph start and end event
+ SubgEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index)
+ : Event(a_tracing_ctx, a_edge, a_subg_index)
+ { /* empty */
+ }
+ };
+
+ // TODO Rename this to OperationEvent
+ struct OpSeqEvent : public Event
+ {
std::string backend;
- std::string label;
+ uint32_t op_index;
+ std::string op_name;
+
+ OpSeqEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index,
+ const std::string a_backend, uint32_t a_op_index, const std::string a_op_name)
+ : Event(a_tracing_ctx, a_edge, a_subg_index)
+ {
+ backend.assign(a_backend);
+ op_index = a_op_index;
+ op_name.assign(a_op_name);
+ }
};
public:
@@ -42,7 +106,7 @@ public:
}
public:
- void onEvent(const Event &event);
+ template <typename EventT> void onEvent(const EventT &event);
protected:
EventRecorder *_rec;
diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.cc b/runtime/onert/core/src/util/EventCollectorGlobal.cc
deleted file mode 100644
index d09b95210..000000000
--- a/runtime/onert/core/src/util/EventCollectorGlobal.cc
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/EventCollectorGlobal.h"
-
-#include <cassert>
-#include <fstream>
-#include <iostream>
-
-#include "util/ConfigSource.h"
-
-namespace onert
-{
-namespace util
-{
-
-EventCollectorGlobal::EventCollectorGlobal() : _recorder{}, _collector{&_recorder}
-{
- // DO NOTHING
-}
-
-EventCollectorGlobal::~EventCollectorGlobal()
-{
- if (!_recorder.empty())
- {
- try
- {
- // TODO Need better way for saved file path than the hardcoded path
- std::ofstream ofs{"trace.global.json"};
- _recorder.writeToFile(ofs);
- }
- catch (const std::exception &e)
- {
- std::cerr << "E: Fail to record event in EventCollectorGlobal: " << e.what() << std::endl;
- }
- }
-}
-
-EventCollectorGlobal &EventCollectorGlobal::get()
-{
- static EventCollectorGlobal instance;
- return instance;
-}
-
-EventDurationBlock::EventDurationBlock(const std::string &tag) : _tag{tag}
-{
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag});
-}
-EventDurationBlock::~EventDurationBlock()
-{
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag});
-}
-
-EventDurationManual::EventDurationManual(const std::string &tag) : _tag{tag}, _pair{true} {}
-
-EventDurationManual::~EventDurationManual()
-{
- // Check if it has called begin-end pair
- assert(_pair);
-}
-
-void EventDurationManual::begin()
-{
- _pair = false;
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag});
-}
-
-void EventDurationManual::end()
-{
- assert(!_pair);
- _pair = true;
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag});
-}
-
-} // namespace util
-} // namespace onert
diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.h b/runtime/onert/core/src/util/EventCollectorGlobal.h
deleted file mode 100644
index 1027ec84d..000000000
--- a/runtime/onert/core/src/util/EventCollectorGlobal.h
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
-#define __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
-
-#include "util/EventRecorder.h"
-#include "util/EventCollector.h"
-
-namespace onert
-{
-namespace util
-{
-
-/**
- * @brief Singleton class for event collection from anywhere in code
- *
- */
-class EventCollectorGlobal
-{
-public:
- /**
- * @brief Get the singleton object of this class
- *
- * @return EventCollectorGlobal& Singleton object
- */
- static EventCollectorGlobal &get();
-
-public:
- /**
- * @brief Getter for event collector object
- *
- * @return EventCollector& Collector object
- */
- EventCollector &collector() { return _collector; }
-
-private:
- EventCollectorGlobal();
- ~EventCollectorGlobal();
-
-private:
- EventRecorder _recorder;
- EventCollector _collector;
-};
-
-/**
- * @brief Helper class for emitting duration event which is handled automatically with ctor/dtor
- *
- */
-class EventDurationBlock
-{
-public:
- /**
- * @brief Raise a duration event with type of BEGIN
- *
- * @param tag A label for the duration event
- */
- EventDurationBlock(const std::string &tag);
- /**
- * @brief Raise a duration event with type of END
- *
- */
- ~EventDurationBlock();
-
-private:
- std::string _tag;
-};
-
-/**
- * @brief Helper class for emitting duration event which is handled manually
- *
- * Usage:
- * {
- * ...
- * EventDurationManual duration("some tag");
- * duration.begin();
- * ...
- * ... // Code for duration
- * ...
- * duration.end();
- * }
- *
- */
-class EventDurationManual
-{
-public:
- /**
- * @brief Construct a new Event Duration Manual object
- *
- * @param tag A label for the duration object
- */
- EventDurationManual(const std::string &tag);
- /**
- * @brief Destroy the Event Duration Manual object
- *
- */
- ~EventDurationManual();
-
- /**
- * @brief Raise a duration event with type of BEGIN
- *
- */
- void begin();
- /**
- * @brief Raise a duration event with type of END
- *
- */
- void end();
-
-private:
- std::string _tag;
- bool _pair;
-};
-
-} // namespace util
-} // namespace onert
-
-/**
- * Helper Macro Definitions
- *
- * HOW TO USE
- *
- * void f(args)
- * {
- * EVENT_DURATION_FUNCTION();
- * ...
- * if(cond)
- * {
- * EVENT_DURATION_REGION("if branch");
- * ...
- * }
- * ...
- * }
- */
-
-#define EVENT_DURATION_FUNCTION() \
- ::onert::util::EventDurationBlock __event_duration__##__LINE__ { __FUNCTION__ }
-
-#define EVENT_DURATION_REGION(tag) \
- ::onert::util::EventDurationBlock __event_duration__##__LINE__ { tag }
-
-#endif // __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
diff --git a/runtime/onert/core/src/util/EventRecorder.cc b/runtime/onert/core/src/util/EventRecorder.cc
index 13a599bed..85a588d38 100644
--- a/runtime/onert/core/src/util/EventRecorder.cc
+++ b/runtime/onert/core/src/util/EventRecorder.cc
@@ -14,396 +14,13 @@
* limitations under the License.
*/
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
-#include <sstream>
-#include <vector>
-#include <unordered_map>
-#include <json/json.h>
-#include <assert.h>
-#include <utility>
-#include <map>
-#include <set>
-#include <stdint.h>
-
-// json type for Chrome Event Trace
-namespace
-{
-
-std::string quote(const std::string &value)
-{
- std::stringstream ss;
- ss << '"' << value << '"';
- return ss.str();
-}
-
-std::string field(const std::string &k, const std::string &v)
-{
- std::stringstream ss;
- ss << quote(k) << " : " << quote(v);
- return ss.str();
-}
-
-struct Content // One Entry in Chrome Event Trace
-{
- std::vector<std::pair<std::string, std::string>> flds;
- std::vector<std::pair<std::string, std::string>> args;
-};
-
-std::string object(const Content &content)
-{
- std::stringstream ss;
-
- ss << "{ ";
-
- ss << field(content.flds[0].first, content.flds[0].second);
-
- for (uint32_t n = 1; n < content.flds.size(); ++n)
- {
- ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second);
- }
-
- if (content.args.size() > 0)
- {
- ss << ", " << quote("args") << " : { ";
- ss << field(content.args.at(0).first, content.args.at(0).second);
-
- for (uint32_t n = 1; n < content.args.size(); ++n)
- {
- ss << ", " << field(content.args.at(n).first, content.args.at(n).second);
- }
-
- ss << "}";
- }
-
- ss << " }";
-
- return ss.str();
-}
-
-void fill(Content &content, const Event &evt)
-{
- content.flds.emplace_back("name", evt.name);
- content.flds.emplace_back("pid", "0");
- content.flds.emplace_back("tid", evt.tid);
- content.flds.emplace_back("ph", evt.ph);
- content.flds.emplace_back("ts", evt.ts);
-}
-
-std::string object(const DurationEvent &evt)
-{
- Content content;
-
- fill(content, evt);
-
- return ::object(content);
-}
-
-std::string object(const CounterEvent &evt)
-{
- Content content;
-
- fill(content, evt);
-
- for (auto it = evt.values.begin(); it != evt.values.end(); ++it)
- {
- content.args.emplace_back(it->first, it->second);
- }
-
- return ::object(content);
-}
-
-} // namespace
-
-// md table type
-namespace
-{
-
-void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list)
-{
- os << "| ";
- for (auto &key : list)
- {
- os << key << " | ";
- }
- os << "\n";
-}
-
-struct MDContent
-{
- std::string name;
- uint64_t begin_ts;
- uint64_t end_ts;
- uint32_t min_rss;
- uint32_t max_rss;
- uint32_t min_page_reclaims;
- uint32_t max_page_reclaims;
-
- MDContent()
- : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX),
- max_page_reclaims(0)
- {
- // DO NOTHING
- }
-
- virtual ~MDContent() = default;
-
- void updateRss(uint32_t rss)
- {
- if (min_rss == UINT32_MAX)
- min_rss = rss;
- if (max_rss == 0)
- max_rss = rss;
-
- if (min_rss > rss)
- min_rss = rss;
- else if (max_rss < rss)
- max_rss = rss;
- }
-
- void updateMinflt(uint32_t minflt)
- {
- if (min_page_reclaims == UINT32_MAX)
- min_page_reclaims = minflt;
- if (max_page_reclaims == 0)
- max_page_reclaims = minflt;
-
- if (min_page_reclaims > minflt)
- min_page_reclaims = minflt;
- else if (max_page_reclaims < minflt)
- max_page_reclaims = minflt;
- }
-
- virtual void write(std::ostream &os) const = 0;
-};
-
-struct OpSeq : public MDContent
-{
- std::string backend;
- uint64_t graph_latency;
-
- struct OpSeqCmp
- {
- bool operator()(const OpSeq &lhs, const OpSeq &rhs) const
- {
- return lhs.begin_ts < rhs.begin_ts;
- }
- bool operator()(const OpSeq &lhs, const OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; }
- bool operator()(OpSeq &lhs, OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; }
- };
-
- void write(std::ostream &os) const override
- {
- uint64_t opseq_latency = end_ts - begin_ts;
- double opseq_per = static_cast<double>(opseq_latency) / graph_latency * 100.0;
- writeMDTableRow(os, {name, backend, std::to_string(opseq_latency), std::to_string(opseq_per),
- std::to_string(min_rss), std::to_string(max_rss),
- std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)});
- }
-};
-
-struct Graph : public MDContent
-{
- std::set<OpSeq, OpSeq::OpSeqCmp> opseqs;
-
- void setOpSeqs(const std::map<std::string, OpSeq> &name_to_opseq)
- {
- uint64_t graph_latency = end_ts - begin_ts;
- for (auto it : name_to_opseq)
- {
- auto opseq = it.second;
- opseq.graph_latency = graph_latency;
-
- opseqs.insert(opseq);
-
- updateRss(opseq.min_rss);
- updateRss(opseq.max_rss);
- updateMinflt(opseq.min_page_reclaims);
- updateMinflt(opseq.max_page_reclaims);
- }
- }
-
- void write(std::ostream &os) const override
- {
- static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)",
- "page_reclaims_min", "page_reclaims_max"};
-
- static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------",
- "-----------------", "-----------------"};
-
- // Graph's Header
- writeMDTableRow(os, graph_headers);
- writeMDTableRow(os, graph_headers_line);
-
- // Graph's contents
- writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss),
- std::to_string(max_rss), std::to_string(min_page_reclaims),
- std::to_string(max_page_reclaims)});
-
- os << "\n";
-
- static std::vector<std::string> opseq_headers{
- "OpSeq name", "backend", "latency(us)", "latency(%)",
- "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"};
-
- static std::vector<std::string> opseq_headers_line{
- "----------", "-------", "-----------", "-----------",
- "-------", "-------", "-----------------", "-----------------"};
-
- os << "## OpSequences \n";
-
- // OpSeq's Header
- writeMDTableRow(os, opseq_headers);
- writeMDTableRow(os, opseq_headers_line);
-
- // OpSeq's contents
- for (auto opseq : opseqs)
- {
- opseq.write(os);
- }
-
- os << "\n";
- }
-};
-
-struct MDTableBuilder
-{
- MDTableBuilder(const std::vector<DurationEvent> &duration_events,
- const std::vector<CounterEvent> &counter_events)
- : _duration_events(duration_events), _counter_events(counter_events)
- {
- for (const auto &evt : _counter_events)
- {
- uint64_t ts = std::stoull(evt.ts);
- auto &name = evt.name;
- assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0);
- assert(evt.values.size() == 1);
- auto &val = evt.values.begin()->second;
- if (_ts_to_values.find(ts) == _ts_to_values.end())
- {
- std::pair<uint32_t, uint32_t> values;
- if (name.compare("maxrss") == 0)
- values.first = std::stoul(val);
- else
- values.second = std::stoul(val);
- _ts_to_values.insert({ts, values});
- }
- else
- {
- auto &values = _ts_to_values.at(ts);
- if (name.compare("maxrss") == 0)
- values.first = std::stoul(val);
- else
- values.second = std::stoul(val);
- }
- }
- }
-
- MDTableBuilder &build()
- {
- for (auto &it : divideGraph())
- {
- size_t begin_idx = it.first;
- size_t end_idx = it.second;
- std::map<std::string, OpSeq> name_to_opseq;
- for (size_t i = begin_idx + 1; i < end_idx; ++i)
- {
- const auto &evt = _duration_events[i];
- assert(evt.name.compare("Graph") != 0);
- assert(evt.ph.compare("B") == 0 || evt.ph.compare("E") == 0);
- if (evt.ph.compare("B") == 0)
- {
- assert(name_to_opseq.find(evt.name) == name_to_opseq.end());
- name_to_opseq.insert({evt.name, makeOpSeq(evt)});
- }
- else
- {
- assert(name_to_opseq.find(evt.name) != name_to_opseq.end());
- auto &opseq = name_to_opseq.at(evt.name);
- updateOpSeq(opseq, evt);
- }
- }
-
- _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_opseq));
- }
-
- return *this;
- }
-
- std::vector<std::pair<size_t, size_t>> divideGraph()
- {
- std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx>
- for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i)
- {
- const auto &evt = _duration_events.at(i);
- if (evt.name.compare("Graph") == 0)
- {
- if (evt.ph.compare("B") == 0)
- begin_idx = i;
- else
- graph_idx_list.emplace_back(begin_idx, i);
- }
- }
- return graph_idx_list;
- }
-
- OpSeq makeOpSeq(const DurationEvent &evt)
- {
- OpSeq opseq;
- opseq.name = evt.name;
- opseq.begin_ts = std::stoull(evt.ts);
- opseq.updateRss(_ts_to_values.at(opseq.begin_ts).first);
- opseq.updateMinflt(_ts_to_values.at(opseq.begin_ts).second);
- opseq.backend = evt.tid;
- return opseq;
- }
-
- void updateOpSeq(OpSeq &opseq, const DurationEvent &evt)
- {
- opseq.end_ts = std::stoull(evt.ts);
- opseq.updateRss(_ts_to_values.at(opseq.end_ts).first);
- opseq.updateMinflt(_ts_to_values.at(opseq.end_ts).second);
- }
-
- Graph makeGraph(size_t begin_idx, size_t end_idx,
- const std::map<std::string, OpSeq> &name_to_opseq)
- {
- Graph graph;
- graph.name = "Graph";
- graph.begin_ts = std::stoull(_duration_events[begin_idx].ts);
- graph.updateRss(_ts_to_values.at(graph.begin_ts).first);
- graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second);
- graph.end_ts = std::stoull(_duration_events[end_idx].ts);
- graph.updateRss(_ts_to_values.at(graph.end_ts).first);
- graph.updateMinflt(_ts_to_values.at(graph.end_ts).second);
- graph.setOpSeqs(name_to_opseq);
- return graph;
- }
-
- void write(std::ostream &os)
- {
- // Write contents
- for (size_t i = 0; i < _graphs.size(); ++i)
- {
- os << "# Graph " << i << "\n";
- _graphs.at(i).write(os);
- }
- }
-
- const std::vector<DurationEvent> &_duration_events;
- const std::vector<CounterEvent> &_counter_events;
- // timestamp to std::pair<maxrss, minflt>
- std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values;
- std::vector<Graph> _graphs;
-};
-
-} // namespace
-
-void EventRecorder::emit(const DurationEvent &evt)
+void EventRecorder::emit(std::unique_ptr<DurationEvent> &&evt)
{
std::lock_guard<std::mutex> lock{_mu};
- _duration_events.push_back(evt);
+ _duration_events.push_back(std::move(evt));
}
void EventRecorder::emit(const CounterEvent &evt)
@@ -412,146 +29,3 @@ void EventRecorder::emit(const CounterEvent &evt)
_counter_events.push_back(evt);
}
-
-void EventRecorder::writeToFile(std::ostream &os)
-{
- std::lock_guard<std::mutex> lock{_mu};
-
- switch (_write_format)
- {
- case WriteFormat::CHROME_TRACING:
- writeChromeTrace(os);
- break;
- case WriteFormat::SNPE_BENCHMARK:
- writeSNPEBenchmark(os);
- break;
- case WriteFormat::MD_TABLE:
- writeMDTable(os);
- break;
- default:
- assert(!"Invalid value");
- break;
- }
-}
-
-void EventRecorder::writeSNPEBenchmark(std::ostream &os)
-{
- Json::Value root;
- auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue};
-
- struct Stat
- {
- uint64_t sum = 0;
- uint64_t count = 0;
- uint64_t max = 0;
- uint64_t min = std::numeric_limits<uint64_t>::max();
-
- void accumulate(uint64_t val)
- {
- sum += val;
- count++;
- max = std::max(max, val);
- min = std::min(min, val);
- }
- };
-
- // Memory
- {
- std::unordered_map<std::string, Stat> mem_stats;
- for (auto &evt : _counter_events)
- {
- auto &mem_stat = mem_stats[evt.name];
- uint64_t val = std::stoull(evt.values["value"]);
- mem_stat.accumulate(val);
- }
-
- auto &mem = exec_data["memory"] = Json::Value{Json::objectValue};
- for (auto &kv : mem_stats)
- {
- auto &key = kv.first;
- auto &val = kv.second;
- mem[key]["Avg_Size"] = val.sum / val.count;
- mem[key]["Max_Size"] = val.max;
- mem[key]["Min_Size"] = val.min;
- mem[key]["Runtime"] = "NA";
- }
- }
-
- // Operation Execution Time
- {
- // NOTE This assumes _duration_events is sorted by "ts" ascending
-
- // 2D keys : stats[tid][name]
- std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats;
- std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps;
- for (auto &evt : _duration_events)
- {
- auto &stat = stats[evt.tid][evt.name];
- auto &begin_ts = begin_timestamps[evt.tid][evt.name];
- uint64_t timestamp = std::stoull(evt.ts);
- if (evt.ph == "B")
- {
- if (begin_ts != 0)
- throw std::runtime_error{"Invalid Data"};
- begin_ts = timestamp;
- }
- else if (evt.ph == "E")
- {
- if (begin_ts == 0 || timestamp < begin_ts)
- throw std::runtime_error{"Invalid Data"};
- stat.accumulate(timestamp - begin_ts);
- begin_ts = 0;
- }
- else
- throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt.ph + "\""};
- }
-
- for (auto &kv : begin_timestamps)
- for (auto &kv2 : kv.second)
- if (kv2.second != 0)
- throw std::runtime_error{"Invalid Data - B and E pair does not match."};
-
- for (auto &kv : stats)
- {
- auto &tid = kv.first;
- auto &map = kv.second;
- auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue};
- for (auto &kv : map)
- {
- auto &name = kv.first;
- auto &val = kv.second;
- json_tid[name]["Avg_Time"] = val.sum / val.count;
- json_tid[name]["Max_Time"] = val.max;
- json_tid[name]["Min_Time"] = val.min;
- json_tid[name]["Runtime"] = tid;
- }
- }
- }
-
- os << root;
-}
-
-void EventRecorder::writeChromeTrace(std::ostream &os)
-{
- os << "{\n";
- os << " " << quote("traceEvents") << ": [\n";
-
- for (auto &evt : _duration_events)
- {
- os << " " << object(evt) << ",\n";
- }
-
- for (auto &evt : _counter_events)
- {
- os << " " << object(evt) << ",\n";
- }
-
- os << " { }\n";
- os << " ]\n";
- os << "}\n";
-}
-
-void EventRecorder::writeMDTable(std::ostream &os)
-{
- MDTableBuilder(_duration_events, _counter_events).build().write(os);
-}
diff --git a/runtime/onert/core/src/util/EventRecorder.h b/runtime/onert/core/src/util/EventRecorder.h
index 37ec1a0f1..5cf03d8ac 100644
--- a/runtime/onert/core/src/util/EventRecorder.h
+++ b/runtime/onert/core/src/util/EventRecorder.h
@@ -17,28 +17,52 @@
#ifndef __ONERT_UTIL_EVENT_RECORDER_H__
#define __ONERT_UTIL_EVENT_RECORDER_H__
+#include "util/TracingCtx.h"
+
#include <map>
#include <memory>
#include <mutex>
-#include <ostream>
#include <vector>
+// refer to https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit#
struct Event
{
- std::string name;
- std::string tid;
- std::string ph; /* REQUIRED */
- std::string ts; /* REQUIRED */
+ const onert::util::TracingCtx *tracing_ctx;
+
+ std::string ph; // Event type.
+ std::string ts; // tracing clock of timestamp of this event
+ std::vector<std::pair<std::string, std::string>> args; // user-defined data: pairs of (key, value)
+
+ virtual ~Event() = default;
};
struct DurationEvent : public Event
{
- // TO BE FILLED
+ uint32_t session_index = 0;
+ uint32_t subg_index = 0;
+
+protected:
+ DurationEvent() = default;
+};
+
+struct SubgDurationEvent : public DurationEvent
+{ /* same with DurationEvent */
+};
+
+// TODO Rename it to OperationDurationEvent
+struct OpSeqDurationEvent : public DurationEvent
+{
+ // Note: DurationEvent's name and tid will be set by EventWriter
+ std::string backend;
+ uint32_t op_index;
+ std::string op_name;
};
struct CounterEvent : public Event
{
+ std::string name; // name of event
+ std::string tid; // thread ID
std::map<std::string, std::string> values;
};
@@ -50,35 +74,22 @@ struct CounterEvent : public Event
class EventRecorder
{
public:
- enum class WriteFormat
- {
- CHROME_TRACING,
- SNPE_BENCHMARK,
- MD_TABLE,
- };
-
-public:
EventRecorder() = default;
public:
- void emit(const DurationEvent &evt);
+ void emit(std::unique_ptr<DurationEvent> &&evt);
void emit(const CounterEvent &evt);
public:
- bool empty() { return _duration_events.empty() && _counter_events.empty(); }
- void writeToFile(std::ostream &os);
- void setWriteFormat(WriteFormat write_format) { _write_format = write_format; }
-
-private:
- void writeSNPEBenchmark(std::ostream &os);
- void writeChromeTrace(std::ostream &os);
- void writeMDTable(std::ostream &os);
+ const std::vector<std::unique_ptr<DurationEvent>> &duration_events() const
+ {
+ return _duration_events;
+ }
+ const std::vector<CounterEvent> &counter_events() const { return _counter_events; }
private:
std::mutex _mu;
- // TODO: Allow user to control write_format
- WriteFormat _write_format{WriteFormat::SNPE_BENCHMARK};
- std::vector<DurationEvent> _duration_events;
+ std::vector<std::unique_ptr<DurationEvent>> _duration_events;
std::vector<CounterEvent> _counter_events;
};
diff --git a/runtime/onert/core/src/util/EventWriter.cc b/runtime/onert/core/src/util/EventWriter.cc
new file mode 100644
index 000000000..ca4bd302e
--- /dev/null
+++ b/runtime/onert/core/src/util/EventWriter.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+
+// initialization
+std::mutex EventWriter::_mutex;
+
+void EventWriter::readyToFlush(std::unique_ptr<EventRecorder> &&recorder)
+{
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+
+ _recorders.emplace_back(std::move(recorder));
+
+ if (--_ref_count > 0)
+ return;
+ }
+ // The caller of this method is the last instance that uses EventWriter.
+ // Let's write log files.
+
+ // Note. According to an internal issue, let snpe json as just file name not '.snpe.json'
+ flush(WriteFormat::SNPE_BENCHMARK);
+ flush(WriteFormat::CHROME_TRACING);
+ flush(WriteFormat::MD_TABLE);
+}
+
+void EventWriter::flush(WriteFormat write_format)
+{
+ auto *writer = _actual_writers[write_format].get();
+ assert(writer);
+
+ writer->flush(_recorders);
+}
diff --git a/runtime/onert/core/src/util/EventWriter.h b/runtime/onert/core/src/util/EventWriter.h
new file mode 100644
index 000000000..0a35a8508
--- /dev/null
+++ b/runtime/onert/core/src/util/EventWriter.h
@@ -0,0 +1,144 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_UTIL_EVENT_WRITER_H__
+#define __ONERT_UTIL_EVENT_WRITER_H__
+
+#include "EventRecorder.h"
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+#include <mutex>
+#include <fstream>
+
+class EventFormatWriter
+{
+public:
+ EventFormatWriter(const std::string &filepath) : _os{filepath, std::ofstream::out} {}
+ virtual ~EventFormatWriter()
+ { /* empty */
+ }
+
+ virtual void flush(const std::vector<std::unique_ptr<EventRecorder>> &) = 0;
+
+protected:
+ std::ofstream _os;
+};
+
+class SNPEWriter : public EventFormatWriter
+{
+public:
+ SNPEWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~SNPEWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+};
+
+class ChromeTracingWriter : public EventFormatWriter
+{
+public:
+ ChromeTracingWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~ChromeTracingWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+
+private:
+ void flushOneRecord(const EventRecorder &);
+};
+
+class MDTableWriter : public EventFormatWriter
+{
+public:
+ MDTableWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~MDTableWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+};
+
+#include <mutex>
+
+class EventWriter
+{
+public:
+ enum class WriteFormat
+ {
+ CHROME_TRACING,
+ SNPE_BENCHMARK,
+ MD_TABLE,
+ };
+
+ /**
+ * @brief Retuens a singleton object
+ */
+ static EventWriter *get(const std::string &filename)
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+
+ static EventWriter singleton(filename);
+ return &singleton;
+ }
+
+ /**
+ * @brief Call this when observer which use EventWriter starts
+ */
+ void startToUse()
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+ _ref_count++;
+ }
+
+ /**
+ * @brief Call this when observer which use EventWriter finishes.
+ * After multiple observers calls this method, the reference count will eventually be 0.
+ * Then, EventWriter will write profiling result file.
+ */
+ void readyToFlush(std::unique_ptr<EventRecorder> &&recorder);
+
+private:
+ EventWriter(const std::string &filepath) : _ref_count(0)
+ {
+ std::string snpe_log_name(filepath);
+ std::string chrome_tracing_log_name(filepath + ".chrome.json");
+ std::string md_table_log_name(filepath + ".table.md");
+
+ _actual_writers[WriteFormat::SNPE_BENCHMARK] = std::make_unique<SNPEWriter>(snpe_log_name);
+ _actual_writers[WriteFormat::CHROME_TRACING] =
+ std::make_unique<ChromeTracingWriter>(chrome_tracing_log_name);
+ _actual_writers[WriteFormat::MD_TABLE] = std::make_unique<MDTableWriter>(md_table_log_name);
+ };
+
+ void flush(WriteFormat write_format);
+
+private:
+ static std::mutex _mutex;
+
+ // number of observer of an executor that want to write profiling data
+ int32_t _ref_count;
+
+ // one recorder object per executor
+ std::vector<std::unique_ptr<EventRecorder>> _recorders;
+
+ std::unordered_map<WriteFormat, std::unique_ptr<EventFormatWriter>> _actual_writers;
+};
+
+#endif // __ONERT_UTIL_EVENT_WRITER_H__
diff --git a/runtime/onert/core/src/util/GeneralConfigSource.cc b/runtime/onert/core/src/util/Index.test.cc
index 7d2757e58..ff73e5e59 100644
--- a/runtime/onert/core/src/util/GeneralConfigSource.cc
+++ b/runtime/onert/core/src/util/Index.test.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,32 +14,21 @@
* limitations under the License.
*/
-#include "util/GeneralConfigSource.h"
-#include "util/logging.h"
+#include "util/Index.h"
-namespace onert
-{
-namespace util
-{
+#include <gtest/gtest.h>
-std::string GeneralConfigSource::get(const std::string &key) const
-{
- auto itr = _map.find(key);
- if (itr == _map.end())
- {
- return "";
- }
- else
- {
- return itr->second;
- }
-}
+using Index = ::onert::util::Index<uint32_t, struct TestTag>;
-void GeneralConfigSource::set(const std::string &key, const std::string &val)
+TEST(Index, neg_index_test)
{
- VERBOSE(GeneralConfigSource) << key << " : " << val << std::endl;
- _map[key] = val;
-}
+ Index idx1{1u};
+ Index idx2{2u};
+ Index idx3{idx1};
-} // namespace util
-} // namespace onert
+ ASSERT_EQ(idx1, 1);
+ ASSERT_EQ(idx1, 1u);
+ ASSERT_EQ(idx1.value(), 1u);
+ ASSERT_NE(idx1, idx2);
+ ASSERT_EQ(idx1, idx3);
+}
diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc
new file mode 100644
index 000000000..e7d90eec4
--- /dev/null
+++ b/runtime/onert/core/src/util/MDTableEventWriter.cc
@@ -0,0 +1,365 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+#include <map>
+#include <set>
+#include <sstream>
+#include <stdint.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+// md table type
+namespace
+{
+
+void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list)
+{
+ os << "| ";
+ for (const auto &key : list)
+ {
+ os << key << " | ";
+ }
+ os << "\n";
+}
+
+struct MDContent
+{
+ std::string name;
+ uint64_t begin_ts;
+ uint64_t end_ts;
+ uint32_t min_rss;
+ uint32_t max_rss;
+ uint32_t min_page_reclaims;
+ uint32_t max_page_reclaims;
+
+ MDContent()
+ : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX),
+ max_page_reclaims(0)
+ {
+ // DO NOTHING
+ }
+
+ virtual ~MDContent() = default;
+
+ void updateRss(uint32_t rss)
+ {
+ if (min_rss == UINT32_MAX)
+ min_rss = rss;
+ if (max_rss == 0)
+ max_rss = rss;
+
+ if (min_rss > rss)
+ min_rss = rss;
+ else if (max_rss < rss)
+ max_rss = rss;
+ }
+
+ void updateMinflt(uint32_t minflt)
+ {
+ if (min_page_reclaims == UINT32_MAX)
+ min_page_reclaims = minflt;
+ if (max_page_reclaims == 0)
+ max_page_reclaims = minflt;
+
+ if (min_page_reclaims > minflt)
+ min_page_reclaims = minflt;
+ else if (max_page_reclaims < minflt)
+ max_page_reclaims = minflt;
+ }
+
+ virtual void write(std::ostream &os) const = 0;
+};
+
+struct Operation : public MDContent
+{
+ std::string backend;
+ uint64_t graph_latency;
+
+ struct OperationCmp
+ {
+ bool operator()(const Operation &lhs, const Operation &rhs) const
+ {
+ return lhs.begin_ts < rhs.begin_ts;
+ }
+ bool operator()(const Operation &lhs, const Operation &rhs)
+ {
+ return lhs.begin_ts < rhs.begin_ts;
+ }
+ bool operator()(Operation &lhs, Operation &rhs) { return lhs.begin_ts < rhs.begin_ts; }
+ };
+
+ void write(std::ostream &os) const override
+ {
+ uint64_t op_latency = end_ts - begin_ts;
+ double op_per = static_cast<double>(op_latency) / graph_latency * 100.0;
+ writeMDTableRow(os, {name, backend, std::to_string(op_latency), std::to_string(op_per),
+ std::to_string(min_rss), std::to_string(max_rss),
+ std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)});
+ }
+};
+
+struct Graph : public MDContent
+{
+ std::set<Operation, Operation::OperationCmp> ops;
+ std::string session_index;
+ std::string subgraph_index;
+
+ void setOperations(const std::map<std::string, Operation> &name_to_op)
+ {
+ uint64_t graph_latency = end_ts - begin_ts;
+ for (auto &&it : name_to_op)
+ {
+ auto op = it.second;
+ op.graph_latency = graph_latency;
+
+ ops.insert(op);
+
+ updateRss(op.min_rss);
+ updateRss(op.max_rss);
+ updateMinflt(op.min_page_reclaims);
+ updateMinflt(op.max_page_reclaims);
+ }
+ }
+
+ void write(std::ostream &os) const override
+ {
+ static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)",
+ "page_reclaims_min", "page_reclaims_max"};
+
+ static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------",
+ "-----------------", "-----------------"};
+
+ // Graph's Header
+ writeMDTableRow(os, graph_headers);
+ writeMDTableRow(os, graph_headers_line);
+
+ // Graph's contents
+ writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss),
+ std::to_string(max_rss), std::to_string(min_page_reclaims),
+ std::to_string(max_page_reclaims)});
+
+ os << "\n";
+
+ static std::vector<std::string> op_headers{
+ "Op name", "backend", "latency(us)", "latency(%)",
+ "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"};
+
+ static std::vector<std::string> op_headers_line{
+ "-------", "-------", "-----------", "-----------",
+ "-------", "-------", "-----------------", "-----------------"};
+
+ os << "## Op \n";
+
+ // Operation's Header
+ writeMDTableRow(os, op_headers);
+ writeMDTableRow(os, op_headers_line);
+
+ // Operation's contents
+ for (auto &&op : ops)
+ {
+ op.write(os);
+ }
+
+ os << "\n";
+ }
+};
+
+std::string getLabel(const OpSeqDurationEvent &evt)
+{
+ std::string subg_label("$" + std::to_string(evt.subg_index) + " subgraph");
+ std::string op_label("@" + std::to_string(evt.op_index) + " " + evt.op_name);
+
+ return subg_label + " " + op_label;
+}
+
+struct MDTableBuilder
+{
+ MDTableBuilder(const std::vector<std::unique_ptr<DurationEvent>> &duration_events,
+ const std::vector<CounterEvent> &counter_events)
+ : _duration_events(duration_events), _counter_events(counter_events)
+ {
+// when ready with low overhead in release build
+#ifdef DEBUG
+ for (const auto &evt : _counter_events)
+ {
+ uint64_t ts = std::stoull(evt.ts);
+ auto &name = evt.name;
+ assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0);
+ assert(evt.values.size() == 1);
+ auto &val = evt.values.begin()->second;
+ if (_ts_to_values.find(ts) == _ts_to_values.end())
+ {
+ std::pair<uint32_t, uint32_t> values;
+ if (name.compare("maxrss") == 0)
+ values.first = std::stoul(val);
+ else
+ values.second = std::stoul(val);
+ _ts_to_values.insert({ts, values});
+ }
+ else
+ {
+ auto &values = _ts_to_values.at(ts);
+ if (name.compare("maxrss") == 0)
+ values.first = std::stoul(val);
+ else
+ values.second = std::stoul(val);
+ }
+ }
+#endif
+ }
+
+ MDTableBuilder &build()
+ {
+ for (const auto &it : divideGraph())
+ {
+ size_t begin_idx = it.first;
+ size_t end_idx = it.second;
+ std::map<std::string, Operation> name_to_op;
+ for (size_t i = begin_idx + 1; i < end_idx; ++i)
+ {
+ const auto *evt = dynamic_cast<const OpSeqDurationEvent *>(_duration_events[i].get());
+ if (evt == nullptr)
+ continue;
+
+ const std::string evt_name = getLabel(*evt);
+ assert(evt->ph.compare("B") == 0 || evt->ph.compare("E") == 0);
+ if (evt->ph.compare("B") == 0)
+ {
+ assert(name_to_op.find(evt_name) == name_to_op.end());
+ name_to_op.insert({evt_name, makeOperation(*evt)});
+ }
+ else
+ {
+ assert(name_to_op.find(evt_name) != name_to_op.end());
+ auto &op = name_to_op.at(evt_name);
+ updateOperation(op, *evt);
+ }
+ }
+
+ _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_op));
+ }
+
+ return *this;
+ }
+
+ std::vector<std::pair<size_t, size_t>> divideGraph()
+ {
+ std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx>
+ for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i)
+ {
+ const auto subg_evt = dynamic_cast<const SubgDurationEvent *>(_duration_events.at(i).get());
+ if (subg_evt == nullptr)
+ continue;
+
+ if (subg_evt->ph.compare("B") == 0)
+ begin_idx = i;
+ else
+ graph_idx_list.emplace_back(begin_idx, i);
+ }
+ return graph_idx_list;
+ }
+
+ Operation makeOperation(const OpSeqDurationEvent &evt)
+ {
+ Operation op;
+ const std::string &evt_name = getLabel(evt);
+ op.name = evt_name;
+ op.begin_ts = std::stoull(evt.ts);
+ op.backend = evt.backend;
+#ifdef DEBUG
+ op.updateRss(_ts_to_values.at(op.begin_ts).first);
+ op.updateMinflt(_ts_to_values.at(op.begin_ts).second);
+#else
+ op.updateRss(0);
+ op.updateMinflt(0);
+#endif
+ return op;
+ }
+
+ void updateOperation(Operation &op, const DurationEvent &evt)
+ {
+ op.end_ts = std::stoull(evt.ts);
+#ifdef DEBUG
+ op.updateRss(_ts_to_values.at(op.end_ts).first);
+ op.updateMinflt(_ts_to_values.at(op.end_ts).second);
+#else
+ op.updateRss(0);
+ op.updateMinflt(0);
+#endif
+ }
+
+ Graph makeGraph(size_t begin_idx, size_t end_idx,
+ const std::map<std::string, Operation> &name_to_op)
+ {
+ Graph graph;
+ graph.name = "Subgraph";
+ graph.begin_ts = std::stoull(_duration_events[begin_idx]->ts);
+ graph.end_ts = std::stoull(_duration_events[end_idx]->ts);
+ graph.setOperations(name_to_op);
+
+ for (const auto &arg : _duration_events[end_idx]->args)
+ {
+ if (arg.first == "session")
+ graph.session_index = arg.second;
+ if (arg.first == "subgraph")
+ graph.subgraph_index = arg.second;
+ }
+
+#ifdef DEBUG
+ graph.updateRss(_ts_to_values.at(graph.begin_ts).first);
+ graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second);
+ graph.updateRss(_ts_to_values.at(graph.end_ts).first);
+ graph.updateMinflt(_ts_to_values.at(graph.end_ts).second);
+#else
+ graph.updateRss(0);
+ graph.updateMinflt(0);
+#endif
+ return graph;
+ }
+
+ void write(std::ostream &os)
+ {
+ // Write contents
+ for (size_t i = 0; i < _graphs.size(); ++i)
+ {
+ auto &graph = _graphs.at(i);
+ os << "# Session: " << graph.session_index << ", Subgraph: " << graph.subgraph_index
+ << ", Running count: " << i << "\n";
+ _graphs.at(i).write(os);
+ }
+ }
+
+ const std::vector<std::unique_ptr<DurationEvent>> &_duration_events;
+ const std::vector<CounterEvent> &_counter_events;
+
+ // timestamp to std::pair<maxrss, minflt>
+ std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values;
+ std::vector<Graph> _graphs;
+};
+
+} // namespace
+
+void MDTableWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &records)
+{
+ for (const auto &recorder : records)
+ {
+ MDTableBuilder(recorder->duration_events(), recorder->counter_events()).build().write(_os);
+ }
+}
diff --git a/runtime/onert/core/src/util/ObjectManager.test.cc b/runtime/onert/core/src/util/ObjectManager.test.cc
new file mode 100644
index 000000000..3fe735732
--- /dev/null
+++ b/runtime/onert/core/src/util/ObjectManager.test.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/Index.h"
+#include "util/ObjectManager.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert;
+
+struct TestTag;
+using Index = typename util::Index<uint32_t, TestTag>;
+
+TEST(ObjectManager, emplace)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index = man.emplace(100);
+ ASSERT_EQ(man.at(index), 100);
+}
+
+TEST(ObjectManager, neg_remove_1)
+{
+ util::ObjectManager<Index, int> man;
+
+ Index index = man.emplace(100);
+ ASSERT_TRUE(man.exist(index));
+ ASSERT_EQ(man.at(index), 100);
+
+ man.remove(index);
+ ASSERT_FALSE(man.exist(index));
+}
+
+TEST(ObjectManager, neg_remove_2)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ ASSERT_TRUE(man.exist(index0));
+ ASSERT_EQ(man.at(index0), 100);
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+
+ man.remove(index0);
+ ASSERT_FALSE(man.exist(index0));
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+}
+
+TEST(ObjectManager, push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Not specify index
+ auto index = man.push(std::make_unique<int>(100));
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify index
+ auto index2 = man.push(std::make_unique<int>(200), Index{33});
+ ASSERT_EQ(index2.value(), 33);
+ ASSERT_EQ(man.at(index2), 200);
+
+ auto index3 = man.push(std::make_unique<int>(300));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index3.value(), 34);
+ ASSERT_EQ(man.at(index3), 300);
+
+ auto index4 = man.push(std::make_unique<int>(400), Index{22});
+ ASSERT_EQ(index4.value(), 22);
+ ASSERT_EQ(man.at(index4), 400);
+
+ auto index5 = man.push(std::make_unique<int>(500));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index5.value(), 35);
+ ASSERT_EQ(man.at(index5), 500);
+}
+
+TEST(ObjectManager, neg_push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Specify index
+ auto index = man.push(std::make_unique<int>(100), Index{55});
+ ASSERT_EQ(index.value(), 55);
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify the same index
+ auto index2 = man.push(std::make_unique<int>(200), Index{55});
+ ASSERT_FALSE(index2.valid());
+}
+
+static const uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max();
+
+TEST(ObjectManager, neg_push_undefined_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Try inserting invalid(undefined) index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32});
+ ASSERT_FALSE(index.valid());
+ ASSERT_EQ(man.size(), 0);
+}
+
+TEST(ObjectManager, neg_push_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index2 = man.push(std::make_unique<int>(200));
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index2.valid());
+}
+
+TEST(ObjectManager, neg_emplace_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index3 = man.emplace(200);
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index3.valid());
+}
+
+TEST(ObjectManager, const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ int sum = 0;
+ man.iterate([&](const Index &index, const int &val) { sum += val; });
+ ASSERT_EQ(sum, 600);
+}
+
+TEST(ObjectManager, non_const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ man.iterate([&](const Index &index, int &val) { val += 1; });
+ ASSERT_EQ(man.at(index0), 101);
+ ASSERT_EQ(man.at(index1), 201);
+ ASSERT_EQ(man.at(index2), 301);
+}
+
+TEST(ObjectManager, set)
+{
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{1}, std::make_unique<int>(100)); // Insert
+ ASSERT_EQ(index, Index{1});
+ auto index2 = man.set(index, std::make_unique<int>(200)); // Overwrite
+ ASSERT_EQ(index2, index);
+ ASSERT_EQ(man.at(index2), 200);
+}
+
+TEST(ObjectManager, neg_set)
+{
+ auto v = std::make_unique<int>(100);
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{}, std::move(v)); // Try set with an invalid index
+ ASSERT_EQ(index, Index{});
+ ASSERT_FALSE(index.valid());
+ ASSERT_NE(v, nullptr); // v must be kept when failure
+}
+
+TEST(ObjectManager, getRawPtr)
+{
+ auto v = std::make_unique<int>(100);
+ auto v_ptr = v.get();
+ util::ObjectManager<Index, int> man;
+ auto index = man.push(std::move(v));
+ ASSERT_EQ(v_ptr, man.getRawPtr(index));
+}
+
+TEST(ObjectManager, neg_getRawPtr)
+{
+ util::ObjectManager<Index, int> man;
+ auto ptr = man.getRawPtr(Index{1});
+ ASSERT_EQ(ptr, nullptr);
+}
diff --git a/runtime/onert/core/src/util/SNPEEventWriter.cc b/runtime/onert/core/src/util/SNPEEventWriter.cc
new file mode 100644
index 000000000..87bbfc662
--- /dev/null
+++ b/runtime/onert/core/src/util/SNPEEventWriter.cc
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <json/json.h>
+
+#include <cassert>
+#include <unordered_map>
+#include <utility>
+
+/**
+ * @brief Version of SNPE format
+ * In version 1
+ * - There is no "version" field in Json
+ * - Only one subgraph is supported
+ * - Operation name is a form of "$3 ADD"
+ *
+ * In version 2,
+ * - "version" : "2" was added in Json
+ * - Multiple session and multiple subgraphs are supported
+ * - When there is only one session, operation name is a form of "$2 subgraph $3 ADD",
+ * meaning ADD op whose operation index 3 in a subgraph whose index is 2
+ * - When there are two or more sessions, operation name is a form of
+ * "$1 session $2 subgraph $3 ADD", meaning ADD op whose operation index 3
+ * in a subgraph whose index is 2, which was run in 1st session.
+ */
+#define SNPE_JSON_SCHEMA_VERSION "2"
+
+namespace
+{
+
+std::string getLabel(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ std::string subg_label("$" + std::to_string(evt_ptr->subg_index) + " subgraph");
+ std::string op_label("$" + std::to_string(evt_ptr->op_index) + " " + evt_ptr->op_name);
+
+ // Note : At this moment, there is only one thread running for EventWriter
+ if (evt_ptr->tracing_ctx->hasMultipleSessions())
+ {
+ std::string session_label("$" + std::to_string(evt_ptr->session_index) + " session");
+ return session_label + " " + subg_label + " " + op_label;
+ }
+ else
+ {
+ // When there is only one session, do not include session info
+ // Refer to https://github.sec.samsung.net/STAR/nnfw/issues/11436#issuecomment-930332
+ return subg_label + " " + op_label;
+ }
+ }
+ else // SubgEvent
+ return "Graph";
+}
+
+std::string getBackend(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ return evt_ptr->backend;
+ else // SubbEvent
+ return "runtime";
+}
+
+} // namespace
+
+void SNPEWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders)
+{
+ struct Stat
+ {
+ uint64_t sum = 0;
+ uint64_t count = 0;
+ uint64_t max = 0;
+ uint64_t min = std::numeric_limits<uint64_t>::max();
+
+ void accumulate(uint64_t val)
+ {
+ sum += val;
+ count++;
+ max = std::max(max, val);
+ min = std::min(min, val);
+ }
+ };
+
+ Json::Value root;
+ root["version"] = SNPE_JSON_SCHEMA_VERSION;
+
+ auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue};
+
+ // Memory
+ {
+ std::unordered_map<std::string, Stat> mem_stats;
+ for (const auto &recorder : recorders)
+ {
+ for (const auto &evt : recorder->counter_events())
+ {
+ auto &mem_stat = mem_stats[evt.name];
+ uint64_t val = std::stoull(evt.values.at("value"));
+ mem_stat.accumulate(val);
+ }
+ }
+
+ auto &mem = exec_data["memory"] = Json::Value{Json::objectValue};
+ for (const auto &kv : mem_stats)
+ {
+ auto &key = kv.first;
+ auto &val = kv.second;
+ mem[key]["Avg_Size"] = val.sum / val.count;
+ mem[key]["Max_Size"] = val.max;
+ mem[key]["Min_Size"] = val.min;
+ mem[key]["Runtime"] = "NA";
+ }
+ }
+
+ // Operation Execution Time
+ {
+ // NOTE This assumes _duration_events is sorted by "ts" ascending
+
+ // 2D keys : stats[tid][name]
+ std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats;
+ std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps;
+ for (const auto &recorder : recorders)
+ {
+ for (const auto &evt : recorder->duration_events())
+ {
+ std::string evt_name = getLabel(*evt);
+ std::string evt_tid = getBackend(*evt);
+
+ auto &stat = stats[evt_tid][evt_name];
+ auto &begin_ts = begin_timestamps[evt_tid][evt_name];
+ uint64_t timestamp = std::stoull(evt->ts);
+ if (evt->ph == "B")
+ {
+ if (begin_ts != 0)
+ throw std::runtime_error{"Invalid Data"};
+ begin_ts = timestamp;
+ }
+ else if (evt->ph == "E")
+ {
+ if (begin_ts == 0 || timestamp < begin_ts)
+ throw std::runtime_error{"Invalid Data"};
+ stat.accumulate(timestamp - begin_ts);
+ begin_ts = 0;
+ }
+ else
+ throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt->ph + "\""};
+ }
+ }
+
+ for (const auto &kv : begin_timestamps)
+ for (const auto &kv2 : kv.second)
+ if (kv2.second != 0)
+ throw std::runtime_error{"Invalid Data - B and E pair does not match."};
+
+ for (const auto &kv : stats)
+ {
+ const auto &tid = kv.first;
+ const auto &map = kv.second;
+ auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue};
+ for (const auto &kv : map)
+ {
+ auto &name = kv.first;
+ auto &val = kv.second;
+ json_tid[name]["Avg_Time"] = val.sum / val.count;
+ json_tid[name]["Max_Time"] = val.max;
+ json_tid[name]["Min_Time"] = val.min;
+ json_tid[name]["Runtime"] = tid;
+ }
+ }
+ }
+
+ _os << root;
+}
diff --git a/runtime/onert/core/src/util/ShapeInference.cc b/runtime/onert/core/src/util/ShapeInference.cc
index 95c15049d..862d6f725 100644
--- a/runtime/onert/core/src/util/ShapeInference.cc
+++ b/runtime/onert/core/src/util/ShapeInference.cc
@@ -22,6 +22,7 @@
#include "util/logging.h"
#include <cassert>
+#include <numeric>
#include <sstream>
#include <cmath>
@@ -72,6 +73,19 @@ ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape
} // namespace
+namespace bcq
+{
+inline int getOutputSize(const ir::Shape &cluster_shape, const int32_t *cluster_buf)
+{
+ int size = 0;
+ for (int idx = 0; idx < cluster_shape.dim(0); idx++)
+ {
+ size += cluster_buf[idx * 2 + 1];
+ }
+ return size;
+}
+} // namespace bcq
+
//
// Shape inference
//
@@ -97,10 +111,9 @@ std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, c
break;
case ir::PaddingType::EXPLICIT:
out_h =
- (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
+ (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
out_w =
- (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal +
- 1;
+ (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
break;
default:
assert(false);
@@ -114,8 +127,13 @@ ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_sha
return broadcastShapes(lhs_shape, rhs_shape);
}
-ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank)
+ir::Shape inferArgMinMaxShape(const ir::Shape &input_shape, int axis, int rank)
{
+ if (axis < 0 || axis >= rank)
+ {
+ throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
+ }
+
ir::Shape out_shape;
for (int idx = 0; idx < rank; ++idx)
{
@@ -171,11 +189,12 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int>
for (int i = 0; i < num_axis; ++i)
{
int current = axes[i];
+ if (!(-input_num_dims <= current && current < input_num_dims))
+ throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
if (current < 0)
{
current += input_num_dims;
}
- assert(0 <= current && current < input_num_dims);
for (int j = 0; j < i; ++j)
{
int previous = axes[j];
@@ -259,19 +278,24 @@ ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs
return output_shape;
}
-ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer)
+/*
+ * shp_shape : SHAPE input tensor's shape
+ * shp_buf : SHAPE input tensor's buffer
+ */
+ir::Shape inferBroadcastToShape(const ir::Shape shp_shape, const int32_t *shp_buf)
{
- const int num_elements = wshape.num_elements();
+
+ const int num_elements = shp_shape.num_elements();
assert(num_elements != 0);
- assert(shape_buffer);
+ assert(shp_buf);
ir::Shape new_shape(num_elements);
for (int i = 0; i < num_elements; ++i)
{
- assert(shape_buffer[i] != 0); // It shouldn't be 0.
- new_shape.dim(i) = shape_buffer[i];
+ assert(shp_buf[i] != 0); // It shouldn't be 0.
+ new_shape.dim(i) = shp_buf[i];
}
return new_shape;
@@ -305,6 +329,9 @@ ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat:
ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
const ir::operation::Conv2D::Param &param, ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"Conv2D: stride values must be positive"};
+
auto ifm_shape = in_shape.asFeature(layout);
// Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
@@ -321,6 +348,9 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &
const ir::operation::DepthwiseConv2D::Param &param,
ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
+
assert(layout == ir::Layout::NHWC);
auto ifm_shape = in_shape.asFeature(layout);
@@ -330,7 +360,7 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &
assert(kf_shape.N == 1);
const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
- param.padding, param.stride);
+ param.padding, param.stride, param.dilation);
return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C};
}
@@ -354,18 +384,22 @@ ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis)
return out_shape;
}
-ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer)
+template <typename T> ir::Shape inferFillShape(const ir::Shape &fill_shape, const T *shape_buf)
{
- ir::Shape out_shape(in_shape.dim(0));
+ ir::Shape out_shape(fill_shape.dim(0));
for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
{
- out_shape.dim(out_x) = buffer[out_x];
+ out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
}
return out_shape;
}
+// template instantiation
+template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int32_t *shape_buf);
+template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int64_t *shape_buf);
+
ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
{
assert(in_shape.rank() >= 2);
@@ -380,11 +414,60 @@ ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &k
return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
}
+ir::Shape inferBCQFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &cluster_shape,
+ const int32_t *cluster_buf)
+{
+ assert(cluster_shape.rank() == 2);
+ assert(cluster_shape.dim(1) == 2);
+
+ const auto input_size = in_shape.dim(1);
+ const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
+
+ return {ir::Shape({output_size, input_size})};
+}
+
+ir::Shape inferBCQGatherShape(const ir::Shape &indices_shape, const ir::Shape &cluster_shape,
+ const int32_t *cluster_buf, int rank,
+ const ir::operation::BCQGather::Param &param)
+{
+ ir::Shape out_shape;
+ ir::Shape in_original_shape;
+
+ assert(cluster_shape.rank() == 2);
+ assert(cluster_shape.dim(1) == 2);
+
+ auto hidden_size = param.input_hidden_size;
+ auto axis = param.axis;
+
+ in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
+ in_original_shape.append(hidden_size);
+
+ const int indices_rank = indices_shape.rank();
+ for (int idx = 0; idx < rank; ++idx)
+ {
+ if (idx == (int)axis)
+ {
+ for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
+ {
+ out_shape.append(indices_shape.dim(indices_idx));
+ }
+ }
+ else
+ {
+ out_shape.append(in_original_shape.dim(idx));
+ }
+ }
+
+ return out_shape;
+}
+
ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
int rank)
{
ir::Shape out_shape;
+
const int indices_rank = indices_shape.rank();
+
for (int idx = 0; idx < rank; ++idx)
{
if (idx == axis)
@@ -470,6 +553,9 @@ ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const
ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param,
const ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"Pool2D: stride values must be positive"};
+
assert(layout == ir::Layout::NHWC);
auto ifm_shape = in_shape.asFeature(layout);
const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
@@ -482,6 +568,17 @@ ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t outp
const int32_t output_width)
{
assert(in_shape.rank() == 4);
+ if (output_height < 0)
+ {
+ throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
+ std::to_string(output_height)};
+ }
+ if (output_width < 0)
+ {
+ throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
+ std::to_string(output_width)};
+ }
+
ir::Shape ret(in_shape.rank());
ret.dim(0) = in_shape.dim(0);
@@ -497,9 +594,9 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt
ir::Shape out_shape(static_cast<int>(1));
out_shape.dim(0) =
- (std::is_integral<T>::value
- ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
- : std::ceil(std::abs((start_val - limit_val) / delta_val)));
+ (std::is_integral<T>::value
+ ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
+ : std::ceil(std::abs((start_val - limit_val) / delta_val)));
return out_shape;
}
@@ -511,12 +608,12 @@ ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_el
const size_t total_num_elements)
{
ir::Shape ret(shape_num_elements);
- int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM;
+ int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
for (int32_t i = 0; i < shape_num_elements; ++i)
{
if (shape_buf[i] < 0)
{
- if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
+ if (flatten_dim != ir::Shape::kUnspecifiedDim)
throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
flatten_dim = i;
ret.dim(i) = 1;
@@ -526,7 +623,7 @@ ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_el
ret.dim(i) = shape_buf[i];
}
}
- if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
+ if (flatten_dim != ir::Shape::kUnspecifiedDim)
ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
// Check reshapable
@@ -566,9 +663,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
ir::Shape true_shape = input_true_shape;
ir::Shape false_shape = input_false_shape;
int most_rank =
- (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
- ? cond_shape.rank()
- : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
+ (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
+ ? cond_shape.rank()
+ : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
ir::Shape calculate_shape(most_rank);
@@ -579,9 +676,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
for (int i = 0; i < most_rank; ++i)
{
calculate_shape.dim(i) =
- (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
- ? cond_shape.dim(i)
- : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
+ (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
+ ? cond_shape.dim(i)
+ : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
(true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
@@ -613,7 +710,8 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
return new_shape;
}
-ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes)
+template <typename T>
+ir::Shape inferSliceShape(const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
{
const uint32_t rank = input_shape.rank();
ir::Shape out_shape(rank);
@@ -623,12 +721,12 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c
const auto input_dim = input_shape.dim(idx);
// begin is zero-based
- auto begin = begins[idx];
+ auto begin = begins_buf[idx];
if (begin < 0)
throw std::runtime_error("shape inference Slice: Invalid begin.");
// size is one-based
- auto size = sizes[idx];
+ auto size = sizes_buf[idx];
if (size < -1)
throw std::runtime_error("shape inference Slice: Invalid size.");
@@ -638,18 +736,23 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c
}
else
{
- if (input_dim < begin + size)
+ if (input_dim < static_cast<int32_t>(begin + size))
throw std::runtime_error("shape inference Slice: Invalid begin and size.");
}
- out_shape.dim(idx) = size;
+ out_shape.dim(idx) = static_cast<int32_t>(size);
}
return out_shape;
}
+// template instantiation
+template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins_buf,
+ const int32_t *sizes_buf);
+template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int64_t *begins_buf,
+ const int64_t *sizes_buf);
ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape,
- const ir::Shape &padding_shape, const int32_t *block_shape_data,
- const int32_t *padding_data)
+ const ir::Shape &padding_shape, const int32_t *block_shape_buf,
+ const int32_t *padding_buf)
{
const uint32_t rank = input_shape.rank();
ir::Shape out_shape(rank);
@@ -677,14 +780,14 @@ ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape
for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
{
int final_dim_size =
- (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]);
+ (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
- assert(final_dim_size % block_shape_data[dim] == 0);
+ assert(final_dim_size % block_shape_buf[dim] == 0);
- out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim];
+ out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
}
- const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1];
+ const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
const int output_channel_size = input_shape.dim(3);
out_shape.dim(0) = output_batch_size;
@@ -740,7 +843,7 @@ ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Sque
if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
{
throw std::runtime_error(
- "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
+ "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
}
if (!should_squeeze[current])
@@ -948,35 +1051,71 @@ ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSlic
return out_shape;
}
-ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier)
+ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier_buf,
+ const int32_t multiplier_size)
{
- // assert(in_shape.rank() == multiplier.rank());
+ if (multiplier_size != in_shape.rank())
+ {
+ throw std::runtime_error(
+ "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
+ ", bad multipliers size: " + std::to_string(multiplier_size) + "");
+ }
ir::Shape new_Shape(in_shape.rank());
for (int i = 0; i < in_shape.rank(); ++i)
{
- assert(multiplier[i]); // multiplier[i] shuld not be 0.
- new_Shape.dim(i) = in_shape.dim(i) * multiplier[i];
+ assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
+ new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
}
return new_Shape;
}
-ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm)
+ir::Shape inferTransposeShape(const ir::Shape &in_shape, const int32_t *perm_buf,
+ const int32_t perm_size)
{
- if (static_cast<int>(perm.size()) > in_shape.rank())
+ const auto rank = in_shape.rank();
+ if (perm_size > rank)
{
- throw std::runtime_error("inferTransposeShape failed, bad rank size: " +
- std::to_string(static_cast<int>(perm.size())));
+ throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
+ std::to_string(perm_size));
}
- ir::Shape out_shape(static_cast<int>(perm.size()));
- for (int idx = 0; idx < static_cast<int>(perm.size()); idx++)
+
+ const int32_t *perm_data = perm_buf;
+ std::vector<int32_t> regular_perm_vec;
+ if (perm_size == 0)
+ {
+ // perm_data will be set to (n-1...0)
+ regular_perm_vec.resize(rank);
+ std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
+ std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
+ perm_data = regular_perm_vec.data();
+ }
+ else
{
- if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size()))
+ assert(rank == perm_size);
+ }
+
+ ir::Shape out_shape(rank);
+ std::vector<bool> visit_perms(rank, false);
+ for (int idx = 0; idx < rank; idx++)
+ {
+ const auto perm_val = perm_data[idx];
+ // Check invalid permutation value
+ if (perm_val < 0 || perm_val >= rank)
{
- throw std::runtime_error("inferTransposeShape failed, bad perm value: " +
- std::to_string(perm[idx]));
+ throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
+ std::to_string(perm_val));
}
- out_shape.dim(idx) = in_shape.dim(perm[idx]);
+
+ // Check duplicated permutation value
+ if (visit_perms.at(perm_val))
+ {
+ throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
+ std::to_string(perm_val));
+ }
+ visit_perms.at(perm_val) = true;
+
+ out_shape.dim(idx) = in_shape.dim(perm_val);
}
return out_shape;
}
diff --git a/runtime/onert/core/src/util/ShapeInference.test.cc b/runtime/onert/core/src/util/ShapeInference.test.cc
new file mode 100644
index 000000000..96579bfa2
--- /dev/null
+++ b/runtime/onert/core/src/util/ShapeInference.test.cc
@@ -0,0 +1,544 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/ShapeInference.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+
+TEST(ShapeInference, Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{3};
+ auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.dim(0), 1);
+ ASSERT_EQ(infered_out_shape.dim(1), 299);
+ ASSERT_EQ(infered_out_shape.dim(2), 299);
+ ASSERT_EQ(infered_out_shape.dim(3), 3);
+}
+
+TEST(ShapeInference, neg_Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{5, 3};
+ ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
+}
+
+TEST(ShapeInference, Pool2DNodeSame)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeValid)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::VALID};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeExplicit)
+{
+ Shape in_shape{10, 3, 5, 20};
+
+ Stride stride{3, 7};
+ Padding padding{4, 3, 2, 1};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, neg_Pool2DNode_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{0, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ ASSERT_THROW(onert::shape_inference::inferPoolShape(in_shape, avg_pool_param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Conv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param =
+ operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+}
+
+TEST(ShapeInference, neg_Conv2D_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{0, 0}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, DepthwiseConv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ auto infered_out_shape =
+ onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+}
+
+TEST(ShapeInference, neg_DepthwiseConv2D_InvalidSride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 0}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Concat)
+{
+ {
+ Shape in1{10, 20, 30, 3, 50};
+ Shape in2{10, 20, 30, 2, 50};
+ Shape in3{10, 20, 30, 2, 50};
+
+ operation::Concat::Param param{3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 5);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 30);
+ ASSERT_EQ(infered_out_shape.dim(3), 7);
+ ASSERT_EQ(infered_out_shape.dim(4), 50);
+ }
+ {
+ // case 1. when axis < 0
+ Shape in1{10, 20, 2};
+ Shape in2{10, 20, 3};
+
+ operation::Concat::Param param{-1};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 5);
+ }
+ {
+ // case 2. when axis < 0
+ Shape in1{2, 20, 2};
+ Shape in2{3, 20, 2};
+
+ operation::Concat::Param param{-3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 5);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 2);
+ }
+}
+
+TEST(ShapeInference, neg_Concat)
+{
+ {
+ operation::Concat::Param param{2};
+ Shape in1{10, 1, 3};
+ Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+ { // wrong rank
+ operation::Concat::Param param{2};
+ Shape in1{10, 2, 3, 4};
+ Shape in2{10, 2, 4}; // rank should be 4
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+}
+
+TEST(ShapeInference, ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ auto check = [&](int32_t axis, Shape &expected) {
+ auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
+
+ ASSERT_EQ(actual.rank(), 3);
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ { // boundary
+ int32_t axis = 0;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+ { // boundary
+ int32_t axis = 2;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // inside
+ int32_t axis = 1;
+ Shape expected{30, 1, 40};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -1;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -3;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+}
+
+TEST(ShapeInference, neg_ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ { // over boundary
+ int32_t axis = 3;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+ { // over boundary
+ int32_t axis = -4;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, FullyConnected)
+{
+ Shape in_shape{3, 4, 5, 6};
+ Shape ker_shape{3, 10};
+ auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 2);
+ ASSERT_EQ(infered_out_shape.dim(0), 36);
+ ASSERT_EQ(infered_out_shape.dim(1), 3);
+}
+
+TEST(ShapeInference, Transpose)
+{
+ auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
+ // pre-conditions
+ ASSERT_EQ(in_shape.rank(), perm.size());
+ ASSERT_EQ(expected.rank(), perm.size());
+ auto inferred_out_shape =
+ onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size());
+ // post-conditions
+ ASSERT_EQ(inferred_out_shape.rank(), perm.size());
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ {
+ ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
+ }
+ };
+ // check for 2-D
+ {
+ Shape in_shape{2, 3};
+ std::vector<int> perm = {1, 0};
+ Shape expected{3, 2};
+ // int32_t rank = 2;
+ check(in_shape, perm, expected);
+ }
+ // check for 3-D
+ {
+ Shape in_shape{1, 2, 3};
+ std::vector<int> perm = {2, 0, 1};
+ Shape expected{3, 1, 2};
+ // int32_t rank = 3;
+ check(in_shape, perm, expected);
+ }
+ // check for 4-D
+ {
+ Shape in_shape{1, 2, 3, 4};
+ std::vector<int> perm = {1, 3, 0, 2};
+ Shape expected{2, 4, 1, 3};
+ // int32_t rank = 4;
+ check(in_shape, perm, expected);
+ }
+}
+
+TEST(ShapeInference, neg_Transpose)
+{
+ Shape in_shape{1, 2, 3};
+ // Invalid parameter size
+ {
+ std::vector<int> perm = {2, 0, 1, 0};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+ // Invalid parameter value
+ {
+ std::vector<int> perm = {2, 0, 3};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, Gather)
+{
+ auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) {
+ int rank = input.rank();
+ auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank);
+
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ // check for 2-D, 3-D, axis 0
+ {
+ Shape input{3, 4};
+ Shape indices{1, 1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 1, 2, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 2-D, 3-D, axis 1
+ {
+ Shape input{3, 4};
+ Shape indices{1, 2, 1};
+ int32_t axis = 1;
+ Shape expected{3, 1, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 0
+ {
+ Shape input{2, 3, 4};
+ Shape indices{1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 2
+ {
+ Shape input{2, 3, 4};
+ Shape indices{2, 1};
+ int32_t axis = 2;
+ Shape expected{2, 3, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 4D, axis 0
+ {
+ Shape input{1, 2, 3, 4};
+ Shape indices{2};
+ int32_t axis = 0;
+ Shape expected{2, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+}
+
+TEST(ShapeInference, BCQFullyConnected)
+{
+ auto check = [&](Shape &in_shape, Shape &cluster_shape, std::vector<int> cluster,
+ Shape &expected) {
+ auto actual =
+ onert::shape_inference::inferBCQFullyConnectedShape(in_shape, cluster_shape, cluster.data());
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape in_shape{10, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+
+ Shape expected{30, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+
+ {
+ Shape in_shape{1, 1};
+ Shape cluster_shape{1, 2};
+ std::vector<int> cluster = {3, 50};
+
+ Shape expected{50, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+}
+
+TEST(ShapeInference, BCQGather)
+{
+ auto check = [&](Shape &indices_shape, Shape &cluster_shape, std::vector<int> cluster,
+ uint32_t hidden_size, uint32_t axis, int rank, Shape &expected) {
+ operation::BCQGather::Param param{hidden_size, axis};
+ auto actual = onert::shape_inference::inferBCQGatherShape(indices_shape, cluster_shape,
+ cluster.data(), rank, param);
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 0;
+ int rank = 2;
+
+ Shape expected{5, 1, 10};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 1;
+ int rank = 2;
+
+ Shape expected{30, 5, 1};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+}
diff --git a/runtime/onert/core/src/util/EnvConfigSource.cc b/runtime/onert/core/src/util/TracingCtx.cc
index 0d25b7353..c05baee60 100644
--- a/runtime/onert/core/src/util/EnvConfigSource.cc
+++ b/runtime/onert/core/src/util/TracingCtx.cc
@@ -1,5 +1,6 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,27 +15,16 @@
* limitations under the License.
*/
-#include "util/EnvConfigSource.h"
-
-#include <cstdlib>
+#include "util/TracingCtx.h"
namespace onert
{
namespace util
{
-std::string EnvConfigSource::get(const std::string &key) const
-{
- const char *value = std::getenv(key.c_str());
- if (value != nullptr)
- {
- return value;
- }
- else
- {
- return GeneralConfigSource::get(key);
- }
-}
+// initializing static member var
+std::mutex TracingCtx::_session_id_mutex;
+uint32_t TracingCtx::_next_session_id = 0;
} // namespace util
} // namespace onert