summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoy Li <royboy@fb.com>2019-04-21 21:12:21 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-21 21:16:07 -0700
commitab78449e8c30d9d5d6fd18a33d8b98dafe58d82c (patch)
tree67369a1d6b0547c6b078c62da69a1b4b77117d06
parenta044ba1af5efad9c7dfdfc9eb44c045b6492ec46 (diff)
downloadpytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.tar.gz
pytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.tar.bz2
pytorch-ab78449e8c30d9d5d6fd18a33d8b98dafe58d82c.zip
Add ScalarType argument to Type::options() (#19270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270 ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791 Differential Revision: D14938707 Pulled By: li-roy fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
-rw-r--r--aten/src/ATen/Context.h5
-rw-r--r--aten/src/ATen/core/Type.h15
-rw-r--r--aten/src/ATen/function_wrapper.py3
-rw-r--r--aten/src/ATen/templates/Type.h15
-rw-r--r--test/cpp/api/tensor_options.cpp6
-rw-r--r--test/cpp/api/tensor_options_cuda.cpp8
-rw-r--r--tools/autograd/templates/Functions.h4
-rw-r--r--torch/csrc/autograd/engine.cpp4
-rw-r--r--torch/csrc/autograd/function.h2
-rw-r--r--torch/csrc/autograd/input_metadata.h8
-rw-r--r--torch/csrc/autograd/python_function.cpp3
-rw-r--r--torch/csrc/autograd/python_function.h1
-rw-r--r--torch/csrc/autograd/python_legacy_variable.cpp3
-rw-r--r--torch/csrc/autograd/python_variable_indexing.cpp8
-rw-r--r--torch/csrc/autograd/variable.cpp2
-rw-r--r--torch/csrc/cuda/comm.cpp2
-rw-r--r--torch/csrc/jit/passes/shape_analysis.cpp5
-rw-r--r--torch/csrc/utils/tensor_new.cpp55
18 files changed, 76 insertions, 73 deletions
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 925b327304..b77728e322 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -176,6 +176,11 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
CAFFE2_API Allocator* getCPUAllocator();
+static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ p, s, /*is_variable*/false);
+}
+
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
Backend::CPU, s, /*is_variable*/false);
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index 2d1c627817..ad7eba20ee 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -176,9 +176,8 @@ struct CAFFE2_API Type {
return this != &other;
}
- /// Constructs the `TensorOptions` from a type and a `device_index`.
- TensorOptions options(int16_t device_index = -1) const {
- return TensorOptions().dtype(typeMeta())
+ TensorOptions options(ScalarType s, int16_t device_index = -1) const {
+ return TensorOptions().dtype(s)
.device(device_type(), device_index)
.layout(layout())
.is_variable(is_variable());
@@ -186,20 +185,16 @@ struct CAFFE2_API Type {
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
/// the device type matches the device type of the type.
- TensorOptions options(c10::optional<Device> device_opt) const {
+ TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
if (!device_opt.has_value()) {
- return options(-1);
+ return options(s, -1);
} else {
Device device = device_opt.value();
AT_ASSERT(device.type() == device_type());
- return options(device.index());
+ return options(s, device.index());
}
}
- operator TensorOptions() const {
- return options();
- }
-
// example
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
virtual Tensor abs(const Tensor & self) const = 0;
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 8143be29e1..81d40813da 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -1604,7 +1604,8 @@ def create_derived(backend_type_env, declarations):
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
elif ret['type'] == 'accreal' or ret['type'] == 'real':
- return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
+ return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
+ 'options(ScalarType::${ScalarName}));')
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
else:
# we using int64_t for long in the API, so correct it here...
diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h
index 14900acc7c..0e1051ed7e 100644
--- a/aten/src/ATen/templates/Type.h
+++ b/aten/src/ATen/templates/Type.h
@@ -119,9 +119,8 @@ struct CAFFE2_API Type {
return this != &other;
}
- /// Constructs the `TensorOptions` from a type and a `device_index`.
- TensorOptions options(int16_t device_index = -1) const {
- return TensorOptions().dtype(typeMeta())
+ TensorOptions options(ScalarType s, int16_t device_index = -1) const {
+ return TensorOptions().dtype(s)
.device(device_type(), device_index)
.layout(layout())
.is_variable(is_variable());
@@ -129,20 +128,16 @@ struct CAFFE2_API Type {
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
/// the device type matches the device type of the type.
- TensorOptions options(c10::optional<Device> device_opt) const {
+ TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
if (!device_opt.has_value()) {
- return options(-1);
+ return options(s, -1);
} else {
Device device = device_opt.value();
AT_ASSERT(device.type() == device_type());
- return options(device.index());
+ return options(s, device.index());
}
}
- operator TensorOptions() const {
- return options();
- }
-
// example
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
${pure_virtual_type_method_declarations}
diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp
index cebb374203..6e345b306e 100644
--- a/test/cpp/api/tensor_options.cpp
+++ b/test/cpp/api/tensor_options.cpp
@@ -66,10 +66,10 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) {
options = TensorOptions(kInt);
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
- options = TensorOptions(getNonVariableType(Backend::SparseCPU, kFloat));
+ options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kFloat));
REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse);
- options = TensorOptions(getNonVariableType(Backend::SparseCPU, kByte));
+ options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte));
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
}
@@ -77,7 +77,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) {
auto options = empty(5, kDouble).options();
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
- options = empty(5, getNonVariableType(Backend::SparseCPU, kByte)).options();
+ options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options();
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
}
diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp
index a048ce5ac6..04229f58c2 100644
--- a/test/cpp/api/tensor_options_cuda.cpp
+++ b/test/cpp/api/tensor_options_cuda.cpp
@@ -42,17 +42,17 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) {
options = CUDA(kInt).options();
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
- options = getNonVariableType(Backend::SparseCUDA, kFloat).options();
+ options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options();
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
- options = getNonVariableType(Backend::SparseCUDA, kByte).options();
+ options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options();
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
options = CUDA(kFloat).options(/*device=*/5);
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
options =
- getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5);
+ getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5);
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
}
@@ -60,7 +60,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) {
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
- options = empty(5, getNonVariableType(Backend::SparseCUDA, kByte)).options();
+ options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options();
REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse);
if (torch::cuda::device_count() > 1) {
diff --git a/tools/autograd/templates/Functions.h b/tools/autograd/templates/Functions.h
index 99b6ab3a2c..bd81dc13ed 100644
--- a/tools/autograd/templates/Functions.h
+++ b/tools/autograd/templates/Functions.h
@@ -35,13 +35,13 @@ struct TypeAndSize {
/* implicit */
TypeAndSize(const Tensor & t)
: sizes(t.sizes().vec())
- , type(&t.dispatch_type()) {}
+ , type(&t.type()) {}
Tensor zeros() { return at::zeros(sizes, *type); }
private:
std::vector<int64_t> sizes;
- Type* type;
+ at::DeprecatedTypeProperties* type;
};
${autograd_function_declarations}
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index b4716ddb3a..db9b0c4f0b 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -334,7 +334,7 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, const
return outputs;
}
-static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
+static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) {
// Types are compatible if they exactly match or if the gradient is a sparse
// version of the expected type.
return expected == actual || (actual.is_sparse() &&
@@ -372,7 +372,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
}
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
}
- if (!is_compatible_type(metadata.type(), grads[i].dispatch_type())) {
+ if (!is_compatible_type(metadata.type(), grads[i].type())) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected type ";
ss << metadata.type() << " but got " << grads[i].type();
diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h
index 380dfd9cc7..512cfb60cb 100644
--- a/torch/csrc/autograd/function.h
+++ b/torch/csrc/autograd/function.h
@@ -130,7 +130,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
/// Adds the type and shape metadata for a new input. Returns the index of
/// of the new input.
uint32_t add_input_metadata(
- const at::Type& type
+ const at::DeprecatedTypeProperties& type
, at::IntArrayRef shape
, at::Device device) noexcept {
uint32_t input_nr = input_metadata_.size();
diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h
index df767de639..94655989d2 100644
--- a/torch/csrc/autograd/input_metadata.h
+++ b/torch/csrc/autograd/input_metadata.h
@@ -12,17 +12,17 @@ namespace torch { namespace autograd {
struct InputMetadata {
InputMetadata() = default;
- InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device)
+ InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device)
: type_{&type} , shape_{shape}, device_{device} { }
InputMetadata(const at::Tensor& t)
- : InputMetadata(t.dispatch_type(), t.sizes(), t.device()) { }
+ : InputMetadata(t.type(), t.sizes(), t.device()) { }
bool is_valid() const {
return type_ != nullptr;
}
- const at::Type& type() const {
+ const at::DeprecatedTypeProperties& type() const {
AT_ASSERT(type_);
return *type_;
}
@@ -40,7 +40,7 @@ struct InputMetadata {
}
private:
- const at::Type* type_ = nullptr;
+ const at::DeprecatedTypeProperties* type_ = nullptr;
at::DimVector shape_;
at::Device device_ = at::kCPU;
};
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 50fc16094e..cc4e23fd5a 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -46,6 +46,7 @@ namespace torch { namespace autograd {
VariableInfo::VariableInfo(const Variable& var)
: type(&var.dispatch_type())
, device(var.device())
+ , scalar_type(var.scalar_type())
, size(var.sizes().vec())
, requires_grad(var.requires_grad()) {
}
@@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var)
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
// NB: This will NOT work if we ever get mixed device gradients
device_guard.reset_device(device);
- return at::zeros(size, type->options());
+ return at::zeros(size, type->options(scalar_type));
}
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h
index 52ead11ba6..a2d1c6c87e 100644
--- a/torch/csrc/autograd/python_function.h
+++ b/torch/csrc/autograd/python_function.h
@@ -25,6 +25,7 @@ struct VariableInfo {
at::Type* type;
at::Device device = at::kCPU;
+ at::ScalarType scalar_type = at::kFloat;
std::vector<int64_t> size;
bool requires_grad;
};
diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp
index 932e4b241c..b61e6fbdb1 100644
--- a/torch/csrc/autograd/python_legacy_variable.cpp
+++ b/torch/csrc/autograd/python_legacy_variable.cpp
@@ -46,7 +46,8 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
if (!data || data == Py_None) {
// For legacy serialization code, create an empty tensor. This is also used
// by nn.Parameter() with no arguments.
- auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options());
+ auto scalar_type = torch::tensors::get_default_scalar_type();
+ auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type));
tensor = static_cast<Variable&>(var).data();
} else if (THPVariable_Check(data)) {
tensor = ((THPVariable*)data)->cdata.data();
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
index eff82a92eb..1415a48544 100644
--- a/torch/csrc/autograd/python_variable_indexing.cpp
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
@@ -110,15 +110,15 @@ static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
}
-static Variable valueToTensor(const at::Type & type, PyObject* value) {
+static Variable valueToTensor(const at::Type & type, const ScalarType scalar_type, PyObject* value) {
if (THPVariable_Check(value)) {
return reinterpret_cast<THPVariable*>(value)->cdata;
}
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
- return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
+ return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options(scalar_type));
}
if (PyFloat_Check(value)) {
- return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options());
+ return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options(scalar_type));
}
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
}
@@ -334,7 +334,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
OptionalDeviceGuard device_guard(device_of(self_));
- auto value = valueToTensor(self_.dispatch_type(), py_value);
+ auto value = valueToTensor(self_.dispatch_type(), self_.scalar_type(), py_value);
// handle simple types: integers, slices, ellipsis, bool
if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index a330230f01..9fa2033c15 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -214,7 +214,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
fn->storage_offset = data().storage_offset();
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
fn->add_input_metadata(
- diff_view_meta->base_.dispatch_type()
+ diff_view_meta->base_.type()
, sizes() // Note: sizes(), not base_.sizes(), is intentional
, diff_view_meta->base_.device());
diff_view_meta->grad_fn_ = std::move(fn);
diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp
index 53faa6baa5..a1743355bf 100644
--- a/torch/csrc/cuda/comm.cpp
+++ b/torch/csrc/cuda/comm.cpp
@@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
tensors.push_back(tensor);
for (auto device : devices.slice(1)) {
_device_guard.set_index(device);
- tensors.push_back(at::empty(tensor.sizes(), type.options()));
+ tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type())));
}
nccl::broadcast(tensors);
} else {
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index c221dd9f01..39b87b6a25 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -157,10 +157,9 @@ class ShapePropagator {
return *iv;
}
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
- auto backend =
- type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
+ auto attype = type->device().is_cpu() ?
+ at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
at::DeviceGuard device_guard(type->device());
- auto& attype = at::getNonVariableType(backend, type->scalarType());
auto t =
at::empty_strided(type->sizes(), type->strides(), attype.options())
.zero_();
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 908a27fd4a..6c23d0a13c 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -53,32 +53,32 @@ void maybe_initialize_cuda(const Device device) {
}
}
-Tensor dispatch_zeros(const Type& type, optional<Device> device, IntArrayRef sizes) {
+Tensor dispatch_zeros(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
maybe_initialize_cuda(type);
AutoNoGIL no_gil;
- return torch::zeros(sizes, type.options(std::move(device)));
+ return torch::zeros(sizes, type.options(scalar_type, std::move(device)));
}
-Tensor dispatch_ones(const Type& type, optional<Device> device, IntArrayRef sizes) {
+Tensor dispatch_ones(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
maybe_initialize_cuda(type);
AutoNoGIL no_gil;
- return torch::ones(sizes, type.options(std::move(device)));
+ return torch::ones(sizes, type.options(scalar_type, std::move(device)));
}
-Tensor dispatch_full(const Type& type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
+Tensor dispatch_full(const Type& type, const ScalarType scalar_type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
maybe_initialize_cuda(type);
AutoNoGIL no_gil;
- return torch::full(sizes, fill_value, type.options(std::move(device)));
+ return torch::full(sizes, fill_value, type.options(scalar_type, std::move(device)));
}
-Tensor new_with_sizes(const Type& type, optional<Device> device, IntArrayRef sizes) {
+Tensor new_with_sizes(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
maybe_initialize_cuda(type);
AutoNoGIL no_gil;
- return torch::empty(sizes, type.options(std::move(device)));
+ return torch::empty(sizes, type.options(scalar_type, std::move(device)));
}
-Tensor new_with_storage(const Type& type, Storage storage) {
- auto tensor = at::empty({}, type.options());
+Tensor new_with_storage(const Type& type, const ScalarType scalar_type, Storage storage) {
+ auto tensor = at::empty({}, type.options(scalar_type));
tensor.set_(std::move(storage));
return tensor;
}
@@ -281,7 +281,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
if (r.idx == 0) {
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
- return at::empty({0}, type.options(r.deviceOptional(0)));
+ return at::empty({0}, type.options(scalar_type, r.deviceOptional(0)));
} else if (r.idx == 1) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
@@ -304,7 +304,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
}
- return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
+ return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
}
throw std::runtime_error("new(): invalid arguments");
}
@@ -323,7 +323,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::OptionalDeviceGuard device_guard(deviceOptional);
- return at::empty({0}, type.options());
+ return at::empty({0}, type.options(scalar_type));
} else if (r.idx == 1) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
@@ -350,7 +350,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
}
- return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
+ return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
}
throw std::runtime_error("new(): invalid arguments");
}
@@ -384,9 +384,9 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::OptionalDeviceGuard device_guard(deviceOptional);
- return at::empty({0}, type.options());
+ return at::empty({0}, type.options(scalar_type));
} else if (r.idx == 1) {
- return new_with_storage(type, r.storage(0));
+ return new_with_storage(type, scalar_type, r.storage(0));
} else if (r.idx == 2) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
@@ -401,7 +401,7 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
}
- return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
+ return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
} else if (r.idx == 5) {
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
@@ -430,9 +430,9 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
auto deviceOptional = r.deviceOptional(0);
check_legacy_ctor_device(type, deviceOptional);
at::OptionalDeviceGuard device_guard(deviceOptional);
- return at::empty({0}, type.options());
+ return at::empty({0}, type.options(scalar_type));
} else if (r.idx == 1) {
- return new_with_storage(type, r.storage(0));
+ return new_with_storage(type, scalar_type, r.storage(0));
} else if (r.idx == 2) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
@@ -447,7 +447,7 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
// unless the sequences is a torch.Size
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
}
- return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
+ return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
} else if (r.idx == 5) {
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(type, deviceOptional);
@@ -504,8 +504,9 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type,
return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
} else if (r.idx == 2) {
const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type);
+ const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
- return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3));
+ return at::sparse_coo_tensor(r.intlist(0), type.options(actual_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3));
}
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
}
@@ -603,7 +604,8 @@ Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
- return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
+ const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
+ return new_with_sizes(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
}
throw std::runtime_error("new_empty(): invalid arguments");
}
@@ -617,7 +619,8 @@ Tensor new_full(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type);
- return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
+ const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type);
+ return dispatch_full(actual_type, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
}
throw std::runtime_error("new_full(): invalid arguments");
}
@@ -631,7 +634,8 @@ Tensor new_ones(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
- return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
+ const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
+ return dispatch_ones(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
}
throw std::runtime_error("new_ones(): invalid arguments");
}
@@ -645,7 +649,8 @@ Tensor new_zeros(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
- return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
+ const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
+ return dispatch_zeros(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
}
throw std::runtime_error("new_zeros(): invalid arguments");
}