summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2019-01-31 15:37:52 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-31 15:41:22 -0800
commita386c28fcd7232ddf45d376812d2d0a5729b292c (patch)
tree9fc13fbfd551409a2cd39f27a496b24af8da5bbb /torch
parentdfb081a7e4d8cbef53084eb17968e837a825b248 (diff)
downloadpytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.tar.gz
pytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.tar.bz2
pytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.zip
Remove constant propagation expect files (#16348)
Summary: Remove constant prop expect files, and express graph conditions via python bindings. First diff in larger effort to remove expect files Pull Request resolved: https://github.com/pytorch/pytorch/pull/16348 Differential Revision: D13906929 Pulled By: eellison fbshipit-source-id: 7963caa3ccbc7bfc0006a160c952aa173d1ce633
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/jit/python_ir.cpp106
1 files changed, 64 insertions, 42 deletions
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 3e73bde7c9..1c887097a2 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -69,35 +69,55 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
}
}
-std::vector<Node*> findAllNodes(Block* block, Symbol kind) {
+std::vector<Node*> findAllNodes(
+ c10::ArrayRef<torch::jit::Block*> blocks,
+ Symbol kind,
+ bool recurse = true) {
std::vector<Node*> ret;
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- auto nodes = findAllNodes(b, kind);
- ret.insert(ret.end(), nodes.begin(), nodes.end());
- }
- if (n->kind() == kind) {
- ret.push_back(n);
+ for (Block* block : blocks) {
+ for (Node* n : block->nodes()) {
+ if (n->kind() == kind) {
+ ret.push_back(n);
+ }
+ if (recurse) {
+ auto nodes = findAllNodes(n->blocks(), kind, recurse);
+ ret.insert(ret.end(), nodes.begin(), nodes.end());
+ }
}
}
return ret;
}
-Node* findNode(Block* block, Symbol kind) {
- for (Node* n : block->nodes()) {
- for (Block* b : n->blocks()) {
- auto node = findNode(b, kind);
- if (node != nullptr) {
- return node;
+std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) {
+ std::vector<Block*> blocks = {block};
+ return findAllNodes(blocks, kind, recurse);
+}
+
+Node* findNode(
+ c10::ArrayRef<torch::jit::Block*> blocks,
+ Symbol kind,
+ bool recurse = true) {
+ for (Block* block : blocks) {
+ for (Node* n : block->nodes()) {
+ if (n->kind() == kind) {
+ return n;
+ }
+ if (recurse) {
+ auto node = findNode(n->blocks(), kind, recurse);
+ if (node != nullptr) {
+ return node;
+ }
}
- }
- if (n->kind() == kind) {
- return n;
}
}
return nullptr;
}
+Node* findNode(Block* block, Symbol kind, bool recurse = true) {
+ std::vector<Block*> blocks = {block};
+ return findNode(blocks, kind, recurse);
+}
+
// execute a Python function, used for Ops we can't optimize but that we want to
// optimize around
struct ConcretePythonOp : public PythonOp {
@@ -269,14 +289,15 @@ void initPythonIRBindings(PyObject* module_) {
})
.def(
"findNode",
- [](Graph& g, const std::string& kind) {
- return findNode(g.block(), Symbol::fromQualString(kind));
- })
+ [](Graph& g, const std::string& kind, bool recurse) {
+ return findNode(g.block(), Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
.def(
"findAllNodes",
- [](Graph& g, const std::string& kind) {
- return findAllNodes(g.block(), Symbol::fromQualString(kind));
- })
+ [](Graph& g, const std::string& kind, bool recurse) {
+ return findAllNodes(
+ g.block(), Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
.def("addInput", [](Graph& g) { return g.addInput(); })
.def("copy", [](Graph& g) { return g.copy(); })
.GS(eraseInput)
@@ -361,7 +382,18 @@ void initPythonIRBindings(PyObject* module_) {
py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
.def("nodes", [](Block& b) {
return py::make_iterator(b.nodes().begin(), b.nodes().end());
- });
+ })
+ .def(
+ "findNode",
+ [](Block& b, const std::string& kind, bool recurse) {
+ return findNode(&b, Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+ .def(
+ "findAllNodes",
+ [](Block& b, const std::string& kind, bool recurse) {
+ return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true);
+
#define NS(name) def(#name, &Node ::name)
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
@@ -400,26 +432,16 @@ void initPythonIRBindings(PyObject* module_) {
.def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
.def(
"findNode",
- [](Node& n, const std::string& kind) {
- Node* node;
- for (Block* b : n.blocks()) {
- node = findNode(b, Symbol::fromQualString(kind));
- if (node != nullptr) {
- return node;
- }
- }
- return node;
- })
+ [](Node& n, const std::string& kind, bool recurse) {
+ return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
+ }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
.def(
"findAllNodes",
- [](Node& n, const std::string& kind) {
- std::vector<Node*> ret;
- for (Block* b : n.blocks()) {
- auto nodes = findAllNodes(b, Symbol::fromQualString(kind));
- ret.insert(ret.end(), nodes.begin(), nodes.end());
- }
- return ret;
- })
+ [](Node& n, const std::string& kind, bool recurse) {
+ return findAllNodes(
+ n.blocks(), Symbol::fromQualString(kind), recurse);
+ }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
+ .def("input", [](Node& n) { return n.input(); })
.def("output", [](Node& n) { return n.output(); })
.NS(addInput)
.NS(replaceInput)