summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/ade/ade/source/passes/communications.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/ade/ade/source/passes/communications.cpp')
m---------inference-engine/thirdparty/ade0
-rw-r--r--inference-engine/thirdparty/ade/ade/source/passes/communications.cpp545
2 files changed, 0 insertions, 545 deletions
diff --git a/inference-engine/thirdparty/ade b/inference-engine/thirdparty/ade
new file mode 160000
+Subproject 0ba3b01dae7262f7828dc6fa65ef3a89fb371cd
diff --git a/inference-engine/thirdparty/ade/ade/source/passes/communications.cpp b/inference-engine/thirdparty/ade/ade/source/passes/communications.cpp
deleted file mode 100644
index fb0da6323..000000000
--- a/inference-engine/thirdparty/ade/ade/source/passes/communications.cpp
+++ /dev/null
@@ -1,545 +0,0 @@
-// Copyright (C) 2018 Intel Corporation
-//
-// SPDX-License-Identifier: Apache-2.0
-//
-
-#include <passes/communications.hpp>
-
-#include <iterator>
-#include <unordered_map>
-#include <unordered_set>
-#include <atomic>
-#include <stdexcept>
-
-#include <typed_graph.hpp>
-
-#include <communication/comm_buffer.hpp>
-#include <communication/comm_interface.hpp>
-#include <communication/callback_connector.hpp>
-
-#include <memory/memory_descriptor.hpp>
-#include <memory/memory_descriptor_view.hpp>
-
-#include <util/algorithm.hpp>
-#include <util/chain_range.hpp>
-
-#include <memory/alloc.hpp>
-
-namespace
-{
-
-using NodeHasher = ade::HandleHasher<ade::Node>;
-
-struct CacheEntry final
-{
- std::unordered_set<ade::NodeHandle, NodeHasher> commNodes;
- std::unordered_set<ade::NodeHandle, NodeHasher> producers;
- std::unordered_set<ade::NodeHandle, NodeHasher> consumers;
-};
-
-using Cache = std::unordered_map<ade::MemoryDescriptorView*, CacheEntry>;
-
-struct CallbackCacheEntry final
-{
- std::unordered_set<ade::NodeHandle, NodeHasher> producers;
- std::unordered_set<ade::NodeHandle, NodeHasher> consumers;
-};
-
-using CallbackCache = std::unordered_map<ade::NodeHandle, CallbackCacheEntry, NodeHasher>;
-
-
-ade::MemoryDescriptorView* findParentView(ade::MemoryDescriptorView* view)
-{
- ASSERT(nullptr != view);
-
- auto parent = view->getParentView();
- if (nullptr != parent)
- {
- return findParentView(parent);
- }
- return view;
-}
-
-void visitProducer(Cache& cache,
- CallbackCache& callbackCache,
- const ade::NodeHandle& commNode,
- const ade::NodeHandle& node,
- ade::passes::ConnectCommChannels::Context& ctx)
-{
- ASSERT(nullptr != node);
- ASSERT(ctx.graph.metadata(node).contains<ade::meta::DataObject>());
- auto memDesc = findParentView(ctx.graph.metadata(node).get<ade::meta::DataObject>().dataRef.getView());
- ASSERT(nullptr != memDesc);
- bool connectedToNode = false;
- for (auto edge: node->inEdges())
- {
- auto srcNode = edge->srcNode();
- if (ctx.graph.metadata(srcNode).contains<ade::meta::NodeInfo>())
- {
- connectedToNode = true;
- callbackCache[commNode].producers.insert(srcNode);
- }
- else if (ctx.graph.metadata(srcNode).contains<ade::meta::DataObject>())
- {
- visitProducer(cache, callbackCache, commNode, srcNode, ctx);
- }
- }
-
- if (connectedToNode)
- {
- cache[memDesc].producers.insert(node);
- cache[memDesc].commNodes.insert(commNode);
- }
-}
-
-void visitConsumer(Cache& cache,
- CallbackCache& callbackCache,
- const ade::NodeHandle& commNode,
- const ade::NodeHandle& node,
- ade::passes::ConnectCommChannels::Context& ctx)
-{
- ASSERT(nullptr != node);
- ASSERT(ctx.graph.metadata(node).contains<ade::meta::DataObject>());
- auto memDesc = findParentView(ctx.graph.metadata(node).get<ade::meta::DataObject>().dataRef.getView());
- ASSERT(nullptr != memDesc);
- bool connectedToNode = false;
- for (auto edge: node->outEdges())
- {
- auto dstNode = edge->dstNode();
- if (ctx.graph.metadata(dstNode).contains<ade::meta::NodeInfo>())
- {
- connectedToNode = true;
- callbackCache[commNode].consumers.insert(dstNode);
- }
- else if (ctx.graph.metadata(dstNode).contains<ade::meta::DataObject>())
- {
- visitConsumer(cache, callbackCache, commNode, dstNode, ctx);
- }
- }
-
- if (connectedToNode)
- {
- cache[memDesc].consumers.insert(node);
- cache[memDesc].commNodes.insert(commNode);
- }
-}
-
-struct DataObject final
-{
- ade::MemoryDescriptorRef memory_ref;
- std::vector<ade::NodeHandle> commNodes;
- std::vector<ade::NodeHandle> producers;
- std::vector<ade::NodeHandle> consumers;
-};
-
-struct CallbackObject final
-{
- ade::NodeHandle commNode;
- std::vector<ade::NodeHandle> producers;
- std::vector<ade::NodeHandle> consumers;
-};
-
-struct CommObjects
-{
- std::vector<DataObject> dataObjects;
- std::vector<CallbackObject> callbackObjects;
-};
-
-CommObjects collectDataObjects(ade::passes::ConnectCommChannels::Context& ctx)
-{
- Cache cache;
- CallbackCache callbackCache;
- for (auto node: ctx.graph.nodes())
- {
- auto meta = ctx.graph.metadata(node);
- if (meta.contains<ade::meta::CommNode>())
- {
- for (auto edge: node->inEdges())
- {
- auto srcNode = edge->srcNode();
- visitProducer(cache, callbackCache, node, srcNode, ctx);
- }
-
- for (auto edge: node->outEdges())
- {
- auto dstNode = edge->dstNode();
- visitConsumer(cache, callbackCache, node, dstNode, ctx);
- }
- }
- }
-
- CommObjects ret;
- for (auto& obj: cache)
- {
- DataObject newObj;
- newObj.memory_ref = *obj.first;
- newObj.commNodes.reserve(obj.second.commNodes.size());
- newObj.producers.reserve(obj.second.producers.size());
- newObj.consumers.reserve(obj.second.consumers.size());
- util::copy(obj.second.commNodes, std::back_inserter(newObj.commNodes));
- util::copy(obj.second.producers, std::back_inserter(newObj.producers));
- util::copy(obj.second.consumers, std::back_inserter(newObj.consumers));
- ASSERT(!newObj.commNodes.empty());
- ASSERT(!newObj.producers.empty());
- ASSERT(!newObj.consumers.empty());
- ret.dataObjects.emplace_back(std::move(newObj));
- }
-
- for (auto& obj: callbackCache)
- {
- CallbackObject newObj;
- newObj.commNode = obj.first;
- newObj.producers.reserve(obj.second.producers.size());
- newObj.consumers.reserve(obj.second.consumers.size());
- util::copy(obj.second.producers, std::back_inserter(newObj.producers));
- util::copy(obj.second.consumers, std::back_inserter(newObj.consumers));
- ASSERT(!newObj.producers.empty());
- ASSERT(!newObj.consumers.empty());
- ret.callbackObjects.emplace_back(std::move(newObj));
- }
- return ret;
-}
-
-/// Fill common part of the BufferDesc
-template<typename T>
-ade::ICommChannel::BufferDesc fillBufferDesc(T& elem)
-{
- auto memRef = elem.memory_ref;
- ASSERT(nullptr != memRef);
- ade::ICommChannel::BufferDesc bufferDesc;
-
- // Fill common part of the BufferDesc
- bufferDesc.writersCount = util::checked_cast<int>(elem.producers.size());
- bufferDesc.readersCount = util::checked_cast<int>(elem.consumers.size());
- bufferDesc.memoryRef = memRef;
- return bufferDesc;
-}
-
-class HostBufferImpl final : public ade::IDataBuffer
-{
-public:
- HostBufferImpl(std::size_t elementSize,
- const ade::memory::DynMdSize& size,
- const ade::memory::DynMdSize& alignment);
-
- HostBufferImpl(const ade::MemoryDescriptorRef& memRef);
-
- ~HostBufferImpl();
-
- // IDataBuffer interface
- virtual MapId map(const Span& span, Access access) override;
- virtual void unmap(const MapId& id) override;
- virtual void finalizeWrite(const ade::IDataBuffer::Span& span) override;
- virtual void finalizeRead(const ade::IDataBuffer::Span& span) override;
- virtual Size alignment(const Span& span) override;
-
-private:
- struct Deleter
- {
- void operator()(void* ptr) const
- {
- ASSERT(nullptr != ptr);
- ade::aligned_free(ptr);
- }
- };
-
- std::atomic<int> m_accessCount = {0};
- ade::memory::DynMdSize m_size;
- ade::memory::DynMdSize m_alignment;
- ade::memory::DynMdView<void> m_view;
- std::unique_ptr<void, Deleter> m_memory;
- ade::MemoryDescriptorRef m_memRef;
-};
-
-HostBufferImpl::HostBufferImpl(std::size_t elementSize,
- const ade::memory::DynMdSize& size,
- const ade::memory::DynMdSize& alignment):
- m_size(size),
- m_alignment(alignment),
- m_view(util::alloc_view<ade::memory::MaxDimensions>
- (elementSize,
- util::memory_range(size.data(), size.dims_count()),
- util::memory_range(alignment.data(), alignment.dims_count()),
- [](std::size_t size, std::size_t align)
- {
- auto ptr = ade::aligned_alloc(size, align);
- if (nullptr == ptr)
- {
- throw_error(std::bad_alloc());
- }
- return ptr;
- })),
- m_memory(m_view.mem.data)
-{
-
-}
-
-HostBufferImpl::HostBufferImpl(const ade::MemoryDescriptorRef& memRef):
- m_size(memRef.span().size()),
- m_memRef(memRef)
-{
-
-}
-
-HostBufferImpl::~HostBufferImpl()
-{
- ASSERT(0 == m_accessCount);
-}
-
-ade::IDataBuffer::MapId HostBufferImpl::map(const Span& span, Access /*access*/)
-{
- auto view = (nullptr != m_view ? m_view : m_memRef.getExternalView());
- ASSERT(nullptr != view);
- ASSERT(span.dims_count() == m_size.dims_count());
- auto accessCount = ++m_accessCount;
- ASSERT(accessCount > 0);
- return MapId{view.slice(span), 0};
-}
-
-void HostBufferImpl::unmap(const MapId& /*id*/)
-{
- auto accessCount = --m_accessCount;
- ASSERT(accessCount >= 0);
-}
-
-void HostBufferImpl::finalizeWrite(const ade::IDataBuffer::Span& /*span*/)
-{
- //Nothing
-}
-
-void HostBufferImpl::finalizeRead(const ade::IDataBuffer::Span& /*span*/)
-{
- //Nothing
-}
-
-ade::IDataBuffer::Size HostBufferImpl::alignment(const Span& span)
-{
- ASSERT(span.dims_count() == m_size.dims_count());
- // TODO: report actual alignment
- Size ret;
- ret.redim(span.dims_count());
- util::fill(ret, 1);
- return ret;
-}
-
-}
-
-void ade::passes::ConnectCommChannels::operator()(ade::passes::ConnectCommChannels::Context ctx) const
-{
- // Step 1:
- // Collect all data objects directly or indirectly connected to comm nodes
- // group them by MemoryDescriptor and by a commnode
- const auto commObjects = collectDataObjects(ctx);
-
- // Step 2:
- // Check comm channels and callbacks validity
- {
- // Step 2.1
- // Check comm channels validity
- for (auto& elem: commObjects.dataObjects)
- {
- for (auto node: util::chain(util::toRange(elem.producers),
- util::toRange(elem.consumers)))
- {
- auto meta = ctx.graph.metadata(node);
- if (!meta.contains<ade::meta::DataObject>() ||
- !meta.contains<ade::meta::CommChannel>() ||
- nullptr == meta.get<ade::meta::CommChannel>().channel)
- {
- throw_error(std::runtime_error("Comm channel wasn't setup properly"));
- }
- }
- }
-
- // Step 2.2
- // Check comm callbacks validity
- for (auto& elem: commObjects.callbackObjects)
- {
- for (auto node: elem.consumers)
- {
- auto meta = ctx.graph.metadata(node);
- if (!meta.contains<ade::meta::CommConsumerCallback>() ||
- nullptr == meta.get<ade::meta::CommConsumerCallback>().callback)
- {
- throw_error(std::runtime_error("Consumer callback metadata error"));
- }
- }
- }
- }
-
- // Step 3:
- // Connect comm channels
- for (auto& elem: commObjects.dataObjects)
- {
- ade::ICommChannel::BufferDesc bufferDesc = fillBufferDesc(elem);
-
- // Step 3.1:
- // Collect buffer preferences
- ade::ICommChannel::BufferPrefs summary;
- summary.preferredAlignment.redim(bufferDesc.memoryRef.span().dims_count());
- util::fill(summary.preferredAlignment, 1);
- for (auto node: util::chain(util::toRange(elem.producers),
- util::toRange(elem.consumers)))
- {
- auto meta = ctx.graph.metadata(node);
- auto channel = meta.get<ade::meta::CommChannel>().channel;
- ASSERT(nullptr != channel);
- ade::ICommChannel::BufferPrefs prefs = channel->getBufferPrefs(bufferDesc);
- ASSERT(prefs.preferredAlignment.dims_count() == summary.preferredAlignment.dims_count());
- for (auto i: util::iota(summary.preferredAlignment.dims_count()))
- {
- ASSERT(prefs.preferredAlignment[i] > 0);
- // TODO: assert alignment is power of 2
- summary.preferredAlignment[i] =
- std::max(summary.preferredAlignment[i],
- prefs.preferredAlignment[i]);
- }
- }
-
- // Step 3.2:
- // Try to get buffer from channels
- std::unique_ptr<ade::IDataBuffer> buffer;
- for (auto node: util::chain(util::toRange(elem.producers),
- util::toRange(elem.consumers)))
- {
- ASSERT(nullptr == buffer);
- auto meta = ctx.graph.metadata(node);
- auto channel = meta.get<ade::meta::CommChannel>().channel;
- ASSERT(nullptr != channel);
- buffer = channel->getBuffer(bufferDesc, summary);
- if (nullptr != buffer)
- {
- break;
- }
- }
-
- if (nullptr == buffer)
- {
- // Step 3.3:
- // Buffer wasn't allocated by plugins, allocate it by framework
- if (nullptr == bufferDesc.memoryRef.getExternalView())
- {
- buffer.reset(new HostBufferImpl(bufferDesc.memoryRef.elementSize(),
- bufferDesc.memoryRef.size(),
- summary.preferredAlignment));
- }
- else
- {
- // Use existing buffer (e.g. from non-virtual object)
- buffer.reset(new HostBufferImpl(bufferDesc.memoryRef));
- }
- }
-
- // Step 3.4:
- // Notify plugins about buffer object
- ASSERT(nullptr != buffer);
- for (auto node: util::chain(util::toRange(elem.producers),
- util::toRange(elem.consumers)))
- {
- auto meta = ctx.graph.metadata(node);
- auto channel = meta.get<ade::meta::CommChannel>().channel;
- channel->setBuffer(ade::DataBufferView(buffer.get(), bufferDesc.memoryRef.span()), bufferDesc);
- }
- std::shared_ptr<ade::IDataBuffer> sharedBuffer(std::move(buffer));
- for (auto commNode: elem.commNodes)
- {
- auto meta = ctx.graph.metadata(commNode);
- meta.get<ade::meta::CommNode>().addDataBuffer(sharedBuffer);
- }
- }
-
- // Step 4
- // Connect comm objects callbacks
- {
- // Multiple comm nodes can be attached to single producer data object
- // so we need to collect and merge them
- std::unordered_map<ade::NodeHandle, std::vector<std::function<void()>>, NodeHasher> producerCallbacks;
- for (auto& elem: commObjects.callbackObjects)
- {
- ASSERT(nullptr != elem.commNode);
- ASSERT(!elem.producers.empty() && !elem.consumers.empty());
-
- ade::CallbackConnector<> connector(util::checked_cast<int>(elem.producers.size()),
- util::checked_cast<int>(elem.consumers.size()));
-
- // Step 4.1
- // Collect callbacks from consumers
- for (auto& consumer: elem.consumers)
- {
- auto meta = ctx.graph.metadata(consumer);
- auto callback = std::move(meta.get<ade::meta::CommConsumerCallback>().callback);
- ASSERT(nullptr != callback);
- connector.addConsumerCallback(std::move(callback));
- }
-
- // Step 4.2
- // Create producer callbacks
- auto resetter = connector.finalize();
- if (nullptr != resetter)
- {
- auto meta = ctx.graph.metadata();
- if (!meta.contains<ade::meta::Finalizers>())
- {
- meta.set(ade::meta::Finalizers());
- }
- meta.get<ade::meta::Finalizers>().finalizers.emplace_back(std::move(resetter));
- }
-
- // Step 4.3
- // Collect producer callbacks
- for (auto& producer: elem.producers)
- {
- auto callback = connector.getProducerCallback();
- ASSERT(nullptr != callback);
- producerCallbacks[producer].emplace_back(std::move(callback));
- }
- }
-
- // Step 4.4
- // Assign producer callbacks
- for (auto& elem: producerCallbacks)
- {
- auto producer = elem.first;
-
- auto callbacks = std::move(elem.second);
- ASSERT(!callbacks.empty());
-
- auto meta = ctx.graph.metadata(producer);
- if (!meta.contains<ade::meta::CommProducerCallback>())
- {
- meta.set(ade::meta::CommProducerCallback());
- }
-
- if (1 == callbacks.size())
- {
- // Assign directly
- meta.get<ade::meta::CommProducerCallback>().callback = callbacks[0];
- }
- else
- {
- // Create wrapper to call all callbacks
- struct Connector final
- {
- std::vector<std::function<void()>> callbacks;
-
- void operator()() const
- {
- ASSERT(!callbacks.empty());
- for (auto& callback: callbacks)
- {
- ASSERT(nullptr != callback);
- callback();
- }
- }
- };
-
- meta.get<ade::meta::CommProducerCallback>().callback = Connector{std::move(callbacks)};
- }
- }
- }
-}
-
-const char* ade::passes::ConnectCommChannels::name()
-{
- return "ade::passes::ConnectCommChannels";
-}