diff options
author | Elias Ellison <eellison@fb.com> | 2019-01-31 15:37:52 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-31 15:41:22 -0800 |
commit | a386c28fcd7232ddf45d376812d2d0a5729b292c (patch) | |
tree | 9fc13fbfd551409a2cd39f27a496b24af8da5bbb /torch | |
parent | dfb081a7e4d8cbef53084eb17968e837a825b248 (diff) | |
download | pytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.tar.gz pytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.tar.bz2 pytorch-a386c28fcd7232ddf45d376812d2d0a5729b292c.zip |
Remove constant propagation expect files (#16348)
Summary:
Remove constant prop expect files, and express graph conditions via python bindings.
First diff in larger effort to remove expect files
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16348
Differential Revision: D13906929
Pulled By: eellison
fbshipit-source-id: 7963caa3ccbc7bfc0006a160c952aa173d1ce633
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/jit/python_ir.cpp | 106 |
1 files changed, 64 insertions, 42 deletions
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 3e73bde7c9..1c887097a2 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -69,35 +69,55 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) { } } -std::vector<Node*> findAllNodes(Block* block, Symbol kind) { +std::vector<Node*> findAllNodes( + c10::ArrayRef<torch::jit::Block*> blocks, + Symbol kind, + bool recurse = true) { std::vector<Node*> ret; - for (Node* n : block->nodes()) { - for (Block* b : n->blocks()) { - auto nodes = findAllNodes(b, kind); - ret.insert(ret.end(), nodes.begin(), nodes.end()); - } - if (n->kind() == kind) { - ret.push_back(n); + for (Block* block : blocks) { + for (Node* n : block->nodes()) { + if (n->kind() == kind) { + ret.push_back(n); + } + if (recurse) { + auto nodes = findAllNodes(n->blocks(), kind, recurse); + ret.insert(ret.end(), nodes.begin(), nodes.end()); + } } } return ret; } -Node* findNode(Block* block, Symbol kind) { - for (Node* n : block->nodes()) { - for (Block* b : n->blocks()) { - auto node = findNode(b, kind); - if (node != nullptr) { - return node; +std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) { + std::vector<Block*> blocks = {block}; + return findAllNodes(blocks, kind, recurse); +} + +Node* findNode( + c10::ArrayRef<torch::jit::Block*> blocks, + Symbol kind, + bool recurse = true) { + for (Block* block : blocks) { + for (Node* n : block->nodes()) { + if (n->kind() == kind) { + return n; + } + if (recurse) { + auto node = findNode(n->blocks(), kind, recurse); + if (node != nullptr) { + return node; + } } - } - if (n->kind() == kind) { - return n; } } return nullptr; } +Node* findNode(Block* block, Symbol kind, bool recurse = true) { + std::vector<Block*> blocks = {block}; + return findNode(blocks, kind, recurse); +} + // execute a Python function, used for Ops we can't optimize but that we want to // optimize around struct ConcretePythonOp : public PythonOp { @@ -269,14 +289,15 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "findNode", - [](Graph& g, const std::string& kind) { - return findNode(g.block(), Symbol::fromQualString(kind)); - }) + [](Graph& g, const std::string& kind, bool recurse) { + return findNode(g.block(), Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) .def( "findAllNodes", - [](Graph& g, const std::string& kind) { - return findAllNodes(g.block(), Symbol::fromQualString(kind)); - }) + [](Graph& g, const std::string& kind, bool recurse) { + return findAllNodes( + g.block(), Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true) .def("addInput", [](Graph& g) { return g.addInput(); }) .def("copy", [](Graph& g) { return g.copy(); }) .GS(eraseInput) @@ -361,7 +382,18 @@ void initPythonIRBindings(PyObject* module_) { py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block") .def("nodes", [](Block& b) { return py::make_iterator(b.nodes().begin(), b.nodes().end()); - }); + }) + .def( + "findNode", + [](Block& b, const std::string& kind, bool recurse) { + return findNode(&b, Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) + .def( + "findAllNodes", + [](Block& b, const std::string& kind, bool recurse) { + return findAllNodes(&b, Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true); + #define NS(name) def(#name, &Node ::name) py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node") @@ -400,26 +432,16 @@ void initPythonIRBindings(PyObject* module_) { .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); }) .def( "findNode", - [](Node& n, const std::string& kind) { - Node* node; - for (Block* b : n.blocks()) { - node = findNode(b, Symbol::fromQualString(kind)); - if (node != nullptr) { - return node; - } - } - return node; - }) + [](Node& n, const std::string& kind, bool recurse) { + return findNode(n.blocks(), Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) .def( "findAllNodes", - [](Node& n, const std::string& kind) { - std::vector<Node*> ret; - for (Block* b : n.blocks()) { - auto nodes = findAllNodes(b, Symbol::fromQualString(kind)); - ret.insert(ret.end(), nodes.begin(), nodes.end()); - } - return ret; - }) + [](Node& n, const std::string& kind, bool recurse) { + return findAllNodes( + n.blocks(), Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true) + .def("input", [](Node& n) { return n.input(); }) .def("output", [](Node& n) { return n.output(); }) .NS(addInput) .NS(replaceInput) |