summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDuc Ngo <duc@fb.com>2018-10-18 12:30:31 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-18 12:32:40 -0700
commit2c566a17c763ab000ebeba1d3c01762bae814e42 (patch)
treea600faa4bb11fbd5ea76757edfbe05adc27bbba0
parent9c617140f79f82628484e7545b228be990994196 (diff)
downloadpytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.tar.gz
pytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.tar.bz2
pytorch-2c566a17c763ab000ebeba1d3c01762bae814e42.zip
nomnigraph - simplify subgraph matching APIs (#12681)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12681 - Get rid of NodeMatchCriteria as a template parameter, which was too generic. So MatchNode<NodeMatchCriteria> becomes MatchNode<GraphType>, and MatchStore stores the predicate on GraphType::NodeRef. - Similarly, get rid of NNNodeMatchCriteria Now one can just pass in a function pointer NodeRef -> bool to NNMatchNode constructor directly like this mg.createNode(is<Relu>) - Merge static utilities in SubgraphMatcher class into MatchGraph class - Rename MatchNode to MatchPredicate Change use cases and tests to make it work Reviewed By: ZolotukhinM Differential Revision: D10386907 fbshipit-source-id: 43874bd154e3d7c29ce07b4b74eca8a7a9f3078a
-rw-r--r--caffe2/core/nomnigraph/Representations/NeuralNet.cc62
-rw-r--r--caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h6
-rw-r--r--caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h82
-rw-r--r--caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h284
-rw-r--r--caffe2/core/nomnigraph/tests/NeuralNetTest.cc19
-rw-r--r--caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc105
-rw-r--r--caffe2/python/pybind_state_nomni.cc28
7 files changed, 250 insertions, 336 deletions
diff --git a/caffe2/core/nomnigraph/Representations/NeuralNet.cc b/caffe2/core/nomnigraph/Representations/NeuralNet.cc
index d3ba0b2b38..06879d8c0b 100644
--- a/caffe2/core/nomnigraph/Representations/NeuralNet.cc
+++ b/caffe2/core/nomnigraph/Representations/NeuralNet.cc
@@ -180,49 +180,29 @@ void coalesceInsertedDataDependencies(repr::NNModule* m) {
}
}
-std::ostream& operator<<(
- std::ostream& oss,
- const NNNodeMatchCriteria& criteria) {
- return oss << criteria.debugString;
-}
-
-NNNodeMatchCriteria criteriaSingleOutputAndConsumer() {
- return NNNodeMatchCriteria(
- [](NNGraph::NodeRef nodeRef) {
- auto nodeOutputs = nn::getOutputs(nodeRef);
- NOM_REQUIRE_OR_RET_FALSE(nodeOutputs.size() == 1);
- auto nodeConsumers = nn::getConsumers(nodeOutputs.front());
- return nodeConsumers.size() == 1;
- },
- "Single output and consumer");
-}
-
-NNNodeMatchCriteria criteriaSingleConsumer() {
- return NNNodeMatchCriteria(
- [](NNGraph::NodeRef nodeRef) {
- auto nodeOutputs = nn::getOutputs(nodeRef);
- NNGraph::NodeRef nodeConsumer = nullptr;
- for (auto nodeOutput : nodeOutputs) {
- for (auto consumer : nn::getConsumers(nodeOutput)) {
- if (nodeConsumer && consumer && consumer != nodeConsumer) {
- return false;
- }
- nodeConsumer = consumer;
- }
- }
- return true;
- },
- "Single consumer");
-}
-
-NNNodeMatchCriteria matchTensor(const std::string& debugString) {
- return matchOp<nom::repr::Tensor>(debugString);
+bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef) {
+ auto nodeOutputs = nn::getOutputs(nodeRef);
+ NOM_REQUIRE_OR_RET_FALSE(nodeOutputs.size() == 1);
+ auto nodeConsumers = nn::getConsumers(nodeOutputs.front());
+ return nodeConsumers.size() == 1;
+}
+
+bool hasUniqueConsumer(NNGraph::NodeRef nodeRef) {
+ auto nodeOutputs = nn::getOutputs(nodeRef);
+ NNGraph::NodeRef nodeConsumer = nullptr;
+ for (auto nodeOutput : nodeOutputs) {
+ for (auto consumer : nn::getConsumers(nodeOutput)) {
+ if (nodeConsumer && consumer && consumer != nodeConsumer) {
+ return false;
+ }
+ nodeConsumer = consumer;
+ }
+ }
+ return true;
}
-NNMatchNode matchExternalTensorNode(const std::string& debugString) {
- return NNMatchNode(matchTensor(debugString))
- .nonTerminal()
- .excludeFromSubgraph();
+NNMatchPredicate matchExternalTensorNode() {
+ return NNMatchPredicate(nn::is<Tensor>).nonTerminal().excludeFromSubgraph();
}
} // namespace nn
diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h
index b12e57d751..61fbebcf9b 100644
--- a/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h
+++ b/caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h
@@ -15,6 +15,7 @@
#include <iterator>
#include <list>
#include <unordered_set>
+#include <utility>
#include <vector>
#include <assert.h>
@@ -240,6 +241,11 @@ class Graph {
return createNodeInternal(Node<T, U...>(std::move(data)));
}
+ template <class Arg>
+ NodeRef createNode(Arg&& arg) {
+ return createNode(T(std::forward<Arg>(arg)));
+ }
+
NodeRef createNode() {
return createNodeInternal(Node<T, U...>());
}
diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h
index 523f29225a..9a73463e3c 100644
--- a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h
+++ b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h
@@ -297,9 +297,9 @@ struct C10_EXPORT
}
};
-template <typename T, typename N>
-inline bool is(N n) {
- return is_impl<T, N>::impl(n);
+template <typename T>
+inline bool is(NNGraph::NodeRef n) {
+ return is_impl<T, NNGraph::NodeRef>::impl(n);
}
// This is just a way to fix issues when the dyn_cast<> implementation
@@ -433,82 +433,18 @@ CAFFE2_API void coalesceInsertedDataDependencies(repr::NNModule* m);
template <NNGraph* G>
struct C10_EXPORT NodeHelper {};
-struct NNNodeMatchCriteria {
- std::function<bool(NNGraph::NodeRef)> predicate;
- std::string debugString;
-
- NNNodeMatchCriteria(
- const std::function<bool(NNGraph::NodeRef)>& predicate,
- const std::string& debugString = "No debug string specified")
- : predicate(predicate), debugString(debugString){};
-
- NNNodeMatchCriteria() = default;
- NNNodeMatchCriteria(const NNNodeMatchCriteria&) = default;
- NNNodeMatchCriteria& operator=(const NNNodeMatchCriteria&) = default;
- NNNodeMatchCriteria(NNNodeMatchCriteria&&) = default;
-
- NNNodeMatchCriteria andCriteria(const NNNodeMatchCriteria& other) {
- auto thisPredicate = predicate;
- auto otherPredicate = other.predicate;
- return NNNodeMatchCriteria(
- [thisPredicate, otherPredicate](NNGraph::NodeRef node) {
- return thisPredicate(node) && otherPredicate(node);
- },
- debugString + " and " + other.debugString);
- }
-};
-
-CAFFE2_API std::ostream& operator<<(
- std::ostream& oss,
- const NNNodeMatchCriteria& criteria);
-
-using NNMatchGraph = nom::matcher::MatchGraph<NNNodeMatchCriteria>;
-using NNMatchNode = nom::matcher::MatchNode<NNNodeMatchCriteria>;
+using NNMatchGraph = nom::matcher::MatchGraph<NNGraph>;
+using NNMatchPredicate = nom::matcher::MatchPredicate<NNGraph>;
-// Commonly used criteria.
+// Commonly used node predicate.
// The node has a single output and the output has a single consumer.
-CAFFE2_API NNNodeMatchCriteria criteriaSingleOutputAndConsumer();
+CAFFE2_API bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef);
// The node has a unique consumer (there may be multiple edges from output
// to the single consumer).
-CAFFE2_API NNNodeMatchCriteria criteriaSingleConsumer();
-
-template <typename NodeType>
-NNNodeMatchCriteria matchOp(const std::string& debugString = "matchOp") {
- return NNNodeMatchCriteria(
- [](NNGraph::NodeRef nodeRef) { return is<NodeType>(nodeRef); },
- debugString);
-}
-
-template <typename NodeType>
-NNNodeMatchCriteria matchOp(
- const std::function<bool(const NodeType&)> predicate,
- const std::string& debugString = "matchOpWithPredicate") {
- return NNNodeMatchCriteria(
- [predicate](NNGraph::NodeRef nodeRef) {
- NOM_REQUIRE_OR_RET_FALSE(is<NodeType>(nodeRef));
- NodeType* node = get<NodeType>(nodeRef);
- return predicate(*node);
- },
- debugString);
-};
-
-CAFFE2_API NNNodeMatchCriteria
-matchTensor(const std::string& debugString = "matchTensor");
-
-CAFFE2_API NNMatchNode
-matchExternalTensorNode(const std::string& debugString = "matchExternalTensor");
-
-struct CAFFE2_API NNNodeMatch {
- static bool isMatch(
- const NNGraph::NodeRef& node,
- const NNNodeMatchCriteria& criteria) {
- return criteria.predicate(node);
- }
-};
+CAFFE2_API bool hasUniqueConsumer(NNGraph::NodeRef nodeRef);
-using NNSubgraphMatcher =
- nom::matcher::SubgraphMatcher<NNGraph, NNNodeMatchCriteria, NNNodeMatch>;
+CAFFE2_API NNMatchPredicate matchExternalTensorNode();
} // namespace nn
diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h
index a303324fbb..c237b6e9f9 100644
--- a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h
+++ b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h
@@ -15,9 +15,9 @@ namespace nom {
namespace matcher {
/**
- * MatchGraph is a graph of MatchNode.
+ * MatchGraph is a graph of MatchPredicate.
*
- * MatchNode needs a NodeMatchCriteria (a predicate for node matching) and
+ * MatchPredicate needs a predicate for node matching and
* - includeInSubgraph: whether this node and nodes/edges reachable from it
* should be included in the return matched subgraph (if the pattern matches).
* This is useful in case we would like to specify a matching pattern but do not
@@ -28,19 +28,21 @@ namespace matcher {
* from the node when doing subgraph matching.
*/
-template <typename NodeMatchCriteria>
-class MatchNode {
+template <typename GraphType>
+class MatchPredicate {
public:
+ using Predicate = std::function<bool(typename GraphType::NodeRef)>;
+
static const int kStarCount = -1;
- MatchNode(const NodeMatchCriteria& criteria) : criteria_(criteria) {}
+ MatchPredicate(const Predicate& criteria) : criteria_(criteria) {}
- MatchNode() = default;
- MatchNode(const MatchNode&) = default;
- MatchNode& operator=(const MatchNode&) = default;
- MatchNode(MatchNode&&) = default;
+ MatchPredicate() = default;
+ MatchPredicate(const MatchPredicate&) = default;
+ MatchPredicate& operator=(const MatchPredicate&) = default;
+ MatchPredicate(MatchPredicate&&) = default;
- NodeMatchCriteria getCriteria() const {
+ Predicate getCriteria() const {
return criteria_;
}
@@ -48,21 +50,21 @@ class MatchNode {
return count_;
}
- MatchNode<NodeMatchCriteria>& count(int count) {
+ MatchPredicate<GraphType>& count(int count) {
count_ = count;
return *this;
}
- MatchNode<NodeMatchCriteria>& starCount() {
+ MatchPredicate<GraphType>& starCount() {
return count(kStarCount);
}
- MatchNode<NodeMatchCriteria>& nonTerminal() {
+ MatchPredicate<GraphType>& nonTerminal() {
nonTerminal_ = true;
return *this;
}
- MatchNode<NodeMatchCriteria>& excludeFromSubgraph() {
+ MatchPredicate<GraphType>& excludeFromSubgraph() {
includeInSubgraph_ = false;
return *this;
}
@@ -75,131 +77,44 @@ class MatchNode {
return includeInSubgraph_;
}
+ std::string getDebugString() const {
+ return debugString_;
+ }
+
+ void setDebugString(const std::string& debugString) {
+ debugString_ = debugString;
+ }
+
private:
- NodeMatchCriteria criteria_;
+ Predicate criteria_;
int count_ = 1;
bool includeInSubgraph_ = true;
bool nonTerminal_ = false;
+ std::string debugString_;
};
-template <typename NodeMatchCriteria>
-using MatchGraph = Graph<MatchNode<NodeMatchCriteria>>;
-
-template <typename NodeMatchCriteria>
-using MatchNodeRef = typename MatchGraph<NodeMatchCriteria>::NodeRef;
-
-// TODO: Reuse convertToDotString once convertToDotString can work
-// with subgraph.
-template <typename NodeMatchCriteria>
-std::string debugString(
- MatchNodeRef<NodeMatchCriteria> rootCriteriaRef,
- bool invertGraphTraversal) {
- std::ostringstream out;
- auto rootNode = rootCriteriaRef->data();
- out << "{rootCriteria = '" << rootNode.getCriteria() << "'";
- if (rootNode.getCount() != 1) {
- out << ", count = " << rootNode.getCount();
- }
- if (rootNode.isNonTerminal()) {
- out << ", nonTerminal = " << rootNode.isNonTerminal();
- }
- auto edges = invertGraphTraversal ? rootCriteriaRef->getInEdges()
- : rootCriteriaRef->getOutEdges();
- if (!edges.empty()) {
- out << ", childrenCriteria = [";
- for (auto& child : edges) {
- auto nextNode = invertGraphTraversal ? child->tail() : child->head();
- out << debugString<NodeMatchCriteria>(nextNode, invertGraphTraversal)
- << ", ";
- }
- out << "]";
- }
- out << "}";
- return out.str();
-}
+template <typename GraphType>
+class SubgraphMatchResult;
-template <typename NodeMatchCriteria, typename GraphType>
-class SubgraphMatchResult {
+// MatchGraph is a graph of MatchPredicate and it contains utilities for
+// subgraph matching.
+// (TODO) the subgraph matching methods currently still
+// requires a root match node to be passed in. We should improve the matching
+// algorithm to eliminate this requirement.
+template <typename GraphType>
+class MatchGraph : public Graph<MatchPredicate<GraphType>> {
public:
- // Map from match node to corresponding node in the graph to be scanned.
- using MatchNodeMap = std::unordered_map<
- MatchNodeRef<NodeMatchCriteria>,
- typename GraphType::NodeRef>;
-
- static SubgraphMatchResult<NodeMatchCriteria, GraphType> notMatched(
- const std::string& debugMessage) {
- return SubgraphMatchResult<NodeMatchCriteria, GraphType>(
- false, debugMessage);
- }
-
- static SubgraphMatchResult<NodeMatchCriteria, GraphType> notMatched() {
- return SubgraphMatchResult<NodeMatchCriteria, GraphType>(
- false, "Debug message is not enabled");
- }
-
- static SubgraphMatchResult<NodeMatchCriteria, GraphType> matched(
- bool ownSubgraph = false) {
- return SubgraphMatchResult<NodeMatchCriteria, GraphType>(
- true, "Matched", ownSubgraph);
- }
-
- bool isMatch() const {
- return isMatch_;
- }
-
- std::string getDebugMessage() const {
- return debugMessage_;
- }
-
- std::shared_ptr<typename GraphType::SubgraphType> getMatchedSubgraph() const {
- return matchedSubgraph_;
- }
-
- std::shared_ptr<MatchNodeMap> getMatchNodeMap() const {
- return matchNodeMap_;
- }
-
- private:
- SubgraphMatchResult(
- bool isMatch,
- const std::string& debugMessage,
- bool ownSubgraph = false)
- : isMatch_(isMatch),
- debugMessage_(debugMessage),
- matchedSubgraph_(
- ownSubgraph ? std::shared_ptr<typename GraphType::SubgraphType>(
- new typename GraphType::SubgraphType())
- : nullptr),
- matchNodeMap_(
- ownSubgraph ? std::shared_ptr<MatchNodeMap>(new MatchNodeMap())
- : nullptr) {}
-
- const bool isMatch_;
- const std::string debugMessage_;
- const std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph_;
- const std::shared_ptr<MatchNodeMap> matchNodeMap_;
-};
-
-/*
- * Utilities for subgraph matching.
- */
-template <
- typename GraphType,
- typename NodeMatchCriteria,
- typename NodeMatcherClass>
-struct SubgraphMatcher {
- using SubgraphMatchResultType =
- SubgraphMatchResult<NodeMatchCriteria, GraphType>;
+ using SubgraphMatchResultType = SubgraphMatchResult<GraphType>;
using ReplaceGraphOperation = std::function<bool(
GraphType&,
typename GraphType::NodeRef,
const SubgraphMatchResultType&)>;
- static bool isNodeMatch(
+ bool isNodeMatch(
typename GraphType::NodeRef node,
- const NodeMatchCriteria& criteria) {
- return NodeMatcherClass::isMatch(node, criteria);
+ const MatchPredicate<GraphType>& matchPredicate) const {
+ return matchPredicate.getCriteria()(node);
}
// Check if there can be a subgraph that matches the given criteria that
@@ -207,11 +122,11 @@ struct SubgraphMatcher {
// The flag invertGraphTraversal specify if we should follow out edges or
// in edges. The default is true which is useful for a functional
// intepretation of a dataflow graph.
- static SubgraphMatchResultType isSubgraphMatch(
+ SubgraphMatchResultType isSubgraphMatch(
typename GraphType::NodeRef root,
- const MatchNodeRef<NodeMatchCriteria>& rootCriteriaRef,
+ const typename MatchGraph::NodeRef& rootCriteriaRef,
bool invertGraphTraversal = true,
- bool debug = false) {
+ bool debug = false) const {
// Create a matched result that owns a matched subgraph object and pass
// the subgraph object around to construct it during matching.
auto matchedResult = SubgraphMatchResultType::matched(true);
@@ -236,11 +151,11 @@ struct SubgraphMatcher {
// is aborted. This maybe useful in certain cases when we want to terminate
// the subgraph search early.
// invertGraphTraversal flag: see documentation in isSubgraphMatch
- static void replaceSubgraph(
+ void replaceSubgraph(
GraphType& graph,
- const MatchNodeRef<NodeMatchCriteria>& criteria,
+ const typename MatchGraph::NodeRef& criteria,
const ReplaceGraphOperation& replaceFunction,
- bool invertGraphTraversal = true) {
+ bool invertGraphTraversal = true) const {
for (auto nodeRef : graph.getMutableNodes()) {
// Make sure the node is still in the graph.
if (!graph.hasNode(nodeRef)) {
@@ -259,15 +174,15 @@ struct SubgraphMatcher {
}
private:
- static SubgraphMatchResultType isSubgraphMatchInternal(
+ SubgraphMatchResultType isSubgraphMatchInternal(
std::shared_ptr<typename SubgraphMatchResultType::MatchNodeMap>
matchedNodes,
std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph,
typename GraphType::NodeRef root,
- const MatchNodeRef<NodeMatchCriteria>& rootCriteriaRef,
+ const typename MatchGraph::NodeRef& rootCriteriaRef,
bool includeInSubgraph,
bool invertGraphTraversal,
- bool debug) {
+ bool debug) const {
auto rootCriteriaNode = rootCriteriaRef->data();
if (rootCriteriaNode.getCount() == 1) {
@@ -283,8 +198,7 @@ struct SubgraphMatcher {
std::ostringstream debugMessage;
debugMessage << "Subgraph root at " << root << " is not the same as "
<< matchedNode << " which previously matched criteria "
- << debugString<NodeMatchCriteria>(
- rootCriteriaRef, invertGraphTraversal);
+ << debugString(rootCriteriaRef, invertGraphTraversal);
return SubgraphMatchResultType::notMatched(debugMessage.str());
} else {
return SubgraphMatchResultType::notMatched();
@@ -292,13 +206,12 @@ struct SubgraphMatcher {
}
}
- if (!isNodeMatch(root, rootCriteriaNode.getCriteria())) {
+ if (!isNodeMatch(root, rootCriteriaNode)) {
if (debug) {
std::ostringstream debugMessage;
debugMessage << "Subgraph root at " << root
<< " does not match criteria "
- << debugString<NodeMatchCriteria>(
- rootCriteriaRef, invertGraphTraversal);
+ << debugString(rootCriteriaRef, invertGraphTraversal);
return SubgraphMatchResultType::notMatched(debugMessage.str());
} else {
return SubgraphMatchResultType::notMatched();
@@ -334,8 +247,7 @@ struct SubgraphMatcher {
: criteriaEdges[criteriaIdx]->head();
int expectedCount = childrenCriteriaRef->data().getCount();
- bool isStarCount =
- expectedCount == MatchNode<NodeMatchCriteria>::kStarCount;
+ bool isStarCount = expectedCount == MatchPredicate<GraphType>::kStarCount;
int countMatch = 0;
@@ -368,7 +280,7 @@ struct SubgraphMatcher {
std::ostringstream debugMessage;
debugMessage << "Child node at " << child
<< " does not match child criteria "
- << debugString<NodeMatchCriteria>(
+ << debugString(
childrenCriteriaRef, invertGraphTraversal)
<< ". We expected " << expectedCount
<< " matches but only found " << countMatch << ".";
@@ -394,8 +306,7 @@ struct SubgraphMatcher {
std::ostringstream debugMessage;
debugMessage << "Expected " << expectedCount
<< " matches for child criteria "
- << debugString<NodeMatchCriteria>(
- childrenCriteriaRef, invertGraphTraversal)
+ << debugString(childrenCriteriaRef, invertGraphTraversal)
<< " but only found " << countMatch;
return SubgraphMatchResultType::notMatched(debugMessage.str());
} else {
@@ -423,6 +334,93 @@ struct SubgraphMatcher {
}
return SubgraphMatchResultType::matched();
}
+
+ // TODO: Reuse convertToDotString once convertToDotString can work
+ // with subgraph.
+ std::string debugString(
+ typename MatchGraph::NodeRef rootCriteriaRef,
+ bool invertGraphTraversal) const {
+ std::ostringstream out;
+ auto rootNode = rootCriteriaRef->data();
+ out << "{root = '" << rootNode.getDebugString() << "'";
+ if (rootNode.getCount() != 1) {
+ out << ", count = " << rootNode.getCount();
+ }
+ if (rootNode.isNonTerminal()) {
+ out << ", nonTerminal = " << rootNode.isNonTerminal();
+ }
+ auto edges = invertGraphTraversal ? rootCriteriaRef->getInEdges()
+ : rootCriteriaRef->getOutEdges();
+ if (!edges.empty()) {
+ out << ", childrenCriteria = [";
+ for (auto& child : edges) {
+ auto nextNode = invertGraphTraversal ? child->tail() : child->head();
+ out << debugString(nextNode, invertGraphTraversal) << ", ";
+ }
+ out << "]";
+ }
+ out << "}";
+ return out.str();
+ }
+};
+
+template <typename GraphType>
+class SubgraphMatchResult {
+ public:
+ // Map from match node to corresponding node in the graph to be scanned.
+ using MatchNodeMap = std::unordered_map<
+ typename MatchGraph<GraphType>::NodeRef,
+ typename GraphType::NodeRef>;
+
+ static SubgraphMatchResult<GraphType> notMatched(
+ const std::string& debugMessage) {
+ return SubgraphMatchResult<GraphType>(false, debugMessage);
+ }
+
+ static SubgraphMatchResult<GraphType> notMatched() {
+ return SubgraphMatchResult<GraphType>(
+ false, "Debug message is not enabled");
+ }
+
+ static SubgraphMatchResult<GraphType> matched(bool ownSubgraph = false) {
+ return SubgraphMatchResult<GraphType>(true, "Matched", ownSubgraph);
+ }
+
+ bool isMatch() const {
+ return isMatch_;
+ }
+
+ std::string getDebugMessage() const {
+ return debugMessage_;
+ }
+
+ std::shared_ptr<typename GraphType::SubgraphType> getMatchedSubgraph() const {
+ return matchedSubgraph_;
+ }
+
+ std::shared_ptr<MatchNodeMap> getMatchNodeMap() const {
+ return matchNodeMap_;
+ }
+
+ private:
+ SubgraphMatchResult(
+ bool isMatch,
+ const std::string& debugMessage,
+ bool ownSubgraph = false)
+ : isMatch_(isMatch),
+ debugMessage_(debugMessage),
+ matchedSubgraph_(
+ ownSubgraph ? std::shared_ptr<typename GraphType::SubgraphType>(
+ new typename GraphType::SubgraphType())
+ : nullptr),
+ matchNodeMap_(
+ ownSubgraph ? std::shared_ptr<MatchNodeMap>(new MatchNodeMap())
+ : nullptr) {}
+
+ const bool isMatch_;
+ const std::string debugMessage_;
+ const std::shared_ptr<typename GraphType::SubgraphType> matchedSubgraph_;
+ const std::shared_ptr<MatchNodeMap> matchNodeMap_;
};
} // namespace matcher
diff --git a/caffe2/core/nomnigraph/tests/NeuralNetTest.cc b/caffe2/core/nomnigraph/tests/NeuralNetTest.cc
index 874da120b5..7aff67880c 100644
--- a/caffe2/core/nomnigraph/tests/NeuralNetTest.cc
+++ b/caffe2/core/nomnigraph/tests/NeuralNetTest.cc
@@ -45,30 +45,29 @@ TEST(NeuralNetGraph, ReplaceGraph) {
auto mg = NNMatchGraph();
auto matchSumInput =
mg.createNode(std::move(matchExternalTensorNode().count(2)));
- auto matchSum = mg.createNode(matchOp<Sum>("matchSum"));
+ auto matchSum = mg.createNode(nn::is<Sum>);
mg.createEdge(matchSumInput, matchSum);
- auto matchSumOutput = mg.createNode(matchTensor("matchSumOutput"));
+ auto matchSumOutput = mg.createNode(nn::is<Tensor>);
mg.createEdge(matchSum, matchSumOutput);
- auto matchRelu = mg.createNode(matchOp<Relu>("matchRelu"));
+ auto matchRelu = mg.createNode(nn::is<Relu>);
mg.createEdge(matchSumOutput, matchRelu);
auto matchRoot = matchRelu;
- EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(sum, matchRoot).isMatch());
- EXPECT_FALSE(
- NNSubgraphMatcher::isSubgraphMatch(reluOutput, matchRoot).isMatch());
- EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(input1, matchRoot).isMatch());
+ EXPECT_FALSE(mg.isSubgraphMatch(sum, matchRoot).isMatch());
+ EXPECT_FALSE(mg.isSubgraphMatch(reluOutput, matchRoot).isMatch());
+ EXPECT_FALSE(mg.isSubgraphMatch(input1, matchRoot).isMatch());
- EXPECT_TRUE(NNSubgraphMatcher::isSubgraphMatch(relu, matchRoot).isMatch());
+ EXPECT_TRUE(mg.isSubgraphMatch(relu, matchRoot).isMatch());
- NNSubgraphMatcher::replaceSubgraph(
+ mg.replaceSubgraph(
graph,
matchRoot,
[&matchSumOutput](
NNGraph& g,
NNGraph::NodeRef relu,
- const NNSubgraphMatcher::SubgraphMatchResultType& matchResult) {
+ const NNMatchGraph::SubgraphMatchResultType& matchResult) {
auto fusedNode = g.createNode(util::make_unique<SumRelu>());
auto sumNode =
getProducer(matchResult.getMatchNodeMap()->at(matchSumOutput));
diff --git a/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc b/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc
index ee677665c6..8266863f37 100644
--- a/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc
+++ b/caffe2/core/nomnigraph/tests/SubgraphMatcherTest.cc
@@ -1,4 +1,5 @@
#include <algorithm>
+#include <functional>
#include "test_util.h"
@@ -12,21 +13,9 @@ namespace matcher {
using NodeType = std::string;
using Criteria = std::string;
-
-// Node matches a criteria (string) if the data string is the same as the
-// criteria. Special case: "*" will match any thing.
-struct TestNodeMatch {
- static bool isMatch(
- const nom::Graph<NodeType>::NodeRef& node,
- const Criteria& criteria) {
- return criteria == "*" || criteria == node->data();
- }
-};
-
using TestGraph = Graph<NodeType>;
-using TestMatcher = SubgraphMatcher<TestGraph, Criteria, TestNodeMatch>;
-using TestMatchGraph = MatchGraph<Criteria>;
-using TestMatchNode = MatchNode<Criteria>;
+using TestMatchGraph = MatchGraph<TestGraph>;
+using TestMatchPredicate = MatchPredicate<TestGraph>;
// Have just one TestMatchGraph in the tests to make it less verbose to create
// the match graphs.
@@ -36,12 +25,25 @@ void reset() {
graph = TestMatchGraph();
}
+// Node matches a criteria (string) if the data string is the same as the
+// criteria. Special case: "*" will match any thing.
+TestMatchPredicate testMatchPredicate(const Criteria& criteria) {
+ return TestMatchPredicate([criteria](TestGraph::NodeRef node) {
+ return criteria == "*" || criteria == node->data();
+ });
+}
+
+Criteria any() {
+ return Criteria("*");
+}
+
// Helper methods to make it less verbose to create match graphs.
TestMatchGraph::NodeRef Tree(
const Criteria& root,
const std::vector<TestMatchGraph::NodeRef>& children = {},
int count = 1) {
- auto result = graph.createNode(std::move(TestMatchNode(root).count(count)));
+ auto result =
+ graph.createNode(std::move(testMatchPredicate(root).count(count)));
for (auto& child : children) {
graph.createEdge(result, child);
}
@@ -50,11 +52,7 @@ TestMatchGraph::NodeRef Tree(
TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) {
return graph.createNode(
- std::move(TestMatchNode(root).count(count).nonTerminal()));
-}
-
-Criteria any() {
- return Criteria("*");
+ std::move(testMatchPredicate(root).count(count).nonTerminal()));
}
std::map<std::string, std::string> TestGraphNodePrinter(
@@ -181,32 +179,32 @@ struct DataFlowTestGraphCriteria {
DataFlowTestGraphCriteria() {
auto matchOpCInputs =
- graph.createNode(std::move(TestMatchNode(Criteria("input"))
+ graph.createNode(std::move(testMatchPredicate(Criteria("input"))
.starCount()
.nonTerminal()
.excludeFromSubgraph()));
- auto matchOpC = graph.createNode(Criteria("opC"));
+ auto matchOpC = graph.createNode(testMatchPredicate("opC"));
graph.createEdge(matchOpCInputs, matchOpC);
- matchOpCOutput = graph.createNode(any());
+ matchOpCOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpC, matchOpCOutput);
- auto matchOpB = graph.createNode(Criteria("opB"));
+ auto matchOpB = graph.createNode(testMatchPredicate("opB"));
graph.createEdge(matchOpCOutput, matchOpB);
graph.createEdge(matchOpCOutput, matchOpB);
- auto matchOpBOutput = graph.createNode(any());
+ auto matchOpBOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpB, matchOpBOutput);
- auto matchOpF = graph.createNode(Criteria("opF"));
+ auto matchOpF = graph.createNode(testMatchPredicate("opF"));
graph.createEdge(matchOpBOutput, matchOpF);
- auto matchOpFOutput = graph.createNode(any());
+ auto matchOpFOutput = graph.createNode(testMatchPredicate(any()));
graph.createEdge(matchOpF, matchOpFOutput);
- matchOpG = graph.createNode(Criteria("opG"));
- auto matchDataI = graph.createNode(
- std::move(TestMatchNode(any()).nonTerminal().excludeFromSubgraph()));
+ matchOpG = graph.createNode(testMatchPredicate("opG"));
+ auto matchDataI = graph.createNode(std::move(
+ testMatchPredicate(any()).nonTerminal().excludeFromSubgraph()));
graph.createEdge(matchOpFOutput, matchOpG);
graph.createEdge(matchDataI, matchOpG);
}
@@ -220,7 +218,7 @@ bool isSubgraphMatch(
TestGraph::NodeRef nodeRef,
const TestMatchGraph::NodeRef& criteria,
bool invertGraphTraversal = true) {
- return TestMatcher::isSubgraphMatch(nodeRef, criteria, invertGraphTraversal)
+ return graph.isSubgraphMatch(nodeRef, criteria, invertGraphTraversal)
.isMatch();
}
} // namespace matcher
@@ -231,15 +229,15 @@ using namespace nom::matcher;
// Simple test cases for node matching criteria.
TEST(SubgraphMatcher, IsNodeMatch) {
- TestGraph graph;
- auto n1 = graph.createNode("Hello");
- auto n2 = graph.createNode("Le");
- graph.createEdge(n1, n2);
-
- EXPECT_TRUE(TestMatcher::isNodeMatch(n1, "Hello"));
- EXPECT_FALSE(TestMatcher::isNodeMatch(n1, "G"));
- EXPECT_TRUE(TestMatcher::isNodeMatch(n2, "Le"));
- EXPECT_FALSE(TestMatcher::isNodeMatch(n2, "le"));
+ TestGraph g;
+ auto n1 = g.createNode("Hello");
+ auto n2 = g.createNode("Le");
+ g.createEdge(n1, n2);
+
+ EXPECT_TRUE(graph.isNodeMatch(n1, testMatchPredicate("Hello")));
+ EXPECT_FALSE(graph.isNodeMatch(n1, testMatchPredicate("G")));
+ EXPECT_TRUE(graph.isNodeMatch(n2, testMatchPredicate("Le")));
+ EXPECT_FALSE(graph.isNodeMatch(n2, testMatchPredicate("le")));
}
// Test subtree matching with a simple tree graph.
@@ -321,7 +319,8 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
- subtree = Tree(any(), {Tree(Criteria("2"), {}, TestMatchNode::kStarCount)});
+ subtree =
+ Tree(any(), {Tree(Criteria("2"), {}, TestMatchPredicate::kStarCount)});
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
@@ -349,23 +348,23 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("2")),
Tree(Criteria("3"), {}, 2),
Tree(Criteria("4"), {}, 2),
- Tree(Criteria("5"), {}, TestMatchNode::kStarCount)
+ Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
- Tree(Criteria("3"), {}, TestMatchNode::kStarCount),
+ Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
Tree(Criteria("4"), {}, 2),
- Tree(Criteria("5"), {}, TestMatchNode::kStarCount)
+ Tree(Criteria("5"), {}, TestMatchPredicate::kStarCount)
});
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
Tree(Criteria("2")),
- Tree(Criteria("3"), {}, TestMatchNode::kStarCount),
+ Tree(Criteria("3"), {}, TestMatchPredicate::kStarCount),
});
// Fails because there are unmatched edges.
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
@@ -569,16 +568,16 @@ TEST(SubgraphMatcher, IsSubtreeMatchRealistic) {
TEST(SubgraphMatcher, ReplaceGraphRealistic) {
reset();
- auto graph = DataFlowTestGraph();
+ auto testGraph = DataFlowTestGraph();
auto subtree = DataFlowTestGraphCriteria();
- TestMatcher::replaceSubgraph(
- graph.graph,
+ graph.replaceSubgraph(
+ testGraph.graph,
subtree.matchOpG,
[subtree](
TestGraph& g,
TestGraph::NodeRef opG,
- const TestMatcher::SubgraphMatchResultType& matchResult) {
+ const TestMatchGraph::SubgraphMatchResultType& matchResult) {
auto fusedNode = g.createNode("opFused");
auto opC = getInNode(
matchResult.getMatchNodeMap()->at(subtree.matchOpCOutput), 0);
@@ -595,10 +594,10 @@ TEST(SubgraphMatcher, ReplaceGraphRealistic) {
// - fused node
// - output node
// - dataC2 node
- auto nodes = graph.graph.getMutableNodes();
+ auto nodes = testGraph.graph.getMutableNodes();
// Test that the graph is transformed as expected.
- EXPECT_EQ(nodes.size(), graph.numInputs + 4);
+ EXPECT_EQ(nodes.size(), testGraph.numInputs + 4);
TestGraph::NodeRef opFused;
TestGraph::NodeRef dataI;
TestGraph::NodeRef dataOut;
@@ -614,9 +613,9 @@ TEST(SubgraphMatcher, ReplaceGraphRealistic) {
EXPECT_EQ(getInNode(dataOut, 0), opFused);
- EXPECT_EQ(opFused->getInEdges().size(), graph.numInputs + 1);
+ EXPECT_EQ(opFused->getInEdges().size(), testGraph.numInputs + 1);
EXPECT_EQ(getInNode(opFused, 0), dataI);
- for (int i = 1; i <= graph.numInputs; i++) {
+ for (int i = 1; i <= testGraph.numInputs; i++) {
EXPECT_EQ(getInNode(opFused, i)->data(), "input");
}
diff --git a/caffe2/python/pybind_state_nomni.cc b/caffe2/python/pybind_state_nomni.cc
index fba800e69d..7c1a2b795b 100644
--- a/caffe2/python/pybind_state_nomni.cc
+++ b/caffe2/python/pybind_state_nomni.cc
@@ -382,9 +382,8 @@ void addNomnigraphMethods(pybind11::module& m) {
py::class_<nn::NNMatchGraph> nnMatchGraph(m, "NNMatchGraph");
nnMatchGraph.def(py::init<>());
- using MatchNodeType =
- nom::Node<nom::matcher::MatchNode<nn::NNNodeMatchCriteria>>;
- py::class_<MatchNodeType> nnMatchNode(m, "MatchNodeRef");
+ using MatchPredicateType = nom::Node<nn::NNMatchPredicate>;
+ py::class_<MatchPredicateType> nnMatchPredicate(m, "MatchPredicateRef");
nnMatchGraph
.def(
@@ -396,13 +395,12 @@ void addNomnigraphMethods(pybind11::module& m) {
"createNode",
[](nn::NNMatchGraph* g, GenericOperator& op, bool strict) {
auto opName = op.getName();
- auto match =
- nn::NNNodeMatchCriteria([opName](NNGraph::NodeRef node) {
- NOM_REQUIRE_OR_RET_FALSE(nn::is<NeuralNetOperator>(node));
- auto nnOp = nn::get<NeuralNetOperator>(node);
- return opName == nnOp->getName();
- });
- auto node = nom::matcher::MatchNode<nn::NNNodeMatchCriteria>(match);
+ auto match = [opName](NNGraph::NodeRef node) {
+ NOM_REQUIRE_OR_RET_FALSE(nn::is<NeuralNetOperator>(node));
+ auto nnOp = nn::get<NeuralNetOperator>(node);
+ return opName == nnOp->getName();
+ };
+ auto node = nn::NNMatchPredicate(match);
if (!strict) {
node.nonTerminal();
}
@@ -414,7 +412,7 @@ void addNomnigraphMethods(pybind11::module& m) {
.def(
"createNode",
[](nn::NNMatchGraph* g, nom::repr::Tensor& tensor, bool strict) {
- auto node = nn::NNMatchNode(nn::matchTensor());
+ auto node = nn::NNMatchPredicate(nn::is<nom::repr::Tensor>);
if (!strict) {
node.nonTerminal();
}
@@ -426,9 +424,8 @@ void addNomnigraphMethods(pybind11::module& m) {
.def(
"createNode",
[](nn::NNMatchGraph* g, bool strict) {
- auto match = nn::NNNodeMatchCriteria(
- [](NNGraph::NodeRef node) { return true; });
- auto node = nom::matcher::MatchNode<nn::NNNodeMatchCriteria>(match);
+ auto match = [](NNGraph::NodeRef node) { return true; };
+ auto node = nn::NNMatchPredicate(match);
if (!strict) {
node.nonTerminal();
}
@@ -444,8 +441,7 @@ void addNomnigraphMethods(pybind11::module& m) {
m.def("matchSubgraph", [](NNGraph::NodeRef node, nn::NNMatchGraph* mg) {
// Get root node or node in root cycle
auto match_node = *nom::algorithm::tarjans(mg).back().getNodes().begin();
- auto result =
- nn::NNSubgraphMatcher::isSubgraphMatch(node, match_node, false);
+ auto result = mg->isSubgraphMatch(node, match_node, false);
if (result.isMatch()) {
return *result.getMatchedSubgraph();
}