diff options
author | Roy Li <royboy@fb.com> | 2019-04-21 21:12:21 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-21 21:16:07 -0700 |
commit | ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c (patch) | |
tree | 67369a1d6b0547c6b078c62da69a1b4b77117d06 | |
parent | a044ba1af5efad9c7dfdfc9eb44c045b6492ec46 (diff) | |
download | pytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.tar.gz pytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.tar.bz2 pytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.zip |
Add ScalarType argument to Type::options() (#19270)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270
ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791
Differential Revision: D14938707
Pulled By: li-roy
fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
-rw-r--r-- | aten/src/ATen/Context.h | 5 | ||||
-rw-r--r-- | aten/src/ATen/core/Type.h | 15 | ||||
-rw-r--r-- | aten/src/ATen/function_wrapper.py | 3 | ||||
-rw-r--r-- | aten/src/ATen/templates/Type.h | 15 | ||||
-rw-r--r-- | test/cpp/api/tensor_options.cpp | 6 | ||||
-rw-r--r-- | test/cpp/api/tensor_options_cuda.cpp | 8 | ||||
-rw-r--r-- | tools/autograd/templates/Functions.h | 4 | ||||
-rw-r--r-- | torch/csrc/autograd/engine.cpp | 4 | ||||
-rw-r--r-- | torch/csrc/autograd/function.h | 2 | ||||
-rw-r--r-- | torch/csrc/autograd/input_metadata.h | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.cpp | 3 | ||||
-rw-r--r-- | torch/csrc/autograd/python_function.h | 1 | ||||
-rw-r--r-- | torch/csrc/autograd/python_legacy_variable.cpp | 3 | ||||
-rw-r--r-- | torch/csrc/autograd/python_variable_indexing.cpp | 8 | ||||
-rw-r--r-- | torch/csrc/autograd/variable.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/cuda/comm.cpp | 2 | ||||
-rw-r--r-- | torch/csrc/jit/passes/shape_analysis.cpp | 5 | ||||
-rw-r--r-- | torch/csrc/utils/tensor_new.cpp | 55 |
18 files changed, 76 insertions, 73 deletions
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 925b327304..b77728e322 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -176,6 +176,11 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&); CAFFE2_API Allocator* getCPUAllocator(); +static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + p, s, /*is_variable*/false); +} + static inline DeprecatedTypeProperties& CPU(ScalarType s) { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( Backend::CPU, s, /*is_variable*/false); diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 2d1c627817..ad7eba20ee 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -176,9 +176,8 @@ struct CAFFE2_API Type { return this != &other; } - /// Constructs the `TensorOptions` from a type and a `device_index`. - TensorOptions options(int16_t device_index = -1) const { - return TensorOptions().dtype(typeMeta()) + TensorOptions options(ScalarType s, int16_t device_index = -1) const { + return TensorOptions().dtype(s) .device(device_type(), device_index) .layout(layout()) .is_variable(is_variable()); @@ -186,20 +185,16 @@ struct CAFFE2_API Type { /// Constructs the `TensorOptions` from a type and a Device. Asserts that /// the device type matches the device type of the type. - TensorOptions options(c10::optional<Device> device_opt) const { + TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const { if (!device_opt.has_value()) { - return options(-1); + return options(s, -1); } else { Device device = device_opt.value(); AT_ASSERT(device.type() == device_type()); - return options(device.index()); + return options(s, device.index()); } } - operator TensorOptions() const { - return options(); - } - // example // virtual Tensor * add(Tensor & a, Tensor & b) = 0; virtual Tensor abs(const Tensor & self) const = 0; diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 8143be29e1..81d40813da 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -1604,7 +1604,8 @@ def create_derived(backend_type_env, declarations): # e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the # ScalarType before constructing the scalar_tensor to avoid overflow checking. elif ret['type'] == 'accreal' or ret['type'] == 'real': - return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());' + return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), ' + 'options(ScalarType::${ScalarName}));') case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call)) else: # we using int64_t for long in the API, so correct it here... diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index 14900acc7c..0e1051ed7e 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -119,9 +119,8 @@ struct CAFFE2_API Type { return this != &other; } - /// Constructs the `TensorOptions` from a type and a `device_index`. - TensorOptions options(int16_t device_index = -1) const { - return TensorOptions().dtype(typeMeta()) + TensorOptions options(ScalarType s, int16_t device_index = -1) const { + return TensorOptions().dtype(s) .device(device_type(), device_index) .layout(layout()) .is_variable(is_variable()); @@ -129,20 +128,16 @@ struct CAFFE2_API Type { /// Constructs the `TensorOptions` from a type and a Device. Asserts that /// the device type matches the device type of the type. - TensorOptions options(c10::optional<Device> device_opt) const { + TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const { if (!device_opt.has_value()) { - return options(-1); + return options(s, -1); } else { Device device = device_opt.value(); AT_ASSERT(device.type() == device_type()); - return options(device.index()); + return options(s, device.index()); } } - operator TensorOptions() const { - return options(); - } - // example // virtual Tensor * add(Tensor & a, Tensor & b) = 0; ${pure_virtual_type_method_declarations} diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp index cebb374203..6e345b306e 100644 --- a/test/cpp/api/tensor_options.cpp +++ b/test/cpp/api/tensor_options.cpp @@ -66,10 +66,10 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) { options = TensorOptions(kInt); REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided); - options = TensorOptions(getNonVariableType(Backend::SparseCPU, kFloat)); + options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kFloat)); REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse); - options = TensorOptions(getNonVariableType(Backend::SparseCPU, kByte)); + options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)); REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse); } @@ -77,7 +77,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) { auto options = empty(5, kDouble).options(); REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided); - options = empty(5, getNonVariableType(Backend::SparseCPU, kByte)).options(); + options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options(); REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse); } diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp index a048ce5ac6..04229f58c2 100644 --- a/test/cpp/api/tensor_options_cuda.cpp +++ b/test/cpp/api/tensor_options_cuda.cpp @@ -42,17 +42,17 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) { options = CUDA(kInt).options(); REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided); - options = getNonVariableType(Backend::SparseCUDA, kFloat).options(); + options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(); REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse); - options = getNonVariableType(Backend::SparseCUDA, kByte).options(); + options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options(); REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse); options = CUDA(kFloat).options(/*device=*/5); REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided); options = - getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5); + getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5); REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse); } @@ -60,7 +60,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) { auto options = empty(5, device(kCUDA).dtype(kDouble)).options(); REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided); - options = empty(5, getNonVariableType(Backend::SparseCUDA, kByte)).options(); + options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options(); REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse); if (torch::cuda::device_count() > 1) { diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h index 99b6ab3a2c..bd81dc13ed 100644 --- a/tools/autograd/templates/Functions.h +++ b/tools/autograd/templates/Functions.h @@ -35,13 +35,13 @@ struct TypeAndSize { /* implicit */ TypeAndSize(const Tensor & t) : sizes(t.sizes().vec()) - , type(&t.dispatch_type()) {} + , type(&t.type()) {} Tensor zeros() { return at::zeros(sizes, *type); } private: std::vector<int64_t> sizes; - Type* type; + at::DeprecatedTypeProperties* type; }; ${autograd_function_declarations} diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index b4716ddb3a..db9b0c4f0b 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -334,7 +334,7 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, const return outputs; } -static bool is_compatible_type(const at::Type& expected, const at::Type& actual) { +static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) { // Types are compatible if they exactly match or if the gradient is a sparse // version of the expected type. return expected == actual || (actual.is_sparse() && @@ -372,7 +372,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const } grads[i] = at::sum_to(std::move(grads[i]), metadata.shape()); } - if (!is_compatible_type(metadata.type(), grads[i].dispatch_type())) { + if (!is_compatible_type(metadata.type(), grads[i].type())) { std::stringstream ss; ss << "invalid gradient at index " << i << " - expected type "; ss << metadata.type() << " but got " << grads[i].type(); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 380dfd9cc7..512cfb60cb 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -130,7 +130,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> { /// Adds the type and shape metadata for a new input. Returns the index of /// of the new input. uint32_t add_input_metadata( - const at::Type& type + const at::DeprecatedTypeProperties& type , at::IntArrayRef shape , at::Device device) noexcept { uint32_t input_nr = input_metadata_.size(); diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index df767de639..94655989d2 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -12,17 +12,17 @@ namespace torch { namespace autograd { struct InputMetadata { InputMetadata() = default; - InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device) + InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device) : type_{&type} , shape_{shape}, device_{device} { } InputMetadata(const at::Tensor& t) - : InputMetadata(t.dispatch_type(), t.sizes(), t.device()) { } + : InputMetadata(t.type(), t.sizes(), t.device()) { } bool is_valid() const { return type_ != nullptr; } - const at::Type& type() const { + const at::DeprecatedTypeProperties& type() const { AT_ASSERT(type_); return *type_; } @@ -40,7 +40,7 @@ struct InputMetadata { } private: - const at::Type* type_ = nullptr; + const at::DeprecatedTypeProperties* type_ = nullptr; at::DimVector shape_; at::Device device_ = at::kCPU; }; diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 50fc16094e..cc4e23fd5a 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -46,6 +46,7 @@ namespace torch { namespace autograd { VariableInfo::VariableInfo(const Variable& var) : type(&var.dispatch_type()) , device(var.device()) + , scalar_type(var.scalar_type()) , size(var.sizes().vec()) , requires_grad(var.requires_grad()) { } @@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var) Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { // NB: This will NOT work if we ever get mixed device gradients device_guard.reset_device(device); - return at::zeros(size, type->options()); + return at::zeros(size, type->options(scalar_type)); } auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list { diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 52ead11ba6..a2d1c6c87e 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -25,6 +25,7 @@ struct VariableInfo { at::Type* type; at::Device device = at::kCPU; + at::ScalarType scalar_type = at::kFloat; std::vector<int64_t> size; bool requires_grad; }; diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index 932e4b241c..b61e6fbdb1 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -46,7 +46,8 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject if (!data || data == Py_None) { // For legacy serialization code, create an empty tensor. This is also used // by nn.Parameter() with no arguments. - auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options()); + auto scalar_type = torch::tensors::get_default_scalar_type(); + auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type)); tensor = static_cast<Variable&>(var).data(); } else if (THPVariable_Check(data)) { tensor = ((THPVariable*)data)->cdata.data(); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index eff82a92eb..1415a48544 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -110,15 +110,15 @@ static Variable sequenceToVariable(const at::Type& type, PyObject* seq) { return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq); } -static Variable valueToTensor(const at::Type & type, PyObject* value) { +static Variable valueToTensor(const at::Type & type, const ScalarType scalar_type, PyObject* value) { if (THPVariable_Check(value)) { return reinterpret_cast<THPVariable*>(value)->cdata; } if (THPUtils_checkLong(value) || PyBool_Check(value)) { - return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options()); + return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options(scalar_type)); } if (PyFloat_Check(value)) { - return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options()); + return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options(scalar_type)); } throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString()); } @@ -334,7 +334,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - auto value = valueToTensor(self_.dispatch_type(), py_value); + auto value = valueToTensor(self_.dispatch_type(), self_.scalar_type(), py_value); // handle simple types: integers, slices, ellipsis, bool if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast) diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index a330230f01..9fa2033c15 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -214,7 +214,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const { fn->storage_offset = data().storage_offset(); fn->set_next_edges(collect_next_edges(diff_view_meta->base_)); fn->add_input_metadata( - diff_view_meta->base_.dispatch_type() + diff_view_meta->base_.type() , sizes() // Note: sizes(), not base_.sizes(), is intentional , diff_view_meta->base_.device()); diff_view_meta->grad_fn_ = std::move(fn); diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 53faa6baa5..a1743355bf 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) { tensors.push_back(tensor); for (auto device : devices.slice(1)) { _device_guard.set_index(device); - tensors.push_back(at::empty(tensor.sizes(), type.options())); + tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type()))); } nccl::broadcast(tensors); } else { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index c221dd9f01..39b87b6a25 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -157,10 +157,9 @@ class ShapePropagator { return *iv; } if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) { - auto backend = - type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA; + auto attype = type->device().is_cpu() ? + at::CPU(type->scalarType()) : at::CUDA(type->scalarType()); at::DeviceGuard device_guard(type->device()); - auto& attype = at::getNonVariableType(backend, type->scalarType()); auto t = at::empty_strided(type->sizes(), type->strides(), attype.options()) .zero_(); diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 908a27fd4a..6c23d0a13c 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -53,32 +53,32 @@ void maybe_initialize_cuda(const Device device) { } } -Tensor dispatch_zeros(const Type& type, optional<Device> device, IntArrayRef sizes) { +Tensor dispatch_zeros(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::zeros(sizes, type.options(std::move(device))); + return torch::zeros(sizes, type.options(scalar_type, std::move(device))); } -Tensor dispatch_ones(const Type& type, optional<Device> device, IntArrayRef sizes) { +Tensor dispatch_ones(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::ones(sizes, type.options(std::move(device))); + return torch::ones(sizes, type.options(scalar_type, std::move(device))); } -Tensor dispatch_full(const Type& type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) { +Tensor dispatch_full(const Type& type, const ScalarType scalar_type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::full(sizes, fill_value, type.options(std::move(device))); + return torch::full(sizes, fill_value, type.options(scalar_type, std::move(device))); } -Tensor new_with_sizes(const Type& type, optional<Device> device, IntArrayRef sizes) { +Tensor new_with_sizes(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::empty(sizes, type.options(std::move(device))); + return torch::empty(sizes, type.options(scalar_type, std::move(device))); } -Tensor new_with_storage(const Type& type, Storage storage) { - auto tensor = at::empty({}, type.options()); +Tensor new_with_storage(const Type& type, const ScalarType scalar_type, Storage storage) { + auto tensor = at::empty({}, type.options(scalar_type)); tensor.set_(std::move(storage)); return tensor; } @@ -281,7 +281,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); check_legacy_ctor_device(type, deviceOptional); - return at::empty({0}, type.options(r.deviceOptional(0))); + return at::empty({0}, type.options(scalar_type, r.deviceOptional(0))); } else if (r.idx == 1) { auto cdata = reinterpret_cast<void*>(r.toInt64(0)); return type.unsafeTensorFromTH(cdata, true); @@ -304,7 +304,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj // unless the sequences is a torch.Size return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } @@ -323,7 +323,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje auto deviceOptional = r.deviceOptional(0); check_legacy_ctor_device(type, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, type.options()); + return at::empty({0}, type.options(scalar_type)); } else if (r.idx == 1) { auto cdata = reinterpret_cast<void*>(r.toInt64(0)); return type.unsafeTensorFromTH(cdata, true); @@ -350,7 +350,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje // unless the sequences is a torch.Size return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } @@ -384,9 +384,9 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar auto deviceOptional = r.deviceOptional(0); check_legacy_ctor_device(type, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, type.options()); + return at::empty({0}, type.options(scalar_type)); } else if (r.idx == 1) { - return new_with_storage(type, r.storage(0)); + return new_with_storage(type, scalar_type, r.storage(0)); } else if (r.idx == 2) { auto cdata = reinterpret_cast<void*>(r.toInt64(0)); return type.unsafeTensorFromTH(cdata, true); @@ -401,7 +401,7 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar // unless the sequences is a torch.Size return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(type, deviceOptional); @@ -430,9 +430,9 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg auto deviceOptional = r.deviceOptional(0); check_legacy_ctor_device(type, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, type.options()); + return at::empty({0}, type.options(scalar_type)); } else if (r.idx == 1) { - return new_with_storage(type, r.storage(0)); + return new_with_storage(type, scalar_type, r.storage(0)); } else if (r.idx == 2) { auto cdata = reinterpret_cast<void*>(r.toInt64(0)); return type.unsafeTensorFromTH(cdata, true); @@ -447,7 +447,7 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg // unless the sequences is a torch.Size return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(type, deviceOptional); @@ -504,8 +504,9 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type, return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type); + const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); - return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3)); + return at::sparse_coo_tensor(r.intlist(0), type.options(actual_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); } @@ -603,7 +604,8 @@ Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObj auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type); - return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4)); + const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); + return new_with_sizes(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4)); } throw std::runtime_error("new_empty(): invalid arguments"); } @@ -617,7 +619,8 @@ Tensor new_full(const Type& type, ScalarType scalar_type, PyObject* args, PyObje auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type); - return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4)); + const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type); + return dispatch_full(actual_type, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4)); } throw std::runtime_error("new_full(): invalid arguments"); } @@ -631,7 +634,8 @@ Tensor new_ones(const Type& type, ScalarType scalar_type, PyObject* args, PyObje auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type); - return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); + const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); + return dispatch_ones(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_ones(): invalid arguments"); } @@ -645,7 +649,8 @@ Tensor new_zeros(const Type& type, ScalarType scalar_type, PyObject* args, PyObj auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type); - return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); + const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); + return dispatch_zeros(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_zeros(): invalid arguments"); } |