diff options
author | Gregory Chanan <gchanan@fb.com> | 2018-09-07 08:46:12 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-07 08:55:59 -0700 |
commit | 110191e5c7347d898782c6589e6925d253b9d9e9 (patch) | |
tree | e454091d9034209da630dd8e303057e9040f4633 /tools | |
parent | 52b37d8b66f442c2e45a5593505b3f01ef71fa66 (diff) | |
download | pytorch-110191e5c7347d898782c6589e6925d253b9d9e9.tar.gz pytorch-110191e5c7347d898782c6589e6925d253b9d9e9.tar.bz2 pytorch-110191e5c7347d898782c6589e6925d253b9d9e9.zip |
Remove detach from TensorImpl, handle via Type. (#11337)
Summary:
This is so that TensorImpl does not have to depend on Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11337
Differential Revision: D9684421
Pulled By: gchanan
fbshipit-source-id: d2af93420ca6d493429c251cfe5a34e9289c4484
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/gen_variable_type.py | 2 | ||||
-rw-r--r-- | tools/autograd/templates/VariableType.cpp | 40 |
2 files changed, 41 insertions, 1 deletions
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 05affcbaa6..d6bcb0821e 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -31,7 +31,7 @@ from .gen_autograd_functions import uses_single_grad # These functions are written manually in templates/VariableType.cpp MANUAL_IMPLEMENTATIONS = { - 'contiguous', 'resize_', 'resize_as_' + 'contiguous', 'resize_', 'resize_as_', 'detach', 'detach_', } # These functions we don't want to record for tracing, because we always want diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 031fd9e450..da7be8824e 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -446,6 +446,46 @@ Tensor VariableType::contiguous(const Tensor & self) const { return self.clone(); } +Tensor VariableType::detach(const Tensor & self) const { + profiler::RecordFunction profiler("detach"); + torch::jit::Node* node = nullptr; + if (jit::tracer::isTracing()) { + auto& graph = jit::tracer::getTracingState()->graph; + node = graph->create(jit::aten::detach, /*outputs=*/0); + jit::tracer::recordSourceLocation(node); + jit::tracer::addInputs(node, "self", self); + graph->appendNode(node); + + } + // <NON_GENERATED_CODE> + auto result = as_variable_ref(const_cast<Tensor&>(self)).detach(); + // </NON_GENERATED_CODE> + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, result); + } + return result; +} + +Tensor & VariableType::detach_(Tensor & self) const { + profiler::RecordFunction profiler("detach_"); + torch::jit::Node* node = nullptr; + if (jit::tracer::isTracing()) { + auto& graph = jit::tracer::getTracingState()->graph; + node = graph->create(jit::aten::detach, /*outputs=*/0); + jit::tracer::recordSourceLocation(node); + jit::tracer::addInputs(node, "self", self); + graph->appendNode(node); + jit::tracer::ensureUnique("detach_", self); + } + // <NON_GENERATED_CODE> + as_variable_ref(self).detach_(); + // </NON_GENERATED_CODE> + if (jit::tracer::isTracing()) { + jit::tracer::addOutput(node, self); + } + return self; +} + static std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) { std::vector<std::vector<int64_t>> args_sizes(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { |