summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2018-09-07 08:46:12 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-07 08:55:59 -0700
commit110191e5c7347d898782c6589e6925d253b9d9e9 (patch)
treee454091d9034209da630dd8e303057e9040f4633 /tools
parent52b37d8b66f442c2e45a5593505b3f01ef71fa66 (diff)
downloadpytorch-110191e5c7347d898782c6589e6925d253b9d9e9.tar.gz
pytorch-110191e5c7347d898782c6589e6925d253b9d9e9.tar.bz2
pytorch-110191e5c7347d898782c6589e6925d253b9d9e9.zip
Remove detach from TensorImpl, handle via Type. (#11337)
Summary: This is so that TensorImpl does not have to depend on Tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/11337 Differential Revision: D9684421 Pulled By: gchanan fbshipit-source-id: d2af93420ca6d493429c251cfe5a34e9289c4484
Diffstat (limited to 'tools')
-rw-r--r--tools/autograd/gen_variable_type.py2
-rw-r--r--tools/autograd/templates/VariableType.cpp40
2 files changed, 41 insertions, 1 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 05affcbaa6..d6bcb0821e 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -31,7 +31,7 @@ from .gen_autograd_functions import uses_single_grad
# These functions are written manually in templates/VariableType.cpp
MANUAL_IMPLEMENTATIONS = {
- 'contiguous', 'resize_', 'resize_as_'
+ 'contiguous', 'resize_', 'resize_as_', 'detach', 'detach_',
}
# These functions we don't want to record for tracing, because we always want
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 031fd9e450..da7be8824e 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -446,6 +446,46 @@ Tensor VariableType::contiguous(const Tensor & self) const {
return self.clone();
}
+Tensor VariableType::detach(const Tensor & self) const {
+ profiler::RecordFunction profiler("detach");
+ torch::jit::Node* node = nullptr;
+ if (jit::tracer::isTracing()) {
+ auto& graph = jit::tracer::getTracingState()->graph;
+ node = graph->create(jit::aten::detach, /*outputs=*/0);
+ jit::tracer::recordSourceLocation(node);
+ jit::tracer::addInputs(node, "self", self);
+ graph->appendNode(node);
+
+ }
+ // <NON_GENERATED_CODE>
+ auto result = as_variable_ref(const_cast<Tensor&>(self)).detach();
+ // </NON_GENERATED_CODE>
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, result);
+ }
+ return result;
+}
+
+Tensor & VariableType::detach_(Tensor & self) const {
+ profiler::RecordFunction profiler("detach_");
+ torch::jit::Node* node = nullptr;
+ if (jit::tracer::isTracing()) {
+ auto& graph = jit::tracer::getTracingState()->graph;
+ node = graph->create(jit::aten::detach, /*outputs=*/0);
+ jit::tracer::recordSourceLocation(node);
+ jit::tracer::addInputs(node, "self", self);
+ graph->appendNode(node);
+ jit::tracer::ensureUnique("detach_", self);
+ }
+ // <NON_GENERATED_CODE>
+ as_variable_ref(self).detach_();
+ // </NON_GENERATED_CODE>
+ if (jit::tracer::isTracing()) {
+ jit::tracer::addOutput(node, self);
+ }
+ return self;
+}
+
static std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {