diff options
author | Duc Ngo <duc@fb.com> | 2018-10-18 12:30:31 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-18 12:32:40 -0700 |
commit | 2c566a17c763ab000ebeba1d3c01762bae814e42 (patch) | |
tree | a600faa4bb11fbd5ea76757edfbe05adc27bbba0 | |
parent | 9c617140f79f82628484e7545b228be990994196 (diff) | |
download | pytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.tar.gz pytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.tar.bz2 pytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.zip |
nomnigraph - simplify subgraph matching APIs (#12681)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12681
- Get rid of NodeMatchCriteria as a template parameter, which was too generic. So MatchNode<NodeMatchCriteria> becomes MatchNode<GraphType>, and MatchStore stores the predicate on GraphType::NodeRef.
- Similarly, get rid of NNNodeMatchCriteria
Now one can just pass in a function pointer NodeRef -> bool to NNMatchNode constructor directly like this
mg.createNode(is<Relu>)
- Merge static utilities in SubgraphMatcher class into MatchGraph class
- Rename MatchNode to MatchPredicate
Change use cases and tests to make it work
Reviewed By: ZolotukhinM
Differential Revision: D10386907
fbshipit-source-id: 43874bd154e3d7c29ce07b4b74eca8a7a9f3078a
7 files changed, 250 insertions, 336 deletions
diff --git a/caffe2/core/nomnigraph/Representations/NeuralNet.cc b/caffe2/core/nomnigraph/Representations/NeuralNet.cc index d3ba0b2b38..06879d8c0b 100644 --- a/caffe2/core/nomnigraph/Representations/NeuralNet.cc +++ b/caffe2/core/nomnigraph/Representations/NeuralNet.cc @@ -180,49 +180,29 @@ void coalesceInsertedDataDependencies(repr::NNModule* m) { } } -std::ostream& operator<<( - std::ostream& oss, - const NNNodeMatchCriteria& criteria) { - return oss << criteria.debugString; -} - -NNNodeMatchCriteria criteriaSingleOutputAndConsumer() { - return NNNodeMatchCriteria( - [](NNGraph::NodeRef nodeRef) { - auto nodeOutputs = nn::getOutputs(nodeRef); - NOM_REQUIRE_OR_RET_FALSE(nodeOutputs.size() == 1); - auto nodeConsumers = nn::getConsumers(nodeOutputs.front()); - return nodeConsumers.size() == 1; - }, - "Single output and consumer"); -} - -NNNodeMatchCriteria criteriaSingleConsumer() { - return NNNodeMatchCriteria( - [](NNGraph::NodeRef nodeRef) { - auto nodeOutputs = nn::getOutputs(nodeRef); - NNGraph::NodeRef nodeConsumer = nullptr; - for (auto nodeOutput : nodeOutputs) { - for (auto consumer : nn::getConsumers(nodeOutput)) { - if (nodeConsumer && consumer && consumer != nodeConsumer) { - return false; - } - nodeConsumer = consumer; - } - } - return true; - }, - "Single consumer"); -} - -NNNodeMatchCriteria matchTensor(const std::string& debugString) { - return matchOp<nom::repr::Tensor>(debugString); +bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef) { + auto nodeOutputs = nn::getOutputs(nodeRef); + NOM_REQUIRE_OR_RET_FALSE(nodeOutputs.size() == 1); + auto nodeConsumers = nn::getConsumers(nodeOutputs.front()); + return nodeConsumers.size() == 1; +} + +bool hasUniqueConsumer(NNGraph::NodeRef nodeRef) { + auto nodeOutputs = nn::getOutputs(nodeRef); + NNGraph::NodeRef nodeConsumer = nullptr; + for (auto nodeOutput : nodeOutputs) { + for (auto consumer : nn::getConsumers(nodeOutput)) { + if (nodeConsumer && consumer && consumer != nodeConsumer) { + return false; + } + nodeConsumer = consumer; + } + } + return true; } -NNMatchNode matchExternalTensorNode(const std::string& debugString) { - return NNMatchNode(matchTensor(debugString)) - .nonTerminal() - .excludeFromSubgraph(); +NNMatchPredicate matchExternalTensorNode() { + return NNMatchPredicate(nn::is<Tensor>).nonTerminal().excludeFromSubgraph(); } } // namespace nn diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h index b12e57d751..61fbebcf9b 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h @@ -15,6 +15,7 @@ #include <iterator> #include <list> #include <unordered_set> +#include <utility> #include <vector> #include <assert.h> @@ -240,6 +241,11 @@ class Graph { return createNodeInternal(Node<T, U...>(std::move(data))); } + template <class Arg> + NodeRef createNode(Arg&& arg) { + return createNode(T(std::forward<Arg>(arg))); + } + NodeRef createNode() { return createNodeInternal(Node<T, U...>()); } diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h index 523f29225a..9a73463e3c 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h @@ -297,9 +297,9 @@ struct C10_EXPORT } }; -template <typename T, typename N> -inline bool is(N n) { - return is_impl<T, N>::impl(n); +template <typename T> +inline bool is(NNGraph::NodeRef n) { + return is_impl<T, NNGraph::NodeRef>::impl(n); } // This is just a way to fix issues when the dyn_cast<> implementation @@ -433,82 +433,18 @@ CAFFE2_API void coalesceInsertedDataDependencies(repr::NNModule* m); template <NNGraph* G> struct C10_EXPORT NodeHelper {}; -struct NNNodeMatchCriteria { - std::function<bool(NNGraph::NodeRef)> predicate; - std::string debugString; - - NNNodeMatchCriteria( - const std::function<bool(NNGraph::NodeRef)>& predicate, - const std::string& debugString = "No debug string specified") - : predicate(predicate), debugString(debugString){}; - - NNNodeMatchCriteria() = default; - NNNodeMatchCriteria(const NNNodeMatchCriteria&) = default; - NNNodeMatchCriteria& operator=(const NNNodeMatchCriteria&) = default; - NNNodeMatchCriteria(NNNodeMatchCriteria&&) = default; - - NNNodeMatchCriteria andCriteria(const NNNodeMatchCriteria& other) { - auto thisPredicate = predicate; - auto otherPredicate = other.predicate; - return NNNodeMatchCriteria( - [thisPredicate, otherPredicate](NNGraph::NodeRef node) { - return thisPredicate(node) && otherPredicate(node); - }, - debugString + " and " + other.debugString); - } -}; - -CAFFE2_API std::ostream& operator<<( - std::ostream& oss, - const NNNodeMatchCriteria& criteria); - -using NNMatchGraph = nom::matcher::MatchGraph<NNNodeMatchCriteria>; -using NNMatchNode = nom::matcher::MatchNode<NNNodeMatchCriteria>; +using NNMatchGraph = nom::matcher::MatchGraph<NNGraph>; +using NNMatchPredicate = nom::matcher::MatchPredicate<NNGraph>; -// Commonly used criteria. +// Commonly used node predicate. // The node has a single output and the output has a single consumer. -CAFFE2_API NNNodeMatchCriteria criteriaSingleOutputAndConsumer(); +CAFFE2_API bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef); // The node has a unique consumer (there may be multiple edges from output // to the single consumer). -CAFFE2_API NNNodeMatchCriteria criteriaSingleConsumer(); - -template <typename NodeType> -NNNodeMatchCriteria matchOp(const std::string& debugString = "matchOp") { - return NNNodeMatchCriteria( - [](NNGraph::NodeRef nodeRef) { return is<NodeType>(nodeRef); }, - debugString); -} - -template <typename NodeType> -NNNodeMatchCriteria matchOp( - const std::function<bool(const NodeType&)> predicate, - const std::string& debugString = "matchOpWithPredicate") { - return NNNodeMatchCriteria( - [predicate](NNGraph::NodeRef nodeRef) { - NOM_REQUIRE_OR_RET_FALSE(is<NodeType>(nodeRef)); - NodeType* node = get<NodeType>(nodeRef); - return predicate(*node); - }, - debugString); -}; - -CAFFE2_API NNNodeMatchCriteria -matchTensor(const std::string& debugString = "matchTensor"); - -CAFFE2_API NNMatchNode -matchExternalTensorNode(const std::string& debugString = "matchExternalTensor"); - -struct CAFFE2_API NNNodeMatch { - static bool isMatch( - const NNGraph::NodeRef& node, - const NNNodeMatchCriteria& criteria) { - return criteria.predicate(node); - } -}; +CAFFE2_API bool hasUniqueConsumer(NNGraph::NodeRef nodeRef); -using NNSubgraphMatcher = - nom::matcher::SubgraphMatcher<NNGraph, NNNodeMatchCriteria, NNNodeMatch>; +CAFFE2_API NNMatchPredicate matchExternalTensorNode(); } // namespace nn diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h index a303324fbb..c237b6e9f9 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h @@ -15,9 +15,9 @@ namespace nom { namespace matcher { /** - * MatchGraph is a graph of MatchNode. + * MatchGraph is a graph of MatchPredicate. * - * MatchNode needs a NodeMatchCriteria (a predicate for node matching) and + * MatchPredicate needs a predicate for node matching and * - includeInSubgraph: whether this node and nodes/edges reachable from it * should be included in the return matched subgraph (if the pattern matches). * This is useful in case we would like to specify a matching pattern but do not @@ -28,19 +28,21 @@ namespace matcher { * from the node when doing subgraph matching. */ -template <typename NodeMatchCriteria> -class MatchNode { +template <typename GraphType> +class MatchPredicate { public: + using Predicate = std::function<bool(typename GraphType::NodeRef)>; + static const int kStarCount = -1; - MatchNode(const NodeMatchCriteria& criteria) : criteria_(criteria) {} + MatchPredicate(const Predicate& criteria) : criteria_(criteria) {} - MatchNode() = default; - MatchNode(const MatchNode&) = default; - MatchNode& operator=(const MatchNode&) = default; - MatchNode(MatchNode&&) = default; + MatchPredicate() = default; + MatchPredicate(const MatchPredicate&) = default; + MatchPredicate& operator=(const MatchPredicate&) = default; + MatchPredicate(MatchPredicate&&) = default; - NodeMatchCriteria getCriteria() const { + Predicate getCriteria() const { return criteria_; } @@ -48,21 +50,21 @@ class MatchNode { return count_; } - MatchNode<NodeMatchCriteria>& count(int count) { + MatchPredicate<GraphType>& count(int count) { count_ = count; return *this; } - MatchNode<NodeMatchCriteria>& starCount() { + MatchPredicate<GraphType>& starCount() { return count(kStarCount); } - MatchNode<NodeMatchCriteria>& nonTerminal() { + MatchPredicate<GraphType>& nonTerminal() { nonTerminal_ = true; return *this; } - MatchNode<NodeMatchCriteria>& excludeFromSubgraph() { + MatchPredicate<GraphType>& excludeFromSubgraph() { includeInSubgraph_ = false; return *this; } @@ -75,131 +77,44 @@ class MatchNode { return includeInSubgraph_; } + std::string getDebugString() const { + return debugString_; + } + + void setDebugString(const std::string& debugString) { + debugString_ = debugString; + } + private: - NodeMatchCriteria criteria_; + Predicate criteria_; int count_ = 1; bool includeInSubgraph_ = true; bool nonTerminal_ = false; + std::string debugString_; }; -template <typename NodeMatchCriteria> -using MatchGraph = Graph<MatchNode<NodeMatchCriteria>>; - -template <typename NodeMatchCriteria> -using MatchNodeRef = typename MatchGraph<NodeMatchCriteria>::NodeRef; - -// TODO: Reuse convertToDotString once convertToDotString can work -// with subgraph. -template <typename NodeMatchCriteria> -std::string debugString( - MatchNodeRef<NodeMatchCriteria> rootCriteriaRef, - bool invertGraphTraversal) { - std::ostringstream out; - auto rootNode = rootCriteriaRef->data(); - out << "{rootCriteria = '" << rootNode.getCriteria() << "'"; - if (rootNode.getCount() != 1) { - out << ", count = " << rootNode.getCount(); - } - if (rootNode.isNonTerminal()) { - out << ", nonTerminal = " << rootNode.isNonTerminal(); - } - auto edges = invertGraphTraversal ? rootCriteriaRef->getInEdges() - : rootCriteriaRef->getOutEdges(); - if (!edges.empty()) { - out << ", childrenCriteria = ["; - for (auto& child : edges) { - auto nextNode = invertGraphTraversal ? child->tail() : child->head(); - out << debugString<NodeMatchCriteria>(nextNode, invertGraphTraversal) - << ", "; - } - out << "]"; - } - out << "}"; - return out.str(); -} +template <typename GraphType> +class SubgraphMatchResult; -template <typename NodeMatchCriteria, typename GraphType> -class SubgraphMatchResult { +// MatchGraph is a graph of MatchPredicate and it contains utilities for +// subgraph matching. +// (TODO) the subgraph matching methods currently still +// requires a root match node to be passed in. We should improve the matching +// algorithm to eliminate this requirement. +template <typename GraphType> +class MatchGraph : public Graph<MatchPredicate<GraphType>> { public: - // Map from match node to corresponding node in the graph to be scanned. - using MatchNodeMap = std::unordered_map< - MatchNodeRef<NodeMatchCriteria>, - typename GraphType::NodeRef>; - - static SubgraphMatchResult<NodeMatchCriteria, GraphType> notMatched( - const std::string& debugMessage) { - return SubgraphMatchResult<NodeMatchCriteria, GraphType>( - false, debugMessage); - } - - static SubgraphMatchResult<NodeMatchCriteria, GraphType> notMatched() { - return SubgraphMatchResult<NodeMatchCriteria, GraphType>( - false, "Debug message is not enabled"); - } - - static SubgraphMatchResult<NodeMatchCriteria, GraphType> matched( - bool ownSubgraph = false) { - return SubgraphMatchResult<NodeMatchCriteria, GraphType>( - true, "Matched", ownSubgraph); - } - - bool isMatch() const { - return isMatch_; - } - - std::string getDebugMessage() const { - return debugMessage_; - } - - std::shared_ptr<typename GraphType::SubgraphType> getMatchedSubgraph() const { - return matchedSubgraph_; - } - - std::shared_ptr<MatchNodeMap> getMatchNodeMap() const { - return matchNodeMap_; - } - - private: - SubgraphMatchResult( - bool isMatch, - const std::string& debugMessage, - bool ownSubgraph = false) - : isMatch_(isMatch), - debugMessage_(debugMessage), - matchedSubgraph_( - ownSubgraph ? std::shared_ptr<typename GraphType::SubgraphType>( - new typename GraphType::SubgraphType()) - : nullptr), - matchNodeMap_( - ownSubgraph ? std::shared_ptr<MatchNodeMap>(new MatchNodeMap()) - : nullptr) {} - - const bool isMatch_; - const std::string debugMessage_; - const std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph_; - const std::shared_ptr<MatchNodeMap> matchNodeMap_; -}; - -/* - * Utilities for subgraph matching. - */ -template < - typename GraphType, - typename NodeMatchCriteria, - typename NodeMatcherClass> -struct SubgraphMatcher { - using SubgraphMatchResultType = - SubgraphMatchResult<NodeMatchCriteria, GraphType>; + using SubgraphMatchResultType = SubgraphMatchResult<GraphType>; using ReplaceGraphOperation = std::function<bool( GraphType&, typename GraphType::NodeRef, const SubgraphMatchResultType&)>; - static bool isNodeMatch( + bool isNodeMatch( typename GraphType::NodeRef node, - const NodeMatchCriteria& criteria) { - return NodeMatcherClass::isMatch(node, criteria); + const MatchPredicate<GraphType>& matchPredicate) const { + return matchPredicate.getCriteria()(node); } // Check if there can be a subgraph that matches the given criteria that @@ -207,11 +122,11 @@ struct SubgraphMatcher { // The flag invertGraphTraversal specify if we should follow out edges or // in edges. The default is true which is useful for a functional // intepretation of a dataflow graph. - static SubgraphMatchResultType isSubgraphMatch( + SubgraphMatchResultType isSubgraphMatch( typename GraphType::NodeRef root, - const MatchNodeRef<NodeMatchCriteria>& rootCriteriaRef, + const typename MatchGraph::NodeRef& rootCriteriaRef, bool invertGraphTraversal = true, - bool debug = false) { + bool debug = false) const { // Create a matched result that owns a matched subgraph object and pass // the subgraph object around to construct it during matching. auto matchedResult = SubgraphMatchResultType::matched(true); @@ -236,11 +151,11 @@ struct SubgraphMatcher { // is aborted. This maybe useful in certain cases when we want to terminate // the subgraph search early. // invertGraphTraversal flag: see documentation in isSubgraphMatch - static void replaceSubgraph( + void replaceSubgraph( GraphType& graph, - const MatchNodeRef<NodeMatchCriteria>& criteria, + const typename MatchGraph::NodeRef& criteria, const ReplaceGraphOperation& replaceFunction, - bool invertGraphTraversal = true) { + bool invertGraphTraversal = true) const { for (auto nodeRef : graph.getMutableNodes()) { // Make sure the node is still in the graph. if (!graph.hasNode(nodeRef)) { @@ -259,15 +174,15 @@ struct SubgraphMatcher { } private: - static SubgraphMatchResultType isSubgraphMatchInternal( + SubgraphMatchResultType isSubgraphMatchInternal( std::shared_ptr<typename SubgraphMatchResultType::MatchNodeMap> matchedNodes, std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph, typename GraphType::NodeRef root, - const MatchNodeRef<NodeMatchCriteria>& rootCriteriaRef, + const typename MatchGraph::NodeRef& rootCriteriaRef, bool includeInSubgraph, bool invertGraphTraversal, - bool debug) { + bool debug) const { auto rootCriteriaNode = rootCriteriaRef->data(); if (rootCriteriaNode.getCount() == 1) { @@ -283,8 +198,7 @@ struct SubgraphMatcher { std::ostringstream debugMessage; debugMessage << "Subgraph root at " << root << " is not the same as " << matchedNode << " which previously matched criteria " - << debugString<NodeMatchCriteria>( - rootCriteriaRef, invertGraphTraversal); + << debugString(rootCriteriaRef, invertGraphTraversal); return SubgraphMatchResultType::notMatched(debugMessage.str()); } else { return SubgraphMatchResultType::notMatched(); @@ -292,13 +206,12 @@ struct SubgraphMatcher { } } - if (!isNodeMatch(root, rootCriteriaNode.getCriteria())) { + if (!isNodeMatch(root, rootCriteriaNode)) { if (debug) { std::ostringstream debugMessage; debugMessage << "Subgraph root at " << root << " does not match criteria " - << debugString<NodeMatchCriteria>( - rootCriteriaRef, invertGraphTraversal); + << debugString(rootCriteriaRef, invertGraphTraversal); return SubgraphMatchResultType::notMatched(debugMessage.str()); } else { return SubgraphMatchResultType::notMatched(); @@ -334,8 +247,7 @@ struct SubgraphMatcher { : criteriaEdges[criteriaIdx]->head(); int expectedCount = childrenCriteriaRef->data().getCount(); - bool isStarCount = - expectedCount == MatchNode<NodeMatchCriteria>::kStarCount; + bool isStarCount = expectedCount == MatchPredicate<GraphType>::kStarCount; int countMatch = 0; @@ -368,7 +280,7 @@ struct SubgraphMatcher { std::ostringstream debugMessage; debugMessage << "Child node at " << child << " does not match child criteria " - << debugString<NodeMatchCriteria>( + << debugString( childrenCriteriaRef, invertGraphTraversal) << ". We expected " << expectedCount << " matches but only found " << countMatch << "."; @@ -394,8 +306,7 @@ struct SubgraphMatcher { std::ostringstream debugMessage; debugMessage << "Expected " << expectedCount << " matches for child criteria " - << debugString<NodeMatchCriteria>( - childrenCriteriaRef, invertGraphTraversal) + << debugString(childrenCriteriaRef, invertGraphTraversal) << " but only found " << countMatch; return SubgraphMatchResultType::notMatched(debugMessage.str()); } else { @@ -423,6 +334,93 @@ struct SubgraphMatcher { } return SubgraphMatchResultType::matched(); } + + // TODO: Reuse convertToDotString once convertToDotString can work + // with subgraph. + std::string debugString( + typename MatchGraph::NodeRef rootCriteriaRef, + bool invertGraphTraversal) const { + std::ostringstream out; + auto rootNode = rootCriteriaRef->data(); + out << "{root = '" << rootNode.getDebugString() << "'"; + if (rootNode.getCount() != 1) { + out << ", count = " << rootNode.getCount(); + } + if (rootNode.isNonTerminal()) { + out << ", nonTerminal = " << rootNode.isNonTerminal(); + } + auto edges = invertGraphTraversal ? rootCriteriaRef->getInEdges() + : rootCriteriaRef->getOutEdges(); + if (!edges.empty()) { + out << ", childrenCriteria = ["; + for (auto& child : edges) { + auto nextNode = invertGraphTraversal ? child->tail() : child->head(); + out << debugString(nextNode, invertGraphTraversal) << ", "; + } + out << "]"; + } + out << "}"; + return out.str(); + } +}; + +template <typename GraphType> +class SubgraphMatchResult { + public: + // Map from match node to corresponding node in the graph to be scanned. + using MatchNodeMap = std::unordered_map< + typename MatchGraph<GraphType>::NodeRef, + typename GraphType::NodeRef>; + + static SubgraphMatchResult<GraphType> notMatched( + const std::string& debugMessage) { + return SubgraphMatchResult<GraphType>(false, debugMessage); + } + + static SubgraphMatchResult<GraphType> notMatched() { + return SubgraphMatchResult<GraphType>( + false, "Debug message is not enabled"); + } + + static SubgraphMatchResult<GraphType> matched(bool ownSubgraph = false) { + return SubgraphMatchResult<GraphType>(true, "Matched", ownSubgraph); + } + + bool isMatch() const { + return isMatch_; + } + + std::string getDebugMessage() const { + return debugMessage_; + } + + std::shared_ptr<typename GraphType::SubgraphType> getMatchedSubgraph() const { + return matchedSubgraph_; + } + + std::shared_ptr<MatchNodeMap> getMatchNodeMap() const { + return matchNodeMap_; + } + + private: + SubgraphMatchResult( + bool isMatch, + const std::string& debugMessage, + bool ownSubgraph = false) + : isMatch_(isMatch), + debugMessage_(debugMessage), + matchedSubgraph_( + ownSubgraph ? std::shared_ptr<typename GraphType::SubgraphType>( + new typename GraphType::SubgraphType()) + : nullptr), + matchNodeMap_( + ownSubgraph ? std::shared_ptr<MatchNodeMap>(new MatchNodeMap()) + : nullptr) {} + + const bool isMatch_; + const std::string debugMessage_; + const std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph_; + const std::shared_ptr<MatchNodeMap> matchNodeMap_; }; } // namespace matcher diff --git a/caffe2/core/nomnigraph/tests/NeuralNetTest.cc b/caffe2/core/nomnigraph/tests/NeuralNetTest.cc index 874da120b5..7aff67880c 100644 --- a/caffe2/core/nomnigraph/tests/NeuralNetTest.cc +++ b/caffe2/core/nomnigraph/tests/NeuralNetTest.cc @@ -45,30 +45,29 @@ TEST(NeuralNetGraph, ReplaceGraph) { auto mg = NNMatchGraph(); auto matchSumInput = mg.createNode(std::move(matchExternalTensorNode().count(2))); - auto matchSum = mg.createNode(matchOp<Sum>("matchSum")); + auto matchSum = mg.createNode(nn::is<Sum>); mg.createEdge(matchSumInput, matchSum); - auto matchSumOutput = mg.createNode(matchTensor("matchSumOutput")); + auto matchSumOutput = mg.createNode(nn::is<Tensor>); mg.createEdge(matchSum, matchSumOutput); - auto matchRelu = mg.createNode(matchOp<Relu>("matchRelu")); + auto matchRelu = mg.createNode(nn::is<Relu>); mg.createEdge(matchSumOutput, matchRelu); auto matchRoot = matchRelu; - EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(sum, matchRoot).isMatch()); - EXPECT_FALSE( - NNSubgraphMatcher::isSubgraphMatch(reluOutput, matchRoot).isMatch()); - EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(input1, matchRoot).isMatch()); + EXPECT_FALSE(mg.isSubgraphMatch(sum, matchRoot).isMatch()); + EXPECT_FALSE(mg.isSubgraphMatch(reluOutput, matchRoot).isMatch()); + EXPECT_FALSE(mg.isSubgraphMatch(input1, matchRoot).isMatch()); - EXPECT_TRUE(NNSubgraphMatcher::isSubgraphMatch(relu, matchRoot).isMatch()); + EXPECT_TRUE(mg.isSubgraphMatch(relu, matchRoot).isMatch()); - NNSubgraphMatcher::replaceSubgraph( + mg.replaceSubgraph( graph, matchRoot, [&matchSumOutput]( NNGraph& g, NNGraph::NodeRef relu, - const NNSubgraphMatcher::SubgraphMatchResultType& matchResult) { + const NNMatchGraph::SubgraphMatchResultType& matchResult) { auto fusedNode = g.createNode(util::make_unique<SumRelu>()); auto sumNode = getProducer(matchResult.getMatchNodeMap()->at(matchSumOutput)); diff --git a/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc b/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc index ee677665c6..8266863f37 100644 --- a/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc +++ b/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc @@ -1,4 +1,5 @@ #include <algorithm> +#include <functional> #include "test_util.h" @@ -12,21 +13,9 @@ namespace matcher { using NodeType = std::string; using Criteria = std::string; - -// Node matches a criteria (string) if the data string is the same as the -// criteria. Special case: "*" will match any thing. -struct TestNodeMatch { - static bool isMatch( - const nom::Graph<NodeType>::NodeRef& node, - const Criteria& criteria) { - return criteria == "*" || criteria == node->data(); - } -}; - using TestGraph = Graph<NodeType>; -using TestMatcher = SubgraphMatcher<TestGraph, Criteria, TestNodeMatch>; -using TestMatchGraph = MatchGraph<Criteria>; -using TestMatchNode = MatchNode<Criteria>; +using TestMatchGraph = MatchGraph<TestGraph>; +using TestMatchPredicate = MatchPredicate<TestGraph>; // Have just one TestMatchGraph in the tests to make it less verbose to create // the match graphs. @@ -36,12 +25,25 @@ void reset() { graph = TestMatchGraph(); } +// Node matches a criteria (string) if the data string is the same as the +// criteria. Special case: "*" will match any thing. +TestMatchPredicate testMatchPredicate(const Criteria& criteria) { + return TestMatchPredicate([criteria](TestGraph::NodeRef node) { + return criteria == "*" || criteria == node->data(); + }); +} + +Criteria any() { + return Criteria("*"); +} + // Helper methods to make it less verbose to create match graphs. TestMatchGraph::NodeRef Tree( const Criteria& root, const std::vector<TestMatchGraph::NodeRef>& children = {}, int count = 1) { - auto result = graph.createNode(std::move(TestMatchNode(root).count(count))); + auto result = + graph.createNode(std::move(testMatchPredicate(root).count(count))); for (auto& child : children) { graph.createEdge(result, child); } @@ -50,11 +52,7 @@ TestMatchGraph::NodeRef Tree( TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) { return graph.createNode( - std::move(TestMatchNode(root).count(count).nonTerminal())); -} - -Criteria any() { - return Criteria("*"); + std::move(testMatchPredicate(root).count(count).nonTerminal())); } std::map<std::string, std::string> TestGraphNodePrinter( @@ -181,32 +179,32 @@ struct DataFlowTestGraphCriteria { DataFlowTestGraphCriteria() { auto matchOpCInputs = - graph.createNode(std::move(TestMatchNode(Criteria("input")) + graph.createNode(std::move(testMatchPredicate(Criteria("input")) .starCount() .nonTerminal() .excludeFromSubgraph())); - auto matchOpC = graph.createNode(Criteria("opC")); + auto matchOpC = graph.createNode(testMatchPredicate("opC")); graph.createEdge(matchOpCInputs, matchOpC); - matchOpCOutput = graph.createNode(any()); + matchOpCOutput = graph.createNode(testMatchPredicate(any())); graph.createEdge(matchOpC, matchOpCOutput); - auto matchOpB = graph.createNode(Criteria("opB")); + auto matchOpB = graph.createNode(testMatchPredicate("opB")); graph.createEdge(matchOpCOutput, matchOpB); graph.createEdge(matchOpCOutput, matchOpB); - auto matchOpBOutput = graph.createNode(any()); + auto matchOpBOutput = graph.createNode(testMatchPredicate(any())); graph.createEdge(matchOpB, matchOpBOutput); - auto matchOpF = graph.createNode(Criteria("opF")); + auto matchOpF = graph.createNode(testMatchPredicate("opF")); graph.createEdge(matchOpBOutput, matchOpF); - auto matchOpFOutput = graph.createNode(any()); + auto matchOpFOutput = graph.createNode(testMatchPredicate(any())); graph.createEdge(matchOpF, matchOpFOutput); - matchOpG = graph.createNode(Criteria("opG")); - auto matchDataI = graph.createNode( - std::move(TestMatchNode(any()).nonTerminal().excludeFromSubgraph())); + matchOpG = graph.createNode(testMatchPredicate("opG")); + auto matchDataI = graph.createNode(std::move( + testMatchPredicate(any()).nonTerminal().excludeFromSubgraph())); graph.createEdge(matchOpFOutput, matchOpG); graph.createEdge(matchDataI, matchOpG); } @@ -220,7 +218,7 @@ bool isSubgraphMatch( TestGraph::NodeRef nodeRef, const TestMatchGraph::NodeRef& criteria, bool invertGraphTraversal = true) { - return TestMatcher::isSubgraphMatch(nodeRef, criteria, invertGraphTraversal) + return graph.isSubgraphMatch(nodeRef, criteria, invertGraphTraversal) .isMatch(); } } // namespace matcher @@ -231,15 +229,15 @@ using namespace nom::matcher; // Simple test cases for node matching criteria. TEST(SubgraphMatcher, IsNodeMatch) { - TestGraph graph; - auto n1 = graph.createNode("Hello"); - auto n2 = graph.createNode("Le"); - graph.createEdge(n1, n2); - - EXPECT_TRUE(TestMatcher::isNodeMatch(n1, "Hello")); - EXPECT_FALSE(TestMatcher::isNodeMatch(n1, "G")); - EXPECT_TRUE(TestMatcher::isNodeMatch(n2, "Le")); - EXPECT_FALSE(TestMatcher::isNodeMatch(n2, "le")); + TestGraph g; + auto n1 = g.createNode("Hello"); + auto n2 = g.createNode("Le"); + g.createEdge(n1, n2); + + EXPECT_TRUE(graph.isNodeMatch(n1, testMatchPredicate("Hello"))); + EXPECT_FALSE(graph.isNodeMatch(n1, testMatchPredicate("G"))); + EXPECT_TRUE(graph.isNodeMatch(n2, testMatchPredicate("Le"))); + EXPECT_FALSE(graph.isNodeMatch(n2, testMatchPredicate("le"))); } // Test subtree matching with a simple tree graph. @@ -321,7 +319,8 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); - subtree = Tree(any(), {Tree(Criteria("2"), {}, TestMatchNode::kStarCount)}); + subtree = + Tree(any(), {Tree(Criteria("2"), {}, TestMatchPredicate::kStarCount)}); EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); @@ -349,23 +348,23 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("2")), Tree(Criteria("3"), {}, 2), Tree(Criteria("4"), {}, 2), - Tree(Criteria("5"), {}, TestMatchNode::kStarCount) + Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount) }); EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { Tree(Criteria("2")), - Tree(Criteria("3"), {}, TestMatchNode::kStarCount), + Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount), Tree(Criteria("4"), {}, 2), - Tree(Criteria("5"), {}, TestMatchNode::kStarCount) + Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount) }); EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { Tree(Criteria("2")), - Tree(Criteria("3"), {}, TestMatchNode::kStarCount), + Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount), }); // Fails because there are unmatched edges. EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); @@ -569,16 +568,16 @@ TEST(SubgraphMatcher, IsSubtreeMatchRealistic) { TEST(SubgraphMatcher, ReplaceGraphRealistic) { reset(); - auto graph = DataFlowTestGraph(); + auto testGraph = DataFlowTestGraph(); auto subtree = DataFlowTestGraphCriteria(); - TestMatcher::replaceSubgraph( - graph.graph, + graph.replaceSubgraph( + testGraph.graph, subtree.matchOpG, [subtree]( TestGraph& g, TestGraph::NodeRef opG, - const TestMatcher::SubgraphMatchResultType& matchResult) { + const TestMatchGraph::SubgraphMatchResultType& matchResult) { auto fusedNode = g.createNode("opFused"); auto opC = getInNode( matchResult.getMatchNodeMap()->at(subtree.matchOpCOutput), 0); @@ -595,10 +594,10 @@ TEST(SubgraphMatcher, ReplaceGraphRealistic) { // - fused node // - output node // - dataC2 node - auto nodes = graph.graph.getMutableNodes(); + auto nodes = testGraph.graph.getMutableNodes(); // Test that the graph is transformed as expected. - EXPECT_EQ(nodes.size(), graph.numInputs + 4); + EXPECT_EQ(nodes.size(), testGraph.numInputs + 4); TestGraph::NodeRef opFused; TestGraph::NodeRef dataI; TestGraph::NodeRef dataOut; @@ -614,9 +613,9 @@ TEST(SubgraphMatcher, ReplaceGraphRealistic) { EXPECT_EQ(getInNode(dataOut, 0), opFused); - EXPECT_EQ(opFused->getInEdges().size(), graph.numInputs + 1); + EXPECT_EQ(opFused->getInEdges().size(), testGraph.numInputs + 1); EXPECT_EQ(getInNode(opFused, 0), dataI); - for (int i = 1; i <= graph.numInputs; i++) { + for (int i = 1; i <= testGraph.numInputs; i++) { EXPECT_EQ(getInNode(opFused, i)->data(), "input"); } diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc index fba800e69d..7c1a2b795b 100644 --- a/caffe2/python/pybind_state_nomni.cc +++ b/caffe2/python/pybind_state_nomni.cc @@ -382,9 +382,8 @@ void addNomnigraphMethods(pybind11::module& m) { py::class_<nn::NNMatchGraph> nnMatchGraph(m, "NNMatchGraph"); nnMatchGraph.def(py::init<>()); - using MatchNodeType = - nom::Node<nom::matcher::MatchNode<nn::NNNodeMatchCriteria>>; - py::class_<MatchNodeType> nnMatchNode(m, "MatchNodeRef"); + using MatchPredicateType = nom::Node<nn::NNMatchPredicate>; + py::class_<MatchPredicateType> nnMatchPredicate(m, "MatchPredicateRef"); nnMatchGraph .def( @@ -396,13 +395,12 @@ void addNomnigraphMethods(pybind11::module& m) { "createNode", [](nn::NNMatchGraph* g, GenericOperator& op, bool strict) { auto opName = op.getName(); - auto match = - nn::NNNodeMatchCriteria([opName](NNGraph::NodeRef node) { - NOM_REQUIRE_OR_RET_FALSE(nn::is<NeuralNetOperator>(node)); - auto nnOp = nn::get<NeuralNetOperator>(node); - return opName == nnOp->getName(); - }); - auto node = nom::matcher::MatchNode<nn::NNNodeMatchCriteria>(match); + auto match = [opName](NNGraph::NodeRef node) { + NOM_REQUIRE_OR_RET_FALSE(nn::is<NeuralNetOperator>(node)); + auto nnOp = nn::get<NeuralNetOperator>(node); + return opName == nnOp->getName(); + }; + auto node = nn::NNMatchPredicate(match); if (!strict) { node.nonTerminal(); } @@ -414,7 +412,7 @@ void addNomnigraphMethods(pybind11::module& m) { .def( "createNode", [](nn::NNMatchGraph* g, nom::repr::Tensor& tensor, bool strict) { - auto node = nn::NNMatchNode(nn::matchTensor()); + auto node = nn::NNMatchPredicate(nn::is<nom::repr::Tensor>); if (!strict) { node.nonTerminal(); } @@ -426,9 +424,8 @@ void addNomnigraphMethods(pybind11::module& m) { .def( "createNode", [](nn::NNMatchGraph* g, bool strict) { - auto match = nn::NNNodeMatchCriteria( - [](NNGraph::NodeRef node) { return true; }); - auto node = nom::matcher::MatchNode<nn::NNNodeMatchCriteria>(match); + auto match = [](NNGraph::NodeRef node) { return true; }; + auto node = nn::NNMatchPredicate(match); if (!strict) { node.nonTerminal(); } @@ -444,8 +441,7 @@ void addNomnigraphMethods(pybind11::module& m) { m.def("matchSubgraph", [](NNGraph::NodeRef node, nn::NNMatchGraph* mg) { // Get root node or node in root cycle auto match_node = *nom::algorithm::tarjans(mg).back().getNodes().begin(); - auto result = - nn::NNSubgraphMatcher::isSubgraphMatch(node, match_node, false); + auto result = mg->isSubgraphMatch(node, match_node, false); if (result.isMatch()) { return *result.getMatchedSubgraph(); } |