diff options
author | Sam Gross <colesbury@gmail.com> | 2018-02-27 17:58:09 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-27 17:58:09 -0500 |
commit | 48a3349c29f1a8a6e1d9d9ad6108e5c15aaeb583 (patch) | |
tree | e22f4217cce1df0b6640f31506363296cbec5e7a | |
parent | 6b95ca4edacb5ab739daa89fdef3f69f22f4f24f (diff) | |
download | pytorch-48a3349c29f1a8a6e1d9d9ad6108e5c15aaeb583.tar.gz pytorch-48a3349c29f1a8a6e1d9d9ad6108e5c15aaeb583.tar.bz2 pytorch-48a3349c29f1a8a6e1d9d9ad6108e5c15aaeb583.zip |
Delete dead Tensor code paths (#5417)
This deletes most of the dead Tensor code paths, including the TensorMethods cwrap and generic/Tensor.cpp.
This also moves the THNN.cwrap/.cpp generation to generate_code which can use ninja if installed.
65 files changed, 212 insertions, 10632 deletions
diff --git a/aten/src/TH/generic/THVector.h b/aten/src/TH/generic/THVector.h index 8b7ded9f5e..a0feb61874 100644 --- a/aten/src/TH/generic/THVector.h +++ b/aten/src/TH/generic/THVector.h @@ -58,7 +58,4 @@ TH_API void THVector_(cinv)(real *y, const real *x, const ptrdiff_t n); #endif /* floating point only part */ -/* Initialize the dispatch pointers */ -TH_API void THVector_(vectorDispatchInit)(void); - #endif diff --git a/aten/src/TH/generic/THVectorDispatch.cpp b/aten/src/TH/generic/THVectorDispatch.cpp index 5145234dc9..572d5261ea 100644 --- a/aten/src/TH/generic/THVectorDispatch.cpp +++ b/aten/src/TH/generic/THVectorDispatch.cpp @@ -255,25 +255,30 @@ void THVector_(normal_fill)(real *data, THVector_(normal_fill_DISPATCHPTR)(data, size, generator, mean, stddev); } -/* This needs to be called in order to initialize the dispatch pointers at runtime. - * This function simply checks what SIMD extensions are available, and then walks the dispatch table +/* + * This struct's constructor initalizes the dispatch tables. It simply checks + * what SIMD extensions are available, and then walks the dispatch table * to choose the best function. * NOTE: As implemented, it will initialize the dispatch pointer to the first supported function. * This means that in the dispatch tables, implementations supporting more recent extensions * need to come first */ -void THVector_(vectorDispatchInit)(void) -{ - uint32_t hostSimdExts = detectHostSIMDExtensions(); - INIT_DISPATCH_PTR(fill); - INIT_DISPATCH_PTR(cadd); - INIT_DISPATCH_PTR(adds); - INIT_DISPATCH_PTR(cmul); - INIT_DISPATCH_PTR(muls); - INIT_DISPATCH_PTR(cdiv); - INIT_DISPATCH_PTR(divs); - INIT_DISPATCH_PTR(copy); - INIT_DISPATCH_PTR(normal_fill); -} +struct THVector_(startup) { + THVector_(startup)() { + uint32_t hostSimdExts = detectHostSIMDExtensions(); + INIT_DISPATCH_PTR(fill); + INIT_DISPATCH_PTR(cadd); + INIT_DISPATCH_PTR(adds); + INIT_DISPATCH_PTR(cmul); + INIT_DISPATCH_PTR(muls); + INIT_DISPATCH_PTR(cdiv); + INIT_DISPATCH_PTR(divs); + INIT_DISPATCH_PTR(copy); + INIT_DISPATCH_PTR(normal_fill); + } +}; + +// Declare a global instance to force static initialization +static THVector_(startup) THVector_(g_startup); #endif @@ -100,7 +100,6 @@ from tools.setup_helpers.nccl import WITH_NCCL, WITH_SYSTEM_NCCL, NCCL_LIB_DIR, NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB from tools.setup_helpers.nnpack import WITH_NNPACK from tools.setup_helpers.nvtoolext import NVTOOLEXT_HOME -from tools.setup_helpers.split_types import split_types from tools.setup_helpers.generate_code import generate_code from tools.setup_helpers.ninja_builder import NinjaBuilder, ninja_build_ext from tools.setup_helpers.dist_check import WITH_DISTRIBUTED, \ @@ -221,10 +220,6 @@ def build_libs(libs): if subprocess.call(build_libs_cmd + libs, env=my_env) != 0: sys.exit(1) - if 'ATen' in libs: - from tools.nnwrap import generate_wrappers as generate_nn_wrappers - generate_nn_wrappers() - class build_deps(Command): user_options = [] @@ -394,29 +389,6 @@ class build_ext(build_ext_parent): _C_LIB = os.path.join(build_temp, build_dir, lib_filename).replace('\\', '/') - THNN.extra_link_args += [_C_LIB] - if WITH_CUDA: - THCUNN.extra_link_args += [_C_LIB] - else: - # To generate .obj files for those .h files for the export class - # a header file cannot build, so it has to be copied to someplace as a source file - temp_dir = 'torch/csrc/generated' - hfile_list = ['torch/csrc/cuda/AutoGPU.h'] - hname_list = [os.path.basename(hfile) for hfile in hfile_list] - rname_list = [os.path.splitext(hname)[0] - for hname in hname_list] - cfile_list = [temp_dir + '/' + rname + - '_cpu_win.cpp' for rname in rname_list] - - if not os.path.exists(temp_dir): - os.mkdir(temp_dir) - - for hfile, cfile in zip(hfile_list, cfile_list): - if os.path.exists(cfile): - os.remove(cfile) - shutil.copyfile(hfile, cfile) - - C.sources += cfile_list if WITH_NINJA: # before we start the normal build make sure all generated code # gets built @@ -541,7 +513,6 @@ main_sources = [ "torch/csrc/assertions.cpp", "torch/csrc/byte_order.cpp", "torch/csrc/utils.cpp", - "torch/csrc/expand_utils.cpp", "torch/csrc/utils/invalid_arguments.cpp", "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", @@ -616,11 +587,11 @@ main_sources = [ "torch/csrc/autograd/functions/special.cpp", "torch/csrc/autograd/functions/utils.cpp", "torch/csrc/autograd/functions/init.cpp", + "torch/csrc/nn/THNN.cpp", "torch/csrc/tensor/python_tensor.cpp", "torch/csrc/onnx/onnx.pb.cpp", "torch/csrc/onnx/onnx.cpp", ] -main_sources += split_types("torch/csrc/Tensor.cpp", ninja_global) try: import numpy as np @@ -680,15 +651,13 @@ if WITH_CUDA: "torch/csrc/cuda/Module.cpp", "torch/csrc/cuda/Storage.cpp", "torch/csrc/cuda/Stream.cpp", - "torch/csrc/cuda/AutoGPU.cpp", "torch/csrc/cuda/utils.cpp", "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/python_comm.cpp", - "torch/csrc/cuda/expand_utils.cpp", "torch/csrc/cuda/lazy_init.cpp", "torch/csrc/cuda/serialization.cpp", + "torch/csrc/nn/THCUNN.cpp", ] - main_sources += split_types("torch/csrc/cuda/Tensor.cpp", ninja_global) if WITH_NCCL: if WITH_SYSTEM_NCCL: @@ -762,17 +731,6 @@ if not IS_WINDOWS: ) extensions.append(DL) -THNN = Extension("torch._thnn._THNN", - sources=['torch/csrc/nn/THNN.cpp'], - language='c++', - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, - extra_link_args=extra_link_args + [ - ATEN_LIB, - make_relative_rpath('../lib'), - ] - ) -extensions.append(THNN) if WITH_CUDA: thnvrtc_link_flags = extra_link_args + [make_relative_rpath('lib')] @@ -797,18 +755,6 @@ if WITH_CUDA: ) extensions.append(THNVRTC) - THCUNN = Extension("torch._thnn._THCUNN", - sources=['torch/csrc/nn/THCUNN.cpp'], - language='c++', - extra_compile_args=extra_compile_args, - include_dirs=include_dirs, - extra_link_args=extra_link_args + [ - ATEN_LIB, - make_relative_rpath('../lib'), - ] - ) - extensions.append(THCUNN) - version = '0.4.0a0' if os.getenv('PYTORCH_BUILD_VERSION'): assert os.getenv('PYTORCH_BUILD_NUMBER') is not None diff --git a/test/test_torch.py b/test/test_torch.py index e93a0999d8..8a85dceb7c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4817,7 +4817,7 @@ class TestTorch(TestCase): def test_print(self): for t in torch._tensor_classes: - if IS_WINDOWS and t in [torch.cuda.sparse.HalfTensor, torch.cuda.HalfTensor]: + if IS_WINDOWS and t == torch.cuda.HalfTensor: return # CUDA HalfTensor is not supported on Windows yet if t == torch.HalfTensor: continue # HalfTensor does not support fill diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 172ee489a0..ccc6acbd9f 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -94,7 +94,10 @@ ${py_methods} static PyMethodDef torch_functions[] = { {"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, + {"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ${py_method_defs} diff --git a/tools/cwrap/plugins/AutoGPU.py b/tools/cwrap/plugins/AutoGPU.py index c8198cee0c..da37508cce 100644 --- a/tools/cwrap/plugins/AutoGPU.py +++ b/tools/cwrap/plugins/AutoGPU.py @@ -7,24 +7,8 @@ class AutoGPU(CWrapPlugin): self.has_self = has_self self.condition = condition - DEFINES = """ -#ifdef THC_GENERIC_FILE -#define THCP_AUTO_GPU 1 -#else -#define THCP_AUTO_GPU 0 -#endif -""" - def process_pre_arg_assign(self, template, option): if not option.get('auto_gpu', True): return template - call = 'THCPAutoGPU __autogpu_guard = THCPAutoGPU(args{});'.format( - ', (PyObject*)self' if self.has_self else '') - - if self.condition is not None: - call = "#if {0}\n {1}\n#endif\n".format(self.condition, call) - + call = 'AutoGPU auto_gpu(get_device(args));' return [call] + template - - def process_full_file(self, code): - return self.DEFINES + code diff --git a/tools/cwrap/plugins/StandaloneExtension.py b/tools/cwrap/plugins/NNExtension.py index 7085066dd5..679621a43f 100644 --- a/tools/cwrap/plugins/StandaloneExtension.py +++ b/tools/cwrap/plugins/NNExtension.py @@ -7,14 +7,15 @@ MODULE_HEAD = """ #include <Python.h> #include <exception> -#include "THP_API.h" +#include "THP.h" +#include "torch/csrc/utils/auto_gpu.h" #include "torch/csrc/nn/type_checks.h" """ -with open(os.path.join(os.path.dirname(__file__), 'templates', 'module_tail.cpp'), 'r') as f: +with open(os.path.join(os.path.dirname(__file__), 'templates', 'nn_tail.cpp'), 'r') as f: MODULE_TAIL = Template(f.read()) -REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_VARARGS, NULL},\n') +REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_STATIC | METH_VARARGS, NULL},\n') MODULE_METHODS_TEMPLATE = Template(""" static PyMethodDef module_methods[] = { @@ -24,7 +25,7 @@ $METHODS """) -class StandaloneExtension(CWrapPlugin): +class NNExtension(CWrapPlugin): TYPE_UNPACK = { 'THFloatTensor*': Template('THNN_FloatTensor_Unpack($arg)'), diff --git a/tools/cwrap/plugins/__init__.py b/tools/cwrap/plugins/__init__.py index a810f2770e..74ce79c17b 100644 --- a/tools/cwrap/plugins/__init__.py +++ b/tools/cwrap/plugins/__init__.py @@ -420,7 +420,7 @@ class CWrapPlugin(object): return template -from .StandaloneExtension import StandaloneExtension +from .NNExtension import NNExtension from .NullableArguments import NullableArguments from .OptionalArguments import OptionalArguments from .ArgcountChecker import ArgcountChecker diff --git a/tools/cwrap/plugins/templates/module_tail.cpp b/tools/cwrap/plugins/templates/module_tail.cpp deleted file mode 100644 index 8c60cb1c7d..0000000000 --- a/tools/cwrap/plugins/templates/module_tail.cpp +++ /dev/null @@ -1,36 +0,0 @@ - -#if PY_MAJOR_VERSION != 2 -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - "$full_name", - NULL, - -1, - module_methods -}; -#endif - -#if PY_MAJOR_VERSION == 2 -PyMODINIT_FUNC init$short_name() -#else -PyMODINIT_FUNC PyInit_$short_name() -#endif -{ -#if PY_MAJOR_VERSION == 2 -#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;} -#else -#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL -#endif - PyObject *module; - -#if PY_MAJOR_VERSION == 2 - ASSERT_TRUE(module = Py_InitModule("$full_name", module_methods)); -#else - ASSERT_TRUE(module = PyModule_Create(&module_def)); -#endif - -#if PY_MAJOR_VERSION != 2 - return module; -#endif - -#undef ASSERT_TRUE -} diff --git a/tools/cwrap/plugins/templates/nn_tail.cpp b/tools/cwrap/plugins/templates/nn_tail.cpp new file mode 100644 index 0000000000..247dc1efe8 --- /dev/null +++ b/tools/cwrap/plugins/templates/nn_tail.cpp @@ -0,0 +1,21 @@ +namespace torch { namespace nn { + +static PyTypeObject thnn_type; + +void init_$short_name(PyObject* c_module) { + ((PyObject*)&thnn_type)->ob_refcnt = 1; + thnn_type.tp_flags = Py_TPFLAGS_DEFAULT; + thnn_type.tp_methods = module_methods; + thnn_type.tp_name = "torch._C.$short_name"; + if (PyType_Ready(&thnn_type) < 0) { + throw python_error(); + } + + PyObject* type_obj = (PyObject*)&thnn_type; + Py_INCREF(type_obj); + if (PyModule_AddObject(c_module, "$short_name", type_obj) < 0) { + throw python_error(); + } +} + +}} // namespace torch::nn diff --git a/tools/nnwrap/generate_wrappers.py b/tools/nnwrap/generate_wrappers.py index 887a1a180a..0d0547955f 100644 --- a/tools/nnwrap/generate_wrappers.py +++ b/tools/nnwrap/generate_wrappers.py @@ -2,7 +2,7 @@ import os import sys from string import Template, ascii_lowercase from ..cwrap import cwrap -from ..cwrap.plugins import StandaloneExtension, NullableArguments, AutoGPU +from ..cwrap.plugins import NNExtension, NullableArguments, AutoGPU from ..shared import import_module BASE_PATH = os.path.realpath(os.path.join(__file__, '..', '..', '..')) @@ -109,7 +109,7 @@ def wrap_nn(): with open('torch/csrc/nn/THNN.cwrap', 'w') as f: f.write(wrapper) cwrap('torch/csrc/nn/THNN.cwrap', plugins=[ - StandaloneExtension('torch._thnn._THNN'), + NNExtension('torch._C._THNN'), NullableArguments(), ]) @@ -124,7 +124,7 @@ def wrap_cunn(): with open('torch/csrc/nn/THCUNN.cwrap', 'w') as f: f.write(wrapper) cwrap('torch/csrc/nn/THCUNN.cwrap', plugins=[ - StandaloneExtension('torch._thnn._THCUNN'), + NNExtension('torch._C._THCUNN'), NullableArguments(), AutoGPU(has_self=False), ]) diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index d9e55dfee5..f2770f0e1c 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -1,8 +1,7 @@ import os import sys -import glob -source_files = set(['.py', '.cpp', '.h']) +source_files = {'.py', '.cpp', '.h'} # TODO: This is a little inaccurate, because it will also pick @@ -18,11 +17,12 @@ def all_generator_source(): inputs = [ - 'torch/csrc/generic/TensorMethods.cwrap', + 'torch/lib/THNN.h', + 'torch/lib/THCUNN.h', 'torch/lib/tmp_install/share/ATen/Declarations.yaml', 'tools/autograd/derivatives.yaml', 'tools/autograd/deprecated.yaml', -] + glob.glob('torch/csrc/generic/methods/*.cwrap') +] outputs = [ 'torch/csrc/autograd/generated/Functions.cpp', @@ -69,28 +69,14 @@ def generate_code(ninja_global=None): # cwrap depends on pyyaml, so we can't import it earlier root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, root) - from tools.cwrap import cwrap - from tools.cwrap.plugins.THPPlugin import THPPlugin - from tools.cwrap.plugins.ArgcountSortPlugin import ArgcountSortPlugin - from tools.cwrap.plugins.AutoGPU import AutoGPU - from tools.cwrap.plugins.BoolOption import BoolOption - from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin - from tools.cwrap.plugins.NullableArguments import NullableArguments - - from tools.cwrap.plugins.WrapDim import WrapDim - from tools.cwrap.plugins.AssertNDim import AssertNDim - - from tools.cwrap.plugins.Broadcast import Broadcast - from tools.cwrap.plugins.ProcessorSpecificPlugin import ProcessorSpecificPlugin from tools.autograd.gen_autograd import gen_autograd from tools.jit.gen_jit_dispatch import gen_jit_dispatch - thp_plugin = THPPlugin() + from tools.nnwrap import generate_wrappers as generate_nn_wrappers + + # Build THNN/THCUNN.cwrap and then THNN/THCUNN.cpp. These are primarily + # used by the legacy NN bindings. + generate_nn_wrappers() - cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[ - ProcessorSpecificPlugin(), BoolOption(), thp_plugin, - AutoGPU(condition='IS_CUDA'), ArgcountSortPlugin(), KwargsPlugin(), - AssertNDim(), WrapDim(), Broadcast() - ]) # Build ATen based Variable classes autograd_gen_dir = 'torch/csrc/autograd/generated' jit_gen_dir = 'torch/csrc/jit/generated' @@ -104,6 +90,7 @@ def generate_code(ninja_global=None): 'torch/lib/tmp_install/share/ATen/Declarations.yaml', jit_gen_dir) + # called from ninja if __name__ == "__main__": generate_code(None) diff --git a/torch/__init__.py b/torch/__init__.py index 9366b6f1ca..861177c04b 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -148,7 +148,6 @@ from ._tensor_str import set_printoptions ################################################################################ from .storage import _StorageBase -from .tensor import _TensorBase class DoubleStorage(_C.DoubleStorageBase, _StorageBase): @@ -183,87 +182,6 @@ class ByteStorage(_C.ByteStorageBase, _StorageBase): pass -class DoubleTensor(_C.DoubleTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return DoubleStorage - - -class FloatTensor(_C.FloatTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return FloatStorage - - -class HalfTensor(_C.HalfTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return HalfStorage - - -class LongTensor(_C.LongTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return LongStorage - - -class IntTensor(_C.IntTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return IntStorage - - -class ShortTensor(_C.ShortTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return ShortStorage - - -class CharTensor(_C.CharTensorBase, _TensorBase): - - def is_signed(self): - # TODO - return False - - @classmethod - def storage_type(cls): - return CharStorage - - -class ByteTensor(_C.ByteTensorBase, _TensorBase): - - def is_signed(self): - return False - - @classmethod - def storage_type(cls): - return ByteStorage - - _storage_classes = { DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, ByteStorage, HalfStorage @@ -272,10 +190,6 @@ _storage_classes = { # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() _tensor_classes = set() -_integer_tensor_classes = { - LongTensor, IntTensor, ShortTensor, CharTensor, ByteTensor -} - ################################################################################ # Import interface functions defined in Python @@ -302,8 +216,6 @@ def manager_path(): _C._initExtension(manager_path()) del manager_path -_C._initialize_tensor_type_bindings() - for name in dir(_C._VariableFunctions): globals()[name] = getattr(_C._VariableFunctions, name) @@ -319,21 +231,6 @@ del IntStorageBase del ShortStorageBase del CharStorageBase del ByteStorageBase -del DoubleTensorBase -del FloatTensorBase -del LongTensorBase -del IntTensorBase -del ShortTensorBase -del CharTensorBase -del ByteTensorBase - -del SparseDoubleTensorBase -del SparseFloatTensorBase -del SparseLongTensorBase -del SparseIntTensorBase -del SparseShortTensorBase -del SparseCharTensorBase -del SparseByteTensorBase ################################################################################ # Import most common subpackages diff --git a/torch/_thnn/__init__.py b/torch/_thnn/__init__.py index a2ffa80742..4f75f3b924 100644 --- a/torch/_thnn/__init__.py +++ b/torch/_thnn/__init__.py @@ -34,7 +34,8 @@ class Backend(object): if self.backend is None: with self.loading_lock: if self.backend is None: - self.backend = load_backend(self.lib_prefix, self.lib_name, + lib = getattr(torch._C, self.lib_name) + self.backend = load_backend(self.lib_prefix, lib, self.functions, self.mixins) return self.backend @@ -52,7 +53,7 @@ _thnn_headers = parse_header(THNN_H_PATH) _thcunn_headers = parse_header(THCUNN_H_PATH) for t in ['Float', 'Double']: - backend = Backend(t, 'torch._thnn._THNN', _thnn_headers) + backend = Backend(t, '_THNN', _thnn_headers) type2backend.backends['THNN{}Backend'.format(t)] = backend type2backend.backends['torch.{}Tensor'.format(t)] = backend @@ -60,7 +61,7 @@ for t in ['Float', 'Double']: for t in ['Half', '', 'Double']: - backend = Backend('Cuda' + t, 'torch._thnn._THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,)) + backend = Backend('Cuda' + t, '_THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,)) type2backend.backends['THNNCuda{}Backend'.format(t)] = backend py_name = 'Float' if t == '' else t type2backend.backends['torch.cuda.{}Tensor'.format(py_name)] = backend diff --git a/torch/_thnn/utils.py b/torch/_thnn/utils.py index 66d527a704..dc0dd4b4b6 100644 --- a/torch/_thnn/utils.py +++ b/torch/_thnn/utils.py @@ -102,11 +102,10 @@ def parse_header(path): def load_backend(t, lib, generic_functions, mixins=tuple()): - lib_handle = importlib.import_module(lib) backend_name = 'THNN{}Backend'.format(t) backend = type(backend_name, mixins + (THNNBackendBase,), {})() for function in generic_functions: full_fn_name = '{}{}'.format(t, function.name) - fn = getattr(lib_handle, full_fn_name) + fn = getattr(lib, full_fn_name) backend.register_method(function.name, fn) return backend diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index bb5bb8be16..ed4baf6a8a 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -113,15 +113,6 @@ THPDtype* getDtype(const at::Type& type) { throw std::invalid_argument("unsupported at::Type"); } -at::Tensor createTensor(PyObject *data) -{ - if (THPVariable_Check(data)) { - return ((THPVariable*)data)->cdata; - } - auto tensor_type = pytype_to_attype.at(Py_TYPE(data)); - auto tensor = ((THPVoidTensor *)data)->cdata; - return tensor_type->unsafeTensorFromTH(tensor, true); // Calls retain on underlying TH Tensor -} PyObject* createPyObject(const at::Tensor& tensor) { auto type = getPyTypeObject(tensor); diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 77be760f58..e5241e83fc 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -27,9 +27,6 @@ PyObject* createPyObject(const at::Storage& storage); PyTypeObject* getPyTypeObject(const at::Tensor& tensor); at::Type& getATenType(PyTypeObject* type); THPDtype* getDtype(const at::Type& type); -//rename to createPyObject when THPP is removed -// Creates a at::Tensor from a PyObject. Does NOT steal the PyObject reference. -at::Tensor createTensor(PyObject* data); std::unique_ptr<at::Storage> createStorage(PyObject* obj); bool isStorage(PyObject* obj); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index f5ec0903be..ab173c9ecf 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -34,52 +34,17 @@ #define WITH_NUMPY_IMPORT_ARRAY #include "THP.h" -#include "ModuleSparse.cpp" #include "DataLoader.cpp" namespace py = pybind11; PyObject* module; -PyObject* tensor_classes; -PyObject *THPDefaultTensorClass = NULL; THPGenerator *THPDefaultGenerator = NULL; //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -static bool THPModule_loadClasses(PyObject *self) -{ -#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; } - PyObject *torch_module = PyImport_ImportModule("torch"); - if (!torch_module) { - THPUtils_setError("class loader couldn't access torch module"); - return false; - } - - ASSERT_NOT_NULL(tensor_classes = PyObject_GetAttrString(torch_module, "_tensor_classes")); - if (!THPDoubleTensor_postInit(torch_module)) return false; - if (!THPFloatTensor_postInit(torch_module)) return false; - if (!THPHalfTensor_postInit(torch_module)) return false; - if (!THPLongTensor_postInit(torch_module)) return false; - if (!THPIntTensor_postInit(torch_module)) return false; - if (!THPShortTensor_postInit(torch_module)) return false; - if (!THPCharTensor_postInit(torch_module)) return false; - if (!THPByteTensor_postInit(torch_module)) return false; - - THPDoubleStorage_postInit(torch_module); - THPFloatStorage_postInit(torch_module); - THPHalfStorage_postInit(torch_module); - THPLongStorage_postInit(torch_module); - THPIntStorage_postInit(torch_module); - THPShortStorage_postInit(torch_module); - THPCharStorage_postInit(torch_module); - THPByteStorage_postInit(torch_module); - - return true; -#undef ASSERT_NOT_NULL -} - static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) { static std::vector<std::string> names; @@ -104,43 +69,31 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) } Py_RETURN_NONE; } - -static bool THPModule_assignStateless(PyObject *self) -{ -#define INIT_STATELESS(type) \ - stateless = PyObject_CallFunctionObjArgs((PyObject*)&TH_CONCAT_2(type, TensorStatelessType), NULL); \ - if (!stateless) { \ - return false; \ - } \ - if (PyObject_SetAttrString(TH_CONCAT_3(THP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \ - return false; \ - } - PyObject *stateless; - INIT_STATELESS(Double); - INIT_STATELESS(Float); - INIT_STATELESS(Half); - INIT_STATELESS(Long); - INIT_STATELESS(Int); - INIT_STATELESS(Short); - INIT_STATELESS(Char); - INIT_STATELESS(Byte); - return true; -#undef INIT_STATELESS -} // // Callback for python part. Used for additional initialization of python classes -static PyObject * THPModule_initExtension(PyObject *self, PyObject *shm_manager_path) +static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manager_path) { HANDLE_TH_ERRORS if (!THPUtils_checkString(shm_manager_path)) { THPUtils_setError("initialization error - expected bytes/string object as shm_manager_path!"); return NULL; } + torch::tensor::initialize_python_bindings(); std::string path = THPUtils_unpackString(shm_manager_path); libshm_init(path.c_str()); - if (!THPModule_loadClasses(self)) return NULL; - if (!THPModule_assignStateless(self)) return NULL; - if (!THPAutograd_initFunctions(self)) return NULL; + + auto module = THPObjectPtr(PyImport_ImportModule("torch")); + if (!module) throw python_error(); + + THPDoubleStorage_postInit(module); + THPFloatStorage_postInit(module); + THPHalfStorage_postInit(module); + THPLongStorage_postInit(module); + THPIntStorage_postInit(module); + THPShortStorage_postInit(module); + THPCharStorage_postInit(module); + THPByteStorage_postInit(module); + THPAutograd_initFunctions(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -158,19 +111,10 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) Py_RETURN_NONE; } -bool THPModule_isTensor(PyObject *obj) -{ - int result = PySet_Contains(tensor_classes, (PyObject*)Py_TYPE(obj)); - if (result == -1) - throw std::logic_error("FATAL: tensor_classes isn't a set!"); - return result; -} - PyObject * THPModule_setDefaultTensorType(PyObject *_unused, PyObject *type) { HANDLE_TH_ERRORS torch::tensor::py_set_default_tensor_type(type); - THPDefaultTensorClass = type; Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -182,219 +126,6 @@ PyObject * THPModule_fromNumpy(PyObject *_unused, PyObject *array) END_HANDLE_TH_ERRORS } -/** - * STATELESS FUNCTIONS - **/ - -static PyObject * findTensor(PyObject *args, PyObject *kwargs) { - for (Py_ssize_t i = 0; i < PyTuple_Size(args); i++) { - PyObject *item = PyTuple_GET_ITEM(args, i); - if (THPModule_isTensor(item) || THPVariable_Check(item)) { - return item; - } - } - if (kwargs) { - Py_ssize_t pos = 0; - PyObject *key, *value; - while (PyDict_Next(kwargs, &pos, &key, &value)) { - if (THPModule_isTensor(value) || THPVariable_Check(value)) { - return value; - } - } - } - return THPDefaultTensorClass; -} - -static PyObject * dispatchStateless(PyObject *args, PyObject *kwargs, const char *name) { - PyObject *tensor = findTensor(args, kwargs); - return THPUtils_dispatchStateless(tensor, name, args, kwargs); -} - -#define IMPLEMENT_STATELESS(name) \ -static PyObject * TH_CONCAT_2(THPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \ -{ \ - return dispatchStateless(args, kwargs, #name); \ -} - -IMPLEMENT_STATELESS(sigmoid) -IMPLEMENT_STATELESS(log) -IMPLEMENT_STATELESS(log1p) -IMPLEMENT_STATELESS(lgamma) -IMPLEMENT_STATELESS(digamma) -IMPLEMENT_STATELESS(polygamma) -IMPLEMENT_STATELESS(erf) -IMPLEMENT_STATELESS(erfinv) -IMPLEMENT_STATELESS(exp) -IMPLEMENT_STATELESS(expm1) -IMPLEMENT_STATELESS(cos) -IMPLEMENT_STATELESS(acos) -IMPLEMENT_STATELESS(cosh) -IMPLEMENT_STATELESS(sin) -IMPLEMENT_STATELESS(asin) -IMPLEMENT_STATELESS(sinh) -IMPLEMENT_STATELESS(tan) -IMPLEMENT_STATELESS(atan) -IMPLEMENT_STATELESS(tanh) -IMPLEMENT_STATELESS(sqrt) -IMPLEMENT_STATELESS(rsqrt) -IMPLEMENT_STATELESS(ceil) -IMPLEMENT_STATELESS(floor) -IMPLEMENT_STATELESS(round) -IMPLEMENT_STATELESS(abs) -IMPLEMENT_STATELESS(trunc) -IMPLEMENT_STATELESS(frac) -IMPLEMENT_STATELESS(mean) -IMPLEMENT_STATELESS(std) -IMPLEMENT_STATELESS(var) -IMPLEMENT_STATELESS(norm) -IMPLEMENT_STATELESS(reciprocal) -IMPLEMENT_STATELESS(neg) -IMPLEMENT_STATELESS(add) -IMPLEMENT_STATELESS(mul) -IMPLEMENT_STATELESS(div) -IMPLEMENT_STATELESS(fmod) -IMPLEMENT_STATELESS(min) -IMPLEMENT_STATELESS(max) -IMPLEMENT_STATELESS(dot) -IMPLEMENT_STATELESS(sum) -IMPLEMENT_STATELESS(prod) -IMPLEMENT_STATELESS(remainder) -IMPLEMENT_STATELESS(cumsum) -IMPLEMENT_STATELESS(cumprod) -IMPLEMENT_STATELESS(clamp) -IMPLEMENT_STATELESS(equal) -IMPLEMENT_STATELESS(eye) -IMPLEMENT_STATELESS(diag) -IMPLEMENT_STATELESS(numel) -IMPLEMENT_STATELESS(sign) -IMPLEMENT_STATELESS(trace) -IMPLEMENT_STATELESS(tril) -IMPLEMENT_STATELESS(triu) -IMPLEMENT_STATELESS(zero) -IMPLEMENT_STATELESS(kthvalue) -IMPLEMENT_STATELESS(mode) -IMPLEMENT_STATELESS(median) -IMPLEMENT_STATELESS(cross) -IMPLEMENT_STATELESS(sort) -IMPLEMENT_STATELESS(topk) -IMPLEMENT_STATELESS(t) -IMPLEMENT_STATELESS(transpose) -IMPLEMENT_STATELESS(squeeze) -IMPLEMENT_STATELESS(unsqueeze) -IMPLEMENT_STATELESS(renorm) -IMPLEMENT_STATELESS(dist) -IMPLEMENT_STATELESS(linspace) -IMPLEMENT_STATELESS(logspace) -IMPLEMENT_STATELESS(histc) -IMPLEMENT_STATELESS(atan2) -IMPLEMENT_STATELESS(pow) -IMPLEMENT_STATELESS(lerp) -IMPLEMENT_STATELESS(zeros) -IMPLEMENT_STATELESS(zeros_like) -IMPLEMENT_STATELESS(ones) -IMPLEMENT_STATELESS(ones_like) -IMPLEMENT_STATELESS(index_select) -IMPLEMENT_STATELESS(take) -IMPLEMENT_STATELESS(ger) -IMPLEMENT_STATELESS(mv) -IMPLEMENT_STATELESS(mm) -IMPLEMENT_STATELESS(bmm) -// TODO: this doesn't implement options that return numbers! -IMPLEMENT_STATELESS(multinomial) -IMPLEMENT_STATELESS(normal) -IMPLEMENT_STATELESS(_standard_gamma) -IMPLEMENT_STATELESS(_dirichlet_grad) -IMPLEMENT_STATELESS(bernoulli) -IMPLEMENT_STATELESS(range) -IMPLEMENT_STATELESS(arange) -IMPLEMENT_STATELESS(gather) -IMPLEMENT_STATELESS(rand) -IMPLEMENT_STATELESS(randn) -IMPLEMENT_STATELESS(masked_select) -IMPLEMENT_STATELESS(gesv) -IMPLEMENT_STATELESS(gels) -IMPLEMENT_STATELESS(trtrs) -IMPLEMENT_STATELESS(symeig) -IMPLEMENT_STATELESS(eig) -IMPLEMENT_STATELESS(svd) -IMPLEMENT_STATELESS(inverse) -IMPLEMENT_STATELESS(potrf) -IMPLEMENT_STATELESS(potrs) -IMPLEMENT_STATELESS(potri) -IMPLEMENT_STATELESS(pstrf) -IMPLEMENT_STATELESS(qr) -IMPLEMENT_STATELESS(geqrf) -IMPLEMENT_STATELESS(orgqr) -IMPLEMENT_STATELESS(ormqr) -IMPLEMENT_STATELESS(btrifact) -IMPLEMENT_STATELESS(btrifact_with_info) -IMPLEMENT_STATELESS(btrisolve) -IMPLEMENT_STATELESS(gt) -IMPLEMENT_STATELESS(lt) -IMPLEMENT_STATELESS(ge) -IMPLEMENT_STATELESS(le) -IMPLEMENT_STATELESS(eq) -IMPLEMENT_STATELESS(ne) - -IMPLEMENT_STATELESS(addmm) -IMPLEMENT_STATELESS(addmv) -IMPLEMENT_STATELESS(addr) -IMPLEMENT_STATELESS(addbmm) -IMPLEMENT_STATELESS(baddbmm) -IMPLEMENT_STATELESS(addcmul) -IMPLEMENT_STATELESS(addcdiv) - -#undef IMPLEMENT_STATELESS - -// In nonzero, the first argument might be a LongTensor that will be used -// for indices output, so we should pick a function based on second -// tensor's type. -static PyObject * THPModule_nonzero(PyObject *_unused, PyObject *args, PyObject *kwargs) -{ - PyObject *tensor = THPDefaultTensorClass; - if (PyTuple_Size(args) == 1) - tensor = PyTuple_GET_ITEM(args, 0); - else if (PyTuple_Size(args) == 2) - tensor = PyTuple_GET_ITEM(args, 1); - return THPUtils_dispatchStateless(tensor, "nonzero", args, kwargs); -} - -static PyObject * THPModule_randperm(PyObject *_unused, PyObject *args, PyObject *kwargs) -{ - PyObject *tensor = THPLongTensorClass; - PyObject *out; - if (kwargs && (out = PyDict_GetItemString(kwargs, "out"))) - tensor = out; - return THPUtils_dispatchStateless(tensor, "randperm", args, kwargs); -} - -static PyObject * THPModule_cat(PyObject *_unused, PyObject *args, PyObject *kwargs) -{ - PyObject *tensor = THPDefaultTensorClass; - THPObjectPtr iterator; - THPObjectPtr item; - PyObject *first_arg=nullptr; - if (args && PyTuple_GET_SIZE(args) > 0) { - first_arg = PyTuple_GET_ITEM(args, 0); - } else if (kwargs && PyTuple_GET_ITEM(args, 0)) { - first_arg = PyDict_GetItemString(kwargs, "seq"); - } - - if (first_arg) { - if (THPModule_isTensor(first_arg)) { - tensor = first_arg; - } else if (PySequence_Check(first_arg)) { - item = PySequence_GetItem(first_arg, 0); - if (item && (THPModule_isTensor(item) || THPVariable_Check(item))) { - tensor = item; - } - } - PyErr_Clear(); - } - - return THPUtils_dispatchStateless(tensor, "cat", args, kwargs); -} - PyObject *THPModule_safeCall(PyObject *_unused, PyObject *args, PyObject *kwargs) { PyObject *result = NULL; @@ -600,30 +331,13 @@ PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) { Py_RETURN_TRUE; } -static PyObject* THPModule_initializeTensorTypeBindings(PyObject *_unused) -{ - HANDLE_TH_ERRORS - torch::tensor::initialize_python_bindings(nullptr); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -#ifdef WITH_CUDA -extern PyObject * THCSPModule_initExtension(PyObject *self); -#endif - static PyMethodDef TorchMethods[] = { {"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, NULL}, {"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, NULL}, {"_add_docstr", (PyCFunction)THPModule_addDocStr, METH_VARARGS, NULL}, - {"_sparse_init", (PyCFunction)THSPModule_initExtension, METH_NOARGS, NULL}, {"_init_names", (PyCFunction)THPModule_initNames, METH_O, NULL}, {"_has_distributed",(PyCFunction)THPModule_hasDistributed, METH_NOARGS, NULL}, {"_initialize_dtypes",(PyCFunction)THPModule_initializeDtypes, METH_NOARGS, NULL}, - {"_initialize_tensor_type_bindings", (PyCFunction)THPModule_initializeTensorTypeBindings, METH_NOARGS, NULL}, -#ifdef WITH_CUDA - {"_cuda_sparse_init", (PyCFunction)THCSPModule_initExtension, METH_NOARGS, NULL}, -#endif {"_safe_call", (PyCFunction)THPModule_safeCall, METH_VARARGS | METH_KEYWORDS, NULL}, {"_set_default_tensor_type", (PyCFunction)THPModule_setDefaultTensorType, METH_O, NULL}, {"_infer_size", (PyCFunction)THPModule_inferSize, METH_VARARGS, NULL}, @@ -643,141 +357,11 @@ static PyMethodDef TorchMethods[] = { {"_to_dlpack", (PyCFunction)THPModule_toDLPack, METH_O, NULL}, {"_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, NULL}, {"set_flush_denormal", (PyCFunction)THPModule_setFlushDenormal, METH_O, NULL}, - - {"sigmoid", (PyCFunction)THPModule_sigmoid, METH_VARARGS | METH_KEYWORDS, NULL}, - {"log", (PyCFunction)THPModule_log, METH_VARARGS | METH_KEYWORDS, NULL}, - {"log1p", (PyCFunction)THPModule_log1p, METH_VARARGS | METH_KEYWORDS, NULL}, - {"lgamma", (PyCFunction)THPModule_lgamma, METH_VARARGS | METH_KEYWORDS, NULL}, - {"digamma", (PyCFunction)THPModule_digamma, METH_VARARGS | METH_KEYWORDS, NULL}, - {"polygamma", (PyCFunction)THPModule_polygamma, METH_VARARGS | METH_KEYWORDS, NULL}, - {"erf", (PyCFunction)THPModule_erf, METH_VARARGS | METH_KEYWORDS, NULL}, - {"erfinv", (PyCFunction)THPModule_erfinv, METH_VARARGS | METH_KEYWORDS, NULL}, - {"exp", (PyCFunction)THPModule_exp, METH_VARARGS | METH_KEYWORDS, NULL}, - {"expm1", (PyCFunction)THPModule_expm1, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cos", (PyCFunction)THPModule_cos, METH_VARARGS | METH_KEYWORDS, NULL}, - {"acos", (PyCFunction)THPModule_acos, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cosh", (PyCFunction)THPModule_cosh, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sin", (PyCFunction)THPModule_sin, METH_VARARGS | METH_KEYWORDS, NULL}, - {"asin", (PyCFunction)THPModule_asin, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sinh", (PyCFunction)THPModule_sinh, METH_VARARGS | METH_KEYWORDS, NULL}, - {"tan", (PyCFunction)THPModule_tan, METH_VARARGS | METH_KEYWORDS, NULL}, - {"atan", (PyCFunction)THPModule_atan, METH_VARARGS | METH_KEYWORDS, NULL}, - {"tanh", (PyCFunction)THPModule_tanh, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sqrt", (PyCFunction)THPModule_sqrt, METH_VARARGS | METH_KEYWORDS, NULL}, - {"rsqrt", (PyCFunction)THPModule_rsqrt, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ceil", (PyCFunction)THPModule_ceil, METH_VARARGS | METH_KEYWORDS, NULL}, - {"floor", (PyCFunction)THPModule_floor, METH_VARARGS | METH_KEYWORDS, NULL}, - {"round", (PyCFunction)THPModule_round, METH_VARARGS | METH_KEYWORDS, NULL}, - {"abs", (PyCFunction)THPModule_abs, METH_VARARGS | METH_KEYWORDS, NULL}, - {"trunc", (PyCFunction)THPModule_trunc, METH_VARARGS | METH_KEYWORDS, NULL}, - {"frac", (PyCFunction)THPModule_frac, METH_VARARGS | METH_KEYWORDS, NULL}, - {"mean", (PyCFunction)THPModule_mean, METH_VARARGS | METH_KEYWORDS, NULL}, - {"std", (PyCFunction)THPModule_std, METH_VARARGS | METH_KEYWORDS, NULL}, - {"var", (PyCFunction)THPModule_var, METH_VARARGS | METH_KEYWORDS, NULL}, - {"norm", (PyCFunction)THPModule_norm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"reciprocal", (PyCFunction)THPModule_reciprocal, METH_VARARGS | METH_KEYWORDS, NULL}, - {"neg", (PyCFunction)THPModule_neg, METH_VARARGS | METH_KEYWORDS, NULL}, - {"add", (PyCFunction)THPModule_add, METH_VARARGS | METH_KEYWORDS, NULL}, - {"mul", (PyCFunction)THPModule_mul, METH_VARARGS | METH_KEYWORDS, NULL}, - {"div", (PyCFunction)THPModule_div, METH_VARARGS | METH_KEYWORDS, NULL}, - {"fmod", (PyCFunction)THPModule_fmod, METH_VARARGS | METH_KEYWORDS, NULL}, - {"min", (PyCFunction)THPModule_min, METH_VARARGS | METH_KEYWORDS, NULL}, - {"max", (PyCFunction)THPModule_max, METH_VARARGS | METH_KEYWORDS, NULL}, - {"dot", (PyCFunction)THPModule_dot, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sum", (PyCFunction)THPModule_sum, METH_VARARGS | METH_KEYWORDS, NULL}, - {"prod", (PyCFunction)THPModule_prod, METH_VARARGS | METH_KEYWORDS, NULL}, - {"remainder", (PyCFunction)THPModule_remainder, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cumsum", (PyCFunction)THPModule_cumsum, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cumprod", (PyCFunction)THPModule_cumprod, METH_VARARGS | METH_KEYWORDS, NULL}, - {"clamp", (PyCFunction)THPModule_clamp, METH_VARARGS | METH_KEYWORDS, NULL}, - {"equal", (PyCFunction)THPModule_equal, METH_VARARGS | METH_KEYWORDS, NULL}, - {"eye", (PyCFunction)THPModule_eye, METH_VARARGS | METH_KEYWORDS, NULL}, - {"diag", (PyCFunction)THPModule_diag, METH_VARARGS | METH_KEYWORDS, NULL}, - {"numel", (PyCFunction)THPModule_numel, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sign", (PyCFunction)THPModule_sign, METH_VARARGS | METH_KEYWORDS, NULL}, - {"trace", (PyCFunction)THPModule_trace, METH_VARARGS | METH_KEYWORDS, NULL}, - {"tril", (PyCFunction)THPModule_tril, METH_VARARGS | METH_KEYWORDS, NULL}, - {"triu", (PyCFunction)THPModule_triu, METH_VARARGS | METH_KEYWORDS, NULL}, - {"zero", (PyCFunction)THPModule_zero, METH_VARARGS | METH_KEYWORDS, NULL}, - {"gt", (PyCFunction)THPModule_gt, METH_VARARGS | METH_KEYWORDS, NULL}, - {"lt", (PyCFunction)THPModule_lt, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ge", (PyCFunction)THPModule_ge, METH_VARARGS | METH_KEYWORDS, NULL}, - {"le", (PyCFunction)THPModule_le, METH_VARARGS | METH_KEYWORDS, NULL}, - {"eq", (PyCFunction)THPModule_eq, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ne", (PyCFunction)THPModule_ne, METH_VARARGS | METH_KEYWORDS, NULL}, - {"kthvalue", (PyCFunction)THPModule_kthvalue, METH_VARARGS | METH_KEYWORDS, NULL}, - {"mode", (PyCFunction)THPModule_mode, METH_VARARGS | METH_KEYWORDS, NULL}, - {"median", (PyCFunction)THPModule_median, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cross", (PyCFunction)THPModule_cross, METH_VARARGS | METH_KEYWORDS, NULL}, - {"sort", (PyCFunction)THPModule_sort, METH_VARARGS | METH_KEYWORDS, NULL}, - {"topk", (PyCFunction)THPModule_topk, METH_VARARGS | METH_KEYWORDS, NULL}, - {"t", (PyCFunction)THPModule_t, METH_VARARGS | METH_KEYWORDS, NULL}, - {"transpose", (PyCFunction)THPModule_transpose, METH_VARARGS | METH_KEYWORDS, NULL}, - {"squeeze", (PyCFunction)THPModule_squeeze, METH_VARARGS | METH_KEYWORDS, NULL}, - {"unsqueeze", (PyCFunction)THPModule_unsqueeze, METH_VARARGS | METH_KEYWORDS, NULL}, - {"nonzero", (PyCFunction)THPModule_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, - {"renorm", (PyCFunction)THPModule_renorm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"dist", (PyCFunction)THPModule_dist, METH_VARARGS | METH_KEYWORDS, NULL}, - {"linspace", (PyCFunction)THPModule_linspace, METH_VARARGS | METH_KEYWORDS, NULL}, - {"logspace", (PyCFunction)THPModule_logspace, METH_VARARGS | METH_KEYWORDS, NULL}, - {"histc", (PyCFunction)THPModule_histc, METH_VARARGS | METH_KEYWORDS, NULL}, - {"atan2", (PyCFunction)THPModule_atan2, METH_VARARGS | METH_KEYWORDS, NULL}, - {"pow", (PyCFunction)THPModule_pow, METH_VARARGS | METH_KEYWORDS, NULL}, - {"lerp", (PyCFunction)THPModule_lerp, METH_VARARGS | METH_KEYWORDS, NULL}, - {"zeros", (PyCFunction)THPModule_zeros, METH_VARARGS | METH_KEYWORDS, NULL}, - {"zeros_like", (PyCFunction)THPModule_zeros_like, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ones", (PyCFunction)THPModule_ones, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ones_like", (PyCFunction)THPModule_ones_like, METH_VARARGS | METH_KEYWORDS, NULL}, - {"index_select", (PyCFunction)THPModule_index_select, METH_VARARGS | METH_KEYWORDS, NULL}, - {"take", (PyCFunction)THPModule_take, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addmm", (PyCFunction)THPModule_addmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addmv", (PyCFunction)THPModule_addmv, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addr", (PyCFunction)THPModule_addr, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ger", (PyCFunction)THPModule_ger, METH_VARARGS | METH_KEYWORDS, NULL}, - {"mv", (PyCFunction)THPModule_mv, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addbmm", (PyCFunction)THPModule_addbmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"baddbmm", (PyCFunction)THPModule_baddbmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addcmul", (PyCFunction)THPModule_addcmul, METH_VARARGS | METH_KEYWORDS, NULL}, - {"addcdiv", (PyCFunction)THPModule_addcdiv, METH_VARARGS | METH_KEYWORDS, NULL}, - {"mm", (PyCFunction)THPModule_mm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"bmm", (PyCFunction)THPModule_bmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"multinomial", (PyCFunction)THPModule_multinomial, METH_VARARGS | METH_KEYWORDS, NULL}, - {"normal", (PyCFunction)THPModule_normal, METH_VARARGS | METH_KEYWORDS, NULL}, - {"_standard_gamma", (PyCFunction)THPModule__standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL}, - {"_dirichlet_grad", (PyCFunction)THPModule__dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL}, - {"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL}, - {"rand", (PyCFunction)THPModule_rand, METH_VARARGS | METH_KEYWORDS, NULL}, - {"randn", (PyCFunction)THPModule_randn, METH_VARARGS | METH_KEYWORDS, NULL}, - {"randperm", (PyCFunction)THPModule_randperm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"range", (PyCFunction)THPModule_range, METH_VARARGS | METH_KEYWORDS, NULL}, - {"arange", (PyCFunction)THPModule_arange, METH_VARARGS | METH_KEYWORDS, NULL}, - {"gather", (PyCFunction)THPModule_gather, METH_VARARGS | METH_KEYWORDS, NULL}, - {"cat", (PyCFunction)THPModule_cat, METH_VARARGS | METH_KEYWORDS, NULL}, - {"masked_select", (PyCFunction)THPModule_masked_select, METH_VARARGS | METH_KEYWORDS, NULL}, - {"gesv", (PyCFunction)THPModule_gesv, METH_VARARGS | METH_KEYWORDS, NULL}, - {"gels", (PyCFunction)THPModule_gels, METH_VARARGS | METH_KEYWORDS, NULL}, - {"trtrs", (PyCFunction)THPModule_trtrs, METH_VARARGS | METH_KEYWORDS, NULL}, - {"symeig", (PyCFunction)THPModule_symeig, METH_VARARGS | METH_KEYWORDS, NULL}, - {"eig", (PyCFunction)THPModule_eig, METH_VARARGS | METH_KEYWORDS, NULL}, - {"svd", (PyCFunction)THPModule_svd, METH_VARARGS | METH_KEYWORDS, NULL}, - {"inverse", (PyCFunction)THPModule_inverse, METH_VARARGS | METH_KEYWORDS, NULL}, - {"potrf", (PyCFunction)THPModule_potrf, METH_VARARGS | METH_KEYWORDS, NULL}, - {"potrs", (PyCFunction)THPModule_potrs, METH_VARARGS | METH_KEYWORDS, NULL}, - {"potri", (PyCFunction)THPModule_potri, METH_VARARGS | METH_KEYWORDS, NULL}, - {"pstrf", (PyCFunction)THPModule_pstrf, METH_VARARGS | METH_KEYWORDS, NULL}, - {"qr", (PyCFunction)THPModule_qr, METH_VARARGS | METH_KEYWORDS, NULL}, - {"geqrf", (PyCFunction)THPModule_geqrf, METH_VARARGS | METH_KEYWORDS, NULL}, - {"orgqr", (PyCFunction)THPModule_orgqr, METH_VARARGS | METH_KEYWORDS, NULL}, - {"ormqr", (PyCFunction)THPModule_ormqr, METH_VARARGS | METH_KEYWORDS, NULL}, - {"btrifact", (PyCFunction)THPModule_btrifact, METH_VARARGS | METH_KEYWORDS, NULL}, - {"btrifact_with_info", (PyCFunction)THPModule_btrifact_with_info, METH_VARARGS | METH_KEYWORDS, NULL}, - {"btrisolve", (PyCFunction)THPModule_btrisolve, METH_VARARGS | METH_KEYWORDS, NULL}, - // Sparse functions - {"smm", (PyCFunction)THSPModule_sspmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"saddmm", (PyCFunction)THSPModule_sspaddmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"dsmm", (PyCFunction)THSPModule_spmm, METH_VARARGS | METH_KEYWORDS, NULL}, - {"hsmm", (PyCFunction)THSPModule_hspmm, METH_VARARGS | METH_KEYWORDS, NULL}, + // {"smm", (PyCFunction)THSPModule_sspmm, METH_VARARGS | METH_KEYWORDS, NULL}, + // {"saddmm", (PyCFunction)THSPModule_sspaddmm, METH_VARARGS | METH_KEYWORDS, NULL}, + // {"dsmm", (PyCFunction)THSPModule_spmm, METH_VARARGS | METH_KEYWORDS, NULL}, + // {"hsmm", (PyCFunction)THSPModule_hspmm, METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} }; @@ -790,15 +374,6 @@ bool THCPShortStorage_init(PyObject *module); bool THCPCharStorage_init(PyObject *module); bool THCPByteStorage_init(PyObject *module); -bool THCPDoubleTensor_init(PyObject *module); -bool THCPFloatTensor_init(PyObject *module); -bool THCPHalfTensor_init(PyObject *module); -bool THCPLongTensor_init(PyObject *module); -bool THCPIntTensor_init(PyObject *module); -bool THCPShortTensor_init(PyObject *module); -bool THCPCharTensor_init(PyObject *module); -bool THCPByteTensor_init(PyObject *module); - bool THCPStream_init(PyObject *module); #ifdef WITH_CUDA @@ -810,14 +385,14 @@ void initModule(PyObject *module); }} // namespace torch::cuda #endif -bool THCSPDoubleTensor_init(PyObject *module); -bool THCSPFloatTensor_init(PyObject *module); -bool THCSPHalfTensor_init(PyObject *module); -bool THCSPLongTensor_init(PyObject *module); -bool THCSPIntTensor_init(PyObject *module); -bool THCSPShortTensor_init(PyObject *module); -bool THCSPCharTensor_init(PyObject *module); -bool THCSPByteTensor_init(PyObject *module); +namespace torch { namespace nn { + +void init__THNN(PyObject*); +#ifdef WITH_CUDA +void init__THCUNN(PyObject*); +#endif + +}} // namespace torch::nn bool THDPDoubleStorage_init(PyObject *module); bool THDPFloatStorage_init(PyObject *module); @@ -828,15 +403,6 @@ bool THDPShortStorage_init(PyObject *module); bool THDPCharStorage_init(PyObject *module); bool THDPByteStorage_init(PyObject *module); -bool THDPDoubleTensor_init(PyObject *module); -bool THDPFloatTensor_init(PyObject *module); -//bool THDPHalfTensor_init(PyObject *module); -bool THDPLongTensor_init(PyObject *module); -bool THDPIntTensor_init(PyObject *module); -bool THDPShortTensor_init(PyObject *module); -bool THDPCharTensor_init(PyObject *module); -bool THDPByteTensor_init(PyObject *module); - static std::vector<PyMethodDef> methods; #ifdef WITH_DISTRIBUTED @@ -914,23 +480,6 @@ static PyObject* initModule() { ASSERT_TRUE(THPCharStorage_init(module)); ASSERT_TRUE(THPByteStorage_init(module)); - ASSERT_TRUE(THPDoubleTensor_init(module)); - ASSERT_TRUE(THPFloatTensor_init(module)); - ASSERT_TRUE(THPHalfTensor_init(module)); - ASSERT_TRUE(THPLongTensor_init(module)); - ASSERT_TRUE(THPIntTensor_init(module)); - ASSERT_TRUE(THPShortTensor_init(module)); - ASSERT_TRUE(THPCharTensor_init(module)); - ASSERT_TRUE(THPByteTensor_init(module)); - - ASSERT_TRUE(THSPDoubleTensor_init(module)); - ASSERT_TRUE(THSPFloatTensor_init(module)); - ASSERT_TRUE(THSPLongTensor_init(module)); - ASSERT_TRUE(THSPIntTensor_init(module)); - ASSERT_TRUE(THSPShortTensor_init(module)); - ASSERT_TRUE(THSPCharTensor_init(module)); - ASSERT_TRUE(THSPByteTensor_init(module)); - #ifdef WITH_CUDA // This will only initialise base classes and attach them to library namespace // They won't be ready for real usage until importing cuda module, that will @@ -945,25 +494,7 @@ static PyObject* initModule() { ASSERT_TRUE(THCPCharStorage_init(module)); ASSERT_TRUE(THCPByteStorage_init(module)); - ASSERT_TRUE(THCPDoubleTensor_init(module)); - ASSERT_TRUE(THCPFloatTensor_init(module)); - ASSERT_TRUE(THCPHalfTensor_init(module)); - ASSERT_TRUE(THCPLongTensor_init(module)); - ASSERT_TRUE(THCPIntTensor_init(module)); - ASSERT_TRUE(THCPShortTensor_init(module)); - ASSERT_TRUE(THCPCharTensor_init(module)); - ASSERT_TRUE(THCPByteTensor_init(module)); - ASSERT_TRUE(THCPStream_init(module)); - - ASSERT_TRUE(THCSPDoubleTensor_init(module)); - ASSERT_TRUE(THCSPFloatTensor_init(module)); - ASSERT_TRUE(THCSPHalfTensor_init(module)); - ASSERT_TRUE(THCSPLongTensor_init(module)); - ASSERT_TRUE(THCSPIntTensor_init(module)); - ASSERT_TRUE(THCSPShortTensor_init(module)); - ASSERT_TRUE(THCSPCharTensor_init(module)); - ASSERT_TRUE(THCSPByteTensor_init(module)); #endif #ifdef WITH_CUDNN @@ -984,15 +515,6 @@ static PyObject* initModule() { ASSERT_TRUE(THDPShortStorage_init(module)); ASSERT_TRUE(THDPCharStorage_init(module)); ASSERT_TRUE(THDPByteStorage_init(module)); - - ASSERT_TRUE(THDPDoubleTensor_init(module)); - ASSERT_TRUE(THDPFloatTensor_init(module)); - //ASSERT_TRUE(THDPHalfTensor_init(module)); - ASSERT_TRUE(THDPLongTensor_init(module)); - ASSERT_TRUE(THDPIntTensor_init(module)); - ASSERT_TRUE(THDPShortTensor_init(module)); - ASSERT_TRUE(THDPCharTensor_init(module)); - ASSERT_TRUE(THDPByteTensor_init(module)); #endif // force ATen to initialize because it handles @@ -1008,6 +530,11 @@ static PyObject* initModule() { if (_import_array() < 0) return NULL; #endif + torch::nn::init__THNN(module); +#ifdef WITH_CUDA + torch::nn::init__THCUNN(module); +#endif + return module; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/Module.h b/torch/csrc/Module.h index ce690fc3ff..e206d10e3e 100644 --- a/torch/csrc/Module.h +++ b/torch/csrc/Module.h @@ -3,11 +3,6 @@ #define THP_STATELESS_ATTRIBUTE_NAME "_torch" -extern PyObject *THPDefaultTensorClass; extern THPGenerator *THPDefaultGenerator; -#ifdef _THP_CORE -bool THPModule_isTensor(PyObject *obj); -#endif - #endif diff --git a/torch/csrc/ModuleSparse.cpp b/torch/csrc/ModuleSparse.cpp deleted file mode 100644 index 49f53980c1..0000000000 --- a/torch/csrc/ModuleSparse.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "THP.h" - -PyObject* sparse_tensor_classes; - -//////////////////////////////////////////////////////////////////////////////// -// SPARSE MODULE INITIALIZATION -//////////////////////////////////////////////////////////////////////////////// - -static bool THSPModule_loadClasses(PyObject *sparse_module) -{ - if (!THSPDoubleTensor_postInit(sparse_module)) return false; - if (!THSPFloatTensor_postInit(sparse_module)) return false; - if (!THSPLongTensor_postInit(sparse_module)) return false; - if (!THSPIntTensor_postInit(sparse_module)) return false; - if (!THSPShortTensor_postInit(sparse_module)) return false; - if (!THSPCharTensor_postInit(sparse_module)) return false; - if (!THSPByteTensor_postInit(sparse_module)) return false; - return true; -} - -static bool THSPModule_assignStateless() -{ -#define INIT_STATELESS(type) \ - stateless = PyObject_Call((PyObject*)&TH_CONCAT_3(Sparse, type, TensorStatelessType), arg, NULL); \ - if (!stateless) { \ - THPUtils_setError("stateless method initialization error"); \ - return false; \ - } \ - if (PyObject_SetAttrString(TH_CONCAT_3(THSP,type,TensorClass), THP_STATELESS_ATTRIBUTE_NAME, stateless) == -1) { \ - THPUtils_setError("stateless method initialization error (on assignment)");\ - } - PyObject *arg = PyTuple_New(0); - PyObject *stateless; - INIT_STATELESS(Double); - INIT_STATELESS(Float); - INIT_STATELESS(Long); - INIT_STATELESS(Int); - INIT_STATELESS(Short); - INIT_STATELESS(Char); - INIT_STATELESS(Byte); - Py_DECREF(arg); - return true; -#undef INIT_STATELESS -} - -// Callback for python part. Used for additional initialization of python classes -PyObject *THSPModule_initExtension(PyObject *self) -{ - PyObject *module = PyImport_ImportModule("torch.sparse"); - if (!module) return NULL; - if (!THSPModule_loadClasses(module)) return NULL; - if (!THSPModule_assignStateless()) return NULL; - Py_RETURN_NONE; -} - -//////////////////////////////////////////////////////////////////////////////// -// Sparse Stateless Functions -//////////////////////////////////////////////////////////////////////////////// - -bool THPModule_isSparseTensor(PyObject *obj) -{ - int result = PySet_Contains(sparse_tensor_classes, (PyObject*)Py_TYPE(obj)); - if (result == -1) - throw std::logic_error("FATAL: sparse_tensor_classes isn't a set!"); - return result; -} - - -#define IMPLEMENT_SPARSE_STATELESS(name) \ -static PyObject * TH_CONCAT_2(THSPModule_, name)(PyObject *_unused, PyObject *args, PyObject *kwargs) \ -{ \ - PyObject *tensor = THSPFloatTensorClass; \ - PyObject *key, *value; \ - Py_ssize_t pos = 0; \ - for (int i = 0; i < PyTuple_Size(args); i++) { \ - PyObject *item = PyTuple_GET_ITEM(args, i); \ - if (THPModule_isTensor(item) || THPVariable_Check(item)) { \ - tensor = item; \ - goto dispatch; \ - } \ - } \ - if (kwargs) { \ - while (PyDict_Next(kwargs, &pos, &key, &value)) { \ - if (THPModule_isTensor(value) || THPVariable_Check(value)) { \ - tensor = value; \ - goto dispatch; \ - } \ - } \ - } \ - \ -dispatch: \ - return THPUtils_dispatchStateless(tensor, #name, args, kwargs); \ -} - -IMPLEMENT_SPARSE_STATELESS(spmm); -IMPLEMENT_SPARSE_STATELESS(sspmm); -IMPLEMENT_SPARSE_STATELESS(sspaddmm); -IMPLEMENT_SPARSE_STATELESS(hspmm); - -#undef IMPLEMENT_SPARSE_STATELESS diff --git a/torch/csrc/Tensor.h b/torch/csrc/Tensor.h deleted file mode 100644 index c3fa78d99f..0000000000 --- a/torch/csrc/Tensor.h +++ /dev/null @@ -1,71 +0,0 @@ -#ifndef THP_TENSOR_INC -#define THP_TENSOR_INC - -#define THPTensor TH_CONCAT_3(THP,Real,Tensor) -#define THPTensorStr TH_CONCAT_STRING_3(torch.,Real,Tensor) -#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass) -#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME) - -#define THPDoubleTensor_Check(obj) PyObject_IsInstance(obj, THPDoubleTensorClass) -#define THPFloatTensor_Check(obj) PyObject_IsInstance(obj, THPFloatTensorClass) -#define THPHalfTensor_Check(obj) PyObject_IsInstance(obj, THPHalfTensorClass) -#define THPLongTensor_Check(obj) PyObject_IsInstance(obj, THPLongTensorClass) -#define THPIntTensor_Check(obj) PyObject_IsInstance(obj, THPIntTensorClass) -#define THPShortTensor_Check(obj) PyObject_IsInstance(obj, THPShortTensorClass) -#define THPCharTensor_Check(obj) PyObject_IsInstance(obj, THPCharTensorClass) -#define THPByteTensor_Check(obj) PyObject_IsInstance(obj, THPByteTensorClass) - -#define THPDoubleTensor_CData(obj) (obj)->cdata -#define THPFloatTensor_CData(obj) (obj)->cdata -#define THPHalfTensor_CData(obj) (obj)->cdata -#define THPLongTensor_CData(obj) (obj)->cdata -#define THPIntTensor_CData(obj) (obj)->cdata -#define THPShortTensor_CData(obj) (obj)->cdata -#define THPCharTensor_CData(obj) (obj)->cdata -#define THPByteTensor_CData(obj) (obj)->cdata - -#ifdef _THP_CORE -#define THPTensorType TH_CONCAT_3(THP,Real,TensorType) -#define THPTensorBaseStr TH_CONCAT_STRING_2(Real,TensorBase) -#define THPTensorStateless TH_CONCAT_2(Real,TensorStateless) -#define THPTensorStatelessType TH_CONCAT_2(Real,TensorStatelessType) -#define THPTensor_stateless_(NAME) TH_CONCAT_4(THP,Real,Tensor_stateless_,NAME) -#endif - -// Sparse Tensors -#define THSPTensor TH_CONCAT_3(THSP,Real,Tensor) -#define THSPTensorStr TH_CONCAT_STRING_3(torch.Sparse,Real,Tensor) -#define THSPTensorClass TH_CONCAT_3(THSP,Real,TensorClass) -#define THSPTensor_(NAME) TH_CONCAT_4(THSP,Real,Tensor_,NAME) - -#define THSPDoubleTensor_Check(obj) PyObject_IsInstance(obj, THSPDoubleTensorClass) -#define THSPFloatTensor_Check(obj) PyObject_IsInstance(obj, THSPFloatTensorClass) -#define THSPLongTensor_Check(obj) PyObject_IsInstance(obj, THSPLongTensorClass) -#define THSPIntTensor_Check(obj) PyObject_IsInstance(obj, THSPIntTensorClass) -#define THSPShortTensor_Check(obj) PyObject_IsInstance(obj, THSPShortTensorClass) -#define THSPCharTensor_Check(obj) PyObject_IsInstance(obj, THSPCharTensorClass) -#define THSPByteTensor_Check(obj) PyObject_IsInstance(obj, THSPByteTensorClass) - -#define THSPDoubleTensor_CData(obj) (obj)->cdata -#define THSPFloatTensor_CData(obj) (obj)->cdata -#define THSPLongTensor_CData(obj) (obj)->cdata -#define THSPIntTensor_CData(obj) (obj)->cdata -#define THSPShortTensor_CData(obj) (obj)->cdata -#define THSPCharTensor_CData(obj) (obj)->cdata -#define THSPByteTensor_CData(obj) (obj)->cdata - -#ifdef _THP_CORE -#define THSPTensorType TH_CONCAT_3(THSP,Real,TensorType) -#define THSPTensorBaseStr TH_CONCAT_STRING_3(Sparse,Real,TensorBase) -#define THSPTensorStateless TH_CONCAT_3(Sparse,Real,TensorStateless) -#define THSPTensorStatelessType TH_CONCAT_3(Sparse,Real,TensorStatelessType) -#define THSPTensor_stateless_(NAME) TH_CONCAT_4(THSP,Real,Tensor_stateless_,NAME) -#endif - -#include "generic/Tensor.h" -#include <TH/THGenerateAllTypes.h> - -#include "generic/Tensor.h" -#include <TH/THGenerateHalfType.h> - -#endif diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index f79e7e050a..7ff9a39f7d 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -2,7 +2,7 @@ #define THP_AUTOGRAD_H PyObject * THPAutograd_initExtension(PyObject *_unused); -bool THPAutograd_initFunctions(PyObject* module); +void THPAutograd_initFunctions(); namespace torch { namespace autograd { diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index aa22650cfc..fe5c52c0ae 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -82,10 +82,10 @@ static struct PyGetSetDef accumulate_grad_properties[] = { {nullptr} }; -bool THPAutograd_initFunctions(PyObject* _unused) +void THPAutograd_initFunctions() { THPObjectPtr module(PyModule_New("torch._C._functions")); - if (!module) return false; + if (!module) throw python_error(); static PyTypeObject AccumulateGradClass; addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties); @@ -110,10 +110,13 @@ bool THPAutograd_initFunctions(PyObject* _unused) generated::initialize_autogenerated_functions(); - THPObjectPtr parent(PyImport_ImportModule("torch._C")); - if (!parent) return false; - PyModule_AddObject(parent.get(), "_functions", module.release()); - return true; + auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!c_module) throw python_error(); + + Py_INCREF(module); + if (PyModule_AddObject(c_module, "_functions", module) < 0) { + throw python_error(); + } } namespace torch { namespace autograd { diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 6e246346fa..7d84562ff6 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -21,10 +21,6 @@ #include "torch/csrc/utils/auto_gpu.h" #include "torch/csrc/Exceptions.h" -#ifdef WITH_CUDA -#include "cuda/AutoGPU.h" -#endif - using namespace torch; using namespace torch::autograd; using namespace torch::jit; @@ -369,10 +365,6 @@ static void _wrap_outputs(THPFunction *self, if (THPVariable_Check(obj)) { return ((THPVariable*)obj)->cdata; } - if (THPModule_isTensor(obj)) { - // temporarily wrap tensors as variables until the classes are merged - return make_variable(createTensor(obj), /*requires_grad=*/false); - } throw TypeError("%s.forward: expected Variable (got %s) for return value %d", Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i); }; @@ -456,10 +448,6 @@ static void _save_variables(THPFunction* self) auto variable = (THPVariable*)obj; bool is_output = variable->cdata.grad_fn().get() == cdata_ptr; self->saved_variables.emplace_back(variable->cdata, is_output); - } else if (THPModule_isTensor(obj)) { - // TODO: remove once Variable and Tensor classes are merged - auto var = make_variable(createTensor(obj), /*requires_grad=*/false); - self->saved_variables.emplace_back(std::move(var), false); } else { throw TypeError( "save_for_backward can only save variables, but argument %d is of " diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 0f0144d1a6..ecb5900805 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -14,7 +14,6 @@ #include "torch/csrc/autograd/function.h" #include "torch/csrc/autograd/generated/VariableType.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" -#include "torch/csrc/cuda/AutoGPU.h" #include "torch/csrc/jit/tracer_state.h" #include "torch/csrc/tensor/python_tensor.h" #include "torch/csrc/utils/auto_gil.h" @@ -150,8 +149,6 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) // by nn.Parameter() with no arguments. auto var = torch::tensor::get_default_tensor_type().tensor(); tensor = static_cast<Variable&>(var).data(); - } else if (THPModule_isTensor(data)) { - tensor = torch::createTensor(data); } else if (THPVariable_Check(data)) { tensor = ((THPVariable*)data)->cdata.data(); } else { diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 8c2b3c2906..320c5cabd5 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -17,8 +17,6 @@ using namespace at; using namespace torch::autograd::utils; -extern bool THPModule_isTensor(PyObject *obj); - namespace torch { namespace autograd { Py_ssize_t THPVariable_length(PyObject* self) { @@ -116,9 +114,6 @@ static Variable valueToTensor(const Type & type, PyObject* value) { if (PyFloat_Check(value)) { return type.scalarTensor(Scalar(THPUtils_unpackDouble(value))); } - if (THPModule_isTensor(value)) { - return make_variable(createTensor(value), /*requires_grad=*/false); - } throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString()); } @@ -153,8 +148,6 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis dim++; } else if (THPVariable_Check(obj)) { handle_var(reinterpret_cast<THPVariable*>(obj)->cdata); - } else if (THPModule_isTensor(obj)) { - handle_var(make_variable(createTensor(obj), /*requires_grad=*/false)); } else if (PySequence_Check(obj)) { handle_var(sequenceToVariable(self.type(), obj)); } else { diff --git a/torch/csrc/copy_utils.h b/torch/csrc/copy_utils.h index 4ad0778c30..0aaec84f92 100644 --- a/torch/csrc/copy_utils.h +++ b/torch/csrc/copy_utils.h @@ -1,10 +1,8 @@ -#ifndef THP_COPY_UTILS_H -#define THP_COPY_UTILS_H +#pragma once #include <functional> #include <vector> #include "Types.h" -#include "expand_utils.h" typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction; struct THPCopyInfo { @@ -56,25 +54,6 @@ inline PyObject * THPStorageCopyMethod(const THPCopyList& v, PyObject *self, PyO return self; } -inline PyObject * THPTensorCopyMethod(const THPCopyList& v, PyObject *self, PyObject *args, PyObject *kwargs) -{ - PyObject *src; - int non_blocking = 0; - int broadcast = 1; - static char *kwlist[] = {"source", "non_blocking", "broadcast", NULL}; - // use int as parse type because bool not available in python2. - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|ii:copy_", kwlist, &src, &non_blocking, &broadcast)) { - return NULL; - } - - if (!THPCopy(v, self, src, non_blocking, broadcast)) { - return NULL; - } - - Py_INCREF(self); - return self; -} - template <typename StorageDst, typename StorageSrc> void THPInsertStorageCopyFunction( THPCopyList& copyList, @@ -101,49 +80,3 @@ void THPInsertStorageCopyFunction( PyTypeObject* srcType = THPTypeInfo<StorageSrc>::pyType(); copyList.push_back({ srcType, wrapper, non_blocking, false }); } - -template <typename TensorDst, typename TensorSrc> -void THPInsertTensorCopyFunction( - THPCopyList& copyList, - void (*copyFunc)(LIBRARY_STATE_TYPE TensorDst* x, TensorSrc* z), - bool non_blocking=false, - bool broadcast=true) -{ - auto wrapper = [copyFunc](PyObject* dst_, PyObject* src_, bool broadcast) { - TensorDst* dst = THPTypeInfo<TensorDst>::cdata(dst_); - TensorSrc* src = THPTypeInfo<TensorSrc>::cdata(src_); - - TensorSrc *src_save = src; - THPPointer<TensorSrc> src_guard(newForExpand<TensorSrc>(LIBRARY_STATE_NOARGS)); - - // support for "broadcast" parameter to copy_. - if (broadcast) { - bool expand_success = false; - try { - expand_inplace1<TensorSrc, TensorDst>(LIBRARY_STATE src_guard.get(), src, dst, "src", "dst", true); - expand_success = true; - } catch (std::exception &e) {} - if (expand_success) { - src = src_guard.get(); - } - } - - PyThreadState *_save = NULL; - try { - Py_UNBLOCK_THREADS; - copyFunc(LIBRARY_STATE dst, src); - Py_BLOCK_THREADS; - } catch (...) { - if (_save) { - Py_BLOCK_THREADS; - } - throw; - } - src = src_save; - }; - - PyTypeObject* srcType = THPTypeInfo<TensorSrc>::pyType(); - copyList.push_back({ srcType, wrapper, non_blocking, broadcast }); -} - -#endif diff --git a/torch/csrc/cuda/AutoGPU.cpp b/torch/csrc/cuda/AutoGPU.cpp deleted file mode 100644 index 227399190e..0000000000 --- a/torch/csrc/cuda/AutoGPU.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "AutoGPU.h" - -#include "THCP.h" -#include <THC/THC.h> - -static int getObjDevice(PyObject *obj) { - PyObject *obj_type = (PyObject*)Py_TYPE(obj); - if (obj_type == THCPDoubleTensorClass) { - return THCudaDoubleTensor_getDevice(LIBRARY_STATE ((THCPDoubleTensor*)obj)->cdata); - } else if (obj_type == THCPFloatTensorClass) { - return THCudaTensor_getDevice(LIBRARY_STATE ((THCPFloatTensor*)obj)->cdata); - } else if (obj_type == THCPHalfTensorClass) { - return THCudaHalfTensor_getDevice(LIBRARY_STATE ((THCPHalfTensor*)obj)->cdata); - } else if (obj_type == THCPLongTensorClass) { - return THCudaLongTensor_getDevice(LIBRARY_STATE ((THCPLongTensor*)obj)->cdata); - } else if (obj_type == THCPIntTensorClass) { - return THCudaIntTensor_getDevice(LIBRARY_STATE ((THCPIntTensor*)obj)->cdata); - } else if (obj_type == THCPShortTensorClass) { - return THCudaShortTensor_getDevice(LIBRARY_STATE ((THCPShortTensor*)obj)->cdata); - } else if (obj_type == THCPCharTensorClass) { - return THCudaCharTensor_getDevice(LIBRARY_STATE ((THCPCharTensor*)obj)->cdata); - } else if (obj_type == THCPByteTensorClass) { - return THCudaByteTensor_getDevice(LIBRARY_STATE ((THCPByteTensor*)obj)->cdata); - } else if (obj_type == THCSPDoubleTensorClass) { - return THCSDoubleTensor_getDevice(LIBRARY_STATE ((THCSPDoubleTensor*)obj)->cdata); - } else if (obj_type == THCSPFloatTensorClass) { - return THCSFloatTensor_getDevice(LIBRARY_STATE ((THCSPFloatTensor*)obj)->cdata); - } else if (obj_type == THCSPHalfTensorClass) { - return THCSHalfTensor_getDevice(LIBRARY_STATE ((THCSPHalfTensor*)obj)->cdata); - } else if (obj_type == THCSPLongTensorClass) { - return THCSLongTensor_getDevice(LIBRARY_STATE ((THCSPLongTensor*)obj)->cdata); - } else if (obj_type == THCSPIntTensorClass) { - return THCSIntTensor_getDevice(LIBRARY_STATE ((THCSPIntTensor*)obj)->cdata); - } else if (obj_type == THCSPShortTensorClass) { - return THCSShortTensor_getDevice(LIBRARY_STATE ((THCSPShortTensor*)obj)->cdata); - } else if (obj_type == THCSPCharTensorClass) { - return THCSCharTensor_getDevice(LIBRARY_STATE ((THCSPCharTensor*)obj)->cdata); - } else if (obj_type == THCSPByteTensorClass) { - return THCSByteTensor_getDevice(LIBRARY_STATE ((THCSPByteTensor*)obj)->cdata); - } - return -1; -} - -static int getObjDevice(PyObject *args, PyObject *self) { - if (self) { - int device = getObjDevice(self); - if (device != -1) { - return device; - } - } - if (args) { - for (int i = 0; i < PyTuple_Size(args); i++) { - int device = getObjDevice(PyTuple_GET_ITEM(args, i)); - if (device != -1) { - return device; - } - } - } - return -1; -} - -THCPAutoGPU::THCPAutoGPU(PyObject *args, PyObject *self) - : AutoGPU(getObjDevice(args, self)) { -} - -void THCPAutoGPU::setObjDevice(PyObject *obj) { - setDevice(getObjDevice(obj)); -} diff --git a/torch/csrc/cuda/AutoGPU.h b/torch/csrc/cuda/AutoGPU.h deleted file mode 100644 index 07e75aa431..0000000000 --- a/torch/csrc/cuda/AutoGPU.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef THCP_AUTOGPU_INC -#define THCP_AUTOGPU_INC - -#include <Python.h> -#include "THP_export.h" -#include "torch/csrc/utils/auto_gpu.h" - -class THP_CLASS THCPAutoGPU : public AutoGPU { -public: - explicit THCPAutoGPU(int device_id=-1) : AutoGPU(device_id) {} -#ifdef WITH_CUDA - THCPAutoGPU(PyObject *args, PyObject *self=NULL); - void setObjDevice(PyObject *obj); -#endif -}; - -#endif diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 63d96f28cb..0f5e96378c 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -18,42 +18,12 @@ #include "torch/csrc/autograd/generated/VariableType.h" #include "torch/csrc/utils/python_strings.h" #include "torch/csrc/cuda/python_comm.h" -#include "ModuleSparse.cpp" using namespace torch; THCState *state; //////////////////////////////////////////////////////////////////////////////// -// Class pointer cache -//////////////////////////////////////////////////////////////////////////////// - -static bool THCPModule_loadClasses(PyObject *torch_module) -{ -#define ASSERT_NOT_NULL(ptr) if (!(ptr)) { THPUtils_setError("couldn't load classes"); return false; } - if (!THCPDoubleTensor_postInit(torch_module)) return false; - if (!THCPFloatTensor_postInit(torch_module)) return false; - if (!THCPHalfTensor_postInit(torch_module)) return false; - if (!THCPLongTensor_postInit(torch_module)) return false; - if (!THCPIntTensor_postInit(torch_module)) return false; - if (!THCPShortTensor_postInit(torch_module)) return false; - if (!THCPCharTensor_postInit(torch_module)) return false; - if (!THCPByteTensor_postInit(torch_module)) return false; - - THCPDoubleStorage_postInit(torch_module); - THCPFloatStorage_postInit(torch_module); - THCPHalfStorage_postInit(torch_module); - THCPLongStorage_postInit(torch_module); - THCPIntStorage_postInit(torch_module); - THCPShortStorage_postInit(torch_module); - THCPCharStorage_postInit(torch_module); - THCPByteStorage_postInit(torch_module); - - return true; -#undef ASSERT_NOT_NULL -} - -//////////////////////////////////////////////////////////////////////////////// // CUDA management methods //////////////////////////////////////////////////////////////////////////////// @@ -345,49 +315,10 @@ PyObject * THCPModule_maxMemoryCached(PyObject *_unused, PyObject *arg) // Cuda module initialization //////////////////////////////////////////////////////////////////////////////// -bool THCPModule_initCuda(PyObject *torch_module) { - HANDLE_TH_ERRORS -#define ASSERT_TRUE(cond) if (!(cond)) { return false; } - state = at::globalContext().lazyInitCUDA(); - -#ifdef USE_MAGMA - THCMagma_init(state); - ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_magma", PyBool_FromLong(true)) != -1); -#else - ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_magma", PyBool_FromLong(false)) != -1); -#endif - -#ifdef CUDA_HALF_TENSOR - ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_half", PyBool_FromLong(true)) != -1); -#else - ASSERT_TRUE(PyObject_SetAttrString(torch_module, "has_half", PyBool_FromLong(false)) != -1); -#endif - - ASSERT_TRUE(THCPModule_loadClasses(torch_module)); - - ASSERT_TRUE(PyObject_SetAttrString(torch_module, "_state_cdata", PyLong_FromVoidPtr(state)) != -1); - - // TODO: register THCudaShutdown handler at exit - return true; -#undef ASSERT_TRUE - END_HANDLE_TH_ERRORS_RET(false) -} - -// Callback for python part. Used for additional initialization of python classes -PyObject * THCPModule_initExtension(PyObject *self) -{ - PyObject *torch_module = PyImport_ImportModule("torch.cuda"); - if (!torch_module) { - THPUtils_setError("class loader couldn't access torch module"); - return NULL; - } - if (!THCPModule_initCuda(torch_module)) { - return NULL; - } - +static void bindCudaDeviceProperties(PyObject* module) { // Add class and method to torch.cuda - auto m = py::handle(torch_module).cast<py::module>(); - py::class_<cudaDeviceProp>(m,"_CudaDeviceProperties") + auto m = py::handle(module).cast<py::module>(); + py::class_<cudaDeviceProp>(m, "_CudaDeviceProperties") .def_readonly("name", &cudaDeviceProp::name) .def_readonly("major", &cudaDeviceProp::major) .def_readonly("minor", &cudaDeviceProp::minor) @@ -405,8 +336,57 @@ PyObject * THCPModule_initExtension(PyObject *self) m.def("_get_device_properties", [](int device) -> cudaDeviceProp * { return at::globalContext().getDeviceProperties(device); }, py::return_value_policy::reference); +} + +// Callback for python part. Used for additional initialization of python classes +static PyObject * THCPModule_initExtension(PyObject *self) +{ + HANDLE_TH_ERRORS + state = at::globalContext().lazyInitCUDA(); + + auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); + if (!m) throw python_error(); + + // Register Storage Python objects with DynamicTypes.cpp + THCPDoubleStorage_postInit(m); + THCPFloatStorage_postInit(m); + THCPHalfStorage_postInit(m); + THCPLongStorage_postInit(m); + THCPIntStorage_postInit(m); + THCPShortStorage_postInit(m); + THCPCharStorage_postInit(m); + THCPByteStorage_postInit(m); + +#ifdef USE_MAGMA + THCMagma_init(state); + bool has_magma = true; +#else + bool has_magma = false; +#endif - Py_RETURN_TRUE; +#ifdef CUDA_HALF_TENSOR + bool has_half = true; +#else + bool has_half = false; +#endif + + auto set_module_attr = [&](const char* name, PyObject* v) { + if (PyObject_SetAttrString(m, name, v) < 0) { + throw python_error(); + } + }; + + set_module_attr("has_magma", has_magma ? Py_True : Py_False); + set_module_attr("has_half", has_half ? Py_True : Py_False); + + auto _state_cdata = THPObjectPtr(PyLong_FromVoidPtr(state)); + if (!_state_cdata) throw python_error(); + set_module_attr("_state_cdata", _state_cdata.get()); + + bindCudaDeviceProperties(m); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS } #ifdef WITH_NCCL diff --git a/torch/csrc/cuda/ModuleSparse.cpp b/torch/csrc/cuda/ModuleSparse.cpp deleted file mode 100644 index 7b39c8fea9..0000000000 --- a/torch/csrc/cuda/ModuleSparse.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "THCP.h" - -static bool THCSPModule_loadClasses(PyObject *sparse_module) -{ - if (!THCSPDoubleTensor_postInit(sparse_module)) return false; - if (!THCSPFloatTensor_postInit(sparse_module)) return false; -#ifdef CUDA_HALF_TENSOR - if (!THCSPHalfTensor_postInit(sparse_module)) return false; -#endif - if (!THCSPLongTensor_postInit(sparse_module)) return false; - if (!THCSPIntTensor_postInit(sparse_module)) return false; - if (!THCSPShortTensor_postInit(sparse_module)) return false; - if (!THCSPCharTensor_postInit(sparse_module)) return false; - if (!THCSPByteTensor_postInit(sparse_module)) return false; - return true; -} - -//////////////////////////////////////////////////////////////////////////////// -// Sparse Cuda module initialization -//////////////////////////////////////////////////////////////////////////////// - -bool THCSPModule_initCudaSparse(PyObject *module) { -#define ASSERT_TRUE(cond) if (!(cond)) { return false; } - ASSERT_TRUE(THCSPModule_loadClasses(module)); - return true; -#undef ASSERT_TRUE -} - -PyObject * THCSPModule_initExtension(PyObject *self) -{ - PyObject *module = PyImport_ImportModule("torch.cuda.sparse"); - if (!module) { - THPUtils_setError("class loader couldn't access torch.cuda.sparse module"); - return NULL; - } - if (!THCSPModule_initCudaSparse(module)) { - return NULL; - } - Py_RETURN_NONE; -} diff --git a/torch/csrc/cuda/THCP.h b/torch/csrc/cuda/THCP.h index 79850b5682..57372197f8 100644 --- a/torch/csrc/cuda/THCP.h +++ b/torch/csrc/cuda/THCP.h @@ -11,7 +11,6 @@ #include "torch/csrc/THP.h" #include "serialization.h" -#include "AutoGPU.h" #include "Module.h" #include "Storage.h" #include "Tensor.h" diff --git a/torch/csrc/cuda/Tensor.h b/torch/csrc/cuda/Tensor.h deleted file mode 100644 index e4ac7c8eba..0000000000 --- a/torch/csrc/cuda/Tensor.h +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef THCP_TENSOR_INC -#define THCP_TENSOR_INC - -#define THCPTensor TH_CONCAT_3(THCP,Real,Tensor) -#define THCPTensorStr TH_CONCAT_STRING_3(torch.cuda.,Real,Tensor) -#define THCPTensorClass TH_CONCAT_3(THCP,Real,TensorClass) -#define THCPTensor_(NAME) TH_CONCAT_4(THCP,Real,Tensor_,NAME) - -#define THCPDoubleTensor_Check(obj) PyObject_IsInstance(obj, THCPDoubleTensorClass) -#define THCPFloatTensor_Check(obj) PyObject_IsInstance(obj, THCPFloatTensorClass) -#define THCPHalfTensor_Check(obj) PyObject_IsInstance(obj, THCPHalfTensorClass) -#define THCPLongTensor_Check(obj) PyObject_IsInstance(obj, THCPLongTensorClass) -#define THCPIntTensor_Check(obj) PyObject_IsInstance(obj, THCPIntTensorClass) -#define THCPShortTensor_Check(obj) PyObject_IsInstance(obj, THCPShortTensorClass) -#define THCPCharTensor_Check(obj) PyObject_IsInstance(obj, THCPCharTensorClass) -#define THCPByteTensor_Check(obj) PyObject_IsInstance(obj, THCPByteTensorClass) - -#define THCPDoubleTensor_CData(obj) (obj)->cdata -#define THCPFloatTensor_CData(obj) (obj)->cdata -#define THCPHalfTensor_CData(obj) (obj)->cdata -#define THCPLongTensor_CData(obj) (obj)->cdata -#define THCPIntTensor_CData(obj) (obj)->cdata -#define THCPShortTensor_CData(obj) (obj)->cdata -#define THCPCharTensor_CData(obj) (obj)->cdata -#define THCPByteTensor_CData(obj) (obj)->cdata - -#ifdef _THP_CORE -#define THCPTensorType TH_CONCAT_3(THCP,Real,TensorType) -#define THCPTensorBaseStr TH_CONCAT_STRING_3(Cuda,Real,TensorBase) -#define THCPTensor_stateless_(NAME) TH_CONCAT_4(THCP,Real,Tensor_stateless_,NAME) -#define THCPTensorStatelessType TH_CONCAT_2(CReal,TensorStatelessType) -#define THCPTensorStateless TH_CONCAT_2(CReal,TensorStateless) -#define THCPTensorStatelessMethods TH_CONCAT_2(CReal,TensorStatelessMethods) -#endif - -#define THCSPTensor TH_CONCAT_3(THCSP,Real,Tensor) -#define THCSPTensorStr TH_CONCAT_STRING_3(torch.cuda.sparse.,Real,Tensor) -#define THCSPTensorClass TH_CONCAT_3(THCSP,Real,TensorClass) -#define THCSPTensor_(NAME) TH_CONCAT_4(THCSP,Real,Tensor_,NAME) - -#define THCSPDoubleTensor_Check(obj) PyObject_IsInstance(obj, THCSPDoubleTensorClass) -#define THCSPFloatTensor_Check(obj) PyObject_IsInstance(obj, THCSPFloatTensorClass) -#define THCSPHalfTensor_Check(obj) PyObject_IsInstance(obj, THCSPHalfTensorClass) -#define THCSPLongTensor_Check(obj) PyObject_IsInstance(obj, THCSPLongTensorClass) -#define THCSPIntTensor_Check(obj) PyObject_IsInstance(obj, THCSPIntTensorClass) -#define THCSPShortTensor_Check(obj) PyObject_IsInstance(obj, THCSPShortTensorClass) -#define THCSPCharTensor_Check(obj) PyObject_IsInstance(obj, THCSPCharTensorClass) -#define THCSPByteTensor_Check(obj) PyObject_IsInstance(obj, THCSPByteTensorClass) - -#define THCSPDoubleTensor_CData(obj) (obj)->cdata -#define THCSPFloatTensor_CData(obj) (obj)->cdata -#define THCSPHalfTensor_CData(obj) (obj)->cdata -#define THCSPLongTensor_CData(obj) (obj)->cdata -#define THCSPIntTensor_CData(obj) (obj)->cdata -#define THCSPShortTensor_CData(obj) (obj)->cdata -#define THCSPCharTensor_CData(obj) (obj)->cdata -#define THCSPByteTensor_CData(obj) (obj)->cdata - -#ifdef _THP_CORE -#define THCSPTensorType TH_CONCAT_3(THCSP,Real,TensorType) -#define THCSPTensorBaseStr TH_CONCAT_STRING_3(CudaSparse,Real,TensorBase) -#define THCSPTensor_stateless_(NAME) TH_CONCAT_4(THCP,Real,Tensor_stateless_,NAME) -#define THCSPTensorStatelessType TH_CONCAT_3(CudaSparse,Real,TensorStatelessType) -#define THCSPTensorStateless TH_CONCAT_3(CudaSparse,Real,TensorStateless) -#define THCSPTensorStatelessMethods TH_CONCAT_3(CudaSparse,Real,TensorStatelessMethods) -#endif - -#include "override_macros.h" - -#define THC_GENERIC_FILE "torch/csrc/generic/Tensor.h" -#include <THC/THCGenerateAllTypes.h> - -#endif diff --git a/torch/csrc/cuda/expand_utils.cpp b/torch/csrc/cuda/expand_utils.cpp deleted file mode 100644 index 580785658c..0000000000 --- a/torch/csrc/cuda/expand_utils.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "torch/csrc/cuda/THCP.h" - -// Declare/Define the expansion functions that have THCState. Note that we -// still need to define the CPU-type versions because the copy functions that -// copy from GPU to CPU type have a THCState. - -#define CUDA_EXPAND 1 - -#include "torch/csrc/expand_utils.h" -#include "torch/csrc/generic/expand_utils-inl.h" - -#undef CUDA_EXPAND diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index f5f88d7ac9..8163347280 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -6,12 +6,14 @@ #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/cuda/THCP.h" #include "torch/csrc/cuda/nccl.h" +#include "torch/csrc/Exceptions.h" #include <nccl.h> #include <sstream> #include <unordered_map> using namespace at; +using namespace torch; using namespace torch::cuda::nccl; using namespace torch::cuda::nccl::detail; @@ -304,12 +306,11 @@ static std::vector<at::Tensor> extract_tensors(PyObject* obj) { Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < length; i++) { PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); - if (THPVariable_Check(item)) { - auto var = (THPVariable*) item; - list.emplace_back(var->cdata.data()); - } else { - list.emplace_back(torch::createTensor(item)); + if (!THPVariable_Check(item)) { + throw TypeError("expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name); } + auto var = (THPVariable*) item; + list.emplace_back(var->cdata.data()); } return list; } diff --git a/torch/csrc/distributed/THDP.h b/torch/csrc/distributed/THDP.h index 058466d22e..33ab36ef45 100644 --- a/torch/csrc/distributed/THDP.h +++ b/torch/csrc/distributed/THDP.h @@ -6,7 +6,6 @@ #include "torch/csrc/THP.h" #include "Module.h" #include "Storage.h" -#include "Tensor.h" #include "../PtrWrapper.h" #ifdef _THP_CORE #include "utils.h" diff --git a/torch/csrc/expand_utils.cpp b/torch/csrc/expand_utils.cpp deleted file mode 100644 index a2cd55d869..0000000000 --- a/torch/csrc/expand_utils.cpp +++ /dev/null @@ -1,5 +0,0 @@ -#include "torch/csrc/THP.h" - -// Declare/Define the expansion functions that lack THCState. -#include "torch/csrc/expand_utils.h" -#include "torch/csrc/generic/expand_utils-inl.h" diff --git a/torch/csrc/expand_utils.h b/torch/csrc/expand_utils.h deleted file mode 100644 index 34dfdb995e..0000000000 --- a/torch/csrc/expand_utils.h +++ /dev/null @@ -1,205 +0,0 @@ -#ifndef THP_EXPAND_UTILS_H -#define THP_EXPAND_UTILS_H - -#include <sstream> -#include <Python.h> - -template <typename ExpandType> -ExpandType *newForExpand(LIBRARY_STATE_TYPE_NOARGS); - -template <typename TensorType> -void expand(LIBRARY_STATE_TYPE TensorType *r, TensorType *tensor, THLongStorage *sizes); - -template <typename TensorType1, typename TensorType2> -void expand2(LIBRARY_STATE_TYPE TensorType1 *r1, TensorType2 *r2, - TensorType1 *e1, TensorType2 *e2, - char *e1_name, char *e2_name) { - if (e1->nDimension <= 0) { - throw std::runtime_error(std::string("can't expand empty tensor ").append(e1_name)); - } - if (e2->nDimension <= 0) { - throw std::runtime_error(std::string("can't expand empty tensor ").append(e2_name)); - } - THLongStoragePtr sizes(THLongStorage_new()); - char error_buffer[1024]; - int ret = THLongStorage_inferSize2(sizes, - e1->size, e1->nDimension, - e2->size, e2->nDimension, - error_buffer, 1024); - if (ret != 0) { - throw std::runtime_error(error_buffer); - } - - expand(LIBRARY_STATE r1, e1, sizes); - expand(LIBRARY_STATE r2, e2, sizes); -} - -template <typename TensorType1, typename TensorType2, typename TensorType3> -void expand3(LIBRARY_STATE_TYPE TensorType1 *r1, TensorType2 *r2, TensorType3 *r3, - TensorType1 *e1, TensorType2 *e2, TensorType3 *e3, - char *e1_name, char *e2_name, char *e3_name) { - if (e1->nDimension <= 0) { - throw std::runtime_error(std::string("can't expand empty tensor ").append(e1_name)); - } - if (e2->nDimension <= 0) { - throw std::runtime_error(std::string("can't expand empty tensor ").append(e2_name)); - } - if (e3->nDimension <= 0) { - throw std::runtime_error(std::string("can't expand empty tensor ").append(e3_name)); - } - - int64_t *e_sizes[3]; - int64_t e_dims[3]; - - e_sizes[ 0 ] = e1->size; - e_sizes[ 1 ] = e2->size; - e_sizes[ 2 ] = e3->size; - e_dims[ 0 ] = e1->nDimension; - e_dims[ 1 ] = e2->nDimension; - e_dims[ 2 ] = e3->nDimension; - - THLongStoragePtr sizes(THLongStorage_new()); - char error_buffer[1024]; - int ret = THLongStorage_inferSizeN(sizes, - 3, - e_sizes, - e_dims, - error_buffer, - 1024); - - if(ret != 0) { - throw std::runtime_error(error_buffer); - } - - expand(LIBRARY_STATE r1, e1, sizes); - expand(LIBRARY_STATE r2, e2, sizes); - expand(LIBRARY_STATE r3, e3, sizes); -} - -template <typename ExpandType, typename TensorType> -void check_backincompat_expand_warn(ExpandType *to_expand, TensorType *tensor, - char *to_expand_name, char *tensor_name, bool fallback, - ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem) { - if (fallback && getBackCompatBroadcastWarn()) { - bool same_shape = THSize_isSameSizeAs(tensor->size, tensor->nDimension, - to_expand->size, to_expand->nDimension); - if (!same_shape && (tensor_nElem == to_expand_nElem)) { - std::ostringstream warn; - warn << tensor_name << " and " << to_expand_name << " do not have the same shape, but are " - << "broadcastable, and have the same number of elements. Changing behavior in a backwards incompatible " - << "manner to broadcasting rather than viewing as 1-dimensional."; - PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1); - } - } -} - -template <typename ExpandType, typename TensorType> -void expand_inplace(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor, - char *to_expand_name, char *tensor_name, bool fallback, - THLongStorage *tensor_size, ptrdiff_t to_expand_nElem, ptrdiff_t tensor_nElem, - bool warn_pointwise_fallback) { - try { - expand<ExpandType>(LIBRARY_STATE r, to_expand, tensor_size); - } catch (std::exception &e) { - if (warn_pointwise_fallback) { - std::ostringstream warn; - warn << to_expand_name << " is not broadcastable to " << tensor_name - << ", but they have the same number of elements. Falling back to deprecated pointwise behavior."; - PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1); - } - throw; - } -} - -template <typename ExpandType, typename TensorType> -void expand_inplace1(LIBRARY_STATE_TYPE ExpandType *r, ExpandType *to_expand, TensorType *tensor, - char *to_expand_name, char *tensor_name, bool fallback) { - ptrdiff_t to_expand_nElem = THSize_nElement(to_expand->nDimension, to_expand->size); - ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size); - bool to_expand_warn = fallback && (to_expand_nElem == tensor_nElem) && to_expand_nElem != 0; - THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension)); - THLongStorage_rawCopy(tensor_size.get(), tensor->size); - - expand_inplace(LIBRARY_STATE r, to_expand, tensor, to_expand_name, tensor_name, fallback, - tensor_size, to_expand_nElem, tensor_nElem, to_expand_warn); - check_backincompat_expand_warn<ExpandType, TensorType>(to_expand, tensor, to_expand_name, tensor_name, fallback, - to_expand_nElem, tensor_nElem); -} - -template <typename TensorType> -void expand_inplace2(LIBRARY_STATE_TYPE TensorType *r1, TensorType *r2, - TensorType *to_expand1, TensorType *to_expand2, TensorType *tensor, - char *to_expand1_name, char *to_expand2_name, char *tensor_name, bool fallback) { - ptrdiff_t tensor_nElem = THSize_nElement(tensor->nDimension, tensor->size); - ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size); - ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size); - bool to_expand1_warn = fallback && (tensor_nElem == to_expand1_nElem) && tensor_nElem != 0; - bool to_expand2_warn = fallback && (tensor_nElem == to_expand2_nElem) && tensor_nElem != 0; - THLongStoragePtr tensor_size(THLongStorage_newWithSize(tensor->nDimension)); - THLongStorage_rawCopy(tensor_size.get(), tensor->size); - - expand_inplace(LIBRARY_STATE r1, to_expand1, tensor, to_expand1_name, tensor_name, fallback, - tensor_size, to_expand1_nElem, tensor_nElem, to_expand1_warn && to_expand2_warn); - expand_inplace(LIBRARY_STATE r2, to_expand2, tensor, to_expand2_name, tensor_name, fallback, - tensor_size, to_expand2_nElem, tensor_nElem, to_expand1_warn && to_expand2_warn); - - check_backincompat_expand_warn<TensorType, TensorType>(to_expand1, tensor, to_expand1_name, tensor_name, fallback, - to_expand1_nElem, tensor_nElem); - check_backincompat_expand_warn<TensorType, TensorType>(to_expand2, tensor, to_expand2_name, tensor_name, fallback, - to_expand2_nElem, tensor_nElem); -} - -template <typename TensorType1, typename TensorType2> -void expand_outplace2(LIBRARY_STATE_TYPE TensorType1 *r1, TensorType2 *r2, - TensorType1 *to_expand1, TensorType2 *to_expand2, - char *to_expand1_name, char *to_expand2_name, bool fallback) { - ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size); - ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size); - bool expand_warn = fallback && (to_expand1_nElem == to_expand2_nElem) && to_expand1_nElem != 0; - try { - expand2<TensorType1, TensorType2>(LIBRARY_STATE r1, r2, to_expand1, to_expand2, to_expand1_name, to_expand2_name); - } catch (std::exception &e) { - if (expand_warn) { - std::ostringstream warn; - warn << to_expand1_name << " and " << to_expand2_name << " not broadcastable, but have the same number of " - << "elements. Falling back to deprecated pointwise behavior."; - PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1); - } - throw; - } - - check_backincompat_expand_warn<TensorType1, TensorType2>(to_expand1, to_expand2, to_expand1_name, to_expand2_name, - fallback, to_expand1_nElem, to_expand2_nElem); -} - -template <typename TensorType1, typename TensorType2, typename TensorType3> -void expand_outplace3(LIBRARY_STATE_TYPE TensorType1 *r1, TensorType2 *r2, TensorType3 *r3, - TensorType1 *to_expand1, TensorType2 *to_expand2, TensorType3 *to_expand3, - char *to_expand1_name, char *to_expand2_name, char *to_expand3_name, bool fallback) { - ptrdiff_t to_expand1_nElem = THSize_nElement(to_expand1->nDimension, to_expand1->size); - ptrdiff_t to_expand2_nElem = THSize_nElement(to_expand2->nDimension, to_expand2->size); - ptrdiff_t to_expand3_nElem = THSize_nElement(to_expand3->nDimension, to_expand3->size); - bool to_expand2_warn = fallback && (to_expand1_nElem == to_expand2_nElem) && to_expand1_nElem != 0; - bool to_expand3_warn = fallback && (to_expand1_nElem == to_expand3_nElem) && to_expand1_nElem != 0; - - try { - expand3<TensorType1, TensorType2, TensorType3>(LIBRARY_STATE r1, r2, r3, - to_expand1, to_expand2, to_expand3, - to_expand1_name, to_expand2_name, to_expand3_name); - } catch (std::exception &e) { - if(to_expand2_warn && to_expand3_warn) { - std::ostringstream warn; - warn << to_expand1_name << ", " << to_expand2_name << ", and " << to_expand3_name << " not broadcastable," - << " but have the same number of elements. Falling back to deprecated pointwise behavior."; - PyErr_WarnEx(PyExc_UserWarning, warn.str().c_str(), 1); - } - throw; - } - - check_backincompat_expand_warn<TensorType1, TensorType2>(to_expand1, to_expand2, to_expand1_name, to_expand2_name, - fallback, to_expand1_nElem, to_expand2_nElem); - check_backincompat_expand_warn<TensorType1, TensorType3>(to_expand1, to_expand3, to_expand1_name, to_expand3_name, - fallback, to_expand1_nElem, to_expand3_nElem); -} - -#endif diff --git a/torch/csrc/generic/SparseTensor.cpp b/torch/csrc/generic/SparseTensor.cpp deleted file mode 100644 index db636fd43c..0000000000 --- a/torch/csrc/generic/SparseTensor.cpp +++ /dev/null @@ -1,304 +0,0 @@ -PyObject *THSPTensorClass = NULL; - -static void THSTensor_(initStorage)(THSTensor* tensor) -{ - // Ensure that PyTorch's "storage is not NULL" invariant is upheld. - // See Note [Storage is not NULL] - if (!tensor->indices->storage) { -#ifdef THC_GENERIC_FILE - tensor->indices->storage = THCudaLongStorage_new(LIBRARY_STATE_NOARGS); -#else - tensor->indices->storage = THLongStorage_new(LIBRARY_STATE_NOARGS); -#endif - } - if (!tensor->values->storage) { - tensor->values->storage = THStorage_(new)(LIBRARY_STATE_NOARGS); - } -} - -PyObject * THSPTensor_(New)(THSTensor *tensor) -{ - THSTensorPtr ptr(tensor); - PyTypeObject *type = (PyTypeObject *)THSPTensorClass; - PyObject *obj = type->tp_alloc(type, 0); - if (obj) { - ((THSPTensor *)obj)->cdata = ptr.release(); - } - return obj; -} - -PyObject * THSPTensor_(NewEmpty)() -{ - THSTensorPtr tensor(THSTensor_(new)(LIBRARY_STATE_NOARGS)); - THSTensor_(initStorage)(tensor.get()); - return THSPTensor_(New)(tensor.release()); -} - -static void THSPTensor_(dealloc)(THSPTensor* self) -{ - if (self->cdata) - THSTensor_(free)(LIBRARY_STATE self->cdata); - Py_TYPE(self)->tp_free((PyObject*)self); -} - -static PyObject * THSPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ -#ifdef THC_GENERIC_FILE -#define THPIndexTensor_Check THCPLongTensor_Check -#define THPIndexTensor THCPLongTensor -#define THIndexTensor THCudaLongTensor -#else -#define THPIndexTensor_Check THPLongTensor_Check -#define THPIndexTensor THPLongTensor -#define THIndexTensor THLongTensor -#endif - HANDLE_TH_ERRORS - -#ifdef THC_GENERIC_FILE - THCPAutoGPU gpu_guard(args, NULL); -#endif - - Py_ssize_t num_args = args ? PyTuple_Size(args) : 0; - - THSPTensorPtr self((THSPTensor *)type->tp_alloc(type, 0)); - THPUtils_assert(self, "failed to allocate a " THSPTensorStr " object"); - THLongStoragePtr sizes; - - // Internally we allow constructing with a keyword only argument cdata - if (kwargs != NULL) { - Py_ssize_t num_kwargs = PyDict_Size(kwargs); - if (num_args == 0) { - // NB: This is an internal option, so we don't want to advertise it. - PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata"); - if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) { - THSTensor *ptr = (THSTensor*)PyLong_AsVoidPtr(cdata_ptr); - self->cdata = ptr; - return (PyObject*)self.release(); - } - } -#ifdef THC_GENERIC_FILE - PyObject *device_id = PyDict_GetItemString(kwargs, "device"); - if (device_id == Py_None) { - num_kwargs--; - } else if (device_id) { - THPUtils_assert(THPUtils_checkLong(device_id), "device argument " - " has to be an int, but got %s", THPUtils_typename(device_id)); - gpu_guard.setDevice(THPUtils_unpackLong(device_id)); - // simulate pop() and pretend this key was never there - num_kwargs--; - } -#endif -#ifdef THC_GENERIC_FILE - THPUtils_assert(num_kwargs == 0, THSPTensorStr " constructor only " - "accepts a 'device' keyword argument"); -#else - THPUtils_assert(num_kwargs == 0, THPTensorStr " constructor doesn't " - "accept any keyword arguments"); -#endif - } - - // torch.SparseTensor() - if (num_args == 0) { - self->cdata = THSTensor_(new)(LIBRARY_STATE_NOARGS); - } else { - PyObject *first_arg = PyTuple_GET_ITEM(args, 0); - - // torch.SparseTensor(size) - if (num_args == 1 && THPUtils_checkLong(first_arg)) { - int64_t size = THPUtils_unpackLong(first_arg); - self->cdata = THSTensor_(newWithSize1d)(LIBRARY_STATE size); - } - // torch.SparseTensor(torch.Size sizes) - else if (num_args == 1 && THPSize_Check(first_arg)) { - THLongStoragePtr sizes(THPUtils_unpackSize(first_arg)); - self->cdata = THSTensor_(newWithSize)(LIBRARY_STATE sizes.get(), nullptr); - } - // torch.SparseTensor(torch.LongTensor indices, torch.LongTensor values) - else if (num_args == 2 && THPIndexTensor_Check(first_arg)) { - PyObject *second_arg = PyTuple_GET_ITEM(args, 1); - if (!THPTensor_(Check)(second_arg)) goto invalid_arguments; - - THIndexTensor *indices = ((THPIndexTensor*)first_arg)->cdata; - THTensor *values = ((THPTensor*)second_arg)->cdata; - -#ifdef THC_GENERIC_FILE - THCAssertSameGPU(THSTensor_(checkGPU)(LIBRARY_STATE 0, 2, indices, values)); -#endif - - self->cdata = THSTensor_(newWithTensor)(LIBRARY_STATE indices, values); - } - // torch.SparseTensor(torch.LongTensor indices, - // torch.Tensor values, - // torch.Size sizes) - else if (num_args > 2 && THPIndexTensor_Check(first_arg)) { - PyObject *second_arg = PyTuple_GET_ITEM(args, 1); - PyObject *third_arg = PyTuple_GET_ITEM(args, 2); - if (!THPTensor_(Check)(second_arg)) goto invalid_arguments; - if (!THPSize_Check(third_arg)) goto invalid_arguments; - - THIndexTensor *indices = ((THPIndexTensor*)first_arg)->cdata; - THTensor *values = ((THPTensor*)second_arg)->cdata; - THLongStoragePtr sizes(THPUtils_unpackSize(third_arg)); - -#ifdef THC_GENERIC_FILE - THCAssertSameGPU(THSTensor_(checkGPU)(LIBRARY_STATE 0, 2, indices, values)); -#endif - - self->cdata = THSTensor_(newWithTensorAndSize)( - LIBRARY_STATE indices, values, sizes); - } - // torch.SparseTensor(int ...) - else if (THPUtils_tryUnpackLongVarArgs(args, 0, sizes)) { - self->cdata = THSTensor_(newWithSize)(LIBRARY_STATE sizes.get(), nullptr); - } - else goto invalid_arguments; // All other cases - } - - THSTensor_(initStorage)(self->cdata); - return (PyObject*)self.release(); - -invalid_arguments: - THPUtils_invalidArguments(args, NULL, THSPTensorStr " constructor", 6, - "no arguments", - "(int size)", - "(torch.Size sizes)", -#ifdef THC_GENERIC_FILE - "(torch.cuda.LongTensor indices, " THPTensorStr " values)", - "(torch.cuda.LongTensor indices, " THPTensorStr " values, torch.Size sizes)", -#else - "(torch.LongTensor indices, " THPTensorStr " values)", - "(torch.LongTensor indices, " THPTensorStr " values, torch.Size sizes)", -#endif - "(int ...)"); - return NULL; - END_HANDLE_TH_ERRORS -#undef THPIndexTensor_Check -#undef THPIndexTensor -#undef THIndexTensor -} - -// TODO: implement equality -PyTypeObject THSPTensorType = { - PyVarObject_HEAD_INIT(NULL, 0) - "torch._C.Sparse" THPTensorBaseStr, /* tp_name */ - sizeof(THSPTensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THSPTensor_(dealloc), /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0,//&THSPTensor_(mappingmethods), /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - NULL, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* will be assigned in init */ /* tp_methods */ - 0, /* will be assigned in init */ /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - THSPTensor_(pynew), /* tp_new */ -}; - -static struct PyMemberDef THSPTensor_(members)[] = { - {(char*)"_cdata", T_ULONGLONG, offsetof(THSPTensor, cdata), READONLY, NULL}, - {NULL} // Sentinel -}; - -typedef struct { - PyObject_HEAD -} THSPTensorStateless; - -PyTypeObject THSPTensorStatelessType = { - PyVarObject_HEAD_INIT(NULL, 0) - "torch._C.Sparse" THPTensorBaseStr ".stateless", /* tp_name */ - sizeof(THSPTensorStateless), /* tp_basicsize */ - 0, /* tp_itemsize */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved / tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - NULL, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - THSPTensor_stateless_(methods), /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - 0, /* tp_new */ - 0, /* tp_free */ - 0, /* tp_is_gc */ - 0, /* tp_bases */ - 0, /* tp_mro */ - 0, /* tp_cache */ - 0, /* tp_subclasses */ - 0, /* tp_weaklist */ -}; - -bool THSPTensor_(init)(PyObject *module) -{ - THSPTensorType.tp_methods = THSPTensor_(methods); - THSPTensorType.tp_members = THSPTensor_(members); - if (PyType_Ready(&THSPTensorType) < 0) - return false; - THSPTensorStatelessType.tp_new = PyType_GenericNew; - if (PyType_Ready(&THSPTensorStatelessType) < 0) - return false; - - PyModule_AddObject(module, THSPTensorBaseStr, (PyObject *)&THSPTensorType); - return true; -} - -bool THSPTensor_(postInit)(PyObject *module) -{ - THSPTensorClass = PyObject_GetAttrString(module, TH_CONCAT_STRING_2(Real,Tensor)); - if (!THSPTensorClass) return false; - bool is_cuda = false; -#ifdef THC_GENERIC_FILE - is_cuda = true; -#endif - const char *type_name = TH_CONCAT_STRING_2(Real,); - torch::registerPyTypeObject((PyTypeObject*)THSPTensorClass, type_name, is_cuda, true); - return true; -} diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 06f62dccd6..5de9e7e4d9 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -290,7 +290,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) size_t view_size = (size_t)THPUtils_unpackLong(_view_size); int64_t device = THPUtils_unpackLong(_device); - THCPAutoGPU __autogpu(device); + AutoGPU __autogpu(device); char *buffer; Py_ssize_t handle_size; diff --git a/torch/csrc/generic/Tensor.cpp b/torch/csrc/generic/Tensor.cpp deleted file mode 100644 index de02b2b255..0000000000 --- a/torch/csrc/generic/Tensor.cpp +++ /dev/null @@ -1,1799 +0,0 @@ -#ifndef TH_GENERIC_FILE - -#define TH_GENERIC_FILE "generic/Tensor.cpp" -#else - -#ifdef WITH_NUMPY - -#include "THHalf.h" - - -// COPY_FROM_ARRAY macros for Numpy -> TH assignment -// -// ELTYPE = data type of the Python array -// ARRAY = base pointer of the Python array -// STORAGE = pointer to a THStorage struct, type THStoragePtr -// SIZE = size in bytes of the array to copy -// CONVERSION = function to apply to each array element before assigning, for half<->float conversions - -#define COPY_FROM_ARRAY_CPU(ELTYPE, ARRAY, STORAGE, SIZE, CONVERSION) \ -{ \ - ELTYPE *arrdata = (ELTYPE*)PyArray_DATA(ARRAY); \ - real *data = STORAGE->data; \ - for (size_t i=0; i<SIZE; i++) { \ - data[i] = CONVERSION(arrdata[i]); \ - } \ -} - -#define COPY_FROM_HALF_ARRAY_CPU_HALF(ARRAY, STORAGE, SIZE) \ -{ \ - char *arrdata = (char*)PyArray_DATA(ARRAY); \ - memcpy(STORAGE->data, arrdata, SIZE * 2); \ -} - -#ifdef THC_REAL_IS_HALF -#define COPY_FROM_ARRAY_CUDA(ELTYPE, ARRAY, STORAGE, SIZE, CONVERSION) \ -{ \ - ELTYPE *arrdata = (ELTYPE*)PyArray_DATA(ARRAY); \ - std::unique_ptr<load_real[]> data_guard(new load_real[SIZE]); \ - load_real *data = data_guard.get(); \ - for (size_t i=0; i<SIZE; i++) { \ - data[i] = arrdata[i]; \ - } \ - THFloatStorage *cpu_storage = \ - THFloatStorage_newWithData(data_guard.get(), SIZE); \ - cpu_storage->flag &= ~TH_STORAGE_FREEMEM; \ - THCudaHalfStorage_copyFloat(LIBRARY_STATE STORAGE, cpu_storage); \ - THFloatStorage_free(cpu_storage); \ -} - -#else -#define COPY_FROM_ARRAY_CUDA(ELTYPE, ARRAY, STORAGE, SIZE, CONVERSION) \ -{ \ - ELTYPE *arrdata = (ELTYPE*)PyArray_DATA(ARRAY); \ - std::unique_ptr<load_real[]> data_guard(new load_real[SIZE]); \ - load_real *data = data_guard.get(); \ - for (size_t i=0; i<SIZE; i++) { \ - data[i] = CONVERSION(arrdata[i]); \ - } \ - THHostStorage *cpu_storage = \ - THHostStorage_(newWithData)(data_guard.get(), SIZE); \ - cpu_storage->flag &= ~TH_STORAGE_FREEMEM; \ - THCStorage_(copyCPU)(LIBRARY_STATE STORAGE, cpu_storage); \ - THHostStorage_(free)(cpu_storage); \ -} -#endif // THC_REAL_IS_HALF - -#define COPY_FROM_HALF_ARRAY_CUDA_HALF(ARRAY, STORAGE, SIZE) \ -{ \ - THHalf *arrdata = (THHalf*)PyArray_DATA(ARRAY); \ - THHostStorage *cpu_storage = \ - THHostStorage_(newWithData)(arrdata, SIZE); \ - cpu_storage->flag &= ~TH_STORAGE_FREEMEM; \ - THCStorage_(copyCPU)(LIBRARY_STATE STORAGE, cpu_storage); \ - THHostStorage_(free)(cpu_storage); \ -} - -#define IDENTITY(X) (X) - -// Fill in the conversions that we know at compile time (as determined by TH[C]_REAL_IS_HALF). -// We need to keep a COPY_FROM_HALF_ARRAY variant since we know the input type only at runtime. -#ifdef THC_GENERIC_FILE -#define COPY_FROM_ARRAY(ELTYPE, ARRAY, STORAGE, SIZE) COPY_FROM_ARRAY_CUDA(ELTYPE, ARRAY, STORAGE, SIZE, IDENTITY) -#ifdef THC_REAL_IS_HALF -#define COPY_FROM_HALF_ARRAY COPY_FROM_HALF_ARRAY_CUDA_HALF -#else -#define COPY_FROM_HALF_ARRAY(ARRAY, STORAGE, SIZE) COPY_FROM_ARRAY_CUDA(THHalf, ARRAY, STORAGE, SIZE, TH_half2float) -#endif // THC_REAL_IS_HALF -#else // THC_GENERIC_FILE -#ifdef TH_REAL_IS_HALF -#define COPY_FROM_ARRAY(ELTYPE, ARRAY, STORAGE, SIZE) COPY_FROM_ARRAY_CPU(ELTYPE, ARRAY, STORAGE, SIZE, TH_float2half) -#define COPY_FROM_HALF_ARRAY COPY_FROM_HALF_ARRAY_CPU_HALF -#else -#define COPY_FROM_ARRAY(ELTYPE, ARRAY, STORAGE, SIZE) COPY_FROM_ARRAY_CPU(ELTYPE, ARRAY, STORAGE, SIZE, IDENTITY) -#define COPY_FROM_HALF_ARRAY(ARRAY, STORAGE, SIZE) COPY_FROM_ARRAY_CPU(THHalf, ARRAY, STORAGE, SIZE, TH_half2float) -#endif // TH_REAL_IS_HALF -#endif // THC_GENERIC_FILE - -#endif // WITH_NUMPY - -PyObject *THPTensorClass = NULL; -THPCopyList THTensor_(copy_functions); - -PyObject * THPTensor_(NewEmpty)() -{ - return THPTensor_(New)(THTensor_(new)(LIBRARY_STATE_NOARGS)); -} - -PyObject * THPTensor_(New)(THTensor *tensor) -{ - THTensorPtr ptr(tensor); - if (!tensor->storage) { - tensor->storage = THStorage_(new)(LIBRARY_STATE_NOARGS); - } - PyTypeObject *type = (PyTypeObject *)THPTensorClass; - PyObject *obj = type->tp_alloc(type, 0); - if (obj) { - ((THPTensor *)obj)->cdata = ptr.release(); - } - return obj; -} - -static THTensor* THPTensor_(_new)() -{ - THTensorPtr tensor(THTensor_(new)(LIBRARY_STATE_NOARGS)); - if (!tensor->storage) { - tensor->storage = THStorage_(new)(LIBRARY_STATE_NOARGS); - } - return tensor.release(); -} - -static THTensor* THPTensor_(_newWithSize)(THLongStorage *size) -{ - THTensorPtr tensor(THTensor_(newWithSize)(LIBRARY_STATE size, NULL)); - // Ensure that PyTorch's "storage is not NULL" invariant is upheld - // See Note [Storage is not NULL] - if (!tensor->storage) { - tensor->storage = THStorage_(new)(LIBRARY_STATE_NOARGS); - } - return tensor.release(); -} - -static void THPTensor_(dealloc)(THPTensor* self) -{ - THTensor_(free)(LIBRARY_STATE self->cdata); - Py_TYPE(self)->tp_free((PyObject*)self); -} - -static std::string THPTensor_(indicesToString)(std::vector<size_t> &indices, - size_t depth) -{ - std::string index = "("; - for (size_t i = 0; i <= depth; ++i) { - index += std::to_string(indices[i]); - index += ", "; - } - index.erase(index.length()-2); // Remove trailing ", " - index += ")"; - return index; -} - -static void THPTensor_(setInconsistentDepthError)(std::vector<size_t> &sizes, - std::vector<size_t> &indices, size_t depth, size_t length) -{ - std::string error = "inconsistent sequence length at index "; - error += THPTensor_(indicesToString)(indices, depth); - error += " - expected "; - error += std::to_string(sizes[depth]); - error += " but got "; - error += std::to_string(length); - THPUtils_setError(error.c_str()); -} - -static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS - Py_ssize_t num_args = args ? PyTuple_Size(args) : 0; - - THPTensorPtr self((THPTensor *)type->tp_alloc(type, 0)); - if (!self) { - return NULL; - } - self->cdata = NULL; -#ifdef THC_GENERIC_FILE - THCPAutoGPU gpu_guard(args, NULL); -#endif - - // Internally we allow constructing with a keyword only argument cdata - if (kwargs != NULL) { - Py_ssize_t num_kwargs = PyDict_Size(kwargs); -#ifdef THC_GENERIC_FILE - PyObject *device_id = PyDict_GetItemString(kwargs, "device"); - if (device_id == Py_None) { - num_kwargs--; - } else if (device_id) { - THPUtils_assert(THPUtils_checkLong(device_id), "device argument " - " has to be an int, but got %s", THPUtils_typename(device_id)); - gpu_guard.setDevice(THPUtils_unpackLong(device_id)); - // simulate pop() and pretend this key was never there - num_kwargs--; - } -#endif - if (num_args == 0) { - PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata"); - if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) { - THTensor *ptr = (THTensor*)PyLong_AsVoidPtr(cdata_ptr); - self->cdata = ptr; - return (PyObject*)self.release(); - } - } - // This is an internal option, so we don't want to advertise it. -#ifdef THC_GENERIC_FILE - THPUtils_assert(num_kwargs == 0, THPTensorStr " constructor only " - "accepts a 'device' keyword argument") -#else - THPUtils_assert(num_kwargs == 0, THPTensorStr " constructor doesn't " - "accept any keyword arguments"); -#endif - } - - // torch.Tensor() - if (num_args == 0) { - self->cdata = THPTensor_(_new)(); - return (PyObject*)self.release(); - } - - PyObject *first_arg = PyTuple_GET_ITEM(args, 0); - - // torch.Tensor(torch.Tensor tensor) - if (num_args == 1 && THPTensor_(Check)(first_arg)) { - THTensor *tensor = ((THPTensor*)first_arg)->cdata; - self->cdata = THTensor_(newWithTensor)(LIBRARY_STATE tensor); - return (PyObject*)self.release(); - } - - // torch.Tensor(torch.Size sizes) - if (num_args == 1 && THPSize_Check(first_arg)) { - THLongStoragePtr sizes(THPUtils_unpackSize(first_arg)); - self->cdata = THPTensor_(_newWithSize)(sizes.get()); - return (PyObject *)self.release(); - } - - // TODO: implement storageOffset, sizes and strides - // torch.Tensor(torch.Storage data) - if (num_args == 1 && THPStorage_(Check)(first_arg)) { - THStorage *storage = ((THPStorage*)first_arg)->cdata; - self->cdata = THTensor_(newWithStorage1d)(LIBRARY_STATE storage, 0, storage->size, -1); - return (PyObject *)self.release(); - } - -#if defined(WITH_NUMPY) - // torch.Tensor(np.ndarray array) - if (num_args == 1 && PyArray_Check(first_arg)) { - auto tensor = torch::utils::tensor_from_numpy(first_arg); - tensor = tensor.toType(torch::getATenType(type)); - return torch::createPyObject(tensor); - } -#endif - - // torch.Tensor(Sequence data) - if (num_args == 1 && PySequence_Check(first_arg) && !THPVariable_Check(first_arg)) { - Py_ssize_t length = PySequence_Length(first_arg); - THPUtils_assert(length >= 0, "couldn't obtain the length of %s", - THPUtils_typename(first_arg)); - if (length == 0) { - self->cdata = THPTensor_(_new)(); - return (PyObject*)self.release(); - } - - Py_INCREF(first_arg); - THPObjectPtr item(first_arg); - std::vector<size_t> sizes; - while ((length = PySequence_Length(item)) >= 0) { - sizes.push_back(length); - // TODO: check for string in this case - THPUtils_assert(sizes.size() < 1000000, "already counted a million " - "dimensions in a given sequence. Most likely your items are also " - "sequences and there's no way to infer how many dimension should " - "the tensor have"); - THPUtils_assert(length > 0, "given sequence has an invalid size of " - "dimension %" PRId64 ": %" PRId64, (int64_t)sizes.size(), (int64_t)length); - item = PySequence_GetItem(item, 0); - if (!item) - return NULL; - } - // Last length check has set an error flag, so we need to clear it. - PyErr_Clear(); - - THLongStoragePtr sizes_storage(THLongStorage_newWithSize(sizes.size())); - int64_t *sizes_data = sizes_storage->data; - for (auto size: sizes) - *sizes_data++ = size; - THTensorPtr tensor(THTensor_(newWithSize)(LIBRARY_STATE sizes_storage, NULL)); - - int ndims = (int) sizes.size(); - std::vector<size_t> indices(ndims); - std::vector<THPObjectPtr> sequences(ndims); - Py_INCREF(first_arg); - item = first_arg; - for (size_t i = 0; i < sequences.size(); i++) { - PyObject *item_ptr = item.get(); - sequences[i] = std::move(item); - if (i < sequences.size()-1) { - item = PySequence_ITEM(item_ptr, 0); - if (!item) - return NULL; - } - } - - // half tensors don't have CPU counterparts so we have to buffer them as - // floats while loading -#ifndef THC_REAL_IS_HALF -#define load_real real -#define UNPACK_REAL(item) THPUtils_(unpackReal)(item) -#else -#define load_real float -#define UNPACK_REAL(item) THPFloatUtils_unpackReal(item) -#endif -#if !defined(THC_GENERIC_FILE) && !defined(THD_GENERIC_FILE) - real *data = tensor->storage->data; -#else - size_t numel = THTensor_(numel)(LIBRARY_STATE tensor); - std::unique_ptr<load_real[]> data_guard(new load_real[numel]); - load_real *data = data_guard.get(); -#endif - THPObjectPtr final_sequence; - while (true) { - final_sequence = std::move(sequences[ndims-1]); - try { - // We're taking a fast-track over the last dimension - for (size_t i = 0; i < sizes[ndims-1]; i++) { - indices[ndims-1] = i; - item = PySequence_ITEM(final_sequence, i); - // We've checked the length earlier, so it must have been an error - if (!item) - return NULL; - *data++ = UNPACK_REAL(item); - } - } catch(std::runtime_error &e) { - std::string index = THPTensor_(indicesToString)(indices, ndims-1); - THPUtils_setError("tried to construct a tensor from a %s%s sequence, " - "but found an item of type %s at index %s", - (ndims > 1 ? "nested " : ""), - THPUtils_typeTraits<real>::python_type_str, - THPUtils_typename(item.get()), - index.c_str()); - return NULL; - } -#ifdef THC_GENERIC_FILE -#ifdef THC_REAL_IS_HALF - THFloatStorage *cpu_storage = THFloatStorage_newWithData(data_guard.get(), numel); - cpu_storage->flag &= ~TH_STORAGE_FREEMEM; - THCudaHalfStorage_copyFloat(LIBRARY_STATE tensor->storage, cpu_storage); - THFloatStorage_free(cpu_storage); -#else - THHostStorage *cpu_storage = THHostStorage_(newWithData)(data_guard.get(), numel); - cpu_storage->flag &= ~TH_STORAGE_FREEMEM; - THCStorage_(copyCPU)(LIBRARY_STATE tensor->storage, cpu_storage); - THHostStorage_(free)(cpu_storage); -#endif -#endif -#undef UNPACK_REAL -#undef load_real - - // Update the counters - int dim = ndims-2; - size_t last_updated_dim = dim; - while (dim >= 0) { - last_updated_dim = dim; - if (++indices[dim] == sizes[dim]) - indices[dim--] = 0; - else - break; - } - // Check if we've just made a full cycle - if ((last_updated_dim == 0 && indices[0] == 0) || ndims == 1) - break; - // Update sequences - for (int i = last_updated_dim+1; i < ndims; i++) { - sequences[i] = PySequence_ITEM(sequences[i-1], indices[i-1]); - if (!sequences[i]) { - THPTensor_(setInconsistentDepthError)(sizes, indices, i, indices[i]); - return NULL; - } - if (!PySequence_Check(sequences[i])) { - std::string index_str = THPTensor_(indicesToString)(indices, i); - THPUtils_setError( - "an item of type %s at index %s doesn't implement a sequence protocol", - THPUtils_typename(sequences[i].get()), index_str.c_str()); - return NULL; - } - Py_ssize_t length = PySequence_Length(sequences[i]); - if (length < 0) { - std::string index_str = THPTensor_(indicesToString)(indices, i); - THPUtils_setError("could not obtain a length of %s at index %s", - THPUtils_typename(sequences[i].get()), index_str.c_str()); - return NULL; - } - if ((size_t)length != sizes[i]) { - THPTensor_(setInconsistentDepthError)(sizes, indices, i, length); - return NULL; - } - } - } - self->cdata = tensor.release(); - return (PyObject *)self.release(); - } - - // torch.Tensor(int ...) - THLongStoragePtr sizes; - if (THPUtils_tryUnpackLongVarArgs(args, 0, sizes)) { - self->cdata = THPTensor_(_newWithSize)(sizes.get()); - return (PyObject *)self.release(); - } - - THPUtils_invalidArguments(args, kwargs, THPTensorStr " constructor", 6, - "no arguments", - "(int ...)", - "(" THPTensorStr " viewed_tensor)", - "(torch.Size size)", - "(" THPStorageStr " data)", - "(Sequence data)"); - return NULL; - END_HANDLE_TH_ERRORS -} - -#ifdef WITH_NUMPY -#define IS_SCALAR(NAME) \ - ((is_long = THPUtils_checkLong(NAME)) || \ - (is_scalar_array = PyArray_CheckScalar(NAME))) -#define UNPACK_SCALAR(IDX_VARIABLE) \ - if (is_long) { \ - idx = THPUtils_unpackLong(IDX_VARIABLE); \ - } else { \ - PyArray_CastScalarToCtype(IDX_VARIABLE, &idx, NumpyLongArrDescr); \ - } -#else -#define IS_SCALAR(NAME) THPUtils_checkLong(NAME) -#define UNPACK_SCALAR(IDX_VARIABLE) idx = THPUtils_unpackLong(IDX_VARIABLE); -#endif - -#if defined(THC_GENERIC_FILE) -#define THIndexTensor THCudaLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME) -#define THPIndexTensor THCPLongTensor -#define THPIndexTensor_Check THCPLongTensor_Check -#define THPIndexTensorClass THCPLongTensorClass -#elif defined(THD_GENERIC_FILE) -#define THIndexTensor THDLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THDLongTensor_,NAME) -#define THPIndexTensor THDPLongTensor -#define THPIndexTensor_Check THDPLongTensor_Check -#define THPIndexTensorClass THDPLongTensorClass -#else -#define THIndexTensor THLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME) -#define THPIndexTensor THPLongTensor -#define THPIndexTensor_Check THPLongTensor_Check -#define THPIndexTensorClass THPLongTensorClass -#endif - -static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim, - THTensorPtr &tresult, THStorage* &sresult, int64_t &storage_offset) -{ -#ifdef WITH_NUMPY - static PyArray_Descr *NumpyLongArrDescr = PyArray_DescrFromType(NPY_INT64); - bool is_long, is_scalar_array; -#endif - // Indexing with a scalar - if(IS_SCALAR(index)) { - int64_t idx; - UNPACK_SCALAR(index); - int64_t dimsize = THTensor_(size)(LIBRARY_STATE tresult.get(), indexed_dim); - - // If the user provided negative idx, convert to positive equivalent - idx = (idx < 0) ? dimsize + idx : idx; - - if (dimsize <= 0) { - PyErr_SetString(PyExc_IndexError, "indexing an empty tensor"); - throw python_error(); - } - if (idx < 0 || idx >= dimsize) { - PyErr_Format(PyExc_IndexError, "index %lld is out of range for dimension " - "%lld (of size %lld)", (long long)idx, (long long)indexed_dim, (long long)dimsize); - throw python_error(); - } - - // If we are indexing a vector, set the storage to the storage underlying - // the vector, and the storage_offset to the location of the element at - // the specificed index. Otherwise, perform a selection - if(THTensor_(nDimension)(LIBRARY_STATE tresult.get()) == 1) { - sresult = tresult.get()->storage; - storage_offset = tresult->storageOffset + tresult->stride[0] * idx; - tresult = NULL; - } else { - THTensor_(select)(LIBRARY_STATE tresult.get(), NULL, indexed_dim, idx); - } - } else if (index == Py_None) { - // _indexOnce will never be called with tresult == NULL, except for a None index - // e.g. x = torch.Tensor(5); y = x[5, None] - if (!tresult) { - tresult = THTensor_(newWithStorage1d)(LIBRARY_STATE sresult, storage_offset, 1, 1); - sresult = NULL; - } else { - // Insert a singleton dimension at indexed_dim, then bump indexed_dim - THTensor_(unsqueeze1d)(LIBRARY_STATE tresult.get(), NULL, indexed_dim++); - } - // Indexing with a slice - } else if (PySlice_Check(index)) { - Py_ssize_t start, end, length, step; - if (!THPUtils_parseSlice(index, THTensor_(size)(LIBRARY_STATE tresult.get(), indexed_dim), &start, &end, &step, &length)) - throw python_error(); - if (step <= 0) { - PyErr_SetString(PyExc_ValueError, "slice step has to be greater than 0"); - throw python_error(); - } - if (length == 0) { - PyErr_SetString(PyExc_ValueError, "result of slicing is an empty tensor"); - throw python_error(); - } - // Modify the Tensor to point to the sliced components - tresult->storageOffset += tresult->stride[indexed_dim] * start; - tresult->stride[indexed_dim] *= step; - tresult->size[indexed_dim] = length; - indexed_dim++; - } else { - return false; - } - return true; -} - -#ifndef TH_REAL_IS_HALF - -static bool THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(PyObject *arg) { - if (PySequence_Check(arg) && !PyTuple_Check(arg)) { - auto fast = THPObjectPtr(PySequence_Fast(arg, NULL)); - for (Py_ssize_t i = 0; i < PySequence_Fast_GET_SIZE(fast.get()); ++i) { - if (!THPUtils_checkLong(PySequence_Fast_GET_ITEM(fast.get(), i))) - return false; - } - return true; - } - return false; -} - -static bool THPTensor_(_checkBasicIntegerArrayIndexing)(THPTensor *indexed, PyObject *arg) { - int64_t ndim = THTensor_(nDimension)(LIBRARY_STATE indexed->cdata); - - if (PySequence_Check(arg) && PySequence_Size(arg) == ndim) { - THPObjectPtr fast = THPObjectPtr(PySequence_Fast(arg, NULL)); - for (Py_ssize_t i = 0; i < ndim; ++i) { - PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i); - if (!THPIndexTensor_Check(item) && !PySequence_Check(item)) { - return false; - } - } - return true; - } - return false; -} - -static bool THPTensor_(_checkAdvancedIndexing)(THPTensor *indexed, PyObject *arg) { - // Currently we only support two forms of advanced indexing: - // - // 1. Indexing with a single non-tuple sequence, not nested within a sequence, - // that is composed only of integer indexers, e.g. x[[0, 1, 4]] - // 2. "Basic Integer Array Indexing" the integer-array indexing strategy - // where we have ndim sequence/LongTensor arguments - // 3. Combining Advanced Indexing with ":", or "..." , with the limitation that - // the advanced indexing dimensions must be adjacent, i.e.: - // - // x[:, :, [1,2], [3,4], :] --> valid - // x[[1,2], [3,4]] --> valid - // x[[1,2], [3,4], ...] --> valid - // x[:, [1,2], :, [3,4], :] --> not valid - - // Verification, Step #1 - single non-tuple sequencer - if (THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(arg)) return true; - - // Verification, Step #2 -- ndim sequencers - if (THPTensor_(_checkBasicIntegerArrayIndexing)(indexed, arg)) return true; - - // Verification, Step #3 -- at least one sequencer, all the rest are - // ':' and/or a single '...', can be less than ndim indexers, all sequencers - // adjacent - - int64_t ndim = THTensor_(nDimension)(LIBRARY_STATE indexed->cdata); - if (PySequence_Check(arg) && PySequence_Size(arg) <= ndim + 1) { - THPObjectPtr fast = THPObjectPtr(PySequence_Fast(arg, NULL)); - - bool sequenceFound = false; - bool nonColonEllipsisFound = false; - bool ellipsisFound = false; - Py_ssize_t lastSeqDim = -1; - - // Note that we can have ndim + 1 Tensors in the case where we have an ellipsis, - // because Python semantics allow it to be "thrown away" so to speak. If this is - // the case, we have to shift the dimension we are considering (in the indexed - // tensor) by -1 afer encountering the Ellipsis when accessing properties of - // the indexed Tensor - bool extraIndexer = PySequence_Fast_GET_SIZE(fast.get()) == ndim + 1; - - for (Py_ssize_t i = 0; i < PySequence_Fast_GET_SIZE(fast.get()); ++i) { - // see explanation above - int correspondingTensorDim = i + (extraIndexer && ellipsisFound ? -1 : 0); - - PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i); - if (THPIndexTensor_Check(item) || PySequence_Check(item)) { - sequenceFound = true; - - // non-adjacent sequencers not yet supported - if (i - 1 != lastSeqDim && lastSeqDim != -1) { - return false; - } - lastSeqDim = i; - - continue; - } - if (PySlice_Check(item)) { - int64_t dimSize = THTensor_(size)(LIBRARY_STATE indexed->cdata, correspondingTensorDim); - // Basically verify that the Slice is ':' and did not specify - // a specific start, end or step - Py_ssize_t start, end, length, step; - if (THPUtils_parseSlice(item, dimSize, &start, &end, &step, &length)) { - if (start != 0 || end != dimSize || step != 1 || length != dimSize) { - nonColonEllipsisFound = true; - break; - } - } - continue; - } - if (Py_TYPE(item) == &PyEllipsis_Type) { - if (ellipsisFound) { - // Can't have duplicate ellipsi - return false; - } - ellipsisFound = true; - continue; - } - nonColonEllipsisFound = true; - break; - } - - // Check if we have ndim+1 indexing objects, that we found an ellipsis - if (PySequence_Size(arg) == ndim + 1 && !ellipsisFound) { - return false; - } - - return sequenceFound && (!nonColonEllipsisFound); - } - return false; - - // Full NumPy advanced indexing requirements are coded up below. To fully support - // such indexing will require changes to the actual indexing logic, so we will - // leave this commented out as a reference - - /** - // Checks whether the specified selection object should trigger advanced - // indexing - - // Case 1: arg is a non-tuple sequence object - if (PySequence_Check(arg) && !PyTuple_Check(arg)) return true; - -#ifdef WITH_NUMPY - // Case 2: arg is an nd-array with type integer or bool - if (PyArray_Check(arg) && (PyArray_TYPE((PyArrayObject*)arg) == NPY_INT64 || PyArray_TYPE((PyArrayObject*)arg) == NPY_BOOL)) return true; -#endif - - // Case 3: arg is a tuple containing at least one sequence object, ndarray, or LongTensor - if (PyTuple_Check(arg)) { - for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) { - PyObject *item = PyTuple_GET_ITEM(arg, i); - if (PySequence_Check(item)) { - return true; - } -#ifdef WITH_NUMPY - if (PyArray_Check(item) && (PyArray_TYPE((PyArrayObject*)item) == NPY_INT64 || PyArray_TYPE((PyArrayObject*)item) == NPY_BOOL)) return true; -#endif - if (THPIndexTensor_Check(item)) return true; - } - } - - **/ -} - -// Exposed at the interpreter level -static PyObject* THPTensor_(checkAdvancedIndexing)(THPTensor *self, PyObject *arg) { - if (THPTensor_(_checkAdvancedIndexing)(self, arg)) { - Py_RETURN_TRUE; - } - Py_RETURN_FALSE; -} - -static bool THPTensor_(_convertToTensorIndexers)( - PyObject *index, - THTensorPtr& indexed, - Py_ssize_t& sequenceLength, - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted) { - - // At the top-level, each indexing element must be one of 3 things: - // - // 1. A LongTensor - // 2. A sequence that can be converted into a LongTensor - // 3. A empty slice object (i.e. ':') - // 4. An Ellipsis (i.e. '...') - // - // This function loops through all of the indexing elements. If we encounter - // a LongTensor, we record the dimension at which it occurs. If we encounter - // another sequence type, we attempt to convert it to a LongTensor, and record - // its position. - // - // Next, once we have all of the indexing Tensors, we attempt to broadcast them. - // If they can be broadcasted, we store each of the broadcasted Tensors in the - // output map, with the dimension of the original tensor as the key. - - // indexingDims stores the indices containing an advanced index sequence, and indexers - // stores the corresponding indexing object, such that the indexer at indexers[i] is - // associated with the dm at indexingDims[i]. This is pre-broadcast. Because we rely - // upon the THPIndexTensor constructor to handle sequence -> tensor conversions, we - // store THPTensors rather than THTensors. - - std::vector<Py_ssize_t> indexingDims; - std::vector<THPPointer<THPIndexTensor>> indexers; - - if (THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(index)) { - // Handle the special case where we only have a single indexer - THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs( - THPIndexTensorClass, index, 0, NULL); - if (!indexer) { - PyErr_Format(PyExc_IndexError, - "When performing advanced indexing the indexing objects must be LongTensors or " - "convertible to LongTensors"); - return false; - } - indexingDims.push_back(0); - indexers.push_back(THPPointer<THPIndexTensor>(indexer)); - } else { - // The top-level indexer should be a sequence, per the check above - THPObjectPtr fast(PySequence_Fast(index, NULL)); - sequenceLength = PySequence_Fast_GET_SIZE(fast.get()); - int ellipsisOffset = 0; - - for (Py_ssize_t i = 0; i < sequenceLength; ++i) { - PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i); - - // If this is an ellipsis, the all subsequent advanced indexing - // objects "positions" should be shifted, e.g. if we have a 5D Tensor - // x, and then x[..., [2, 3]], then the "position" of [2, 3] is 4, - // - // BUT ONLY IF, we don't have ndim other indexing objects, in which case - // the ellipsis creates a shift of -1 to counterbalance its "emptyness" - if (Py_TYPE(item) == &PyEllipsis_Type) { - if (sequenceLength != (THTensor_(nDimension)(LIBRARY_STATE indexed) + 1)) { - ellipsisOffset = THTensor_(nDimension)(LIBRARY_STATE indexed) - sequenceLength; - } else { - ellipsisOffset = -1; - } - continue; - } - - if (!PySlice_Check(item)) { - PyObject *obj = PySequence_Fast_GET_ITEM(fast.get(), i); - // Returns NULL upon conversion failure - THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs( - THPIndexTensorClass, obj, NULL); - if (!indexer) { - PyErr_Format(PyExc_IndexError, - "When performing advanced indexing the indexing objects must be LongTensors or " - "convertible to LongTensors. The indexing object at position %zd is of type %s " - "and cannot be converted", i, THPUtils_typename(obj)); - - return false; - } - indexingDims.push_back(i + ellipsisOffset); - indexers.push_back(THPPointer<THPIndexTensor>(indexer)); - } - } - } - - // Next, we need to verify that the Tensors are broadcastable. Keep these - // as raw pointer vectors - std::vector<THIndexTensor*> maybeBroadcasted; - std::vector<THIndexTensor*> candidates; - - // Extract the underlying Tensors for use in the expansion API call - for (const auto& indexer : indexers) { - maybeBroadcasted.emplace_back(THIndexTensor_(new)(LIBRARY_STATE_NOARGS)); - // borrow the underlying Tensor from the indexer map - candidates.emplace_back(indexer.get()->cdata); - } - - // Broadcast/Expand indexing Tensors as necessary - try { - THIndexTensor_(expandNd)(LIBRARY_STATE maybeBroadcasted.data(), candidates.data(), maybeBroadcasted.size()); - - // Broadcast succeeded, place Broadcasted Tensors into output map by the index at - // which they occurred, transferring ownership to that map object - for (unsigned int i = 0; i < indexingDims.size(); ++i) { - THPPointer<THIndexTensor> owned(maybeBroadcasted[i]); - broadcasted[indexingDims[i]] = std::move(owned); - } - - // Next, before doing any further work, we want to verify that all the indices - // are in bounds at each advanced index dimension. This occurs only on the CPU, - // as point gets on CUDA Tensors would be slow. CUDA out of bounds errors - // will trigger a device-side assert - -#if !defined(THC_GENERIC_FILE) - ptrdiff_t nElement = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get()); - THLongStoragePtr viewer(THLongStorage_newWithSize(1)); - THLongStorage_set(viewer.get(), 0, nElement); - for (auto& dimBroadcast : broadcasted) { - Py_ssize_t dim = dimBroadcast.first; - int64_t sizeAtDim = THTensor_(size)(LIBRARY_STATE indexed, dim); - - // Need to make contiguous to view as 1D :/ - THPPointer<THIndexTensor> contig(THIndexTensor_(newContiguous)(LIBRARY_STATE dimBroadcast.second.get())); - - // View as 1D + get1D makes me sad :( - THPPointer<THIndexTensor> flat(THIndexTensor_(newView)(LIBRARY_STATE contig.get(), viewer)); - for (ptrdiff_t i = 0; i < THIndexTensor_(nElement)(LIBRARY_STATE flat.get()); ++i) { - int64_t indexAtDim = THTensor_fastGet1d(flat.get(), i); - if (indexAtDim >= sizeAtDim) { - PyErr_Format(PyExc_IndexError, "index %lld from broadcast indexer is out of range " - "for dimension %lld (of size %lld)", - (long long)indexAtDim, (long long)dim, (long long)sizeAtDim); - - - return false; - } - } - } -#endif - } catch (std::exception& e) { - // Broadcasted failed, cleanup and return error. I'm not sure if there is a better - // way to do this where we don't have to manually clean up the memory - for (const auto& tensor : maybeBroadcasted) { - THIndexTensor_(free)(LIBRARY_STATE tensor); - } - PyErr_Format(PyExc_IndexError, "The advanced indexing objects could not be broadcast"); - - return false; - } - - return true; -} - -static inline int64_t THPTensor_(_indexToOffset)( - THTensorPtr& indexed, - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted, - ptrdiff_t index) -{ - // We need to translate an "index" into a linear offset within the Tensor indexed. - // We will perform the normal mod/divide loop, except in the case of an advance indexed - // dimension, we need to take special care to utilize the size and subset of indices - // specified by the Tensor at the advanced indexed dimension. We hereafter refer to - // this as the "broadcast" dimension, although in the case of a single indexer, the - // broadcast op is pretty much a no-op. - // - // For example, suppose we have a three-dimensional Tensor x of shape (5, 10, 15), - // and our indexing operation is x[:, (2, 4, 5), :]. - // - // For Linear Index 32: - // - // dim = 2 (size = 15): 32 % 15 = 2; 32 / 15 = 2 - // dim = 1 (size = 3): 2 % 3 = 2; 2 / 3 = 0 - // dim = 0 (size = 5): 0 % 5 = 0; end - // - // So we have selected the index (0, 2, 2). Now for the strides calculation. For the - // non-broadcast dimensions, we simply do the index * the stride. But for the broadcast - // dimension we need to get the corresponding subset index (i.e., pick from (2, 4, 5)) - // and use that before multiplying by the stride at that dimension. - // - // (assumes that x is contiguous) - // - // dim = 2 (stride = 1): 2 * stride = 2, offset = 2 - // dim = 1 (stride = 15): (broadcast[2] = 5) * stride = 75, offset = 77 - // dim = 0 (stride = 75): 0 * stride = 0, offset = 77 - // - // So we can see how this works. - // - // The other complication occurs when we have more than one advanced indexer. Consider - // the case: - // - // x = torch.Tensor(3, 4, 6, 3) - // x.stride = (72, 18, 3, 1) - // x[:, [0, 1], [2, 3], :] - // - // Because the advanced indexers are broadcast and iterated as one, we need to apply - // the same index in each of the advanced indexing dimensions. When we reach an advanced - // indexing element, we look to see if the next dimension we will consider is also part - // of the advanced indexing. If it is, we maintain the index: - // - // For Linear Index 16: - // - // dim = 3 (size = 3): 16 % 3 = 1; 16 / 3 = 5 - // dim = 2 (size = 2): 5 % 2 = 1; Do Not Update Index - // dim = 1 (size = 2): 5 % 2 = 1; 5 / 2 = 2 - // dim = 0 (size = 3): 2 % 3 = 2; end - // - // Then for the offsets: - // - // dim = 3 (stride = 1): 1 * stride = 1, offset: 1 - // dim = 2 (stride = 3): [2, 3][1] = 3 * stride = 9, offset = 10 - // dim = 1 (stride = 18): [0, 1][1] = 1 * stride = 18, offset = 28 - // dim = 0 (stride = 72): 2 * stride = 144, offset = 172 - // - // Special care needs to be taken to handle advanced indexers at the beginning, end. - - int64_t offset = 0; - for (int64_t i = THTensor_(nDimension)(LIBRARY_STATE indexed) - 1; i >= 0; --i) { - // Get size at dimension i, its the size of the indexed Tensor at that dimension if its - // not an advanced indexing dimension, otherwise its the size of the broadcast Tensor - ptrdiff_t sizeAtDim, indexAtDim, nextIndex; - int64_t strideAtDim = THTensor_(stride)(LIBRARY_STATE indexed, i); - - auto broadcast = broadcasted.find(i); - if (broadcast != broadcasted.end()) { - sizeAtDim = THIndexTensor_(nElement)(LIBRARY_STATE broadcast->second.get()); - indexAtDim = THTensor_fastGet1d(broadcast->second.get(), index % sizeAtDim); - - if (i > 0 && broadcasted.find(i - 1) != broadcasted.end()) { - nextIndex = index; - } else { - nextIndex = index / sizeAtDim; - } - } else { - sizeAtDim = THTensor_(size)(LIBRARY_STATE indexed, i); - indexAtDim = index % sizeAtDim; - nextIndex = index / sizeAtDim; - } - - offset += indexAtDim * strideAtDim; - index = nextIndex; - } - - // size at dim is a bad name, because its really the number of elements in the - // broadcast tensor, rather than the size of the indexed Tensor at that dim - - return offset; -} - -// Caller takes ownership of the returned IndexTensor -static THIndexTensor* THPTensor_(_calculateLinearIndices)( - THTensorPtr& indexed, - Py_ssize_t sequenceLength, - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted) { - - // Get the number of indices to generate - this is the product of the size at each dimension, - // that is not part of the advanced indexing, multiplied by the nElement of one of the broadcast - // Tensors. For example: - // - // x = torch.Tensor(10) - // x[[0, 2, 4], ] --> no dims not part of indexing, size = 3 - // - // x = torch.Tensor(5, 5) - // x[[0, 3, 3], [1]] --> no dims not part of indexing, size = 3 - // x[:, [2, 3]] --> dim_0 not part of indexing, size = 5 - // --> multiply by nElement of broadcast Tensor, nElement = 2 - // --> total_size = 10 - // - // x = torch.Tensor(5, 5, 5) - // x[[0, 1], :, :] --> dim_1, dim_2 not part of indexing, size = 5 * 5 = 25 - // --> multiply by nElement of broadcast Tensor, nElement = 2 - // --> total_size = 50 - - // TODO: should this be 1? what if there are no things to index? ???? - ptrdiff_t indexingElements = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get()); - for (Py_ssize_t i = 0; i < THTensor_(nDimension)(LIBRARY_STATE indexed.get()); ++i) { - indexingElements *= broadcasted.find(i) != broadcasted.end() ? - 1 : THTensor_(size)(LIBRARY_STATE indexed.get(), i); - } - - // The broadcasted advanced indexing tensor might not be one-dimensional, but we are - // generating a vector of indices, so we need to view the indexer as 1D prior to getting - // the value for the particular dimension. - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> flattenedBroadcasters; - THLongStorage *indexerSize = THLongStorage_newWithSize(1); - - // All broadcast Tensors have the same number of elements - ptrdiff_t dimIndexingElements = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get()); - THLongStorage_set(indexerSize, 0, dimIndexingElements); - - for (auto& broadcast : broadcasted) { - THIndexTensor *contig = THIndexTensor_(newContiguous)(LIBRARY_STATE broadcast.second.get()); - THPPointer<THIndexTensor> flat(THIndexTensor_(newView)(LIBRARY_STATE contig, indexerSize)); - flattenedBroadcasters[broadcast.first] = std::move(flat); - THIndexTensor_(free)(LIBRARY_STATE contig); - } - THLongStorage_free(indexerSize); - -#ifdef THC_GENERIC_FILE - // Call GPU kernel for index calculation - THCudaLongTensor *cudaIndices = - THCudaLongTensor_newWithSize1d(LIBRARY_STATE indexingElements); - int64_t baseOffset = THTensor_(storageOffset)(LIBRARY_STATE indexed); - - // Need to pass broadcast Tensors to API, pass NULL ptr for all empty - // (i.e. not-advanced indexed) dims - std::vector<THCudaLongTensor *> indexers( - THTensor_(nDimension)(LIBRARY_STATE indexed.get()), NULL); - - for (int i = 0; i < THTensor_(nDimension)(LIBRARY_STATE indexed.get()); ++i) { - if (flattenedBroadcasters.count(i) > 0) { - indexers[i] = flattenedBroadcasters[i].get(); - } - } - - THTensor_(calculateAdvancedIndexingOffsets)(LIBRARY_STATE cudaIndices, indexed, baseOffset, indexers.data()); - - return cudaIndices; -#else - THIndexTensor *linearIndices = THIndexTensor_(newWithSize1d)(LIBRARY_STATE indexingElements); - int64_t baseOffset = THTensor_(storageOffset)(LIBRARY_STATE indexed); - for (ptrdiff_t i = 0; i < indexingElements; ++i) { - int64_t linearIdx = THPTensor_(_indexToOffset)( - indexed, flattenedBroadcasters, i); - THTensor_fastSet1d(linearIndices, i, baseOffset + linearIdx); - } - return linearIndices; -#endif -} - -static bool THPTensor_(_advancedIndexCommonInit)( - PyObject *index, - THTensorPtr &indexed, - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted, - THIndexTensor **linearIndices, - THTensor **flattened) { - - // Precondition: index is an object that specifies advanced indexing. - // For now, we only support the simple integer-array indexing strategy - // where there are ndim(self) indexing sequences/LongTensors that can be - // broadcasted and iterated as one - // Precondition: tresult points to the Tensor we are indexing, and is also where - // we will store the output Tensor - - // First attempt to convert to Tensor indexers from the arbitrary - // python/tensor objects passed - - Py_ssize_t sequenceLength; - if (!THPTensor_(_convertToTensorIndexers)(index, indexed, sequenceLength, broadcasted)) { - return false; - } - - // At this point broadcasted should store our indexing Tensors. - // Our strategy is to view the indexed Tensor as a 1D Tensor, calculate - // the linear indices for each tuple of indexing elements, and then call - // indexSelect using those linear indices - *linearIndices = THPTensor_(_calculateLinearIndices)(indexed, sequenceLength, broadcasted); - - *flattened = THTensor_(newWithStorage1d)(LIBRARY_STATE - THTensor_(storage)(LIBRARY_STATE indexed.get()), - 0, - THStorage_(size)(LIBRARY_STATE - THTensor_(storage)(LIBRARY_STATE indexed.get())), - 1); - - return true; -} - -// Should called, written in such a way that if any of the parameters are not -// initialized we still don't crash -static void THPTensor_(_advancedIndexCommonCleanup)( - THIndexTensor *linearIndices, - THTensor *flattened) { - if (linearIndices) THIndexTensor_(free)(LIBRARY_STATE linearIndices); - if (flattened) THTensor_(free)(LIBRARY_STATE flattened); -} - -static bool THPTensor_(_advancedIndexGet)(PyObject *index, THTensorPtr &tresult) -{ - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted; - THIndexTensor *linearIndices = NULL; - THTensor *flattened = NULL; - bool success = THPTensor_(_advancedIndexCommonInit)( - index, tresult, broadcasted, &linearIndices, &flattened); - - if (success) { - THTensor *result = THTensor_(new)(LIBRARY_STATE_NOARGS); - - // Index Select makes a copy of the storage, thus it is enforcing NumPy semantics, which - // says that the array returned by advanced indexing is a copy, not a view - THTensor_(indexSelect)(LIBRARY_STATE result, flattened, 0, linearIndices); - - // Finally, we need to calculate the appropriate shape of the output Tensor - // The size at each dimension is unmodified from the input Tensor, except where - // there are advanced indexers. In this case, the n dimensions containing adjacent - // advanced indexers are reshaped to be the size of the broadcast indexer. - // - // Example, x = torch.Tensor(5, 10, 15) - // - // x[[0, 2, 4], [2, 3, 4], [1, 1, 2]] - // - // Broadcast Advanced Indexer Size: 1D Tensor of Size 3 - // Result Size: 1D Tensor of Size 3 - // - // x[:, [2, 4, 5], :] - // Broadcast Advanced Indexer Size: 1D Tensor of Size 3 - // Result Size: (5, 3, 15) - // - // x[:, [[0, 0], [1, 2]], [[1, 3], [2, 4]]] - // Broadcast Advanced Indexer Size: 2D Tensor (2, 2) - // Result Size: (5, 2, 2) - // - // x[:, [[1, 2, 3], [2, 3, 4]], :] - // Broadcast Advanced Indexer Size: 2D Tensor of Size (2, 3) - // Result Size: (5, 2, 3, 15) - - // First, calculate the number of dimensions of the output shape. This is the - // number of non-advanced indexed dimensions + the number of dimensions in the - // broadcast Tensor - int baseDims = THTensor_(nDimension)(LIBRARY_STATE tresult.get()) - broadcasted.size(); - - // Fast path, if we have ndim advanced indexers, the output shape is simply the - // broadcast shape - if (baseDims == 0) { - auto iter = broadcasted.begin(); - THTensor_(resizeNd)(LIBRARY_STATE result, - THIndexTensor_(nDimension)(LIBRARY_STATE iter->second.get()), - iter->second.get()->size, - NULL); - } else { - // We have at least one dimension that is not part of advanced indexing. This - // implementation is pretty much shit, there might be a better way of doing this... - THIndexTensor *broadcastShape = broadcasted.begin()->second.get(); - - int indexedDims = THIndexTensor_(nDimension)(LIBRARY_STATE broadcastShape); - THLongStorage *outputShape = THLongStorage_newWithSize(baseDims + indexedDims); - - int baseDimPtr = 0; - int outputDimPtr = 0; - bool insertedSubspace = false; - while (outputDimPtr != baseDims + indexedDims) { - auto iter = broadcasted.find(baseDimPtr); - if (iter == broadcasted.end()) { - outputShape->data[outputDimPtr] = THTensor_(size)(LIBRARY_STATE tresult.get(), baseDimPtr); - ++baseDimPtr; - ++outputDimPtr; - } else if (!insertedSubspace) { - for (int dim = 0; dim < indexedDims; ++dim) { - outputShape->data[outputDimPtr] = THIndexTensor_(size)(LIBRARY_STATE iter->second.get(), dim); - ++outputDimPtr; - } - insertedSubspace = true; - } else { - // ignore - ++baseDimPtr; - } - } - - THTensor_(resizeNd)(LIBRARY_STATE result, - baseDims + indexedDims, - outputShape->data, - NULL); - - THLongStorage_free(outputShape); - } - - // result ptr takes ownership of result tensor, and implicitly frees the - // indexed one - tresult = result; - } - - THPTensor_(_advancedIndexCommonCleanup)(linearIndices, flattened); - return success; -} - -static bool THPTensor_(_advancedIndexSet)(PyObject *index, THTensorPtr &dest, PyObject *src) -{ - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted; - THIndexTensor *linearIndices = NULL; - THTensor *flattened = NULL; - bool success = THPTensor_(_advancedIndexCommonInit)( - index, dest, broadcasted, &linearIndices, &flattened); - - if (success) { - if (THPUtils_(checkReal)(src)) { - real v = THPUtils_(unpackReal)(src); - THTensor_(indexFill)(LIBRARY_STATE flattened, 0, linearIndices, v); - } else if (THPTensor_(Check)(src)) { - // Because we are doing an index copy, we need to make sure of two things: - // 1. the src Tensor is 1D and - // 2. the src is made contiguous before being flattened into a 1D view, if - // necessary - - THTensor *contiguous = THTensor_(newContiguous)(LIBRARY_STATE ((THPTensor*)src)->cdata); - THTensor *cviewed = THTensor_(newWithStorage1d)(LIBRARY_STATE - THTensor_(storage)(LIBRARY_STATE contiguous), - THTensor_(storageOffset)(LIBRARY_STATE contiguous), - THTensor_(nElement)(LIBRARY_STATE contiguous), - 1); - - THTensor_(indexCopy)(LIBRARY_STATE flattened, 0, linearIndices, cviewed); - THTensor_(free)(LIBRARY_STATE contiguous); - THTensor_(free)(LIBRARY_STATE cviewed); - } else { - THPUtils_setError("can't assign %s to a " THPTensorStr " using a LongTensor " - "(only " THPTensorStr " or %s are supported)", - THPUtils_typename(src), THPUtils_typeTraits<real>::python_type_str); - success = false; - } - } - - THPTensor_(_advancedIndexCommonCleanup)(linearIndices, flattened); - return success; -} - -static bool THPTensor_(_advancedIndexAdd)(PyObject *index, THTensorPtr &dest, THTensorPtr &src) { - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted; - THIndexTensor *linearIndices = NULL; - THTensor *flattened = NULL; - bool success = THPTensor_(_advancedIndexCommonInit)( - index, dest, broadcasted, &linearIndices, &flattened); - - if (success) { - // Verify src tensor is contiguous before flattening - THTensor *contiguous = THTensor_(newContiguous)(LIBRARY_STATE src); - THTensor *cviewed = THTensor_(newWithStorage1d)(LIBRARY_STATE - THTensor_(storage)(LIBRARY_STATE contiguous), - THTensor_(storageOffset)(LIBRARY_STATE contiguous), - THTensor_(nElement)(LIBRARY_STATE contiguous), - 1); - - THTensor_(indexAdd)(LIBRARY_STATE flattened, 0, linearIndices, cviewed); - THTensor_(free)(LIBRARY_STATE contiguous); - THTensor_(free)(LIBRARY_STATE cviewed); - } - - THPTensor_(_advancedIndexCommonCleanup)(linearIndices, flattened); - return success; -} - -static bool THPTensor_(_advancedIndexSelect)(PyObject *index, THTensorPtr &dest, THTensorPtr &src) { - std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted; - THIndexTensor *linearIndices = NULL; - THTensor *flattened = NULL; - bool success = THPTensor_(_advancedIndexCommonInit)( - index, src, broadcasted, &linearIndices, &flattened); - - if (success) { - THTensor_(indexSelect)(LIBRARY_STATE dest, flattened, 0, linearIndices); - } - - THPTensor_(_advancedIndexCommonCleanup)(linearIndices, flattened); - return success; -} - -// Needed for autograd to support twice differentiable indexing -static PyObject* THPTensor_(advancedIndexAdd)(THPTensor *self, PyObject *args) { - HANDLE_TH_ERRORS - - THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "advancedIndexAdd takes exactly two " - "arguments (%d given)", (int) PyTuple_GET_SIZE(args)); - - THPUtils_assert(THPTensor_(_checkAdvancedIndexing)(self, PyTuple_GET_ITEM(args, 0)), - "first argument must be an indexer that triggers advanced indexing"); - - THPUtils_assert(THPTensor_(Check)(PyTuple_GET_ITEM(args, 1)), "Second argument " - "must be a Tensor"); - - THTensorPtr gradOutput(THTensor_(newWithTensor)( - LIBRARY_STATE ((THPTensor *)PyTuple_GET_ITEM(args, 1))->cdata)); - THTensorPtr dest(THTensor_(newWithTensor)(LIBRARY_STATE self->cdata)); - - bool success = THPTensor_(_advancedIndexAdd)(PyTuple_GET_ITEM(args, 0), dest, gradOutput); - if (!success) { - return NULL; - } - - Py_INCREF(self); - return (PyObject *)self; - END_HANDLE_TH_ERRORS -} - -// Needed for autograd to support backwards passes when there are overlapping -// indices -static PyObject* THPTensor_(advancedIndexSelect)(THPTensor *self, PyObject *args) { - HANDLE_TH_ERRORS - - THPUtils_assert(PyTuple_GET_SIZE(args) == 1, "advancedIndexSelect takes exactly one " - "argument (%d given)", (int) PyTuple_GET_SIZE(args)); - - THPUtils_assert(THPTensor_(_checkAdvancedIndexing)(self, PyTuple_GET_ITEM(args, 0)), - "first argument must be an indexer that triggers advanced indexing"); - - THTensorPtr dest(THTensor_(new)(LIBRARY_STATE_NOARGS)); - THTensorPtr src(THTensor_(newWithTensor)(LIBRARY_STATE self->cdata)); - - bool success = THPTensor_(_advancedIndexSelect)(PyTuple_GET_ITEM(args, 0), dest, src); - if (!success) { - return NULL; - } - - return THPTensor_(New)(dest.release()); - END_HANDLE_TH_ERRORS -} - -#endif // TH_REAL_IS_HALF - -// Handles indexing into a Tensor given a tuple, ellipses, sequence, etc. index -static bool THPTensor_(_index)(THPTensor *self, PyObject *index, - THTensorPtr &tresult, THStorage * &sresult, int64_t &storage_offset) -{ - // As a base case, we create a new Tensor that is a copy of the Tensor - // we are indexing - tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata); - sresult = NULL; - int indexed_dim = 0; - int invalid_indexer_dim = 0; - - if(PyTuple_Check(index)) { - // num_indexers is the number of indexing objects in the tuple, num_effective_indexers - // is the number of non-None, non-ellipses indexing objects - int64_t num_indexers = (int64_t)PyTuple_Size(index); - int64_t num_effective_indexers = num_indexers; - int64_t num_tensor_dim = THTensor_(nDimension)(LIBRARY_STATE self->cdata); - int64_t ellipsis_pos = -1; - for (int i = 0; i < num_indexers; i++) { - PyObject *indexer = PyTuple_GET_ITEM(index, i); - if (indexer == Py_Ellipsis) { - if (ellipsis_pos != -1) throw std::runtime_error("ellipsis can be used at most once"); - ellipsis_pos = i; - num_effective_indexers--; - } - if (indexer == Py_None) { - num_effective_indexers--; - } - } - if (num_effective_indexers > num_tensor_dim) { - PyErr_Format(PyExc_IndexError, - "trying to index %" PRId64 " dimensions of a %" PRId64 " dimensional tensor", - num_effective_indexers, num_tensor_dim); - return false; - } - - // Loop through the indices and perform the indiviudal indexing at each dim - bool valid = true; - for (int indexTuplePos = 0; indexTuplePos < num_indexers; indexTuplePos++) { - if (indexTuplePos == ellipsis_pos) { - // tresult can be NULL if ellipsis is the last item - // Note that the presence of the Ellipsis shifts the "indexed" dim by the number - // of dimensions minus the number of effective indexers - if (tresult) indexed_dim += (num_tensor_dim - num_effective_indexers); - continue; - } - PyObject *indexer = PyTuple_GET_ITEM(index, indexTuplePos); - valid = THPTensor_(_indexOnce)(indexer, indexed_dim, tresult, sresult, storage_offset); - if (!valid) { - tresult = NULL; - // overwrite this, so the message mentions the incorrect object - index = indexer; - invalid_indexer_dim = indexTuplePos; - break; - } - } - if (valid) return true; - } else if (index == Py_Ellipsis) { - // The result of indexing with an ellipsis only is just the entire existing - // Tensor - return true; - } else { - // index is a scalar, perform the indexing once on the 0th-dimension - if (THPTensor_(_indexOnce)(index, indexed_dim, tresult, sresult, storage_offset)) - return true; - } - - PyErr_Format(PyExc_TypeError, - "Performing basic indexing on a tensor and encountered an error indexing dim %d " - "with an object of type %s. The only supported types are integers, slices, " -#ifdef WITH_NUMPY - "numpy scalars, " -#endif - "or if indexing with a " -#ifndef THC_GENERIC_FILE - "torch.LongTensor or torch.ByteTensor " -#else - "torch.cuda.LongTensor or torch.cuda.ByteTensor " -#endif - "only a single Tensor may be passed.", - invalid_indexer_dim, THPUtils_typename(index)); - return false; -} -#undef IS_SCALAR -#undef UNPACK_SCALAR - -template<bool force_tensor> -static PyObject * THPTensor_(getValue)(THPTensor *self, PyObject *index) -{ - HANDLE_TH_ERRORS - -#ifndef TH_REAL_IS_HALF -#if defined(THC_GENERIC_FILE) - THCPByteTensor *mask = THCPByteTensor_Check(index) ? (THCPByteTensor*)index : NULL; - THCPAutoGPU __gpu_guard(NULL, (PyObject*)self); -#elif defined(THD_GENERIC_FILE) - THDPByteTensor *mask = THDPByteTensor_Check(index) ? (THDPByteTensor*)index : NULL; -#else - THPByteTensor *mask = THPByteTensor_Check(index) ? (THPByteTensor*)index : NULL; -#endif - if (mask) { - THTensorPtr t(THTensor_(new)(LIBRARY_STATE_NOARGS)); - THTensor_(maskedSelect)(LIBRARY_STATE t.get(), self->cdata, mask->cdata); - return THPTensor_(New)(t.release()); - } - if (THPIndexTensor_Check(index)) { - THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata; - - // TH will also throw an error, but its a Runtime Error that is less interpretable - // than doing it at this layer - if (THIndexTensor_(nDimension)(LIBRARY_STATE index_t) > 1) { - PyErr_Format(PyExc_IndexError, "Indexing a Tensor with a " -#ifndef THC_GENERIC_FILE - "torch.LongTensor " -#else - "torch.cuda.LongTensor " -#endif - "triggers index_select semantics, and thus we expect an empty tensor or a vector, " - "but the indexing Tensor passed has %lld dimensions", - (long long) THIndexTensor_(nDimension)(LIBRARY_STATE index_t)); - throw python_error(); - } - - THTensorPtr index_result(THTensor_(new)(LIBRARY_STATE_NOARGS)); - THTensor_(indexSelect)(LIBRARY_STATE index_result.get(), self->cdata, 0, index_t); - return THPTensor_(New)(index_result.release()); - } -#endif - - THTensorPtr tresult; - THStorage *sresult; - int64_t storage_offset; - - // Check and see if the indexing object triggers advanced indexing semantics -#ifndef TH_REAL_IS_HALF - if (THPTensor_(_checkAdvancedIndexing)(self, index)) { - tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata); - if (!THPTensor_(_advancedIndexGet)(index, tresult)) { - return NULL; - } - // TODO: needed? - return THPTensor_(New)(tresult.release()); - } -#endif // TH_REAL_IS_HALF - - if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset)) - return NULL; - if (tresult) - return THPTensor_(New)(tresult.release()); - if (sresult) { - if (force_tensor) { - return THPTensor_(New)(THTensor_(newWithStorage1d)(LIBRARY_STATE sresult, storage_offset, 1, -1)); - } else { - return THPUtils_(newReal)(THStorage_(get)(LIBRARY_STATE sresult, storage_offset)); - } - } - THPUtils_setError("An unknown error has occurred when indexing a tensor " - "in THPTensor_(getValue). Please report this in a github issue at: " - "https://github.com/pytorch/pytorch"); - return NULL; - END_HANDLE_TH_ERRORS -} - -template<bool force_tensor> -static int THPTensor_(setValue)(THPTensor *self, PyObject *index, PyObject *value) -{ - HANDLE_TH_ERRORS - -#ifndef TH_REAL_IS_HALF -#if defined(THC_GENERIC_FILE) - THCPByteTensor *mask = THCPByteTensor_Check(index) ? (THCPByteTensor*)index : NULL; - THCPAutoGPU __gpu_guard(NULL, (PyObject*)self); -#elif defined(THD_GENERIC_FILE) - THDPByteTensor *mask = THDPByteTensor_Check(index) ? (THDPByteTensor*)index : NULL; -#else - THPByteTensor *mask = THPByteTensor_Check(index) ? (THPByteTensor*)index : NULL; -#endif - if (mask) { - if (THPUtils_(checkReal)(value)) { - real v = THPUtils_(unpackReal)(value); - THTensor_(maskedFill)(LIBRARY_STATE self->cdata, mask->cdata, v); - } else if (THPTensor_(Check)(value)) { - THTensor_(maskedCopy)(LIBRARY_STATE self->cdata, mask->cdata, ((THPTensor*)value)->cdata); - } else { - THPUtils_setError("can't assign %s to a " THPTensorStr " using a mask " - "(only " THPTensorStr " or %s are supported)", - THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str); - } - return 0; - } - if (THPIndexTensor_Check(index)) { - THIndexTensor *index_t = ((THPIndexTensor*)index)->cdata; - - // TH will also throw an error, but its a Runtime Error that is less interpretable - // than doing it at this layer - if (THIndexTensor_(nDimension)(LIBRARY_STATE index_t) != 1) { - PyErr_Format(PyExc_IndexError, "Setting values by indexing a Tensor with a " -#ifndef THC_GENERIC_FILE - "torch.LongTensor " -#else - "torch.cuda.LongTensor " -#endif - "triggers index_fill or index_copy semantics, and thus we expect a vector, but " - "the indexing Tensor passed has %lld dimensions", - (long long) THIndexTensor_(nDimension)(LIBRARY_STATE index_t)); - throw python_error(); - } - - if (THPUtils_(checkReal)(value)) { - real v = THPUtils_(unpackReal)(value); - THTensor_(indexFill)(LIBRARY_STATE self->cdata, 0, index_t, v); - } else if (THPTensor_(Check)(value)) { - THTensor_(indexCopy)(LIBRARY_STATE self->cdata, 0, index_t, ((THPTensor*)value)->cdata); - } else { - THPUtils_setError("can't assign %s to a " THPTensorStr " using a LongTensor " - "(only " THPTensorStr " or %s are supported)", - THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str); - } - return 0; - } -#endif - - THTensorPtr tresult; - THStorage *sresult; - int64_t storage_offset; - - // Check and see if the indexing object triggers advanced indexing semantics -#ifndef TH_REAL_IS_HALF - if (THPTensor_(_checkAdvancedIndexing)(self, index)) { - tresult = THTensor_(newWithTensor)(LIBRARY_STATE self->cdata); - if (!THPTensor_(_advancedIndexSet)(index, tresult, value)) { - return -1; - } - return 0; - } - -#endif // TH_REAL_IS_HALF - if (!THPTensor_(_index)(self, index, tresult, sresult, storage_offset)) - return -1; - if (sresult) { - if (!force_tensor) { - if (!THPUtils_(checkReal)(value)) { - THPUtils_setError("can't assign a %s to a scalar value of type %s", - THPUtils_typename(value), THPUtils_typeTraits<real>::python_type_str); - return -1; - } - THStorage_(set)(LIBRARY_STATE sresult, storage_offset, THPUtils_(unpackReal)(value)); - return 0; - } else { - tresult = THTensor_(newWithStorage1d)(LIBRARY_STATE sresult, storage_offset, 1, -1); - } - } - if (tresult) { - if (THPUtils_(checkReal)(value)) { -#ifndef TH_REAL_IS_HALF - THTensor_(fill)(LIBRARY_STATE tresult.get(), THPUtils_(unpackReal)(value)); -#else - throw std::runtime_error("torch.HalfTensors don't support scalar assignments"); -#endif - } else { - // TODO: try to do this without creating a temporary object - THPTensorPtr tmp((THPTensor*)THPTensor_(New)(tresult.release())); - if (!tmp) - return -1; - if (!THPCopy(THTensor_(copy_functions), (PyObject*)tmp.get(), value, false, false)) { - return -1; - } - } - return 0; - } - THPUtils_setError("An unknown error has occurred when indexing a tensor " - "in THPTensor_(setValue). Please report this in a github issue at: " - "https://github.com/pytorch/pytorch"); - return -1; - END_HANDLE_TH_ERRORS_RET(-1) -} -#undef THIndexTensor -#undef THIndexTensor_ -#undef THPIndexTensor -#undef THPIndexTensor_Check - -Py_ssize_t THPTensor_(length)(THPTensor *self) -{ - if (self->cdata->nDimension == 0) - return 0; - return self->cdata->size[0]; -} - -#include "TensorMethods.cpp" - -static PyMappingMethods THPTensor_(mappingmethods) = { - (lenfunc)THPTensor_(length), - (binaryfunc)THPTensor_(getValue)<false>, - (objobjargproc)THPTensor_(setValue)<false> -}; - -// TODO: implement equality -PyTypeObject THPTensorType = { - PyVarObject_HEAD_INIT(NULL, 0) - "torch._C." THPTensorBaseStr, /* tp_name */ - sizeof(THPTensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)THPTensor_(dealloc), /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - &THPTensor_(mappingmethods), /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - NULL, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* will be assigned in init */ /* tp_methods */ - 0, /* will be assigned in init */ /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - THPTensor_(pynew), /* tp_new */ -}; - -static struct PyMemberDef THPTensor_(members)[] = { - {(char*)"_cdata", T_ULONGLONG, offsetof(THPTensor, cdata), READONLY, NULL}, - {NULL} -}; - -typedef struct { - PyObject_HEAD -} THPTensorStateless; - -PyTypeObject THPTensorStatelessType = { - PyVarObject_HEAD_INIT(NULL, 0) - "torch._C." THPTensorBaseStr ".stateless", /* tp_name */ - sizeof(THPTensorStateless), /* tp_basicsize */ - 0, /* tp_itemsize */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved / tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - NULL, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - THPTensor_stateless_(methods), /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - 0, /* tp_new */ - 0, /* tp_free */ - 0, /* tp_is_gc */ - 0, /* tp_bases */ - 0, /* tp_mro */ - 0, /* tp_cache */ - 0, /* tp_subclasses */ - 0, /* tp_weaklist */ -}; - -#if !defined(TH_REAL_IS_HALF) && !defined(THD_GENERIC_FILE) -#include "SparseTensor.cpp" -#endif - -#ifndef THD_GENERIC_FILE -void THPTensor_(initCopyMethods)() -{ - auto& h = THTensor_(copy_functions); - // copy from same type - THPInsertTensorCopyFunction(h, &THTensor_(copy)); - // copy from CPU types - THPInsertTensorCopyFunction(h, &THTensor_(copyByte)); - THPInsertTensorCopyFunction(h, &THTensor_(copyChar)); - THPInsertTensorCopyFunction(h, &THTensor_(copyShort)); - THPInsertTensorCopyFunction(h, &THTensor_(copyInt)); - THPInsertTensorCopyFunction(h, &THTensor_(copyLong)); - THPInsertTensorCopyFunction(h, &THTensor_(copyFloat)); - THPInsertTensorCopyFunction(h, &THTensor_(copyHalf)); - THPInsertTensorCopyFunction(h, &THTensor_(copyDouble)); -#ifdef THC_GENERIC_FILE - // copy from GPU types - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaByte)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaChar)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaShort)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaInt)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaLong)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaFloat)); - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaDouble)); -#ifdef CUDA_HALF_TENSOR - THPInsertTensorCopyFunction(h, &THTensor_(copyCudaHalf)); -#endif - THPInsertTensorCopyFunction(h, &THCTensor_(copyAsyncCPU), true); - // add CPU <- GPU copies to base type - #define THCpuTensor_(name) TH_CONCAT_4(TH, Real, Tensor_, name) - extern THPCopyList THCpuTensor_(copy_functions); - auto& b = THCpuTensor_(copy_functions); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaByte)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaChar)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaShort)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaInt)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaLong)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaFloat)); - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaDouble)); -#ifdef CUDA_HALF_TENSOR - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyCudaHalf)); -#endif - THPInsertTensorCopyFunction(b, &THCpuTensor_(copyAsyncCuda), true); - #undef THCpuTensor_ -#endif -} -#else -void THPTensor_(initCopyMethods)() -{ - // TODO: cross type copies - auto& h = THTensor_(copy_functions); - THPInsertCopyFunction(h, &THDTensor_(copy)); - - #define THCpuTensor_(name) TH_CONCAT_4(TH, Real, Tensor_, name) - #define THCpuTensor TH_CONCAT_3(TH, Real, Tensor) - #define THPCpuTensorType TH_CONCAT_3(THP, Real, TensorType) - extern THPCopyList THCpuTensor_(copy_functions); - auto& b = THCpuTensor_(copy_functions); - - THDPInsertCopyFunctionFromMaster(h, &THDTensor_(copyFromMaster), &THPCpuTensorType); - THDPInsertCopyFunctionFromWorker(b, THDTensor_(copyFromWorker)); - - #undef THCpuTensor - #undef THCpuTensor_ - #undef THPCpuTensorType -} -#endif // !defined(THD_GENERIC_FILE) - -bool THPTensor_(init)(PyObject *module) -{ -#if !defined(THC_GENERIC_FILE) && !defined(TH_REAL_IS_HALF) - THVector_(vectorDispatchInit)(); -#endif - THPTensorType.tp_methods = THPTensor_(methods); - THPTensorType.tp_members = THPTensor_(members); - if (PyType_Ready(&THPTensorType) < 0) - return false; - THPTensorStatelessType.tp_new = PyType_GenericNew; - if (PyType_Ready(&THPTensorStatelessType) < 0) - return false; - - PyModule_AddObject(module, THPTensorBaseStr, (PyObject *)&THPTensorType); - THPTensor_(initCopyMethods)(); - return true; -} - -bool THPTensor_(postInit)(PyObject *module) -{ - THPTensorClass = PyObject_GetAttrString(module,(char*)TH_CONCAT_STRING_2(Real,Tensor)); - if (!THPTensorClass) return false; - - bool is_cuda = false; -#ifdef THC_GENERIC_FILE - is_cuda = true; -#endif - const char *type_name = TH_CONCAT_STRING_2(Real,); - torch::registerPyTypeObject((PyTypeObject*)THPTensorClass, type_name, is_cuda, false); - return true; -} - -#endif diff --git a/torch/csrc/generic/Tensor.h b/torch/csrc/generic/Tensor.h deleted file mode 100644 index 01f205fb4f..0000000000 --- a/torch/csrc/generic/Tensor.h +++ /dev/null @@ -1,74 +0,0 @@ -#ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/Tensor.h" -#else - -#if defined(TH_REAL_IS_HALF) || defined(THD_GENERIC_FILE) -#define GENERATE_SPARSE 0 -#else -#define GENERATE_SPARSE 1 -#endif - -struct THPTensor { - PyObject_HEAD - // Invariant: After __new__ (not __init__), this field is always non-NULL. - THTensor *cdata; -}; - -#if GENERATE_SPARSE -struct THSPTensor { - PyObject_HEAD - // Invariant: After __new__ (not __init__), this field is always non-NULL. - THSTensor *cdata; -}; -#endif - -/** - * Creates a new Python (Sparse) Tensor object using the give THTensor. The - * returned PyObject* pointer can be safely casted to a THPTensor*. Note: This - * "steals" the THTensor* `ptr`. On error, NULL is returned and the `ptr` ref - * count is decremented. - */ -THP_API PyObject * THPTensor_(New)(THTensor *ptr); -#if GENERATE_SPARSE -THP_API PyObject * THSPTensor_(New)(THSTensor *ptr); -#endif - -/** - * Creates a new empty Python Tensor object - */ -THP_API PyObject * THPTensor_(NewEmpty)(void); -#if GENERATE_SPARSE -THP_API PyObject * THSPTensor_(NewEmpty)(void); -#endif - -THP_API PyObject *THPTensorClass; -#if GENERATE_SPARSE -THP_API PyObject *THSPTensorClass; -#endif - -#ifdef _THP_CORE -#include "torch/csrc/Types.h" - -// TODO: init stateless in THPTensor_(init) and remove this -THP_API PyTypeObject THPTensorStatelessType; -#if GENERATE_SPARSE -THP_API PyTypeObject THSPTensorStatelessType; -#endif - -bool THPTensor_(init)(PyObject *module); -bool THPTensor_(postInit)(PyObject *module); -#if GENERATE_SPARSE -bool THSPTensor_(init)(PyObject *module); -bool THSPTensor_(postInit)(PyObject *module); -#endif - -THP_API PyTypeObject THPTensorType; -template <> struct THPTypeInfo<THTensor> { - static PyTypeObject* pyType() { return &THPTensorType; } - static THTensor* cdata(PyObject* p) { return ((THPTensor*)p)->cdata; } -}; -#endif - -#undef GENERATE_SPARSE - -#endif diff --git a/torch/csrc/generic/TensorMethods.cwrap b/torch/csrc/generic/TensorMethods.cwrap deleted file mode 100644 index bcb5f141b0..0000000000 --- a/torch/csrc/generic/TensorMethods.cwrap +++ /dev/null @@ -1,173 +0,0 @@ -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || \ - defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || \ - defined(THC_REAL_IS_HALF) -#define RealStr "float" -#else -#define RealStr "int" -#endif - -#ifdef THC_REAL_IS_HALF -#define AS_REAL(x) THC_float2half(x) -#else -#define AS_REAL(x) x -#endif - -#ifdef THD_GENERIC_FILE -#define IS_DISTRIBUTED true -#else -#define IS_DISTRIBUTED false -#endif - -#ifndef THC_GENERIC_FILE -#define IS_CUDA false -#define CUDA_FLOAT false -#else - -#define IS_CUDA true - -#if defined(THC_REAL_IS_BYTE) -#define CUDA_BYTE 1 -#else -#define CUDA_BYTE 0 -#endif - -#if defined(THC_REAL_IS_CHAR) -#define CUDA_CHAR 1 -#else -#define CUDA_CHAR 0 -#endif - -#if defined(THC_REAL_IS_SHORT) -#define CUDA_SHORT 1 -#else -#define CUDA_SHORT 0 -#endif - -#if defined(THC_REAL_IS_INT) -#define CUDA_INT 1 -#else -#define CUDA_INT 0 -#endif - -#if defined(THC_REAL_IS_LONG) -#define CUDA_LONG 1 -#else -#define CUDA_LONG 0 -#endif - -#if defined(THC_REAL_IS_FLOAT) -#define CUDA_FLOAT 1 -#else -#define CUDA_FLOAT 0 -#endif - -#if defined(THC_REAL_IS_DOUBLE) -#define CUDA_DOUBLE 1 -#else -#define CUDA_DOUBLE 0 -#endif - -#if defined(THC_REAL_IS_HALF) -#define CUDA_HALF 1 -#else -#define CUDA_HALF 0 -#endif - -#endif // ifndef THC_GENERIC_FILE - -#if IS_CUDA -#define THIndexTensor THCudaLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME) -#define THPIndexTensor THCPLongTensor -#define THPIndexTensor_(NAME) TH_CONCAT_2(THCPLongTensor_,NAME) -#define THPIndexTensorClass THCPLongTensorClass -#elif IS_DISTRIBUTED -#define THIndexTensor THDLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THDLongTensor_,NAME) -#define THPIndexTensor THDPLongTensor -#define THPIndexTensor_(NAME) TH_CONCAT_2(THDPLongTensor_,NAME) -#define THPIndexTensorClass THDPLongTensorClass -#else -#define THIndexTensor THLongTensor -#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME) -#define THPIndexTensor THPLongTensor -#define THPIndexTensor_(NAME) TH_CONCAT_2(THPLongTensor_,NAME) -#define THPIndexTensorClass THPLongTensorClass -#endif - -#if IS_CUDA -#define THIntegerTensor THCudaIntTensor -#define THIntegerTensor_(NAME) TH_CONCAT_2(THCudaIntTensor_,NAME) -#define THPIntegerTensor THCPIntTensor -#define THPIntegerTensorClass THCPIntTensorClass -#elif IS_DISTRIBUTED -#define THIntegerTensor THDIntTensor -#define THIntegerTensor_(NAME) TH_CONCAT_2(THDIntTensor_,NAME) -#define THPIntegerTensor THDPIntTensor -#define THPIntegerTensorClass THDPIntTensorClass -#else -#define THIntegerTensor THIntTensor -#define THIntegerTensor_(NAME) TH_CONCAT_2(THIntTensor_,NAME) -#define THPIntegerTensor THPIntTensor -#define THPIntegerTensorClass THPIntTensorClass -#endif - -#if IS_CUDA -#define THBoolTensor THCudaByteTensor -#define THPBoolTensor THCPByteTensor -#define THPBoolTensorClass THCPByteTensorClass -#elif IS_DISTRIBUTED -#define THBoolTensor THDByteTensor -#define THPBoolTensor THDPByteTensor -#define THPBoolTensorClass THDPByteTensorClass -#else -#define THBoolTensor THByteTensor -#define THPBoolTensor THPByteTensor -#define THPBoolTensorClass THPByteTensorClass -#endif - -#if IS_CUDA -#define THPModuleStr "torch.cuda." -#elif IS_DISTRIBUTED -#define THPModuleStr "torch.distributed." -#else -#define THPModuleStr "torch." -#endif - -// The C API uses THLongStorage for size and stride, but the Python API uses -// torch.Size or tuple -typedef THLongStorage THSize; -typedef THLongStorage THStride; - -!!inc methods/Tensor.cwrap -!!inc methods/TensorApply.cwrap -!!inc methods/TensorMath.cwrap -!!inc methods/TensorCompare.cwrap -!!inc methods/TensorRandom.cwrap -!!inc methods/TensorCuda.cwrap - -#if !IS_DISTRIBUTED -!!inc methods/SparseTensor.cwrap -#endif - -// cwrap should put definitions before undefs, so let's mark this place -// PUT DEFINITIONS IN HERE PLEASE - -#undef IS_CUDA -#undef CUDA_BYTE -#undef CUDA_CHAR -#undef CUDA_SHORT -#undef CUDA_INT -#undef CUDA_LONG -#undef CUDA_FLOAT -#undef CUDA_DOUBLE -#undef CUDA_HALF -#undef THIndexTensor -#undef THIndexTensor_ -#undef THPIndexTensor -#undef THPIndexTensorClass -#undef THBoolTensor -#undef THPBoolTensor -#undef THPBoolTensorClass -#undef RealStr -#undef AS_REAL diff --git a/torch/csrc/generic/methods/SparseTensor.cwrap b/torch/csrc/generic/methods/SparseTensor.cwrap deleted file mode 100644 index 5fb5ea6aba..0000000000 --- a/torch/csrc/generic/methods/SparseTensor.cwrap +++ /dev/null @@ -1,548 +0,0 @@ -#if IS_CUDA || !defined(TH_REAL_IS_HALF) -PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS - THSTensor* tensor = ((THSPTensor*)self)->cdata; - if (PyTuple_Size(args) == 0 && (!kwargs || PyDict_Size(kwargs) == 0)) { - return THPSize_New(tensor->nDimensionI + tensor->nDimensionV, tensor->size); - } - - int tuplecount = args ? PyTuple_Size(args) : 0; - int dictcount = kwargs ? PyDict_Size(kwargs) : 0; - - PyObject* pydim = NULL; - if (tuplecount == 1 && dictcount == 0) { - pydim = PyTuple_GET_ITEM(args, 0); - } else if (dictcount == 1 && tuplecount == 0) { - pydim = PyDict_GetItemString(kwargs, "dim"); - } - - if (pydim && THPUtils_checkLong(pydim)) { - int dim = (int)THPUtils_unpackLong(pydim); - if (dim < 0) - dim += tensor->nDimensionI + tensor->nDimensionV; - return PyInt_FromLong(THSTensor_(size)(LIBRARY_STATE tensor, dim)); - } - - THPUtils_invalidArguments(args, kwargs, "size", 2, "(int dim)", "no arguments"); - return NULL; - END_HANDLE_TH_ERRORS -} -[[ - name: THSPTensor_(size) - python_name: size - method_flags: METH_KEYWORDS - only_register: True - sparse: yes -]] -#endif - -[[ - name: THSPTensor_(new) - python_name: new - method_flags: METH_KEYWORDS - backends: - - CUDA - only_register: True - sparse: yes -]] -#if IS_CUDA -static PyObject * THSPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs); -PyObject * THSPTensor_(new)(THPTensor *self, PyObject *args, PyObject *kwargs) -{ - THCPAutoGPU gpu_guard(args, (PyObject*)self); - return THSPTensor_(pynew)(Py_TYPE(self), args, kwargs); -} -#endif - -[[ - name: nDimension - sparse: yes - python_name: ndimension - return: int64_t - arguments: - - THSTensor* self -]] -[[ - name: THPTensor_(nDimension) - python_name: dim - only_register: True - method_flags: METH_KEYWORDS - sparse: yes -]] - -[[ - name: nnz - python_name: _nnz - sparse: yes - return: int64_t - arguments: - - THSTensor* self -]] - -[[ - name: nDimensionI - python_name: _dimI - sparse: yes - return: long - arguments: - - THSTensor* self -]] - -[[ - name: nDimensionV - python_name: _dimV - sparse: yes - return: long - arguments: - - THSTensor* self -]] - -[[ - name: isCoalesced - sparse: yes - python_name: is_coalesced - return: bool - arguments: - - THSTensor* self -]] - -[[ - name: indices - python_name: _indices - sparse: yes - cname: newIndices - return: THIndexTensor* - arguments: - - THSTensor* self -]] - -[[ - name: values - python_name: _values - sparse: yes - cname: newValues - return: THTensor* - arguments: - - THSTensor* self -]] - -[[ - name: coalesce - cname: newCoalesce - sparse: yes - return: THSTensor* - arguments: - - THSTensor* self -]] - -[[ - name: clone - sparse: yes - cname: newClone - return: THSTensor* - arguments: - - THSTensor* self -]] - -[[ - name: toDense - sparse: yes - python_name: to_dense - return: THTensor* - arguments: - - THSTensor* self -]] - -[[ - name: resizeAs_ - python_name: resize_as_ - sparse: yes - cname: resizeAs - return: self - arguments: - - THSTensor* self - - THSTensor* the_template -]] - -[[ - name: transpose - sparse: yes - cname: newTranspose - return: THSTensor* - arguments: - - THSTensor* self - - int64_t dim0 - - int64_t dim1 -]] - -[[ - name: transpose_ - sparse: yes - cname: transpose - return: argument 0 - arguments: - - THSTensor* self - - int64_t dim0 - - int64_t dim1 -]] - -[[ - name: t - sparse: yes - variants: - - method - - function - cname: newTranspose - return: THSTensor* - before_call: | - int64_t nDimI = ((THSPTensor*)${arg0})->cdata->nDimensionI; - int64_t nDimV = ((THSPTensor*)${arg0})->cdata->nDimensionV; - THPUtils_assert(nDimI == 2 && nDimV == 0, "t() expects a 2D sparse tensor, but self is %ldD indices and %ldD values", nDimI, nDimV); - arguments: - - THSTensor* self - - CONSTANT 0 - - CONSTANT 1 -]] - -[[ - name: t_ - sparse: yes - cname: transpose - return: self - before_call: | - int64_t nDimI = ((THSPTensor*)${arg0})->cdata->nDimensionI; - int64_t nDimV = ((THSPTensor*)${arg0})->cdata->nDimensionV; - THPUtils_assert(nDimI == 2 && nDimV == 0, "t_() expects a 2D sparse tensor, but self is %ldD indices and %ldD values", nDimI, nDimV); - arguments: - - THSTensor* self - - CONSTANT 0 - - CONSTANT 1 -]] - -[[ - name: mm - sparse: yes - variants: - - function - cname: spaddmm - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: spmm - variants: - - function - sparse: yes - cname: spaddmm - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: hspmm - variants: - - function - sparse: yes - cname: hspmm - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - CONSTANT AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: sspmm - variants: - - function - sparse: yes - cname: sspaddmm - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: sspaddmm - sparse: yes - variants: - - method - - function - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - THSTensor* self - - arg: real alpha - default: AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: spadd - sparse: yes - cname: spcadd - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* mat1 - - arg: real value - default: AS_REAL(1) - - THSTensor* mat2 -]] - -[[ - name: zero_ - sparse: yes - cname: zero - return: self - arguments: - - THSTensor* self -]] - -[[ - name: zeros - sparse: yes - variants: - - function - auto_gpu: False - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - arg: THSize* size - long_args: True -]] - -[[ - name: zeros_like - sparse: yes - cname: zerosLike - variants: - - function - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - THSTensor* input -]] - -[[ - name: add - sparse: yes - variants: - - method - - function - return: argument 0 - cname: cadd - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - -[[ - name: add_ - sparse: yes - return: argument 0 - cname: cadd - arguments: - - THSTensor* self - - THSTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - -[[ - name: sub - sparse: yes - variants: - - method - - function - return: argument 0 - cname: csub - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - -[[ - name: sub_ - sparse: yes - return: argument 0 - cname: csub - arguments: - - THSTensor* self - - THSTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - -[[ - name: mul - sparse: yes - return: argument 0 - variants: - - method - - function - options: - - cname: mul - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - real value - - cname: cmul - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - THSTensor* other -]] - -[[ - name: mul_ - sparse: yes - return: argument 0 - options: - - cname: mul - arguments: - - THSTensor* self - - THSTensor* self - - real value - - cname: cmul - arguments: - - THSTensor* self - - THSTensor* self - - THSTensor* other -]] - -[[ - name: div - sparse: yes - cname: div - variants: - - method - - function - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - real value -]] - -[[ - name: div_ - sparse: yes - cname: div - return: argument 0 - arguments: - - THSTensor* self - - THSTensor* self - - real value -]] - -[[ - name: norm - types: - - float - - double - backends: - - CPU - - CUDA - sparse: yes - cname: normall - variants: - - method - - function - return: real - arguments: - - THSTensor* self - - arg: real value - default: AS_REAL(2) -]] - -[[ - name: pow - types: - - float - - double - backends: - - CPU - - CUDA - sparse: yes - cname: pow - variants: - - method - - function - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - THSTensor* self - - real value -]] - -[[ - name: _sparse_mask - defined_if: "!IS_DISTRIBUTED" - cname: sparseMask - return: argument 0 - arguments: - - arg: THSTensor* result - output: True - - THTensor* self - - THSTensor* mask -]] - -[[ - name: getDevice - sparse: yes - python_name: get_device - backends: - - CUDA - return: int64_t - arguments: - - THSTensor* self -]] diff --git a/torch/csrc/generic/methods/Tensor.cwrap b/torch/csrc/generic/methods/Tensor.cwrap deleted file mode 100644 index edc5a6f6d5..0000000000 --- a/torch/csrc/generic/methods/Tensor.cwrap +++ /dev/null @@ -1,1152 +0,0 @@ -// TODO: check that there are no args -[[ - name: THPTensor_(elementSize) - python_name: element_size - cpu_half: True - auto_gpu: False - only_register: True -]] -static PyObject * THPTensor_(elementSize)(THPTensor *self, PyObject *args) -{ - return PyLong_FromLong(THStorage_(elementSize)(LIBRARY_STATE_NOARGS)); -} - -// TODO: check that there are no args -[[ - name: THPTensor_(storage) - python_name: storage - cpu_half: True - auto_gpu: False - only_register: True -]] -static PyObject * THPTensor_(storage)(THPTensor *self, PyObject *args) -{ - // TODO: memory leak on error - THStorage *result = THTensor_(storage)(LIBRARY_STATE self->cdata); - if (result == NULL) - Py_RETURN_NONE; - THStorage_(retain)(LIBRARY_STATE result); - THStoragePtr _tmp(result); - PyObject *ret = THPStorage_(New)(result); - _tmp.release(); - return ret; -} - -[[ - name: storageOffset - python_name: storage_offset - cpu_half: True - auto_gpu: False - return: int64_t - arguments: - - THTensor* self -]] - -[[ - name: nDimension - python_name: ndimension - cpu_half: True - auto_gpu: False - return: int64_t - arguments: - - THTensor* self -]] -[[ - name: THPTensor_(nDimension) - python_name: dim - cpu_half: True - auto_gpu: False - only_register: True - method_flags: METH_KEYWORDS -]] - -[[ - python_name: index - name: THPTensor_(getValue)<true> - only_register: True - override_method_flags: METH_O -]] - -[[ - python_name: _set_index - name: THPTensor_(setIndex) - only_register: True -]] -PyObject * THPTensor_(setIndex)(THPTensor *self, PyObject *args) -{ - THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "set_index takes exactly two " - "arguments (%d given)", (int)PyTuple_GET_SIZE(args)); - if (THPTensor_(setValue)<true>(self, PyTuple_GET_ITEM(args, 0), PyTuple_GET_ITEM(args, 1)) != 0) - return NULL; - Py_RETURN_NONE; -} - -[[ - python_name: _check_advanced_indexing - name: THPTensor_(checkAdvancedIndexing) - cpu_half: False - only_register: True - override_method_flags: METH_O -]] - -[[ - python_name: _advanced_index_add - name: THPTensor_(advancedIndexAdd) - cpu_half: False - only_register: True -]] - -[[ - python_name: _advanced_index_select - name: THPTensor_(advancedIndexSelect) - cpu_half: False - only_register: True -]] - -[[ - name: resize_ - return: self - cname: resize - cpu_half: True - before_call: - THPUtils_assert(arg_self->storage->flag & TH_STORAGE_RESIZABLE, - "calling resize_ on a tensor that has non-resizable storage. Clone it first " - "or create a new tensor instead."); - arguments: - - THTensor* self - - arg: THSize* size - long_args: True - - CONSTANT NULL -]] - -[[ - name: zeros - variants: - - function - auto_gpu: False - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THSize* size - long_args: True -]] - -[[ - name: zeros_like - cname: zerosLike - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* input -]] - -[[ - name: ones - variants: - - function - auto_gpu: False - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THSize* size - long_args: True -]] - -[[ - name: ones_like - cname: onesLike - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* input -]] - -[[ - name: numel - return: int64_t - cname: nElement - cpu_half: True - auto_gpu: False - variants: - - method - - function - arguments: - - THTensor* self -]] -[[ - name: THPTensor_(numel) - python_name: nelement - cpu_half: True - auto_gpu: False - only_register: True - method_flags: METH_KEYWORDS -]] - -[[ - name: set_ - cname: set - cpu_half: True - auto_gpu: False - return: argument 0 - options: - - cname: set - arguments: - - THTensor* self - - THTensor* source - - cname: setStorage - arguments: - - THTensor* self - - CONSTANT NULL, 0, NULL, NULL - - cname: setStorage - before_call: THLongStoragePtr __storage_size(THLongStorage_newWithSize1(THStorage_(size)(LIBRARY_STATE arg_storage))); - arguments: - - THTensor* self - - THStorage* storage - - CONSTANT 0 - - CONSTANT __storage_size.get() - - CONSTANT NULL - - cname: setStorage - arguments: - - THTensor* self - - THStorage* sourceStorage - - int64_t storage_offset - - THSize* size - - arg: THStride* stride - default: NULL -]] - -[[ - name: THPTensor_(select) - python_name: select - cpu_half: True - auto_gpu: False - only_register: True -]] -static PyObject * THPTensor_(select)(THPTensor *self, PyObject *args) -{ - HANDLE_TH_ERRORS - int64_t dim, idx; - if (!PyArg_ParseTuple(args, "LL", &dim, &idx)) - return NULL; - - int ndim = THTensor_(nDimension)(LIBRARY_STATE self->cdata); - - THPUtils_assert(dim >= -(ndim) && dim < (ndim), - "dimension out of range (expected to be in range of [%d, %d], but got %d)", - -(ndim), (ndim)-1, dim); - if (dim<0) dim += ndim; - - if(ndim > 1) { - THTensorPtr selected(THTensor_(newWithTensor)(LIBRARY_STATE self->cdata)); - THTensor_(select)(LIBRARY_STATE selected.get(), NULL, (int) dim, idx); - return THPTensor_(New)(selected.release()); - } - else { - THArgCheck(ndim == 1, 1, "empty Tensor"); - return THPUtils_(newReal)(THTensor_(get1d)(LIBRARY_STATE self->cdata, idx)); - } - END_HANDLE_TH_ERRORS -} - -PyObject * THPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS - THTensor* tensor = ((THPTensor*)self)->cdata; - if (PyTuple_Size(args) == 0 && (!kwargs || PyDict_Size(kwargs) == 0)) { - return THPSize_New(tensor->nDimension, tensor->size); - } - - int tuplecount = args ? (int) PyTuple_Size(args) : 0; - int dictcount = kwargs ? (int) PyDict_Size(kwargs) : 0; - - PyObject* pydim = NULL; - if (tuplecount == 1 && dictcount == 0) { - pydim = PyTuple_GET_ITEM(args, 0); - } else if (dictcount == 1 && tuplecount == 0) { - pydim = PyDict_GetItemString(kwargs, "dim"); - } - - if (pydim && THPUtils_checkLong(pydim)) { - int dim = (int)THPUtils_unpackLong(pydim); - if (dim < 0) - dim += tensor->nDimension; - return PyInt_FromLong(THTensor_(size)(LIBRARY_STATE tensor, dim)); - } - - THPUtils_invalidArguments(args, kwargs, "size", 2, "(int dim)", "no arguments"); - return NULL; - END_HANDLE_TH_ERRORS -} -[[ - name: THPTensor_(size) - python_name: size - cpu_half: True - auto_gpu: False - method_flags: METH_KEYWORDS - only_register: True -]] - -PyObject * THPTensor_(stride)(PyObject *self, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS - THTensor* tensor = ((THPTensor*)self)->cdata; - if (PyTuple_Size(args) == 0 && (!kwargs || PyDict_Size(kwargs) == 0)) { - PyObject* stride = PyTuple_New(tensor->nDimension); - for (int i = 0; i != tensor->nDimension; ++i) { - PyTuple_SET_ITEM(stride, i, PyLong_FromLong(tensor->stride[i])); - } - return stride; - } - - int tuplecount = args ? (int) PyTuple_Size(args) : 0; - int dictcount = kwargs ? (int) PyDict_Size(kwargs) : 0; - - PyObject* pydim = NULL; - if (tuplecount == 1 && dictcount == 0) { - pydim = PyTuple_GET_ITEM(args, 0); - } else if (dictcount == 1 && tuplecount == 0) { - pydim = PyDict_GetItemString(kwargs, "dim"); - } - - if (pydim && THPUtils_checkLong(pydim)) { - int dim = (int)THPUtils_unpackLong(pydim); - if (dim < 0) - dim += tensor->nDimension; - return PyInt_FromLong(THTensor_(stride)(LIBRARY_STATE tensor, dim)); - } - - THPUtils_invalidArguments(args, kwargs, "stride", 2, "(int dim)", "no arguments"); - return NULL; - END_HANDLE_TH_ERRORS -} -[[ - name: THPTensor_(stride) - python_name: stride - cpu_half: True - auto_gpu: False - method_flags: METH_KEYWORDS - only_register: True -]] - -[[ - name: fill_ - cname: fill - return: self - arguments: - - THTensor* self - - real value -]] - -[[ - name: isSameSizeAs - python_name: is_same_size - cpu_half: True - auto_gpu: False - return: bool - arguments: - - THTensor* self - - THTensor* other -]] - -[[ - name: isContiguous - python_name: is_contiguous - cpu_half: True - auto_gpu: False - return: bool - arguments: - - THTensor* self -]] - -[[ - name: isSetTo - python_name: is_set_to - cpu_half: True - auto_gpu: False - return: bool - arguments: - - THTensor* self - - THTensor* tensor -]] - -[[ - name: maskedFill_ - cname: maskedFill - python_name: masked_fill_ - return: self - arguments: - - arg: THTensor* self - broadcast: mask inplace fallback types:Byte - - THBoolTensor* mask - - real value -]] - -[[ - name: maskedCopy_ - cname: maskedCopy - python_name: masked_scatter_ - return: self - arguments: - - arg: THTensor* self - broadcast: mask inplace fallback types:Byte - - THBoolTensor* mask - - THTensor* source -]] - -[[ - name: maskedSelect - python_name: masked_select - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: mask fallback types:Byte - - THBoolTensor* mask -]] - -[[ - name: transpose - variants: - - method - - function - cname: newTranspose - cpu_half: True - auto_gpu: False - return: THTensor* - arguments: - - THTensor* self - - arg: int64_t dim0 - wrap_dim: self - - arg: int64_t dim1 - wrap_dim: self -]] - -[[ - name: transpose_ - cname: transpose - cpu_half: True - auto_gpu: False - return: self - arguments: - - THTensor* self - - THTensor* self - - arg: int64_t dim0 - wrap_dim: self - - arg: int64_t dim1 - wrap_dim: self -]] - -[[ - name: t - variants: - - method - - function - auto_gpu: False - cname: newTranspose - return: THTensor* - before_call: | - int64_t ndim = arg_self->nDimension; - THPUtils_assert(ndim == 2, "t() expects a 2D tensor, but self is %ldD", ndim); - arguments: - - THTensor* self - - CONSTANT 0 - - CONSTANT 1 -]] - -[[ - name: t_ - cname: transpose - auto_gpu: False - return: self - before_call: | - int64_t ndim = arg_self->nDimension; - THPUtils_assert(ndim == 2, "t_() expects a 2D tensor, but self is %ldD", ndim); - arguments: - - THTensor* self - - THTensor* self - - CONSTANT 0 - - CONSTANT 1 -]] - -[[ - name: squeeze - cpu_half: True - variants: - - method - - function - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - cname: squeeze1d - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self -]] - -[[ - name: squeeze_ - cpu_half: True - return: self - options: - - cname: squeeze - arguments: - - THTensor* self - - THTensor* self - - cname: squeeze1d - arguments: - - THTensor* self - - THTensor* self - - arg: int64_t dim - wrap_dim: self -]] - -[[ - name: unsqueeze - variants: - - method - - function - cpu_half: True - auto_gpu: False - return: argument 0 - cname: unsqueeze1d - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self+1 -]] - -[[ - name: unsqueeze_ - cpu_half: True - auto_gpu: False - return: self - cname: unsqueeze1d - arguments: - - THTensor* self - - THTensor* self - - arg: int64_t dim - wrap_dim: self+1 -]] - -[[ - name: nonzero - variants: - - method - - function - return: argument 0 - arguments: - - arg: THIndexTensor* result - output: True - - THTensor* self -]] - -[[ - name: contiguous - cname: newContiguous - return: THTensor* - arguments: - - THTensor* self -]] - -[[ - name: clone - cname: newClone - return: THTensor* - aten_sparse: True - arguments: - - THTensor* self -]] - -[[ - name: view - cname: newView - auto_gpu: False - return: THTensor* - arguments: - - THTensor* self - - arg: THSize* size - long_args: True -]] - -[[ - name: expand - cname: newExpand - return: THTensor* - arguments: - - THTensor* self - - arg: THSize* size - long_args: True -]] - -[[ - name: resizeAs_ - python_name: resize_as_ - cname: resizeAs - return: self - arguments: - - THTensor* self - - THTensor* the_template -]] - -[[ - name: indexSelect - python_name: index_select - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index -]] - -[[ - name: indexCopy_ - python_name: index_copy_ - cname: indexCopy - return: argument 0 - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - THTensor* source -]] -[[ - name: take - cname: take - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THIndexTensor* index -]] -[[ - name: put_ - cname: put - backends: - - CPU - - CUDA - return: argument 0 - arguments: - - THTensor* self - - THIndexTensor* index - - THTensor* source - - arg: bool accumulate - default: "false" -]] -[[ - name: indexAdd_ - python_name: index_add_ - cname: indexAdd - return: argument 0 - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - THTensor* source -]] - -[[ - name: indexFill_ - python_name: index_fill_ - cname: indexFill - return: argument 0 - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - real value -]] - -[[ - name: narrow - cpu_half: True - auto_gpu: False - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dimension - wrap_dim: self - - int64_t start - - int64_t length -]] - -[[ - name: unfold - cpu_half: True - auto_gpu: False - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dimension - wrap_dim: self - - int64_t size - - int64_t step -]] - -[[ - name: range - variants: - - function - backends: - - CPU - - CUDA - return: argument 0 - before_arg_assign: | - PyErr_WarnEx(PyExc_UserWarning, "torch.range is deprecated in favor of torch.arange " - "and will be removed in 0.3. Note that arange generates values in [start; end), " - "not [start; end].", 1); - arguments: - - arg: THTensor* result - output: True - - accreal start - - accreal end - - arg: accreal step - default: 1 -]] - -[[ - name: arange - variants: - - function - backends: - - CPU - - CUDA - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - accreal start - - accreal end - - accreal step - - arguments: - - arg: THTensor* result - output: True - - accreal start - - accreal end - - CONSTANT 1 - - arguments: - - arg: THTensor* result - output: True - - CONSTANT 0 - - accreal end - - CONSTANT 1 -]] - -[[ - name: scatter_ - return: argument 0 - options: - - cname: scatter - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - THTensor* src - - cname: scatterFill - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - real value -]] - -[[ - name: scatter_add_ - return: argument 0 - options: - - cname: scatterAdd - arguments: - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index - - THTensor* src -]] - -[[ - name: gather - variants: - - method - - function - return: argument 0 - before_call: | - THLongStoragePtr _size(THIndexTensor_(newSizeOf)(LIBRARY_STATE arg_index)); - THTensor_(resize)(LIBRARY_STATE arg_result, _size, NULL); - arguments: - - arg: THTensor* result - output: True - resize: index - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - THIndexTensor* index -]] - -[[ - name: THPTensor_stateless_(cat) - python_name: cat - method_flags: METH_KEYWORDS - only_register: True - variants: - - function -]] -#ifndef TH_REAL_IS_HALF -static PyObject * THPTensor_stateless_(cat)(THPTensor *_unused, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS -#if IS_CUDA - THCPAutoGPU __autogpu_guard(-1); -#endif - static char* argnames[] = { "seq", "dim", "out", NULL }; - PyObject *_seq = NULL; - int dim = 0; - PyObject *___out = NULL; - - THPObjectPtr sequence; - std::vector<THTensor *> tensors; - THPTensorPtr result; - Py_ssize_t len; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|iO", argnames, &_seq, &dim, &___out)) { - goto invalid_arguments; - } - - sequence = PySequence_Fast(_seq, "seq must be a sequence"); - if (!sequence) { - // NOTE: we use the error message from invalidArguments when _seq is not a sequence - goto invalid_arguments; - } - - len = PySequence_Fast_GET_SIZE(sequence.get()); - THPUtils_assert(len > 0, "seq can't be empty"); - - if (___out && ___out != Py_None) { - if (!THPTensor_(Check)(___out)) { - goto invalid_arguments; - } - Py_INCREF(___out); - result = (THPTensor *)___out; - } else { - result = (THPTensor *)THPTensor_(NewEmpty)(); - if (!result) return NULL; - } - - for (int i = 0; i < len; i++) { - PyObject *item = PySequence_Fast_GET_ITEM(sequence.get(), i); - if (!THPTensor_(Check)(item)) - goto invalid_arguments; - tensors.push_back(((THPTensor*)item)->cdata); - } - - for (THTensor *t : tensors) { - auto ndim = THTensor_(nDimension)(LIBRARY_STATE t); - if (ndim > 0) { - THPUtils_assert(dim > 0 ? dim < ndim : ndim + dim >= 0, - "dim out of range - got %d but the tensor is only %dD", - dim, ndim); - if (dim < 0) dim += ndim; - break; - } - } - -#if IS_CUDA - __autogpu_guard.setDevice(THTensor_(getDevice)(LIBRARY_STATE tensors[0])); -#endif - - THTensor_(catArray)(LIBRARY_STATE result->cdata, tensors.data(), (int) tensors.size(), dim); - return (PyObject*)result.release(); - -invalid_arguments: - THPUtils_invalidArguments(args, kwargs, "cat", 2, - "(sequence[" THPTensorStr "] seq)", - "(sequence[" THPTensorStr "] seq, int dim)"); - return NULL; - END_HANDLE_TH_ERRORS -} -#endif - -[[ - name: data_ptr - defined_if: "!IS_DISTRIBUTED" - with_gil: True - auto_gpu: False - return: void* - cpu_half: True - cname: data - arguments: - - THTensor* self -]] - -[[ - name: equal - variants: - - method - - function - return: bool - arguments: - - THTensor* self - - THTensor* other -]] - -[[ - python_name: copy_ - name: THPTensor_(copy_) - cpu_half: True - method_flags: METH_KEYWORDS - only_register: True -]] -PyObject * THPTensor_(copy_)(PyObject *self, PyObject *args, PyObject *kwargs) -{ - HANDLE_TH_ERRORS - return THPTensorCopyMethod(THTensor_(copy_functions), self, args, kwargs); - END_HANDLE_TH_ERRORS -} - -[[ - name: __and__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitand - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cbitand - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: __iand__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitand - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cbitand - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: __or__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitor - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cbitor - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: __ior__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitor - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cbitor - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: __xor__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitxor - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cbitxor - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: __ixor__ - variants: - - method - - function - return: argument 0 - options: - - cname: bitxor - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cbitxor - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: __lshift__ - variants: - - method - - function - return: argument 0 - options: - - cname: lshift - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: clshift - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: __ilshift__ - variants: - - method - - function - return: argument 0 - options: - - cname: lshift - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: clshift - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: __rshift__ - variants: - - method - - function - return: argument 0 - options: - - cname: rshift - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: crshift - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: __irshift__ - variants: - - method - - function - return: argument 0 - options: - - cname: rshift - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: crshift - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] diff --git a/torch/csrc/generic/methods/TensorApply.cwrap b/torch/csrc/generic/methods/TensorApply.cwrap deleted file mode 100644 index addc214939..0000000000 --- a/torch/csrc/generic/methods/TensorApply.cwrap +++ /dev/null @@ -1,158 +0,0 @@ -#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) -#define BUILD_REAL_FMT "d" -#else -#define BUILD_REAL_FMT "L" -#endif - -#if !IS_CUDA && !IS_DISTRIBUTED -[[ - name: THPTensor_(apply) - python_name: apply_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - cpu_half: True - only_register: True - override_method_flags: METH_O -]] -static PyObject * THPTensor_(apply)(THPTensor *self, PyObject *arg) -{ - HANDLE_TH_ERRORS - if (!PyCallable_Check(arg)) { - THPUtils_setError("apply requires a callable as it's first argument"); - return NULL; - } - - THTensor *tensor = self->cdata; - TH_TENSOR_APPLY(real, tensor, - PyObject *ret = - PyObject_CallFunction(arg, (char*)BUILD_REAL_FMT, *tensor_data); - if (!ret) - return NULL; - if (!THPUtils_(checkReal)(ret)) { - Py_DECREF(ret); - THError("given function should return a number"); - } - *tensor_data = THPUtils_(unpackReal)(ret); - Py_DECREF(ret); - ); - - Py_INCREF(self); - return (PyObject*)self; - END_HANDLE_TH_ERRORS -} - -[[ - name: THPTensor_(map) - python_name: map_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - cpu_half: True - only_register: True -]] -static PyObject * THPTensor_(map)(THPTensor *self, PyObject *args) -{ - HANDLE_TH_ERRORS - PyObject *fn; - THPTensor *src_object; - if (!PyArg_ParseTuple(args, "O!O&", THPTensorClass, &src_object, THPUtils_getCallable, &fn)) - return NULL; - - THTensor *tensor = self->cdata; - THTensor *src = src_object->cdata; - - THTensor *src_save = src; - THTensorPtr src_guard(THTensor_(new)(LIBRARY_STATE_NOARGS)); - - bool expand_success = false; - try { - expand_inplace1<THTensor, THTensor>(src_guard.get(), src, tensor, "src", "tensor", true); - expand_success = true; - } catch (std::exception &e) {} - if (expand_success) { - src = src_guard.get(); - } - - TH_TENSOR_APPLY2(real, tensor, real, src, - PyObject *ret = - PyObject_CallFunction(fn, (char*)(BUILD_REAL_FMT BUILD_REAL_FMT), - *tensor_data, *src_data); - if (!ret) - return NULL; - if (!THPUtils_(checkReal)(ret)) { - Py_DECREF(ret); - THError("given function should return a number"); - } - *tensor_data = THPUtils_(unpackReal)(ret); - Py_DECREF(ret); - ); - - src = src_save; - - Py_INCREF(self); - return (PyObject*)self; - END_HANDLE_TH_ERRORS -} - -[[ - name: THPTensor_(map2) - python_name: map2_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - cpu_half: True - only_register: True -]] -static PyObject * THPTensor_(map2)(THPTensor *self, PyObject *args) -{ - HANDLE_TH_ERRORS - PyObject *fn; - THPTensor *src1_object; - THPTensor *src2_object; - if (!PyArg_ParseTuple(args, "O!O!O&", THPTensorClass, &src1_object, THPTensorClass, &src2_object, THPUtils_getCallable, &fn)) - return NULL; - - THTensor *tensor = self->cdata; - THTensor *src1 = src1_object->cdata; - THTensor *src2 = src2_object->cdata; - - THTensor *src1_save = src1; - THTensorPtr src1_guard(THTensor_(new)(LIBRARY_STATE_NOARGS)); - THTensor *src2_save = src2; - THTensorPtr src2_guard(THTensor_(new)(LIBRARY_STATE_NOARGS)); - - bool expand_success = false; - try { - expand_inplace2<THTensor>(src1_guard.get(), src2_guard.get(), src1, src2, tensor, "src1", "src2", "tensor", true); - expand_success = true; - } catch (std::exception &e) {} - if (expand_success) { - src1 = src1_guard.get(); - src2 = src2_guard.get(); - } - - TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2, - PyObject *ret = - PyObject_CallFunction(fn, (char*)(BUILD_REAL_FMT BUILD_REAL_FMT BUILD_REAL_FMT), - *tensor_data, *src1_data, *src2_data); - if (!ret) - return NULL; - if (!THPUtils_(checkReal)(ret)) { - Py_DECREF(ret); - THError("given function should return a number"); - } - *tensor_data = THPUtils_(unpackReal)(ret); - Py_DECREF(ret); - ); - - src1 = src1_save; - src2 = src2_save; - - Py_INCREF(self); - return (PyObject*)self; - END_HANDLE_TH_ERRORS -} -#endif /* !IS_CUDA */ - -#undef BUILD_REAL_FMT diff --git a/torch/csrc/generic/methods/TensorCompare.cwrap b/torch/csrc/generic/methods/TensorCompare.cwrap deleted file mode 100644 index f843b242a9..0000000000 --- a/torch/csrc/generic/methods/TensorCompare.cwrap +++ /dev/null @@ -1,760 +0,0 @@ -[[ - name: lt - variants: - - method - return: argument 0 - options: - - cname: ltValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: ltTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: lt_ - return: self - options: - - cname: ltValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: ltTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - arg: THTensor* other -]] - -[[ - name: lt - variants: - - function - return: argument 0 - options: - - cname: ltValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: ltTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: ltValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: ltTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other -]] - - -[[ - name: gt - variants: - - method - return: argument 0 - options: - - cname: gtValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: gtTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: gt_ - return: self - options: - - cname: gtValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: gtTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: gt - variants: - - function - return: argument 0 - options: - - cname: gtValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: gtTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: gtValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: gtTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - -]] - - -[[ - name: le - variants: - - method - return: argument 0 - options: - - cname: leValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: leTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: le_ - return: self - options: - - cname: leValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: leTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: le - variants: - - function - return: argument 0 - options: - - cname: leValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: leTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: leValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: leTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other -]] - - -[[ - name: ge - variants: - - method - return: argument 0 - options: - - cname: geValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: geTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: ge_ - return: self - options: - - cname: geValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: geTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: ge - variants: - - function - return: argument 0 - options: - - cname: geValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: geTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: geValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: geTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other -]] - - -[[ - name: eq - variants: - - method - return: argument 0 - options: - - cname: eqValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: eqTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: eq_ - return: self - options: - - cname: eqValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: eqTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: eq - variants: - - function - return: argument 0 - options: - - cname: eqValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: eqTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: eqValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: eqTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other -]] - - -[[ - name: ne - variants: - - method - return: argument 0 - options: - - cname: neValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* self - - real value - - cname: neTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: ne_ - return: self - options: - - cname: neValueT - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: neTensorT - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: ne - variants: - - function - return: argument 0 - options: - - cname: neValue - arguments: - - arg: THBoolTensor* result - output: True - - THTensor* tensor - - real value - - cname: neTensor - arguments: - - arg: THBoolTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other - - cname: neValueT - arguments: - - arg: THTensor* result - output: True - - THTensor* tensor - - real value - - cname: neTensorT - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* tensor - broadcast: other fallback - - THTensor* other -]] - -[[ - name: min - variants: - - method - - function - options: - - cname: minall - return: real - arguments: - - THTensor* self - - cname: cmin - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other - - cname: min - return: argument 0,1 - arguments: - - arg: THTensor* min - output: True - - arg: THIndexTensor* min_indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: min - return: argument 0,1 - before_call: maybeThrowBackCompatKeepdimWarn("min"); - arguments: - - arg: THTensor* min - output: True - - arg: THIndexTensor* min_indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: max - variants: - - method - - function - options: - - cname: maxall - return: real - arguments: - - THTensor* self - - cname: cmax - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other - - cname: max - return: argument 0,1 - arguments: - - arg: THTensor* max - output: True - - arg: THIndexTensor* max_indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: max - return: argument 0,1 - before_call: maybeThrowBackCompatKeepdimWarn("max"); - arguments: - - arg: THTensor* max - output: True - - arg: THIndexTensor* max_indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: kthvalue - backends: - - CPU - variants: - - method - - function - return: argument 0,1 - options: - - before_call: int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - CONSTANT __last_dim - - bool keepdim - - before_call: | - int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - maybeThrowBackCompatKeepdimWarn("kthvalue"); - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - CONSTANT __last_dim - - CONSTANT false - - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - before_call: maybeThrowBackCompatKeepdimWarn("kthvalue"); - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: mode - variants: - - method - - function - return: argument 0,1 - options: - - before_call: int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - CONSTANT __last_dim - - bool keepdim - - before_call: | - int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - maybeThrowBackCompatKeepdimWarn("mode"); - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - CONSTANT __last_dim - - CONSTANT false - - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - before_call: maybeThrowBackCompatKeepdimWarn("mode"); - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: median - variants: - - method - - function - return: argument 0,1 - options: - - cname: medianall - return: real - arguments: - - THTensor* self - - cname: median - before_call: int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - CONSTANT __last_dim - - bool keepdim - - before_call: maybeThrowBackCompatKeepdimWarn("median"); - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false - - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim -]] - -[[ - name: sort - variants: - - method - - function - return: argument 0,1 - options: - - before_call: int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - CONSTANT __last_dim - - arg: bool descending - default: "false" - kwarg_only: True - - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - arg: bool descending - default: "false" -]] - -[[ - name: topk - variants: - - method - - function - return: argument 0,1 - options: - - before_call: int64_t __last_dim = THTensor_(nDimension)(LIBRARY_STATE ((THPTensor*)$arg2)->cdata)-1; - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - CONSTANT __last_dim - - arg: bool largest - default: "true" - kwarg_only: True - - arg: bool sorted - default: "true" - kwarg_only: True - - arguments: - - arg: THTensor* values - output: True - - arg: THIndexTensor* indices - output: True - - THTensor* self - - int64_t k - - arg: int64_t dim - wrap_dim: self - - arg: bool largest - default: "true" - - arg: bool sorted - default: "true" -]] - -[[ - name: all - types: - - Byte - backends: - - CPU - - CUDA - cname: logicalall - return: bool - arguments: - - THTensor* self -]] - -[[ - name: any - types: - - Byte - backends: - - CPU - - CUDA - cname: logicalany - return: bool - arguments: - - THTensor* self -]] diff --git a/torch/csrc/generic/methods/TensorCuda.cwrap b/torch/csrc/generic/methods/TensorCuda.cwrap deleted file mode 100644 index dd3c5a2997..0000000000 --- a/torch/csrc/generic/methods/TensorCuda.cwrap +++ /dev/null @@ -1,48 +0,0 @@ -[[ - name: getDevice - python_name: get_device - backends: - - CUDA - return: int64_t - arguments: - - THTensor* self -]] - -[[ - name: THPTensor_(new) - python_name: new - method_flags: METH_KEYWORDS - backends: - - CUDA - only_register: True -]] -#if IS_CUDA -static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs); -PyObject * THPTensor_(new)(THPTensor *self, PyObject *args, PyObject *kwargs) -{ - THCPAutoGPU gpu_guard(args, (PyObject*)self); - return THPTensor_(pynew)(Py_TYPE(self), args, kwargs); -} -#endif - -[[ - name: THPTensor_(recordStream) - python_name: record_stream - override_method_flags: METH_O - backends: - - CUDA - only_register: True -]] -#if IS_CUDA -PyObject * THPTensor_(recordStream)(THPTensor *self, PyObject *arg) -{ - HANDLE_TH_ERRORS - if (!THCPStream_Check(arg)) { - return PyErr_Format(PyExc_TypeError, "expected Stream object"); - } - void* data = THTensor_(data)(LIBRARY_STATE self->cdata); - THCCachingAllocator_recordStream(data, ((THCPStream*)arg)->cdata); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} -#endif diff --git a/torch/csrc/generic/methods/TensorMath.cwrap b/torch/csrc/generic/methods/TensorMath.cwrap deleted file mode 100644 index 4cbd0606c6..0000000000 --- a/torch/csrc/generic/methods/TensorMath.cwrap +++ /dev/null @@ -1,2771 +0,0 @@ -[[ - name: abs - return: argument 0 - types: - - floating_point - - Long - - Int - - Short - backends: - - CPU - - CUDA - variants: - - method - - function - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: abs_ - cname: abs - return: self - types: - - floating_point - - Long - - Int - - Short - backends: - - CPU - - CUDA - arguments: - - THTensor* self - - THTensor* self -]] - - -[[ - name: sigmoid_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: sigmoid - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: sigmoid - types: - - floating_point - backends: - - CPU - - CUDA - cname: sigmoid - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: log_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: log - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: log - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: log1p_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: log1p - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: log1p - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: lgamma - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: lgamma_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: lgamma - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: digamma - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: digamma_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: digamma - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: polygamma - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - int64_t n - - THTensor* self -]] - -[[ - name: polygamma_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: polygamma - return: self - arguments: - - THTensor* self - - int64_t n - - THTensor* self -]] - -[[ - name: exp_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: exp - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: exp - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: expm1_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: expm1 - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: expm1 - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: cos_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: cos - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: cos - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: acos_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: acos - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: acos - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: cosh_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: cosh - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: cosh - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: sin_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: sin - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: sin - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: asin_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: asin - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: asin - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: sinh_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: sinh - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: sinh - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: tan_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: tan - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: tan - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: atan_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: atan - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: atan - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: tanh_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: tanh - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: tanh - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: erf_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: erf - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: erf - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: erfinv_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: erfinv - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: erfinv - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: sqrt_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: sqrt - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: sqrt - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: rsqrt_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: rsqrt - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: rsqrt - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: ceil_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: ceil - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: ceil - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: floor_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: floor - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: floor - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: round_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: round - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: round - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: trunc_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: trunc - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: trunc - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - - -[[ - name: frac_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: frac - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: frac - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: mean - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: meanall - return: accreal - arguments: - - THTensor* self - - cname: mean - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: mean - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("mean"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: var - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: varall - return: accreal - arguments: - - THTensor* self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - cname: var - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - bool keepdim - - cname: var - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("var"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - CONSTANT false -]] - -[[ - name: std - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: stdall - return: accreal - arguments: - - THTensor* self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - cname: std - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - bool keepdim - - cname: std - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("std"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - CONSTANT false -]] - -[[ - name: norm - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: normall - return: accreal - arguments: - - THTensor* self - - arg: real p - default: AS_REAL(2) - - cname: norm - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real p - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: norm - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("norm"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real p - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: renorm - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: renorm - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real p - - arg: int64_t dim - wrap_dim: self - - real maxnorm -]] - -[[ - name: renorm_ - types: - - floating_point - backends: - - CPU - - CUDA - options: - - cname: renorm - return: self - arguments: - - THTensor* self - - THTensor* self - - real p - - arg: int64_t dim - wrap_dim: self - - real maxnorm -]] - -[[ - name: dist - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: dist - return: accreal - arguments: - - arg: THTensor* self - broadcast: other fallback - - THTensor* other - - arg: real p - default: AS_REAL(2) -]] - -[[ - name: reciprocal - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: cinv - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: reciprocal_ - types: - - floating_point - backends: - - CPU - - CUDA - options: - - cname: cinv - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: neg - backends: - - CPU - - CUDA - variants: - - method - - function - options: - - cname: neg - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: neg_ - backends: - - CPU - - CUDA - options: - - cname: neg - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: atan2 - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - cname: atan2 - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: atan2_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: atan2 - return: argument 0 - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other fallback inplace - - THTensor* other -]] - - - -// These options look the same in stateful method - only the first one will -// be available. Still, they differ in torch.pow. -[[ - name: pow - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - options: - - cname: pow - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real exponent - - cname: cpow - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: exponent fallback - - THTensor* exponent - - cname: tpow - arguments: - - arg: THTensor* result - output: True - - real base - - THTensor* self -]] - -[[ - name: pow_ - types: - - floating_point - backends: - - CPU - - CUDA - return: argument 0 - cname: pow - options: - - cname: pow - arguments: - - THTensor* self - - THTensor* self - - real exponent - - cname: cpow - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: exponent inplace fallback - - THTensor* exponent -]] - -[[ - name: lerp - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - cname: lerp - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: end fallback - - THTensor* end - - real weight -]] - -[[ - name: lerp_ - types: - - floating_point - backends: - - CPU - - CUDA - return: self - cname: lerp - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: end fallback inplace - - THTensor* end - - real weight -]] - -[[ - name: linspace - types: - - Float - - Double - backends: - - CPU - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - real start - - real end - - arg: int64_t steps - default: 100 -]] - -[[ - name: logspace - types: - - Float - - Double - backends: - - CPU - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - real start - - real end - - arg: int64_t steps - default: 100 -]] - -[[ - name: histc - types: - - Float - - Double - backends: - - CPU - variants: - - method - - function - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - CONSTANT 100 - - CONSTANT 0 - - CONSTANT 0 - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - int64_t bins - - CONSTANT 0 - - CONSTANT 0 - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - int64_t bins - - real min - - CONSTANT 0 - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - int64_t bins - - real min - - real max -]] - -[[ - name: zero_ - cname: zero - return: self - arguments: - - THTensor* self -]] - -[[ - name: sum - variants: - - method - - function - options: - - cname: sumall - return: accreal - arguments: - - THTensor* self - - cname: sum - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: sum - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("sum"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: prod - variants: - - method - - function - options: - - cname: prodall - return: accreal - arguments: - - THTensor* self - - cname: prod - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - bool keepdim - - cname: prod - return: argument 0 - before_call: maybeThrowBackCompatKeepdimWarn("prod"); - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self - - CONSTANT false -]] - -[[ - name: cumsum - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self -]] - -[[ - name: cumprod - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t dim - wrap_dim: self -]] - -[[ - name: sign - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] - -[[ - name: sign_ - cname: sign - return: self - arguments: - - THTensor* self - - THTensor* self -]] - -[[ - name: trace - variants: - - method - - function - return: accreal - arguments: - - THTensor* self -]] - -[[ - name: add - variants: - - method - - function - return: argument 0 - options: - - cname: add - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cadd - aten_sparse: True - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - arg: real value - default: AS_REAL(1) - - THTensor* other - - sparse: True - cname: spcadd - aten_dense_sparse: True - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - -[[ - name: add_ - return: argument 0 - options: - - cname: add - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cadd - aten_sparse: True - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - arg: real value - default: AS_REAL(1) - - THTensor* other - - sparse: True - cname: spcadd - aten_dense_sparse: True - arguments: - - THTensor* self - - THTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* other -]] - - -[[ - name: sub - variants: - - method - - function - return: argument 0 - options: - - cname: sub - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: csub - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - arg: real value - default: AS_REAL(1) - - THTensor* other -]] - -[[ - name: sub_ - return: argument 0 - options: - - cname: sub - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: csub - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - arg: real value - default: AS_REAL(1) - - THTensor* other -]] - - -[[ - name: mul - variants: - - method - - function - return: argument 0 - options: - - cname: mul - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cmul - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - arg: THTensor* other -]] - -[[ - name: mul_ - return: argument 0 - options: - - cname: mul - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cmul - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - - -[[ - name: div - variants: - - method - - function - return: argument 0 - options: - - cname: div - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cdiv - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: div_ - return: argument 0 - options: - - cname: div - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cdiv - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - - -[[ - name: fmod - return: argument 0 - variants: - - method - - function - options: - - cname: fmod - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cfmod - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - THTensor* other -]] - -[[ - name: fmod_ - return: argument 0 - options: - - cname: fmod - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cfmod - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - - -[[ - name: remainder - return: argument 0 - variants: - - method - - function - options: - - cname: remainder - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real value - - cname: cremainder - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: other fallback - - arg: THTensor* other -]] - -[[ - name: remainder_ - return: argument 0 - options: - - cname: remainder - arguments: - - THTensor* self - - THTensor* self - - real value - - cname: cremainder - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: other inplace fallback - - THTensor* other -]] - -[[ - name: clamp - variants: - - method - - function - return: argument 0 - options: - - cname: clamp - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real min - - real max - - cname: cmaxValue - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: real min - kwarg_only: True - - cname: cminValue - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: real max - kwarg_only: True -]] - -[[ - name: clamp_ - cname: clamp - return: self - options: - - cname: clamp - arguments: - - THTensor* self - - THTensor* self - - real min - - real max - - cname: cmaxValue - arguments: - - THTensor* self - - THTensor* self - - arg: real min - kwarg_only: True - - cname: cminValue - arguments: - - THTensor* self - - THTensor* self - - arg: real max - kwarg_only: True -]] - -[[ - name: dot - backend_type_pairs: [[CUDA,floating_point], [CPU,all]] - - variants: - - method - - function - return: accreal - arguments: - - arg: THTensor* self - assert_ndim: 1 - - arg: THTensor* tensor - assert_ndim: 1 -]] - -[[ - name: tril - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t diagonal - default: 0 -]] - -[[ - name: tril_ - cname: tril - return: self - arguments: - - THTensor* self - - THTensor* self - - arg: int64_t diagonal - default: 0 -]] - -[[ - name: triu - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t diagonal - default: 0 -]] - -[[ - name: triu_ - cname: triu - return: self - arguments: - - THTensor* self - - THTensor* self - - arg: int64_t diagonal - default: 0 -]] - -[[ - name: cross - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THTensor* other - - arg: int64_t dim - default: -1 -]] - -[[ - name: eye - backends: - - CPU - - CUDA - variants: - - function - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - int64_t n - - argument 1 - - arguments: - - arg: THTensor* result - output: True - - int64_t n - - int64_t m -]] - -[[ - name: diag - variants: - - method - - function - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: int64_t diagonal - default: 0 -]] - -[[ - name: addmm - variants: - - method - - function - return: argument 0 - options: - - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - arg: THTensor* self - broadcast: mat1,mat2 dims:mat1.dim0,mat2.dim1 - - arg: real alpha - default: AS_REAL(1) - - THTensor* mat1 - - THTensor* mat2 - - cname: spaddmm - sparse: yes - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: addmm_ - return: self - options: - - cname: addmm - arguments: - - THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THTensor* mat1 - - THTensor* mat2 - - cname: spaddmm - sparse: yes - arguments: - - arg: THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THSTensor* mat1 - - THTensor* mat2 -]] - -[[ - name: addmv - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - arg: THTensor* self - broadcast: mat,vec dims:mat.dim0 - - arg: real alpha - default: AS_REAL(1) - - THTensor* mat - - THTensor* vec -]] - -[[ - name: addmv_ - cname: addmv - return: self - arguments: - - THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THTensor* mat - - THTensor* vec -]] - -[[ - name: addr - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - arg: THTensor* self - broadcast: vec1,vec2 dims:vec1.dim0,vec2.dim0 - - arg: real alpha - default: AS_REAL(1) - - THTensor* vec1 - - THTensor* vec2 -]] - -[[ - name: addr_ - cname: addr - return: self - arguments: - - THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THTensor* vec1 - - THTensor* vec2 -]] - -[[ - name: ger - cname: addr - variants: - - method - - function - return: argument 0 - before_call: | - int64_t s1 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg4)->cdata, 0); - int64_t s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 0); - THTensor_(resize2d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2); - arguments: - - arg: THTensor* result - output: True - resize: [ [self,0], [vec2,0] ] - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THTensor* self - - THTensor* vec2 -]] - -[[ - name: mv - cname: addmv - variants: - - method - - function - return: argument 0 - before_call: | - int64_t s = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg4)->cdata, 0); - THTensor_(resize1d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s); - #if !IS_CUDA - THTensor_(zero)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata); - #endif - arguments: - - arg: THTensor* result - output: True - resize: [ [self, 0] ] - cpu_zero: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THTensor* self - - THTensor* vec -]] - -[[ - name: mm - variants: - - method - - function - return: argument 0 - options: - - cname: addmm - before_call: | - int64_t s1 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg4)->cdata, 0); - int64_t s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 1); - THTensor_(resize2d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2); - #if !IS_CUDA - THTensor_(zero)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata); - #endif - arguments: - - arg: THTensor* result - output: True - resize: [ [self, 0], [mat2,1] ] - cpu_zero: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THTensor* self - - THTensor* mat2 - - cname: spaddmm - sparse: True - arguments: - - arg: THTensor* result - output: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THSTensor* self - - THTensor* mat2 -]] - -[[ - name: bmm - cname: baddbmm - variants: - - method - - function - return: argument 0 - before_call: | - int64_t s1 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg4)->cdata, 0); - int64_t s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg4)->cdata, 1); - int64_t s3 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 2); - THTensor_(resize3d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2, s3); - #if !IS_CUDA - THTensor_(zero)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata); - #endif - arguments: - - arg: THTensor* result - output: True - resize: [ [self,0], [self,1], [mat2,2] ] - cpu_zero: True - - CONSTANT AS_REAL(0) - - argument 0 - - CONSTANT AS_REAL(1) - - THTensor* self - - THTensor* mat2 -]] - -[[ - name: addbmm - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - arg: THTensor* self - broadcast: batch1,batch2 dims:batch1.dim1,batch2.dim2 - - arg: real alpha - default: AS_REAL(1) - - THTensor* batch1 - - THTensor* batch2 -]] - -[[ - name: addbmm_ - cname: addbmm - return: self - arguments: - - THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THTensor* batch1 - - THTensor* batch2 -]] - -[[ - name: baddbmm - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: real beta - default: AS_REAL(1) - - arg: THTensor* self - broadcast: batch1,batch2 dims:batch1.dim0,batch1.dim1,batch2.dim2 - - arg: real alpha - default: AS_REAL(1) - - THTensor* batch1 - - THTensor* batch2 -]] - -[[ - name: baddbmm_ - cname: baddbmm - return: argument 0 - arguments: - - THTensor* self - - arg: real beta - default: AS_REAL(1) - - THTensor* self - - arg: real alpha - default: AS_REAL(1) - - THTensor* batch1 - - THTensor* batch2 -]] - -[[ - name: addcmul - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: tensor1,tensor2 fallback - - arg: real value - default: AS_REAL(1) - - THTensor* tensor1 - - THTensor* tensor2 -]] - -[[ - name: addcmul_ - options: - - cname: addcmul - return: argument 0 - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: tensor1,tensor2 inplace fallback - - arg: real value - default: AS_REAL(1) - - THTensor* tensor1 - - THTensor* tensor2 - - cname: spaddcmul - defined_if: "!IS_DISTRIBUTED" - return: argument 0 - arguments: - - THTensor* self - - THTensor* self - - arg: real value - default: AS_REAL(1) - - THSTensor* tensor1 - - THSTensor* tensor2 -]] - -[[ - name: addcdiv - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: tensor1,tensor2 fallback - - arg: real value - default: AS_REAL(1) - - THTensor* tensor1 - - THTensor* tensor2 -]] - -[[ - name: addcdiv_ - cname: addcdiv - return: argument 0 - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: tensor1,tensor2 inplace fallback - - arg: real value - default: AS_REAL(1) - - THTensor* tensor1 - - THTensor* tensor2 -]] - -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || CUDA_FLOAT || CUDA_DOUBLE -// We need to pass pointers to chars to tensor lapack functions... -static const char __U = 'U'; -static const char __L = 'L'; -static const char __N = 'N'; -static const char __V = 'V'; -static const char __A = 'A'; -static const char __S = 'S'; -#if !IS_CUDA -static const char __T = 'T'; -static const char __R = 'R'; -#endif -static const char *U = &__U; -static const char *L = &__L; -static const char *N = &__N; -static const char *V = &__V; -static const char *A = &__A; -static const char *S = &__S; -#if !IS_CUDA -static const char *T = &__T; -static const char *R = &__R; -#endif -#endif - -[[ - name: gesv - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* solution - output: True - - arg: THTensor* lu - output: True - - THTensor* self - - THTensor* A -]] - -[[ - name: gels - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self - - THTensor* A -]] - -[[ - name: trtrs - types: - - Float - - Double - backends: - - CPU - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self - - THTensor* A - - arg: bool upper - if_true: U - if_false: L - default: U - - arg: bool transpose - if_true: T - if_false: N - default: N - - arg: bool unitriangular - if_true: U - if_false: N - default: N -]] - -[[ - name: symeig - cname: syev - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self - - arg: bool eigenvectors - if_true: V - if_false: N - default: N - - arg: bool upper - if_true: U - if_false: L - default: U -]] - -[[ - name: eig - cname: geev - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self - - arg: bool eigenvectors - if_true: V - if_false: N - default: N -]] - -[[ - name: svd - cname: gesvd - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1,2 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - arg: THTensor* res3 - output: True - - THTensor* self - - arg: bool some - if_true: S - if_false: A - default: S -]] - -[[ - name: inverse - cname: getri - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* output - output: True - - THTensor* self -]] - -[[ - name: potrf - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* output - output: True - - THTensor* self - - arg: bool upper - if_true: U - if_false: L - default: U -]] - -[[ - name: potrs - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THTensor* input2 - - arg: bool upper - if_true: U - if_false: L - default: U -]] - -[[ - name: potri - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* output - output: True - - THTensor* self - - arg: bool upper - if_true: U - if_false: L - default: U -]] - -[[ - name: pstrf - types: - - Float - - Double - backends: - - CPU - variants: - - method - - function - return: argument 0,1 - after_call: - THIntTensor_sub(((THPIntTensor*)$arg1)->cdata, ((THPIntTensor*)$arg1)->cdata, 1); - arguments: - - arg: THTensor* res1 - output: True - - arg: THIntegerTensor* res2 - output: True - - THTensor* self - - arg: bool upper - if_true: U - if_false: L - default: U - - arg: real tol - default: -1 -]] - -[[ - name: qr - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self -]] - -[[ - name: geqrf - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self -]] - -[[ - name: orgqr - types: - - Float - - Double - backends: - - CPU - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THTensor* input2 -]] - -[[ - name: ormqr - types: - - Float - - Double - backends: - - CPU - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THTensor* input2 - - THTensor* input3 - - arg: bool left - if_true: L - if_false: R - default: L - - arg: bool transpose - if_true: T - if_false: N - default: N -]] - -[[ - name: btrifact - cname: btrifact - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1 - arguments: - - arg: THTensor* result - output: True - - arg: THIntegerTensor* pivots - output: True - - arg: THIntegerTensor* info - kwarg_only: True - default: NULL - - arg: bool pivot - kwarg_only: True - default: "true" - - THTensor* self -]] - -[[ - name: btrifact_with_info - cname: btrifact - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0,1,2 - arguments: - - arg: THTensor* result - output: True - - arg: THIntegerTensor* pivots - output: True - - arg: THIntegerTensor* info - output: True - - arg: bool pivot - kwarg_only: True - default: "true" - - THTensor* self -]] - -[[ - name: btrisolve - cname: btrisolve - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - THTensor* LU_data - - THIntegerTensor* LU_pivots -]] diff --git a/torch/csrc/generic/methods/TensorRandom.cwrap b/torch/csrc/generic/methods/TensorRandom.cwrap deleted file mode 100644 index a0ee66ec03..0000000000 --- a/torch/csrc/generic/methods/TensorRandom.cwrap +++ /dev/null @@ -1,374 +0,0 @@ -[[ - name: randperm - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - int64_t n -]] - -[[ - name: random_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - - CUDA - return: self - options: - - cname: random - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - cname: cappedRandom - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - int64_t to - - cname: clampedRandom - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - int64_t from - - int64_t to -]] - -[[ - name: multinomial - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - method - - function - return: argument 0 - arguments: - - arg: THIndexTensor* result - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - THTensor* self - - int num_samples - - arg: bool replacement - default: "false" -]] - -[[ - name: uniform_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: uniform - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double from - default: 0 - - arg: double to - default: 1 -]] - -[[ - name: normal - types: - - floating_point - backends: - - CPU - - CUDA - return: argument 0 - variants: - - function - options: - - cname: normal_means - arguments: - - arg: THTensor* output - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - THTensor* means - - arg: double std - default: 1 - - cname: normal_stddevs - arguments: - - arg: THTensor* output - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double mean - default: 0 - - THTensor* std - - cname: normal_means_stddevs - arguments: - - arg: THTensor* output - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - THTensor* means - - THTensor* std -]] - -[[ - name: normal_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: normal - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double mean - default: 0 - - arg: double std - default: 1 -]] - -[[ - name: cauchy_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: cauchy - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double median - default: 0 - - arg: double sigma - default: 1 -]] - -[[ - name: logNormal_ - cname: logNormal - python_name: log_normal_ - types: - - floating_point - backends: - - CPU - - CUDA - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double mean - default: 1 - - arg: double std - default: 2 -]] - -[[ - name: exponential_ - types: - - floating_point - backends: - - CPU - - CUDA - cname: exponential - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double lambd - default: 1 -]] - -[[ - name: _standard_gamma - types: - - floating_point - backends: - - CPU - return: argument 0 - variants: - - function - options: - - cname: standard_gamma - arguments: - - arg: THTensor* output - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - THTensor* alpha -]] - -[[ - name: _dirichlet_grad - types: - - floating_point - backends: - - CPU - return: argument 0 - variants: - - function - options: - - cname: dirichlet_grad - arguments: - - arg: THTensor* output - output: True - - THTensor* x - - THTensor* alpha - - THTensor* total -]] - -[[ - name: rand - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: THSize* size - long_args: True -]] - -[[ - name: randn - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: THSize* size - long_args: True -]] - -[[ - name: geometric_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - - CUDA - cname: geometric - return: self - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - double p -]] - -#define THDoubleTensor_BERNOULLI_TENSOR THDoubleTensor_bernoulli_DoubleTensor -#define THFloatTensor_BERNOULLI_TENSOR THFloatTensor_bernoulli_FloatTensor -#define THCudaDoubleTensor_BERNOULLI_TENSOR THCudaDoubleTensor_bernoulli_DoubleTensor -#define THCudaTensor_BERNOULLI_TENSOR THCudaTensor_bernoulli_FloatTensor - -[[ - name: bernoulli - types: - - Float - - Double - backends: - - CPU - - CUDA - return: argument 0 - variants: - - method - - function - before_call: - CPU: THTensor_(resizeAs)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, ((THPTensor*)$arg2)->cdata); - CUDA: THTensor_(resizeAs)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, ((THPTensor*)$arg1)->cdata); - cname: BERNOULLI_TENSOR - arguments: - - arg: THTensor* output - output: True - resize: self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - THTensor* self -]] - -#undef THDoubleTensor_BERNOULLI_TENSOR -#undef THFloatTensor_BERNOULLI_TENSOR -#undef THCudaDoubleTensor_BERNOULLI_TENSOR -#undef THCudaTensor_BERNOULLI_TENSOR - -[[ - name: bernoulli_ - defined_if: "!IS_DISTRIBUTED" - backends: - - CPU - - CUDA - return: self - options: - - cname: bernoulli - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - arg: double p - default: 0.5 - - cname: bernoulli_FloatTensor - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - BackendFloatTensor* float_p - - cname: bernoulli_DoubleTensor - arguments: - - THTensor* self - - arg: THGenerator* generator - default: THPGenerator_TH_CData(THPDefaultGenerator) - kwarg_only: True - - BackendDoubleTensor* float_p -]] diff --git a/torch/csrc/generic/utils.cpp b/torch/csrc/generic/utils.cpp index 881eab4970..33c9a8366e 100644 --- a/torch/csrc/generic/utils.cpp +++ b/torch/csrc/generic/utils.cpp @@ -26,34 +26,20 @@ void THPPointer<THPStorage>::free() { Py_DECREF(ptr); } -template<> -void THPPointer<THPTensor>::free() { - if (ptr) - Py_DECREF(ptr); -} - #if GENERATE_SPARSE template<> void THPPointer<THSTensor>::free() { if (ptr) THSTensor_(free)(LIBRARY_STATE ptr); } - -template<> -void THPPointer<THSPTensor>::free() { - if (ptr) - Py_DECREF(ptr); -} #endif template class THPPointer<THStorage>; template class THPPointer<THTensor>; template class THPPointer<THPStorage>; -template class THPPointer<THPTensor>; #if GENERATE_SPARSE template class THPPointer<THSTensor>; -template class THPPointer<THSPTensor>; #endif #undef GENERATE_SPARSE diff --git a/torch/csrc/generic/utils.h b/torch/csrc/generic/utils.h index 83d508eecc..a9191a7c6f 100644 --- a/torch/csrc/generic/utils.h +++ b/torch/csrc/generic/utils.h @@ -9,17 +9,14 @@ #endif struct THPStorage; -struct THPTensor; struct THSPTensor; typedef class THPPointer<THStorage> THStoragePtr; typedef class THPPointer<THTensor> THTensorPtr; typedef class THPPointer<THPStorage> THPStoragePtr; -typedef class THPPointer<THPTensor> THPTensorPtr; #if GENERATE_SPARSE typedef class THPPointer<THSTensor> THSTensorPtr; -typedef class THPPointer<THSPTensor> THSPTensorPtr; #endif #if (!defined(THC_GENERIC_FILE) || defined(THC_REAL_IS_HALF)) && \ diff --git a/torch/csrc/jit/pybind.h b/torch/csrc/jit/pybind.h index 7a7f224ab3..2aa239aa38 100644 --- a/torch/csrc/jit/pybind.h +++ b/torch/csrc/jit/pybind.h @@ -23,9 +23,6 @@ public: if (THPVariable_Check(source)) { value = torch::jit::tracer::TraceInput(((THPVariable*)source)->cdata); return true; - } else if (THPModule_isTensor(source)) { - value = torch::jit::tracer::TraceInput(torch::createTensor(source)); - return true; } else { return false; } diff --git a/torch/csrc/nn/type_checks.h b/torch/csrc/nn/type_checks.h index db51d0ba35..b0463c309b 100644 --- a/torch/csrc/nn/type_checks.h +++ b/torch/csrc/nn/type_checks.h @@ -5,96 +5,101 @@ #include <ATen/ATen.h> -#include "THP_API.h" #include "torch/csrc/autograd/python_variable.h" namespace torch { namespace nn { -inline bool check_type(PyObject* obj, PyObject* cls, at::TypeID typeID) { - if ((PyObject*)Py_TYPE(obj) == cls) { - return true; - } +inline bool check_type(PyObject* obj, at::TypeID typeID) { if (THPVariable_Check(obj)) { return ((THPVariable*)obj)->cdata.data().type().ID() == typeID; } return false; } -template<typename TP, typename T> -inline T* unpack(PyObject* obj, PyObject* cls) { - if ((PyObject*)Py_TYPE(obj) == cls) { - return ((TP*)obj)->cdata; - } +template<typename T> +inline T* unpack(PyObject* obj) { return (T*) ((THPVariable*)obj)->cdata.data().unsafeGetTH(false); } }} // namespace torch::nn +static inline int get_device(PyObject* args) { + for (int i = 0, n = PyTuple_GET_SIZE(args); i != n; i++) { + PyObject* arg = PyTuple_GET_ITEM(args, i); + if (THPVariable_Check(arg)) { + auto& tensor = THPVariable_UnpackData(arg); + if (tensor.type().is_cuda()) { + return tensor.get_device(); + } + } + } + return -1; +} static inline bool THNN_FloatTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THPFloatTensorClass, at::TypeID::CPUFloat); + return torch::nn::check_type(obj, at::TypeID::CPUFloat); } static inline bool THNN_DoubleTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THPDoubleTensorClass, at::TypeID::CPUDouble); + return torch::nn::check_type(obj, at::TypeID::CPUDouble); } static inline bool THNN_LongTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THPLongTensorClass, at::TypeID::CPULong); + return torch::nn::check_type(obj, at::TypeID::CPULong); } static inline bool THNN_IntTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THPIntTensorClass, at::TypeID::CPUInt); + return torch::nn::check_type(obj, at::TypeID::CPUInt); } static inline THFloatTensor* THNN_FloatTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THPFloatTensor, THFloatTensor>(obj, THPFloatTensorClass); + return torch::nn::unpack<THFloatTensor>(obj); } static inline THDoubleTensor* THNN_DoubleTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THPDoubleTensor, THDoubleTensor>(obj, THPDoubleTensorClass); + return torch::nn::unpack<THDoubleTensor>(obj); } static inline THLongTensor* THNN_LongTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THPLongTensor, THLongTensor>(obj, THPLongTensorClass); + return torch::nn::unpack<THLongTensor>(obj); } static inline THIntTensor* THNN_IntTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THPIntTensor, THIntTensor>(obj, THPIntTensorClass); + return torch::nn::unpack<THIntTensor>(obj); } #ifdef WITH_CUDA static inline bool THNN_CudaHalfTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THCPHalfTensorClass, at::TypeID::CUDAHalf); + return torch::nn::check_type(obj, at::TypeID::CUDAHalf); } static inline bool THNN_CudaFloatTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THCPFloatTensorClass, at::TypeID::CUDAFloat); + return torch::nn::check_type(obj, at::TypeID::CUDAFloat); } static inline bool THNN_CudaDoubleTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THCPDoubleTensorClass, at::TypeID::CUDADouble); + return torch::nn::check_type(obj, at::TypeID::CUDADouble); } static inline bool THNN_CudaLongTensor_Check(PyObject* obj) { - return torch::nn::check_type(obj, THCPLongTensorClass, at::TypeID::CUDALong); + return torch::nn::check_type(obj, at::TypeID::CUDALong); } static inline THCudaHalfTensor* THNN_CudaHalfTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THCPHalfTensor, THCudaHalfTensor>(obj, THCPHalfTensorClass); + return torch::nn::unpack<THCudaHalfTensor>(obj); } static inline THCudaTensor* THNN_CudaFloatTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THCPFloatTensor, THCudaTensor>(obj, THCPFloatTensorClass); + return torch::nn::unpack<THCudaTensor>(obj); } static inline THCudaDoubleTensor* THNN_CudaDoubleTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THCPDoubleTensor, THCudaDoubleTensor>(obj, THCPDoubleTensorClass); + return torch::nn::unpack<THCudaDoubleTensor>(obj); } static inline THCudaLongTensor* THNN_CudaLongTensor_Unpack(PyObject* obj) { - return torch::nn::unpack<THCPLongTensor, THCudaLongTensor>(obj, THCPLongTensorClass); + return torch::nn::unpack<THCudaLongTensor>(obj); } #endif // WITH_CUDA diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index f63309eb4b..b2e1374f9c 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -194,7 +194,7 @@ static void initialize_aten_types(std::vector<PyTensorType>& tensor_types) { default_tensor_type->is_base_type = true; } -void initialize_python_bindings(PyObject* module) { +void initialize_python_bindings() { // Initialize the at::Type* pointers, name, and properties of the PyTensorType // vector. After this call, the vector must not be resized. initialize_aten_types(tensor_types); diff --git a/torch/csrc/tensor/python_tensor.h b/torch/csrc/tensor/python_tensor.h index f6da26fd47..5a27ccca53 100644 --- a/torch/csrc/tensor/python_tensor.h +++ b/torch/csrc/tensor/python_tensor.h @@ -7,7 +7,7 @@ namespace torch { namespace tensor { // Initializes the Python tensor type objects: torch.Tensor, torch.FloatTensor, // etc. and binds them in their containing modules. -void initialize_python_bindings(PyObject* module); +void initialize_python_bindings(); // Sets the concrete type constructed by calls to torch.Tensor() and most // factory methods on the torch module. diff --git a/torch/csrc/utils/tuple_parser.cpp b/torch/csrc/utils/tuple_parser.cpp index 727c4b8efa..86d19794f1 100644 --- a/torch/csrc/utils/tuple_parser.cpp +++ b/torch/csrc/utils/tuple_parser.cpp @@ -45,11 +45,6 @@ auto TupleParser::parse(double& x, const std::string& param_name) -> void { x = THPUtils_unpackDouble(obj); } -auto TupleParser::parse(at::Tensor& x, const std::string& param_name) -> void { - PyObject* obj = next_arg(); - x = torch::createTensor(obj); -} - auto TupleParser::parse(std::vector<int>& x, const std::string& param_name) -> void { PyObject* obj = next_arg(); if (!PyTuple_Check(obj)) { diff --git a/torch/csrc/utils/tuple_parser.h b/torch/csrc/utils/tuple_parser.h index f487c4c711..648f3ba975 100644 --- a/torch/csrc/utils/tuple_parser.h +++ b/torch/csrc/utils/tuple_parser.h @@ -13,7 +13,6 @@ struct TupleParser { void parse(bool& x, const std::string& param_name); void parse(int& x, const std::string& param_name); void parse(double& x, const std::string& param_name); - void parse(at::Tensor& x, const std::string& param_name); void parse(std::vector<int>& x, const std::string& param_name); void parse(std::string& x, const std::string& param_name); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 7268c266c3..67ef77f62a 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -159,7 +159,6 @@ def _lazy_init(): "Cannot re-initialize CUDA in forked subprocess. " + msg) _check_driver() torch._C._cuda_init() - torch._C._cuda_sparse_init() _cudart = _load_cudart() _cudart.cudaGetErrorName.restype = ctypes.c_char_p _cudart.cudaGetErrorString.restype = ctypes.c_char_p @@ -466,7 +465,6 @@ from .random import * ################################################################################ -from ..tensor import _TensorBase from ..storage import _StorageBase @@ -541,87 +539,6 @@ class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase): pass -class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return DoubleStorage - - -class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return FloatStorage - - -class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return LongStorage - - -class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return IntStorage - - -class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return ShortStorage - - -class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase): - - def is_signed(self): - # TODO - return False - - @classmethod - def storage_type(cls): - return CharStorage - - -class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase): - - def is_signed(self): - return False - - @classmethod - def storage_type(cls): - return ByteStorage - - -class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase): - - def is_signed(self): - return True - - @classmethod - def storage_type(): - return HalfStorage - - torch._storage_classes.add(DoubleStorage) torch._storage_classes.add(FloatStorage) torch._storage_classes.add(LongStorage) @@ -631,12 +548,6 @@ torch._storage_classes.add(CharStorage) torch._storage_classes.add(ByteStorage) torch._storage_classes.add(HalfStorage) -torch._integer_tensor_classes.add(LongTensor) -torch._integer_tensor_classes.add(IntTensor) -torch._integer_tensor_classes.add(ShortTensor) -torch._integer_tensor_classes.add(CharTensor) -torch._integer_tensor_classes.add(ByteTensor) - from . import sparse from . import profiler from . import nvtx diff --git a/torch/cuda/sparse.py b/torch/cuda/sparse.py index b0dc613f48..f37a34118d 100644 --- a/torch/cuda/sparse.py +++ b/torch/cuda/sparse.py @@ -1,93 +1 @@ -import torch -from torch import _C -from ..tensor import _TensorBase -from torch.sparse import _SparseBase, _sparse_tensor_classes -from . import _lazy_init, device, _dummy_type - - -if not hasattr(torch._C, 'CudaSparseDoubleTensorBase'): - # Define dummy base classes - for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half']: - tensor_name = 'CudaSparse{0}TensorBase'.format(t) - - torch._C.__dict__[tensor_name] = _dummy_type(tensor_name) - - -class _CudaSparseBase(object): - is_cuda = True - is_sparse = True - - def type(self, *args, **kwargs): - with device(self.get_device()): - return super(_CudaSparseBase, self).type(*args, **kwargs) - - def __new__(cls, *args, **kwargs): - _lazy_init() - # We need this method only for lazy init, so we can remove it - del _CudaSparseBase.__new__ - return super(_CudaSparseBase, cls).__new__(cls, *args, **kwargs) - - -class DoubleTensor(_CudaSparseBase, torch._C.CudaSparseDoubleTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -class FloatTensor(_CudaSparseBase, torch._C.CudaSparseFloatTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -class LongTensor(_CudaSparseBase, torch._C.CudaSparseLongTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -class IntTensor(_CudaSparseBase, torch._C.CudaSparseIntTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -class ShortTensor(_CudaSparseBase, torch._C.CudaSparseShortTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -class CharTensor(_CudaSparseBase, torch._C.CudaSparseCharTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - # TODO - return False - - -class ByteTensor(_CudaSparseBase, torch._C.CudaSparseByteTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return False - - -class HalfTensor(_CudaSparseBase, torch._C.CudaSparseHalfTensorBase, _SparseBase, _TensorBase): - - def is_signed(self): - return True - - -_sparse_tensor_classes.add(DoubleTensor) -_sparse_tensor_classes.add(FloatTensor) -_sparse_tensor_classes.add(LongTensor) -_sparse_tensor_classes.add(IntTensor) -_sparse_tensor_classes.add(ShortTensor) -_sparse_tensor_classes.add(CharTensor) -_sparse_tensor_classes.add(ByteTensor) -_sparse_tensor_classes.add(HalfTensor) - -torch._integer_tensor_classes.add(LongTensor) -torch._integer_tensor_classes.add(IntTensor) -torch._integer_tensor_classes.add(ShortTensor) -torch._integer_tensor_classes.add(CharTensor) -torch._integer_tensor_classes.add(ByteTensor) +# The Tensor classes are added to this module by python_tensor.cpp diff --git a/torch/distributed/remote_types.py b/torch/distributed/remote_types.py index de85040ffb..a8d10cd93b 100644 --- a/torch/distributed/remote_types.py +++ b/torch/distributed/remote_types.py @@ -1,6 +1,5 @@ import torch -from ..tensor import _TensorBase from ..storage import _StorageBase @@ -41,87 +40,6 @@ class HalfStorage(_DistributedBase, torch._C.DistributedHalfStorageBase, _Storag pass -class DoubleTensor(_DistributedBase, torch._C.DistributedDoubleTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return DoubleStorage - - -class HalfTensor(_DistributedBase, torch._C.DistributedHalfTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return HalfStorage - - -class FloatTensor(_DistributedBase, torch._C.DistributedFloatTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return FloatStorage - - -class LongTensor(_DistributedBase, torch._C.DistributedLongTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return LongStorage - - -class IntTensor(_DistributedBase, torch._C.DistributedIntTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return IntStorage - - -class ShortTensor(_DistributedBase, torch._C.DistributedShortTensorBase, _TensorBase): - def is_signed(self): - return True - - @classmethod - def storage_type(cls): - return ShortStorage - - -class CharTensor(_DistributedBase, torch._C.DistributedCharTensorBase, _TensorBase): - def is_signed(self): - # TODO - return False - - @classmethod - def storage_type(cls): - return CharStorage - - -class ByteTensor(_DistributedBase, torch._C.DistributedByteTensorBase, _TensorBase): - def is_signed(self): - return False - - @classmethod - def storage_type(cls): - return ByteStorage - - -# class HalfTensor(_DistributedBase, torch._C.DistributedHalfTensorBase, _TensorBase): - # def is_signed(self): - # return True - # @classmethod - # def storage_type(): - # return HalfStorage - - torch._storage_classes.add(DoubleStorage) torch._storage_classes.add(FloatStorage) torch._storage_classes.add(HalfStorage) @@ -131,20 +49,6 @@ torch._storage_classes.add(ShortStorage) torch._storage_classes.add(CharStorage) torch._storage_classes.add(ByteStorage) -torch._tensor_classes.add(DoubleTensor) -torch._tensor_classes.add(FloatTensor) -torch._tensor_classes.add(HalfTensor) -torch._tensor_classes.add(LongTensor) -torch._tensor_classes.add(IntTensor) -torch._tensor_classes.add(ShortTensor) -torch._tensor_classes.add(CharTensor) -torch._tensor_classes.add(ByteTensor) - -torch._integer_tensor_classes.add(LongTensor) -torch._integer_tensor_classes.add(IntTensor) -torch._integer_tensor_classes.add(ShortTensor) -torch._integer_tensor_classes.add(CharTensor) -torch._integer_tensor_classes.add(ByteTensor) _type_names = ['Double', 'Float', 'Half', 'Long', 'Int', 'Short', 'Char', 'Byte'] _locals = locals() diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index 87b7bb9204..9624446f99 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -9,7 +9,7 @@ from torch.distributions.utils import _finfo, broadcast_all def _dirichlet_sample_nograd(concentration): - probs = torch._C._standard_gamma(concentration) + probs = torch._standard_gamma(concentration) probs /= probs.sum(-1, True) eps = _finfo(probs).eps return probs.clamp_(min=eps, max=1 - eps) @@ -18,7 +18,7 @@ def _dirichlet_sample_nograd(concentration): # This helper is exposed for testing. def _Dirichlet_backward(x, concentration, grad_output): total = concentration.sum(-1, True).expand_as(concentration) - grad = torch._C._dirichlet_grad(x, concentration, total) + grad = torch._dirichlet_grad(x, concentration, total) return grad * (grad_output - (x * grad_output).sum(-1, True)) diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 60ec109ef1..69500070ab 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -9,8 +9,6 @@ from torch.distributions.utils import _finfo, broadcast_all, lazy_property def _standard_gamma(concentration): - if not isinstance(concentration, Variable): - return torch._C._standard_gamma(concentration) return concentration._standard_gamma() diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 53fc7306eb..f37a34118d 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -1,186 +1 @@ -import torch -from torch import _C -from ..tensor import _TensorBase - -_sparse_tensor_classes = set() - - -class _SparseBase(object): - is_cuda = False - is_sparse = True - - def cpu(self): - return self.type(getattr(torch.sparse, self.__class__.__name__)) - - def is_pinned(self): - raise NotImplementedError - - def pin_memory(self): - raise NotImplementedError - - def share_memory_(self): - raise NotImplementedError - - def is_shared(self): - raise NotImplementedError - - def __deepcopy__(self, _memo): - memo = _memo.setdefault('torch', {}) - if self._cdata in memo: - return memo[self._cdata] - new_tensor = self.clone() - memo[self._cdata] = new_tensor - return new_tensor - - def __reduce__(self): - raise NotImplementedError - - def __getstate__(self): - raise NotImplementedError - - def __setstate__(self, state): - raise NotImplementedError - - def __bool__(self): - # TODO (easy) implement numel and remove this override - raise NotImplementedError - - def __iter__(self): - raise NotImplementedError - - def split(self, split_size, dim=0): - raise NotImplementedError - - def chunk(self, n_chunks, dim=0): - raise NotImplementedError - - def tolist(self): - raise NotImplementedError - - def view_as(self, tensor): - raise NotImplementedError - - def permute(self, *dims): - raise NotImplementedError - - def expand(self, *sizes): - raise NotImplementedError - - def expand_as(self, tensor): - raise NotImplementedError - - def repeat(self, *sizes): - raise NotImplementedError - - def __rsub__(self, other): - raise NotImplementedError - - def __matmul__(self, other): - raise NotImplementedError - - def __rdiv__(self, other): - raise NotImplementedError - - def __idiv__(self, other): - raise NotImplementedError - - def __mod__(self, other): - raise NotImplementedError - - def __neg__(self): - raise NotImplementedError - - def __eq__(self, other): - raise NotImplementedError - - def __ne__(self, other): - raise NotImplementedError - - def __lt__(self, other): - raise NotImplementedError - - def __le__(self, other): - raise NotImplementedError - - def __gt__(self, other): - raise NotImplementedError - - def __ge__(self, other): - raise NotImplementedError - - def __and__(self, other): - raise NotImplementedError - - def __or__(self, other): - raise NotImplementedError - - def __xor__(self, other): - raise NotImplementedError - - def __iand__(self, other): - raise NotImplementedError - - def __ior__(self, other): - raise NotImplementedError - - def __ixor__(self, other): - raise NotImplementedError - - def __str__(self): - # NB: modest duplication with _tensor_str - size_str = 'x'.join(str(size) for size in self.size()) - return '{} of size {} with indices:\n{}and values:\n{}'.format( - self.__class__.__name__, size_str, self._indices(), self._values()) - - -class DoubleTensor(_SparseBase, _C.SparseDoubleTensorBase, _TensorBase): - def is_signed(self): - return True - - -class FloatTensor(_SparseBase, _C.SparseFloatTensorBase, _TensorBase): - def is_signed(self): - return True - - -class LongTensor(_SparseBase, _C.SparseLongTensorBase, _TensorBase): - def is_signed(self): - return True - - -class IntTensor(_SparseBase, _C.SparseIntTensorBase, _TensorBase): - def is_signed(self): - return True - - -class ShortTensor(_SparseBase, _C.SparseShortTensorBase, _TensorBase): - def is_signed(self): - return True - - -class CharTensor(_SparseBase, _C.SparseCharTensorBase, _TensorBase): - def is_signed(self): - # TODO - return False - - -class ByteTensor(_SparseBase, _C.SparseByteTensorBase, _TensorBase): - def is_signed(self): - return False - - -_sparse_tensor_classes.add(DoubleTensor) -_sparse_tensor_classes.add(FloatTensor) -_sparse_tensor_classes.add(LongTensor) -_sparse_tensor_classes.add(IntTensor) -_sparse_tensor_classes.add(ShortTensor) -_sparse_tensor_classes.add(CharTensor) -_sparse_tensor_classes.add(ByteTensor) - -torch._integer_tensor_classes.add(LongTensor) -torch._integer_tensor_classes.add(IntTensor) -torch._integer_tensor_classes.add(ShortTensor) -torch._integer_tensor_classes.add(CharTensor) -torch._integer_tensor_classes.add(ByteTensor) - -_C._sparse_init() +# The Tensor classes are added to this module by python_tensor.cpp diff --git a/torch/tensor.py b/torch/tensor.py deleted file mode 100644 index 5fb6d37f7d..0000000000 --- a/torch/tensor.py +++ /dev/null @@ -1,391 +0,0 @@ -import torch -import warnings -from . import _tensor_str -from . import _utils -import sys - - -class _TensorBase(object): - #: bool: True if this is a CUDA tensor - is_cuda = False - is_sparse = False - - # NB: This implementation is CPU only; see THPTensor_(new) for the - # CUDA case, which handles constructing the tensor on the same GPU - # as this tensor. - def new(self, *args, **kwargs): - r"""Constructs a new tensor of the same data type as :attr:`self` tensor. - - Any valid argument combination to the tensor constructor is accepted by - this method, including sizes, :class:`torch.Storage`, NumPy ndarray, - Python Sequence, etc. See :ref:`torch.Tensor <tensor-doc>` for more - details. - - .. note:: For CUDA tensors, this method will create new tensor on the - same device as this tensor. - """ - return self.__class__(*args, **kwargs) - - type = _utils._type - cuda = _utils._cuda - - def type_as(self, tensor): - r"""Returns this :attr:`self` tensor cast to the type of the given - tensor. - - This is a no-op if the :attr:`self` tensor is already of the correct - type. This is equivalent to:: - - self.type(tensor.type()) - - Params: - tensor (Tensor): the tensor with the desired type - """ - return self.type(tensor.type()) - - def cpu(self): - r"""Returns a CPU copy of this tensor if it's not already on the CPU""" - return self.type(getattr(torch, self.__class__.__name__)) - - def double(self): - r"""Casts this tensor to double type""" - return self.type(type(self).__module__ + '.DoubleTensor') - - def float(self): - r"""Casts this tensor to float type""" - return self.type(type(self).__module__ + '.FloatTensor') - - def half(self): - r"""Casts this tensor to half-precision float type""" - return self.type(type(self).__module__ + '.HalfTensor') - - def long(self): - r"""Casts this tensor to long type""" - return self.type(type(self).__module__ + '.LongTensor') - - def int(self): - r"""Casts this tensor to int type""" - return self.type(type(self).__module__ + '.IntTensor') - - def short(self): - r"""Casts this tensor to short type""" - return self.type(type(self).__module__ + '.ShortTensor') - - def char(self): - r"""Casts this tensor to char type""" - return self.type(type(self).__module__ + '.CharTensor') - - def byte(self): - r"""Casts this tensor to byte type""" - return self.type(type(self).__module__ + '.ByteTensor') - - def is_pinned(self): - r"""Returns true if this tensor resides in pinned memory""" - storage = self.storage() - return storage.is_pinned() if storage else False - - def pin_memory(self): - r"""Copies the tensor to pinned memory, if it's not already pinned.""" - if self.is_cuda: - raise TypeError("cannot pin '{0}' only CPU memory can be pinned" - .format(self.type())) - storage = self.contiguous().storage() - if storage is None: - storage = (self.storage_type())() - return type(self)().set_(storage.pin_memory()).view_as(self) - - def share_memory_(self): - r"""Moves the underlying storage to shared memory. - - This is a no-op if the underlying storage is already in shared memory - and for CUDA tensors. Tensors in shared memory cannot be resized. - """ - self.storage().share_memory_() - return self - - def is_shared(self): - r"""Checks if tensor is in shared memory. - - This is always ``True`` for CUDA tensors. - """ - return self.storage().is_shared() - - @property - def shape(self): - r"""Alias for .size() - - Returns a torch.Size object, containing the dimensions of the - :attr:`self` Tensor. - """ - return self.size() - - def __deepcopy__(self, _memo): - memo = _memo.setdefault('torch', {}) - if self._cdata in memo: - return memo[self._cdata] - new_storage = self.storage().__deepcopy__(_memo) - new_tensor = self.new() - new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride()) - memo[self._cdata] = new_tensor - return new_tensor - - def __reduce__(self): - # NOTE: _rebuild_tensor does not call __setstate__ - args = self.__getstate__() - return (_utils._rebuild_tensor, args) - - def __getstate__(self): - return (self.storage(), - self.storage_offset(), - tuple(self.size()), - self.stride()) - - def __setstate__(self, state): - self.set_(*state) - - def __repr__(self): - return str(self) - - def __str__(self): - # All strings are unicode in Python 3, while we have to encode unicode - # strings in Python2. If we can't, let python decide the best - # characters to replace unicode characters with. - if sys.version_info > (3,): - return _tensor_str._str(self) - else: - if hasattr(sys.stdout, 'encoding'): - return _tensor_str._str(self).encode( - sys.stdout.encoding or 'UTF-8', 'replace') - else: - return _tensor_str._str(self).encode('UTF-8', 'replace') - - def __bool__(self): - if self.numel() == 0: - return False - elif self.numel() == 1: - return torch.squeeze(self)[0] != 0 - raise RuntimeError("bool value of " + torch.typename(self) + - " containing more than one value is ambiguous") - - __nonzero__ = __bool__ - - def __iter__(self): - if self.nelement() > 0: - return iter(map(lambda i: self.select(0, i), range(self.size(0)))) - else: - return iter([]) - - def split(self, split_size, dim=0): - r"""Splits this tensor into tensor chunks of :attr:`split_size` size. - - See :func:`torch.split`. - """ - return torch.split(self, split_size, dim) - - def chunk(self, n_chunks, dim=0): - r"""Splits this tensor into a certain number of tensor chunks. - - See :func:`torch.chunk`. - """ - return torch.chunk(self, n_chunks, dim) - - def matmul(self, other): - r"""Matrix product of two tensors. - - See :func:`torch.matmul`.""" - return torch.matmul(self, other) - - def tolist(self): - r"""Returns a nested list represenation of this tensor.""" - return torch.autograd.Variable(self).tolist() - - def view_as(self, tensor): - r"""Returns this tensor viewed as the size as the specified tensor. - - This is equivalent to:: - - self.view(tensor.size()) - """ - return self.view(tensor.size()) - - def permute(self, *dims): - r"""Permute the dimensions of this tensor. - - Args: - *dims (int...): The desired ordering of dimensions - - Example: - >>> x = torch.randn(2, 3, 5) - >>> x.size() - torch.Size([2, 3, 5]) - >>> x.permute(2, 0, 1).size() - torch.Size([5, 2, 3]) - """ - perm = list(dims) - tensor = self - n_dims = tensor.dim() - assert len(perm) == n_dims, 'Invalid permutation' - for i, p in enumerate(perm): - if p != i and p != -1: - j = i - while True: - assert 0 <= perm[j] and perm[j] < n_dims, 'Invalid permutation' - tensor = tensor.transpose(j, perm[j]) - perm[j], j = -1, perm[j] - if perm[j] == i: - break - perm[j] = -1 - return tensor - - def expand_as(self, tensor): - r"""Expands this tensor to the size of the specified tensor. - - This is equivalent to:: - - self.expand(tensor.size()) - """ - return self.expand(tensor.size()) - - repeat = _utils._repeat - - def masked_copy_(self, *args, **kwargs): - warnings.warn("masked_copy_ is deprecated and renamed to masked_scatter_, and will be removed in v0.3") - return self.masked_scatter_(*args, **kwargs) - - # TODO: add tests for operators - def __add__(self, other): - return self.add(other) - __radd__ = __add__ - - def __iadd__(self, other): - return self.add_(other) - - def __sub__(self, other): - return self.sub(other) - - def __rsub__(self, other): - return self.new().resize_as_(self).fill_(other).add_(-1, self) - - def __isub__(self, other): - return self.sub_(other) - - def __mul__(self, other): - return self.mul(other) - __rmul__ = __mul__ - - def __imul__(self, other): - return self.mul_(other) - - def __matmul__(self, other): - if not torch.is_tensor(other): - return NotImplemented - return self.matmul(other) - - def __pow__(self, other): - return self.pow(other) - - def __rpow__(self, other): - return torch.pow(other, self) - - def __ipow__(self, other): - return self.pow_(other) - - def __div__(self, other): - return self.div(other) - __truediv__ = __div__ - - def __rdiv__(self, other): - return self.new().resize_as_(self).fill_(other).div_(self) - __rtruediv__ = __rdiv__ - - def __idiv__(self, other): - return self.div_(other) - __itruediv__ = __idiv__ - - def __mod__(self, other): - return self.remainder(other) - - def __neg__(self): - return self.neg() - - def __eq__(self, other): - return self.eq(other) - - def __ne__(self, other): - return self.ne(other) - - def __lt__(self, other): - return self.lt(other) - - def __le__(self, other): - return self.le(other) - - def __gt__(self, other): - return self.gt(other) - - def __ge__(self, other): - return self.ge(other) - - # TODO: add native add or and xor in the libs - def __invert__(self): - if type(self).__name__ != 'ByteTensor': - raise RuntimeError('logical operations are supported on ByteTensors only') - return (1 - self) - - def __hash__(self): - return id(self) - - def __int__(self): - if self.numel() == 1: - return int(self[(0,) * self.ndimension()]) - raise TypeError("only 1-element tensors can be converted " - "to Python scalars") - - def __long__(self): - if self.numel() == 1: - return long(self[(0,) * self.ndimension()]) - raise TypeError("only 1-element tensors can be converted " - "to Python scalars") - - def __float__(self): - if self.numel() == 1: - return float(self[(0,) * self.ndimension()]) - raise TypeError("only 1-element tensors can be converted " - "to Python scalars") - - # provide user guidance when they inavertently call autograd properties on a Tensor - @property - def data(self): - raise RuntimeError('cannot call .data on a torch.Tensor: did you intend to use autograd.Variable?') - - def numpy(self): - return torch.autograd.Variable(self).numpy() - - # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` - def __array__(self, dtype=None): - if dtype is None: - return self.cpu().numpy() - else: - return self.cpu().numpy().astype(dtype, copy=False) - - # Wrap Numpy array again in a suitable tensor when done, to support e.g. - # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` - def __array_wrap__(self, array): - if array.ndim == 0: - # TODO: remove this when 0-dimensional tensors are supported - if array.dtype.kind == 'b': - return bool(array) - elif array.dtype.kind in ('i', 'u'): - return int(array) - elif array.dtype.kind == 'f': - return float(array) - elif array.dtype.kind == 'c': - return complex(array) - else: - raise RuntimeError('bad scalar {!r}'.format(array)) - else: - if array.dtype == bool: - # Workaround, torch has no built-in bool tensor - array = array.astype('uint8') - - return torch.from_numpy(array) |