summaryrefslogtreecommitdiff
path: root/torch/csrc
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-04-06 15:12:05 -0400
committerGitHub <noreply@github.com>2018-04-06 15:12:05 -0400
commit87e369111a5defce56b4f0b3fac53970fa0d5462 (patch)
treeec66c22523a8ac9932c15ed510132b1d8ca28687 /torch/csrc
parentfc7aa5c3be8c7a6581e649a50f11c9cb038d7b0b (diff)
downloadpytorch-87e369111a5defce56b4f0b3fac53970fa0d5462.tar.gz
pytorch-87e369111a5defce56b4f0b3fac53970fa0d5462.tar.bz2
pytorch-87e369111a5defce56b4f0b3fac53970fa0d5462.zip
Add string-style devices to all tensors. (#6283)
* Add string-style devices to all tensors. Previously, tensors only had a 'get_device' method which would throw an exception on a CPU tensor. This made it necessary to if/else code that was meant to be device agnostic. This PR implements the following: 1) Adds a 'device' property to all tensors that returns a string representation of the device for all tensors. For cpu tensors this is 'cpu'. For cuda tensors this is 'cuda:X', where X is the cuda device ordinal. 2) Adds a DeviceSpec class. This is just a helper class for separating device_type and device_index specification and to allow partial specification. For example, you can call DeviceSpec('cuda'), DeviceSpec('cuda:0'), DeviceSpec('cuda', 1). Also has backwards compatibility support for specifying integers, which are treated as cuda devices. DeviceSpecs have the following properties: a) device_type: string representation of the device type (i.e. 'cpu' or 'cuda') b) device_index: integer for the device index (None if not specified) c) cuda_device_index: for backwards compatibility; behaves roughly like `get_device` did previously. I.e. if a function previously took integers for cuda devices, it can now take DeviceSpecs (or strings), and can maintain the old functionality by calling `old_index = DeviceSpec(old).cuda_device_index`. 3) tensor methods and torch. functions that took integer devices can now take integers, strings, or DeviceSpecs. For example: torch.randn((2,3), dtype=torch.cuda.float32, device='cuda:1') TODO in future PRs: A) Split out cuda from dtype so you don't need to overspecify cuda-ness B) We currently only support strings/DeviceSpecs in tensor methods and torch. functions. We should have equivalents torch.cuda.device(...), torch.cuda.device_of, etc. at the torch. level that work on strings/DeviceSpecs * Add deviceInt64 to python arg parser. * device_str. * Remove device_str. * remove device prefix from attributes. * Use const char * instead of string. * Move autogpu index out of Device. * comment on is_default. * Rename torch.DeviceSpec to torch.device. * comment. * Fix tests. * Fix flake8. * Fix sparse_coo_tensor parameter name. * Improve error message. * Remove device_ prefix from C++ device object. * Allocate static strings. * Return not implemented from rich compare. * Move torch::Device to THPDevice. * Remove cuda index. * Py_RETURN_NOTIMPLEMENTED doesn't exist in python2.
Diffstat (limited to 'torch/csrc')
-rw-r--r--torch/csrc/Device.cpp198
-rw-r--r--torch/csrc/Device.h19
-rw-r--r--torch/csrc/Module.cpp2
-rw-r--r--torch/csrc/autograd/python_variable.cpp17
-rw-r--r--torch/csrc/utils/device.cpp31
-rw-r--r--torch/csrc/utils/device.h17
-rw-r--r--torch/csrc/utils/python_arg_parser.cpp15
-rw-r--r--torch/csrc/utils/python_arg_parser.h50
-rw-r--r--torch/csrc/utils/tensor_new.cpp100
9 files changed, 398 insertions, 51 deletions
diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp
new file mode 100644
index 0000000000..442af5c948
--- /dev/null
+++ b/torch/csrc/Device.cpp
@@ -0,0 +1,198 @@
+#include "Device.h"
+
+#include <cstring>
+#include <structmember.h>
+#include <sstream>
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/utils/object_ptr.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/python_strings.h"
+
+PyObject *THPDevice_New(const torch::Device& device)
+{
+ auto type = (PyTypeObject*)&THPDeviceType;
+ auto self = THPObjectPtr{type->tp_alloc(type, 0)};
+ if (!self) throw python_error();
+ auto self_ = reinterpret_cast<THPDevice*>(self.get());
+ self_->device = device;
+ return self.release();
+}
+
+static const char* cuda_str = "cuda";
+static const char* cpu_str = "cpu";
+
+static inline const char* deviceTypeString(torch::DeviceType device_type) {
+ switch (device_type) {
+ case torch::DeviceType::CUDA:
+ return cuda_str;
+ case torch::DeviceType::CPU:
+ return cpu_str;
+ default:
+ throw std::runtime_error("unexpected device type");
+ }
+}
+
+PyObject *THPDevice_repr(THPDevice *self)
+{
+ std::ostringstream oss;
+ oss << "Device(device_type=\'" << deviceTypeString(self->device.type) << "\'";
+ if (!self->device.is_default) {
+ oss << ", device_index=" << self->device.index;
+ }
+ oss << ")";
+ return THPUtils_packString(oss.str().c_str());
+}
+
+PyObject *THPDevice_str(THPDevice*self)
+{
+ std::ostringstream oss;
+ if (!self->device.is_default) {
+ oss << deviceTypeString(self->device.type) << ":" << self->device.index;
+ } else {
+ oss << deviceTypeString(self->device.type);
+ }
+ return THPUtils_packString(oss.str().c_str());
+}
+
+PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
+{
+ HANDLE_TH_ERRORS
+ static torch::PythonArgParser parser({
+ "Device(Device device)",
+ "Device(String device_type, int64_t? device_index=-1)"
+ });
+ torch::ParsedArgs<2> parsed_args;
+ auto r = parser.parse(args, kwargs, parsed_args);
+ if (r.idx == 0) {
+ auto device = r.device(0);
+ return THPDevice_New(device);
+ } else if (r.idx == 1) {
+ auto as_device = r.device(0); // this works, because device can take strings
+ auto device_type = r.string(0);
+ if (!as_device.is_default) {
+ throw std::runtime_error("device_type (string) must not include an index because index "
+ "was passed explicitly: " + device_type);
+ }
+
+ auto is_default = r.isNone(1);
+ auto device_index = r.toInt64WithDefault(1, -1);
+ // make sure this is constructible
+ auto device = torch::Device(as_device.type, device_index, is_default);
+ return THPDevice_New(device);
+ }
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPDevice_type(THPDevice *self)
+{
+ HANDLE_TH_ERRORS
+ return THPUtils_packString(deviceTypeString(self->device.type));
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPDevice_index(THPDevice *self)
+{
+ HANDLE_TH_ERRORS
+ if (self->device.is_default) {
+ Py_RETURN_NONE;
+ } else {
+ return THPUtils_packInt64(self->device.index);
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
+ HANDLE_TH_ERRORS
+ if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
+ // Py_RETURN_NOTIMPLEMENTED not in python 2.
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ THPDevice *da = reinterpret_cast<THPDevice*>(a);
+ THPDevice *db = reinterpret_cast<THPDevice*>(b);
+
+ switch(op) {
+ case Py_EQ:
+ if (da->device == db->device) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+ case Py_NE:
+ if (da->device == db->device) {
+ Py_RETURN_FALSE;
+ } else {
+ Py_RETURN_TRUE;
+ }
+ case Py_LT:
+ case Py_LE:
+ case Py_GT:
+ case Py_GE:
+ throw torch::TypeError("comparison not implemented");
+ default:
+ throw torch::TypeError("unexpected comparison op");
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+typedef PyObject *(*getter)(PyObject *, void *);
+
+static struct PyGetSetDef THPDevice_properties[] = {
+ {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
+ {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
+ {nullptr}
+};
+
+PyTypeObject THPDeviceType = {
+ PyVarObject_HEAD_INIT(nullptr, 0)
+ "torch.Device", /* tp_name */
+ sizeof(THPDevice), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ 0, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_reserved */
+ (reprfunc)THPDevice_repr, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ (reprfunc)THPDevice_str, /* tp_str */
+ 0, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ nullptr, /* tp_doc */
+ 0, /* tp_traverse */
+ 0, /* tp_clear */
+ (richcmpfunc)THPDevice_rc, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ 0, /* tp_iter */
+ 0, /* tp_iternext */
+ 0, /* tp_methods */
+ 0, /* tp_members */
+ THPDevice_properties, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ THPDevice_pynew, /* tp_new */
+};
+
+void THPDevice_init(PyObject *module)
+{
+ if (PyType_Ready(&THPDeviceType) < 0) {
+ throw python_error();
+ }
+ Py_INCREF(&THPDeviceType);
+ if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) {
+ throw python_error();
+ }
+}
diff --git a/torch/csrc/Device.h b/torch/csrc/Device.h
new file mode 100644
index 0000000000..76057601e1
--- /dev/null
+++ b/torch/csrc/Device.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <Python.h>
+#include "torch/csrc/utils/device.h"
+
+struct THPDevice {
+ PyObject_HEAD
+ torch::Device device;
+};
+
+extern PyTypeObject THPDeviceType;
+
+inline bool THPDevice_Check(PyObject *obj) {
+ return Py_TYPE(obj) == &THPDeviceType;
+}
+
+PyObject * THPDevice_New(const torch::Device& device);
+
+void THPDevice_init(PyObject *module);
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 40499737a1..6263e6c474 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -18,6 +18,7 @@
#include "THP.h"
#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Device.h"
#include "torch/csrc/Dtype.h"
#include "torch/csrc/DataLoader.h"
#include "torch/csrc/Generator.h"
@@ -463,6 +464,7 @@ static PyObject* initModule() {
THPSize_init(module);
THPDtype_init(module);
THPLayout_init(module);
+ THPDevice_init(module);
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index d377b0e98a..7ea866f1f0 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -1,6 +1,7 @@
#include "torch/csrc/autograd/python_variable.h"
#include "THP.h"
+#include "torch/csrc/Device.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Size.h"
@@ -27,6 +28,7 @@
#include <list>
#include <memory>
#include <structmember.h>
+#include <sstream>
using namespace at;
using namespace torch;
@@ -387,6 +389,20 @@ static PyObject * THPVariable_layout(THPVariable* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
+static PyObject * THPVariable_device(THPVariable* self, PyObject* args) {
+ HANDLE_TH_ERRORS
+ auto& self_ = self->cdata;
+ if (self_.type().is_cuda()) {
+ torch::Device device(torch::DeviceType::CUDA, self_.get_device(), false);
+ return THPDevice_New(device);
+ }
+ else {
+ torch::Device device(torch::DeviceType::CPU, -1, true);
+ return THPDevice_New(device);
+ }
+ END_HANDLE_TH_ERRORS
+}
+
static struct PyGetSetDef THPVariable_properties[] = {
{"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
{"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
@@ -407,6 +423,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
{"dtype", (getter)THPVariable_dtype, NULL, NULL, NULL},
{"layout", (getter)THPVariable_layout, NULL, NULL, NULL},
+ {"device", (getter)THPVariable_device, NULL, NULL, NULL},
{nullptr}
};
diff --git a/torch/csrc/utils/device.cpp b/torch/csrc/utils/device.cpp
new file mode 100644
index 0000000000..330a0989a5
--- /dev/null
+++ b/torch/csrc/utils/device.cpp
@@ -0,0 +1,31 @@
+#include "device.h"
+#include <stdexcept>
+#include <string>
+
+namespace torch {
+
+Device::Device(DeviceType type, int64_t index, bool is_default)
+ : type(type), index(index), is_default(is_default) {
+ if (!is_default) {
+ switch (type) {
+ case DeviceType::CPU:
+ if (index != 0) {
+ throw std::runtime_error("cpu device index must be 0, got " + std::to_string(index));
+ }
+ break;
+ case DeviceType::CUDA:
+ if (index < 0) {
+ throw std::runtime_error("device index must be positive, got " + std::to_string(index));
+ }
+ break;
+ default:
+ throw std::runtime_error("unexpected DeviceType");
+ }
+ }
+}
+
+bool Device::operator==(const Device& rhs) {
+ return this->type == rhs.type && this->index == rhs.index && this->is_default == rhs.is_default;
+}
+
+}
diff --git a/torch/csrc/utils/device.h b/torch/csrc/utils/device.h
new file mode 100644
index 0000000000..91c53781e8
--- /dev/null
+++ b/torch/csrc/utils/device.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <cstdint>
+
+namespace torch {
+
+enum class DeviceType {CPU=0, CUDA=1};
+
+struct Device {
+ DeviceType type;
+ int64_t index;
+ bool is_default; // is default device for type.
+ Device(DeviceType type, int64_t index, bool is_default);
+ bool operator==(const Device& rhs);
+};
+
+}
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index 4f2064852c..4327dda5b6 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -25,6 +25,8 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"PyObject*", ParameterType::PYOBJECT},
{"Dtype", ParameterType::DTYPE},
{"Layout", ParameterType::LAYOUT},
+ {"Device", ParameterType::DEVICE},
+ {"String", ParameterType::STRING},
};
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
@@ -111,6 +113,9 @@ bool FunctionParameter::check(PyObject* obj) {
case ParameterType::PYOBJECT: return true;
case ParameterType::DTYPE: return THPDtype_Check(obj);
case ParameterType::LAYOUT: return THPLayout_Check(obj);
+ case ParameterType::DEVICE:
+ return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
+ case ParameterType::STRING: return THPUtils_checkString(obj);
default: throw std::runtime_error("unknown parameter type");
}
}
@@ -129,6 +134,8 @@ std::string FunctionParameter::type_name() const {
case ParameterType::PYOBJECT: return "object";
case ParameterType::DTYPE: return "torch.dtype";
case ParameterType::LAYOUT: return "torch.layout";
+ case ParameterType::DEVICE: return "torch.device";
+ case ParameterType::STRING: return "str";
default: throw std::runtime_error("unknown parameter type");
}
}
@@ -175,6 +182,14 @@ void FunctionParameter::set_default_str(const std::string& str) {
} else {
throw std::runtime_error("invalid default value for dtype: " + str);
}
+ } else if (type_ == ParameterType::DEVICE) {
+ if (str != "None") {
+ throw std::runtime_error("invalid device: " + str);
+ }
+ } else if (type_ == ParameterType::STRING) {
+ if (str != "None" || str != "") {
+ throw std::runtime_error("invalid default string: " + str);
+ }
}
}
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index c45f5fa111..590fd051f4 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -26,22 +26,25 @@
#include <vector>
#include <ATen/ATen.h>
-#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Device.h"
#include "torch/csrc/Dtype.h"
+#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/Generator.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/tensor/python_tensor.h"
+#include "torch/csrc/utils/device.h"
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_numbers.h"
+#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/numpy_stub.h"
namespace torch {
enum class ParameterType {
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
- BOOL, STORAGE, PYOBJECT, DTYPE, LAYOUT
+ BOOL, STORAGE, PYOBJECT, DTYPE, LAYOUT, DEVICE, STRING
};
struct FunctionParameter;
@@ -93,6 +96,9 @@ struct PythonArgs {
inline const THPDtype& dtype(int i);
inline const THPDtype& dtypeWithDefault(int i, const THPDtype& default_dtype);
inline const THPLayout& layout(int i);
+ inline Device device(int i);
+ inline int64_t deviceInt64(int i);
+ inline std::string string(int i);
inline PyObject* pyobject(int i);
inline int64_t toInt64(int i);
inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -272,6 +278,46 @@ inline const THPLayout& PythonArgs::layout(int i) {
return *reinterpret_cast<THPLayout*>(args[i]);
}
+static std::string cuda_str = "cuda";
+static std::string cpu_str = "cpu";
+static std::string cuda_prefix = "cuda:";
+static std::string cpu_prefix = "cpu:";
+
+inline Device PythonArgs::device(int i) {
+ if (!args[i]) return Device(DeviceType::CPU, -1, true); // TODO: use CUDA if default type is a cuda type.
+ if (THPDevice_Check(args[i])) {
+ auto device = reinterpret_cast<THPDevice*>(args[i]);
+ return device->device;
+ }
+ if (THPUtils_checkLong(args[i])) {
+ auto index = THPUtils_unpackLong(args[i]);
+ return Device(DeviceType::CUDA, index, index == -1);
+ }
+ std::string device_str = THPUtils_unpackString(args[i]);
+ if (device_str == cpu_str) {
+ return Device(DeviceType::CPU, -1, true);
+ } else if (device_str == cuda_str) {
+ return Device(DeviceType::CUDA, -1, true);
+ } else if (device_str.compare(0, cpu_prefix.length(), cpu_prefix) == 0) {
+ auto device_index = std::stoi(device_str.substr(cpu_prefix.length()));
+ return Device(DeviceType::CPU, device_index, false);
+ } else if (device_str.compare(0, cuda_prefix.length(), cuda_prefix) == 0) {
+ auto device_index = std::stoi(device_str.substr(cuda_prefix.length()));
+ return Device(DeviceType::CUDA, device_index, false);
+ }
+ throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
+}
+
+inline int64_t PythonArgs::deviceInt64(int i) {
+ auto dev = device(i);
+ return (dev.is_default || dev.type == DeviceType::CPU) ? -1 : dev.index;
+}
+
+inline std::string PythonArgs::string(int i) {
+ if (!args[i]) return "";
+ return THPUtils_unpackString(args[i]);
+}
+
inline int64_t PythonArgs::toInt64(int i) {
if (!args[i]) return signature.params[i].default_int;
return THPUtils_unpackLong(args[i]);
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index fbafa9fb3a..25d6145792 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -221,33 +221,33 @@ static Tensor legacy_new_from_sequence(const Type & type, int device, PyObject*
static Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new(*, int64_t? device=-1)",
- "new(IntList size, *, int64_t? device=-1)",
+ "new(*, Device? device=None)",
+ "new(IntList size, *, Device? device=None)",
"new(*, int64_t cdata)|hidden",
- "new(Tensor indices, Tensor values, *, int64_t? device=-1)",
- "new(Tensor indices, Tensor values, IntList size, *, int64_t? device=-1)",
+ "new(Tensor indices, Tensor values, *, Device? device=None)",
+ "new(Tensor indices, Tensor values, IntList size, *, Device? device=None)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- AutoGPU auto_gpu(r.toInt64(0));
+ AutoGPU auto_gpu(r.deviceInt64(0));
return type.tensor();
} else if (r.idx == 1) {
PyObject* arg = r.pyobject(0);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
- return new_with_sizes(type, r.toInt64(1), r.intlist(0));
+ return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
} else if (r.idx == 2) {
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return type.unsafeTensorFromTH(cdata, true);
} else if (r.idx == 3) {
- AutoGPU auto_gpu(r.toInt64(2));
+ AutoGPU auto_gpu(r.deviceInt64(2));
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1));
} else if (r.idx == 4) {
- AutoGPU auto_gpu(r.toInt64(3));
+ AutoGPU auto_gpu(r.deviceInt64(3));
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
}
throw std::runtime_error("new(): invalid arguments");
@@ -255,12 +255,12 @@ static Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObje
Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new(*, int64_t? device=-1)",
- "new(IntList size, *, int64_t? device=-1)",
+ "new(*, Device? device=None)",
+ "new(IntList size, *, Device? device=None)",
"new(Storage storage)",
"new(*, int64_t cdata)|hidden",
"new(Tensor other)",
- "new(PyObject* data, *, int64_t? device=-1)",
+ "new(PyObject* data, *, Device? device=None)",
});
if (type.is_sparse()) {
@@ -270,16 +270,16 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- AutoGPU auto_gpu(r.toInt64(0));
+ AutoGPU auto_gpu(r.deviceInt64(0));
return type.tensor();
} else if (r.idx == 1) {
PyObject* arg = r.pyobject(0);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
- return new_with_sizes(type, r.toInt64(1), r.intlist(0));
+ return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
} else if (r.idx == 2) {
return new_with_storage(type, *r.storage(0));
} else if (r.idx == 3) {
@@ -288,44 +288,44 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
} else if (r.idx == 4) {
return new_with_tensor(type, r.tensor(0));
} else if (r.idx == 5) {
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
throw std::runtime_error("new(): invalid arguments");
}
static Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new(*, int64_t? device=-1)",
- "new(IntList size, *, int64_t? device=-1)",
+ "new(*, Device? device=None)",
+ "new(IntList size, *, Device? device=None)",
"new(*, int64_t cdata)|hidden",
- "new(Tensor indices, Tensor values, *, int64_t? device=-1)",
- "new(Tensor indices, Tensor values, IntList size, *, int64_t? device=-1)",
+ "new(Tensor indices, Tensor values, *, Device? device=None)",
+ "new(Tensor indices, Tensor values, IntList size, *, Device? device=None)",
});
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- AutoGPU auto_gpu(r.toInt64(0));
+ AutoGPU auto_gpu(r.deviceInt64(0));
return type.tensor();
} else if (r.idx == 1) {
PyObject* arg = r.pyobject(0);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
- return new_with_sizes(type, r.toInt64(1), r.intlist(0));
+ return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
} else if (r.idx == 2) {
- auto cdata = reinterpret_cast<void*>(r.toInt64(0));
+ auto cdata = reinterpret_cast<void*>(r.deviceInt64(0));
return type.unsafeTensorFromTH(cdata, true);
} else if (r.idx == 3) {
// Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
// have a device (we should infer it).
- AutoGPU auto_gpu(r.toInt64(2));
+ AutoGPU auto_gpu(r.deviceInt64(2));
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1));
} else if (r.idx == 4) {
// Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
// have a device (we should infer it).
- AutoGPU auto_gpu(r.toInt64(3));
+ AutoGPU auto_gpu(r.deviceInt64(3));
return type.sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
}
throw std::runtime_error("new(): invalid arguments");
@@ -333,12 +333,12 @@ static Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObjec
Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new(*, int64_t? device=-1)",
- "new(IntList size, *, int64_t? device=-1)",
+ "new(*, Device? device=None)",
+ "new(IntList size, *, Device? device=None)",
"new(Storage storage)",
"new(*, int64_t cdata)|hidden",
"new(Tensor other)", // this doesn't have a dtype/device because it creates an alias.
- "new(PyObject* data, *, int64_t? device=-1)",
+ "new(PyObject* data, *, Device? device=None)",
});
if (type.is_sparse()) {
@@ -348,16 +348,16 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- AutoGPU auto_gpu(r.toInt64(0));
+ AutoGPU auto_gpu(r.deviceInt64(0));
return type.tensor();
} else if (r.idx == 1) {
PyObject* arg = r.pyobject(0);
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
// new(sequence) binds to this signature but should be treated differently
// unless the sequences is a torch.Size
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
- return new_with_sizes(type, r.toInt64(1), r.intlist(0));
+ return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
} else if (r.idx == 2) {
return new_with_storage(type, *r.storage(0));
} else if (r.idx == 3) {
@@ -366,7 +366,7 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
} else if (r.idx == 4) {
return new_with_tensor(type, r.tensor(0));
} else if (r.idx == 5) {
- return legacy_new_from_sequence(type, r.toInt64(1), r.pyobject(0));
+ return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
}
throw std::runtime_error("new(): invalid arguments");
}
@@ -386,8 +386,8 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
const auto& default_sparse_type = type.toBackend(sparse_backend);
static PythonArgParser parser({
- "sparse_coo_tensor(PyObject* indices, PyObject* values, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
- "sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "sparse_coo_tensor(PyObject* indices, PyObject* values, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
+ "sparse_coo_tensor(PyObject* indices, PyObject* values, IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<6> parsed_args;
@@ -397,7 +397,7 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
const auto& sparse_type = typeWithDefault(r, 2, default_sparse_type);
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
const auto& index_type = dense_type.toScalarType(kLong);
- AutoGPU autogpu(r.toInt64(3));
+ AutoGPU autogpu(r.deviceInt64(3));
// explanation of booleans: allow variables, do type conversion of them, copy numpy data
Tensor indices = internal_new_from_data(index_type, -1, r.pyobject(0), false, true, false);
Tensor values = internal_new_from_data(dense_type, -1, r.pyobject(1), false, true, type_inference);
@@ -408,7 +408,7 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
const auto& sparse_type = typeWithDefault(r, 3, default_sparse_type);
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
const auto& index_type = dense_type.toScalarType(kLong);
- AutoGPU autogpu(r.toInt64(4));
+ AutoGPU autogpu(r.deviceInt64(4));
// explanation of booleans: allow variables, do type conversion of them, copy numpy data
Tensor indices = internal_new_from_data(index_type, -1, r.pyobject(0), false, true, false);
Tensor values = internal_new_from_data(dense_type, -1, r.pyobject(1), false, true, type_inference);
@@ -420,14 +420,15 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "tensor(PyObject* data, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
bool type_inference = r.isNone(1);
- return set_requires_grad(internal_new_from_data(typeWithDefault(r, 1, type), r.toInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
+ return set_requires_grad(internal_new_from_data(
+ typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
}
throw std::runtime_error("tensor(): invalid arguments");
}
@@ -435,69 +436,70 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new_tensor(PyObject* data, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "new_tensor(PyObject* data, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
- return set_requires_grad(new_from_data_copy(typeWithDefault(r, 1, type), r.toInt64(2), r.pyobject(0)), r.toBool(3));
+ return set_requires_grad(new_from_data_copy(
+ typeWithDefault(r, 1, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
}
throw std::runtime_error("new_tensor(): invalid arguments");
}
Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new_empty(IntList size, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "new_empty(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, type);
- return set_requires_grad(new_with_sizes(actual_type, r.toInt64(2), r.intlist(0)), r.toBool(3));
+ return set_requires_grad(new_with_sizes(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
}
throw std::runtime_error("new_empty(): invalid arguments");
}
Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new_full(IntList size, Scalar fill_value, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "new_full(IntList size, Scalar fill_value, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 2, type);
- return set_requires_grad(dispatch_full(actual_type, r.scalar(1), r.toInt64(3), r.intlist(0)), r.toBool(4));
+ return set_requires_grad(dispatch_full(actual_type, r.scalar(1), r.deviceInt64(3), r.intlist(0)), r.toBool(4));
}
throw std::runtime_error("new_full(): invalid arguments");
}
Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new_ones(IntList size, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "new_ones(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, type);
- return set_requires_grad(dispatch_ones(actual_type, r.toInt64(2), r.intlist(0)), r.toBool(3));
+ return set_requires_grad(dispatch_ones(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
}
throw std::runtime_error("new_ones(): invalid arguments");
}
Tensor new_zeros(const at::Type& type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
- "new_zeros(IntList size, *, Dtype dtype=None, int64_t? device=-1, bool requires_grad=False)",
+ "new_zeros(IntList size, *, Dtype dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.idx == 0) {
const auto& actual_type = typeWithDefault(r, 1, type);
- return set_requires_grad(dispatch_zeros(actual_type, r.toInt64(2), r.intlist(0)), r.toBool(3));
+ return set_requires_grad(dispatch_zeros(actual_type, r.deviceInt64(2), r.intlist(0)), r.toBool(3));
}
throw std::runtime_error("new_zeros(): invalid arguments");
}