diff options
author | Gregory Chanan <gchanan@fb.com> | 2017-09-20 11:19:23 -0700 |
---|---|---|
committer | Edward Z. Yang <ezyang@mit.edu> | 2017-11-02 19:53:36 -0400 |
commit | a10030eec79fa72abea103f9fbf4a0d37105569c (patch) | |
tree | f261def95018ef2d2f5a9e9e345109b5a733cc9f /aten | |
parent | c369d4da851e72f8c06d688a984ed54d9a098824 (diff) | |
download | pytorch-a10030eec79fa72abea103f9fbf4a0d37105569c.tar.gz pytorch-a10030eec79fa72abea103f9fbf4a0d37105569c.tar.bz2 pytorch-a10030eec79fa72abea103f9fbf4a0d37105569c.zip |
Represent empty tensors as size {0} tensors and fix scalar checks.
This gets rid of kUndefinedDimensions and has nice properties like:
- the dimensionality always matches the length of the sizes and strides.
- the number of elements is always the product of the sizes (starting at the identity)
- the shape you pass to factory functions (e.g. randn) matches the shape that is returned
etc.
In addition to the empty tensor change, this makes some related changes:
1) expand is now a native function, because it needs to operate on the ATen view of the size/strides.
2) adds tests for a number of functions operating on empty, scalar, non-scalar tensors.
This uncovered a number of scalar_check bugs; some of these are fixed in the generated code,
some that need to be manually specified can be specified by a 'scalar_check' argument in the cwrap.
3) fixes the formatting of empty tensors
4) changes the THLongStorageView API; the public API was getting overly complicated, so now you call
'makeFromSize', 'makeFromStride', 'makeFromLength' and it just handles the correct mapping for that type.
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/Declarations.cwrap | 13 | ||||
-rw-r--r-- | aten/src/ATen/ExpandUtils.cpp | 76 | ||||
-rw-r--r-- | aten/src/ATen/ExpandUtils.h | 40 | ||||
-rw-r--r-- | aten/src/ATen/Formatting.cpp | 25 | ||||
-rw-r--r-- | aten/src/ATen/Local.cwrap | 3 | ||||
-rw-r--r-- | aten/src/ATen/NativeFunctions.h | 27 | ||||
-rw-r--r-- | aten/src/ATen/THLongStorageView.h | 36 | ||||
-rw-r--r-- | aten/src/ATen/function_wrapper.py | 28 | ||||
-rw-r--r-- | aten/src/ATen/templates/TensorDense.cpp | 7 | ||||
-rw-r--r-- | aten/src/ATen/templates/TensorDerived.cpp | 9 | ||||
-rw-r--r-- | aten/src/ATen/templates/Type.h | 8 | ||||
-rw-r--r-- | aten/src/ATen/test/CMakeLists.txt | 3 | ||||
-rw-r--r-- | aten/src/ATen/test/scalar_tensor_test.cpp | 226 | ||||
-rwxr-xr-x | aten/tools/run_tests.sh | 1 |
14 files changed, 427 insertions, 75 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index e9b001c477..e9c10ca271 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -157,15 +157,18 @@ return: argument 0 options: - cname: set + scalar_check: source arguments: - THTensor* self - THTensor* source - cname: setStorage + scalar_check: False arguments: - THTensor* self - CONSTANT NULL, 0, NULL, NULL - cname: setStorage before_call: THLongStoragePtr __storage_size(THLongStorage_newWithSize1(THStorage_(size)(LIBRARY_STATE arg_storage))); + scalar_check: False arguments: - THTensor* self - THStorage* storage @@ -428,19 +431,11 @@ long_args: True ]] [[ - name: expand - cname: newExpand - return: THTensor* - arguments: - - THTensor* self - - arg: THSize* size - long_args: True -]] -[[ name: resizeAs_ python_name: resize_as_ cname: resizeAs return: self + scalar_check: the_template arguments: - THTensor* self - THTensor* the_template diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp new file mode 100644 index 0000000000..9b6d228976 --- /dev/null +++ b/aten/src/ATen/ExpandUtils.cpp @@ -0,0 +1,76 @@ +#include "ATen/ExpandUtils.h" + +namespace at { + +std::vector<int64_t> infer_size(IntList a, IntList b) { + auto dimsA = a.size(); + auto dimsB = b.size(); + ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB; + std::vector<int64_t> expandedSizes(ndim); + + for (long i = ndim - 1; i >= 0; --i) { + long offset = ndim - 1 - i; + long dimA = dimsA - 1 - offset; + long dimB = dimsB - 1 - offset; + long sizeA = (dimA >= 0) ? a[dimA] : 1; + long sizeB = (dimB >= 0) ? b[dimB] : 1; + if (sizeA == sizeB || sizeA == 1 || sizeB == 1) { + expandedSizes[i] = std::max(sizeA, sizeB); + } else { + std::ostringstream oss; + oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b (" + << sizeB << ") at non-singleton dimension " << i; + throw std::runtime_error(oss.str()); + } + } + + return expandedSizes; +} + +std::tuple<std::vector<int64_t>, std::vector<int64_t> > +inferExpandGeometry(const Tensor &tensor, IntList sizes) { + int64_t ndim = sizes.size(); + + if (tensor.dim() == 0) { + std::vector<int64_t> expandedStrides(ndim, 0); + return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(sizes.vec(), expandedStrides); + } + std::vector<int64_t> expandedSizes(ndim); + std::vector<int64_t> expandedStrides(ndim); + + // create a new geometry for the tensors + for (int64_t i = ndim - 1; i >= 0; --i) { + int64_t offset = ndim - 1 - i; + int64_t dim = tensor.dim() - 1 - offset; + int64_t size = (dim >= 0) ? tensor.sizes()[dim] : 1; + int64_t stride = (dim >= 0) ? + tensor.strides()[dim] : expandedSizes[i + 1] * expandedStrides[i + 1]; + int64_t targetSize = sizes[i]; + if (targetSize == -1) { + if (dim < 0) { + std::ostringstream oss; + oss << "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, " + << "non-existing dimension " << i; + throw std::runtime_error(oss.str()); + } else { + targetSize = size; + } + } + if (size != targetSize) { + if (size == 1) { + size = targetSize; + stride = 0; + } else { + std::ostringstream oss; + oss << "The expanded size of the tensor (" << targetSize << ") must match the existing size (" << size + << ") at non-singleton dimension " << i; + throw std::runtime_error(oss.str()); + } + } + expandedSizes[i] = size; + expandedStrides[i] = stride; + } + return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(expandedSizes, expandedStrides); +} + +} diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index b3b1c3832c..a5c6970456 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -5,6 +5,9 @@ namespace at { +std::vector<int64_t> infer_size(IntList a, IntList b); +std::tuple<std::vector<int64_t>, std::vector<int64_t> > inferExpandGeometry(const Tensor &tensor, IntList sizes); + inline std::tuple<Tensor> expand_inplace(const Tensor &tensor, const Tensor &to_expand) { if (tensor.sizes().equals(to_expand.sizes())) { return std::make_tuple(to_expand); @@ -21,49 +24,24 @@ inline std::tuple<Tensor, Tensor> expand_inplace(const Tensor &tensor, const Ten return std::make_tuple(to_expand1.expand(tensor.sizes()), to_expand2.expand(tensor.sizes())); } -inline std::vector<int64_t> infer_size2(IntList a, IntList b) { - auto dimsA = a.size(); - auto dimsB = b.size(); - ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB; - std::vector<int64_t> expandedSizes(ndim); - - for (long i = ndim - 1; i >= 0; --i) { - long offset = ndim - 1 - i; - long dimA = dimsA - 1 - offset; - long dimB = dimsB - 1 - offset; - long sizeA = (dimA >= 0) ? a[dimA] : 1; - long sizeB = (dimB >= 0) ? b[dimB] : 1; - if (sizeA == sizeB || sizeA == 1 || sizeB == 1) { - expandedSizes[i] = std::max(sizeA, sizeB); - } else { - std::ostringstream oss; - oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b (" - << sizeB << ") at non-singleton dimension " << i; - throw std::runtime_error(oss.str()); - } - } - - return expandedSizes; -} - inline std::tuple<Tensor, Tensor> expand_outplace(const Tensor &to_expand1, const Tensor &to_expand2) { if (to_expand1.sizes().equals(to_expand2.sizes())) { return std::make_tuple(to_expand1, to_expand2); } - auto expanded_size = infer_size2(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = infer_size(to_expand1.sizes(), to_expand2.sizes()); return std::make_tuple(to_expand1.expand(expanded_size), to_expand2.expand(expanded_size)); } -std::tuple<Tensor, Tensor, Tensor> expand_outplace(const Tensor &to_expand1, - const Tensor &to_expand2, - const Tensor &to_expand3) { +inline std::tuple<Tensor, Tensor, Tensor> expand_outplace(const Tensor &to_expand1, + const Tensor &to_expand2, + const Tensor &to_expand3) { if (to_expand1.sizes().equals(to_expand2.sizes()) && to_expand1.sizes().equals(to_expand3.sizes())) { return std::make_tuple(to_expand1, to_expand2, to_expand3); } - auto expanded_size12 = infer_size2(to_expand1.sizes(), to_expand2.sizes()); - auto expanded_size = infer_size2(expanded_size12, to_expand3.sizes()); + auto expanded_size12 = infer_size(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = infer_size(expanded_size12, to_expand3.sizes()); return std::make_tuple(to_expand1.expand(expanded_size), to_expand2.expand(expanded_size), to_expand3.expand(expanded_size)); } diff --git a/aten/src/ATen/Formatting.cpp b/aten/src/ATen/Formatting.cpp index 799878016a..353158a066 100644 --- a/aten/src/ATen/Formatting.cpp +++ b/aten/src/ATen/Formatting.cpp @@ -242,7 +242,7 @@ void __printTensor(std::ostream& stream, Tensor& self, int64_t linesize) std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesize) { FormatGuard guard(stream); if(!tensor_.defined()) { - stream << "[ Tensor (empty) ]"; + stream << "[ Tensor (undefined) ]"; } else { Type& cpudouble = tensor_.type().toBackend(kCPU).toScalarType(kDouble); Tensor tensor = tensor_.toType(cpudouble).contiguous(); @@ -250,17 +250,22 @@ std::ostream& print(std::ostream& stream, const Tensor & tensor_, int64_t linesi stream << defaultfloat << tensor.data<double>()[0] << std::endl; stream << "[ " << tensor_.pImpl->toString() << "{} ]"; } else if(tensor.ndimension() == 1) { - double scale; - int64_t sz; - std::tie(scale, sz) = __printFormat(stream, tensor); - if(scale != 1) { - printScale(stream, scale); + if (tensor.numel() == 0) { + stream << "[ Tensor (empty) ]"; } - double* tensor_p = tensor.data<double>(); - for(int64_t i = 0; i < tensor.size(0); i++) { - stream << std::setw(sz) << tensor_p[i]/scale << std::endl; + else { + double scale; + int64_t sz; + std::tie(scale, sz) = __printFormat(stream, tensor); + if(scale != 1) { + printScale(stream, scale); + } + double* tensor_p = tensor.data<double>(); + for(int64_t i = 0; i < tensor.size(0); i++) { + stream << std::setw(sz) << tensor_p[i]/scale << std::endl; + } + stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "} ]"; } - stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "} ]"; } else if(tensor.ndimension() == 2) { __printMatrix(stream, tensor, linesize, 0); stream << "[ " << tensor_.pImpl->toString() << "{" << tensor.size(0) << "," << tensor.size(1) << "} ]"; diff --git a/aten/src/ATen/Local.cwrap b/aten/src/ATen/Local.cwrap index 580a2fa1d4..f9973cebaa 100644 --- a/aten/src/ATen/Local.cwrap +++ b/aten/src/ATen/Local.cwrap @@ -98,7 +98,7 @@ AT_ASSERT(ndim > 0, "unnarrow() cannot be applied to a 0-dim tensor."); std::vector<int64_t> self_sizes = self.sizes(); self_sizes[dimension] = dimSize; - auto self_sizes_ = THLongStorageView::make(self_sizes, true); + auto self_sizes_ = THLongStorageView::makeFromSize(self_sizes); ${THTensor}_zeros(${state,}result_->tensor, self_sizes_); auto narrowed_result = result.narrow(dimension, offset, self.size(dimension)); narrowed_result.copy_(self); @@ -126,6 +126,7 @@ - THStride* stride aten_custom_call: | ${THTensor}_setStorage(${state,}result_->tensor, self_->tensor->storage, self_->tensor->storageOffset, size_, stride_); + result_->maybeScalar(size.size() == 0); ]] [[ diff --git a/aten/src/ATen/NativeFunctions.h b/aten/src/ATen/NativeFunctions.h index 820aca11a8..1fc784e212 100644 --- a/aten/src/ATen/NativeFunctions.h +++ b/aten/src/ATen/NativeFunctions.h @@ -2,6 +2,7 @@ #include "ATen/ATen.h" #include "ATen/WrapDimUtils.h" +#include "ATen/ExpandUtils.h" #include <vector> namespace at { @@ -62,7 +63,7 @@ type_method_definition_dispatch: at::native::is_same_size [/NativeFunction] */ static inline bool is_same_size(const Tensor &self, const Tensor &other) { - return self.dim() == other.dim() && self.sizes().equals(other.sizes()); + return self.sizes().equals(other.sizes()); } /* @@ -98,5 +99,29 @@ static inline Tensor permute(const Tensor & self, IntList dims) { return self.as_strided(newSizes, newStrides); } +/* +[NativeFunction] +name: expand +arg: Tensor self +arg: IntList sizes +return: Tensor +variants: method, function +type_method_definition_level: base +type_method_definition_dispatch: at::native::expand +[/NativeFunction] +*/ +static inline Tensor expand(const Tensor &self, IntList sizes) { + if (sizes.size() < (size_t)self.dim()) { + throw std::runtime_error("the number of sizes provided must be greater or equal to the " + "number of dimensions in the tensor"); + } + + std::vector<int64_t> expandedSizes; + std::vector<int64_t> expandedStrides; + std::tie(expandedSizes, expandedStrides) = inferExpandGeometry(self, sizes); + + return self.as_strided(expandedSizes, expandedStrides); +} + } } diff --git a/aten/src/ATen/THLongStorageView.h b/aten/src/ATen/THLongStorageView.h index b7dc6752c4..4ac5d62060 100644 --- a/aten/src/ATen/THLongStorageView.h +++ b/aten/src/ATen/THLongStorageView.h @@ -4,25 +4,37 @@ namespace at { +static inline bool is_noelem_tensor_size(ArrayRef<int64_t> size) { + return size.size() == 1 && size[0] == 0; +} + // make a fake storage out of a size, pointer pair... // used as an argument where THSize and THStride are passed into TH class THLongStorageView { public: - // zero_dim_to_one converts an empty ArrayRef into [1] - // empty_to_null converts an empty ArrayRef into a null THLongStorage - static THLongStorageView make(ArrayRef<int64_t> ref, bool zero_dim_to_one = false, bool empty_to_null = false) { - assert(!(zero_dim_to_one && empty_to_null)); - return THLongStorageView(ref, zero_dim_to_one, empty_to_null); + static THLongStorageView makeFromSize(ArrayRef<int64_t> ref) { + return THLongStorageView(ref, true, false, false); + } + // noelem_to_empty is to differentiate strides of empty tensors vs scalars. In ATen, both may have strides [1], + // but in TH an empty tensor should have stride [], while a scalar should have stride [1]. + static THLongStorageView makeFromStride(ArrayRef<int64_t> ref, bool noelem_to_empty) { + return THLongStorageView(ref, false, true, noelem_to_empty); + } + static THLongStorageView makeFromLength(ArrayRef<int64_t> ref) { + return THLongStorageView(ref, false, false, false); } operator THLongStorage*() { - if (storage.size == 0 && empty_to_null) { + if (storage.size == 0 && zero_dim_to_null) { return nullptr; } return &storage; } private: - THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool empty_to_null) - : empty_to_null(empty_to_null) + // zero_dim_to_one converts an empty ArrayRef into [1] + // zero_dim_to_null converts an empty ArrayRef into a null THLongStorage + // noelem_to_empty makes an ArrayRef of [0] into an empty THLongStorage + THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool zero_dim_to_null, bool noelem_to_empty) + : zero_dim_to_null(zero_dim_to_null) { if(zero_dim_to_one && ref.size() == 0) { // make storage of size 0 actually a 1-length storage with 1 element @@ -30,7 +42,11 @@ private: one = 1; storage.data = &one; storage.size = 1; - } else { + } else if (noelem_to_empty && is_noelem_tensor_size(ref)) { + storage.data = (int64_t*)(ref.data()); + storage.size = 0; + } + else { storage.data = (int64_t*)(ref.data()); storage.size = ref.size(); } @@ -41,7 +57,7 @@ private: } int64_t one; THLongStorage storage; - bool empty_to_null; + bool zero_dim_to_null; }; } diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 24a64970f2..c2f2fec6c1 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -159,8 +159,8 @@ CHECKED_CAST = { 'THGenerator*': CodeTemplate( 'check_generator<${Backend}Generator>(${arg_name}, &context->defaultGenerator(backend()))'), - 'THSize*': CodeTemplate('THLongStorageView::make(${arg_name}, true)'), - 'THStride*': CodeTemplate('THLongStorageView::make(${arg_name}, false, true)'), + 'THSize*': CodeTemplate('THLongStorageView::makeFromSize(${arg_name})'), + 'THStride*': CodeTemplate('THLongStorageView::makeFromStride(${arg_name}, ${noelem_to_empty})'), 'real': CodeTemplate('${arg_name}.to${ScalarName}()'), 'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'), 'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, ' @@ -194,7 +194,7 @@ CONSTANT_REPLACEMENTS = [ ('THPDefaultGenerator->cdata', 'dynamic_cast<${Generator}&>().generator'), ('__storage_size.get\\(\\)', - 'THLongStorageView::make(static_cast<int64_t>(storage.size()))'), + 'THLongStorageView::makeFromLength(static_cast<int64_t>(storage.size()))'), ('__last_dim', 'self.ndimension()-1'), ] @@ -773,11 +773,20 @@ def create_derived(backend_type_env, declarations): # if there is a THSize* argument, then its dimensions are used to determine scalar. # otherwise, it is true if all the input tensors are scalars, scalar_check_is_from_size = False + scalar_check_is_from_option = False scalar_check = None + scalar_check_opt = option.get('scalar_check') + if scalar_check_opt is not None: + if scalar_check_opt is not False: + scalar_check = '{}->isScalar()'.format(scalar_check_opt + '_') + else: + scalar_check = 'false' + scalar_check_is_from_option = True + for arg in option['arguments']: if is_real_argument_to_wrapper(arg): count += 1 - if arg['type'] == 'THSize*': + if arg['type'] == 'THSize*' and not scalar_check_is_from_option: scalar_check_is_from_size = True scalar_check = '{}.size() == 0'.format(arg['name']) if arg['type'] == 'TensorList': @@ -816,10 +825,12 @@ def create_derived(backend_type_env, declarations): if 'default_init' in arg: default_init.append(arg['default_init']) + noelem_to_empty = 'is_noelem_tensor_size(size)' if 'size' in seen_names else 'false' check_cast = CHECKED_CAST[arg['type']].substitute( env, arg_name=arg['name'], arg_pos=count, null_okay=null_okay, default_init=default_init, - size=arg.get('size')) + size=arg.get('size'), + noelem_to_empty=noelem_to_empty) body.append("auto {}_ = {};".format( arg['name'], check_cast)) if drop_argument(arg, option) or replace_with_null(arg): @@ -845,12 +856,15 @@ def create_derived(backend_type_env, declarations): else: body += initializers - # isScalar() for all input tensors is and'd to form + # for out-of-place: isScalar() for all input tensors is and'd to form # the test for whether the output is also a scalar + # for in-place: isScalar() shouldn't change as a result of the operation if (not arg.get('output') and 'Tensor' in arg['type'] and 'TensorList' not in arg['type'] and 'THS' not in arg['type'] and - not scalar_check_is_from_size): + not scalar_check_is_from_size and + not scalar_check_is_from_option and + not option['inplace']): check = '{}->isScalar()'.format(arg['name'] + '_') if nullable_argument(arg): check = '(!{} || {})'.format(arg['name'] + '_', check) diff --git a/aten/src/ATen/templates/TensorDense.cpp b/aten/src/ATen/templates/TensorDense.cpp index 2f49c27750..878b222632 100644 --- a/aten/src/ATen/templates/TensorDense.cpp +++ b/aten/src/ATen/templates/TensorDense.cpp @@ -1,7 +1,12 @@ // included as 'TensorDenseOrSparse' in TensorDerived.cpp IntList ${Tensor}::strides() const { - return IntList(reinterpret_cast<int64_t*>(tensor->stride),dim()); + int64_t d = tensor->nDimension; + if (d != 0) { + return IntList(reinterpret_cast<int64_t*>(tensor->stride),dim()); + } else { + return IntList(kEmptyStrides); + } } Scalar ${Tensor}::localScalar() { AT_ASSERT(isScalar(),"localScalar() called on Tensor with %d dims",sizes().size()); diff --git a/aten/src/ATen/templates/TensorDerived.cpp b/aten/src/ATen/templates/TensorDerived.cpp index e39507823a..de4ce41ac8 100644 --- a/aten/src/ATen/templates/TensorDerived.cpp +++ b/aten/src/ATen/templates/TensorDerived.cpp @@ -20,7 +20,12 @@ const char * ${Tensor}::toString() const { } IntList ${Tensor}::sizes() const { - return IntList(reinterpret_cast<int64_t*>(tensor->size),dim()); + int64_t d = ${THTensor_nDimension}; + if (d != 0) { + return IntList(reinterpret_cast<int64_t*>(tensor->size),dim()); + } else { + return IntList(kEmptySizes); + } } int64_t ${Tensor}::dim() const { @@ -30,7 +35,7 @@ int64_t ${Tensor}::dim() const { // See Note [Undefined-dim versus 0-dim] if (d != 0) return d; - return kUndefinedDimensions; + return kEmptySizes.size(); } const char * ${Tensor}::typeString() { diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index ce3613f002..5dd0a9c2df 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -34,9 +34,11 @@ struct Generator; // zero elements. // // Because we are backed by Torch tensors, we need to be able to -// represent this state (of numel==0). kUndefinedDimensions represents this -// situation. -constexpr int64_t kUndefinedDimensions = std::numeric_limits<int64_t>::min(); +// represent this state (of numel==0). These tensors are represented +// by one-dimensional tensors with size[0] == 0 and stride[0] == 1 +// (the stride is arbitrary but matches the NumPy equivalent). +constexpr std::array<int64_t, 1> kEmptySizes { {0} }; +constexpr std::array<int64_t, 1> kEmptyStrides { {1} }; static inline void noop_deleter(void*) {} diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index f2425afc99..6f0e22ba3c 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -18,3 +18,6 @@ target_link_libraries(dlconvertor_test ATen) add_executable(native_test native_test.cpp) target_link_libraries(native_test ATen) + +add_executable(scalar_tensor_test scalar_tensor_test.cpp) +target_link_libraries(scalar_tensor_test ATen) diff --git a/aten/src/ATen/test/scalar_tensor_test.cpp b/aten/src/ATen/test/scalar_tensor_test.cpp new file mode 100644 index 0000000000..60b65938bb --- /dev/null +++ b/aten/src/ATen/test/scalar_tensor_test.cpp @@ -0,0 +1,226 @@ +#include "ATen/ATen.h" +#include <iostream> +#include <numeric> + +using namespace at; + +void assert_equal_size_dim(const Tensor &lhs, const Tensor &rhs) { + assert(lhs.dim() == rhs.dim()); + assert(lhs.sizes().equals(rhs.sizes())); +} + +bool should_expand(const IntList &from_size, const IntList &to_size) { + if(from_size.size() > to_size.size()) { + return false; + } + for (auto from_dim_it = from_size.rbegin(); from_dim_it != from_size.rend(); ++from_dim_it) { + for (auto to_dim_it = to_size.rbegin(); to_dim_it != to_size.rend(); ++to_dim_it) { + if (*from_dim_it != 1 && *from_dim_it != *to_dim_it) { + return false; + } + } + } + return true; +} + +int main() { + Type & T = CPU(kFloat); + + std::vector<std::vector<int64_t> > sizes = { {}, {0}, {1}, {1, 1}, {2}}; + + // single-tensor/size tests + for (auto s = sizes.begin(); s != sizes.end(); ++s) { + // verify that the dim, sizes, strides, etc match what was requested. + auto t = T.ones(*s); + assert(t.dim() == s->size()); + assert(t.ndimension() == s->size()); + assert(t.sizes().equals(*s)); + assert(t.strides().size() == s->size()); + auto numel = std::accumulate(s->begin(), s->end(), 1, std::multiplies<int64_t>()); + assert(t.numel() == numel); + // verify we can output + std::cout << t << std::endl; + + // set_ + auto t2 = T.ones(*s); + t2.set_(); + assert_equal_size_dim(t2, T.ones({0})); + + // unsqueeze + if (t.numel() != 0) { + if (t.dim() > 0) { + assert(t.unsqueeze(0).dim() == t.dim() + 1); + } else { + // FIXME: should be able to remove this if/else, unsqueezing a scalar should give 1-dimension + assert(t.unsqueeze(0).dim() == t.dim() + 2); + } + } else { + try { + // can't unsqueeze empty tensor + t.unsqueeze(0); + assert (false); + } catch (std::runtime_error &e) {} + } + + // squeeze + if (t.dim() > 0 && t.sizes()[0] == 1) { + // FIXME: the max should be 0, but we don't reduce down to scalars properly yet + assert(t.squeeze(0).dim() == std::max<int64_t>(t.dim() - 1, 1)); + } else if (t.dim() == 0 || t.numel() == 0) { + try { + t.squeeze(0); + assert(false); + } catch (std::runtime_error &e) {} + } else { + // In PyTorch, it is a no-op to try to squeeze a dimension that has size != 1; + // in NumPy this is an error. + assert(t.squeeze(0).dim() == t.dim()); + } + + // reduce + if (t.dim() > 0 && t.numel() != 0) { + // FIXME: the max should be 0, but we don't reduce down to scalars properly yet + assert(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 1)); + } else if (t.dim() == 0) { + try { + t.sum(0); + assert(false); + } catch (std::runtime_error &e) {} + } else { + // FIXME: you should be able to reduce over size {0} + try { + t.sum(0); + assert(false); + } catch (std::runtime_error &e) {} + } + + // simple indexing + if (t.dim() > 0 && t.numel() != 0) { + assert(t[0].dim() == std::max<int64_t>(t.dim() - 1, 0)); + } else if (t.dim() == 0) { + try { + t[0]; + assert(false); + } catch (std::runtime_error &e) {} + } + } + + for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) { + for (auto rhs_it = sizes.begin(); rhs_it != sizes.end(); ++rhs_it) { + // is_same_size should only match if they are the same shape + { + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + if(*lhs_it != *rhs_it) { + assert(!lhs.is_same_size(rhs)); + assert(!rhs.is_same_size(lhs)); + } + } + // forced size functions (resize_, resize_as, set_) + { + // resize_ + { + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + lhs.resize_(*rhs_it); + assert_equal_size_dim(lhs, rhs); + } + // resize_as_ + { + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + lhs.resize_as_(rhs); + assert_equal_size_dim(lhs, rhs); + } + // set_ + { + { + // with tensor + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + lhs.set_(rhs); + assert_equal_size_dim(lhs, rhs); + } + { + // with storage + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + auto storage = T.storage(rhs.numel()); + lhs.set_(*storage); + // should not be dim 0 because an empty storage is dim 1; all other storages aren't scalars + assert(lhs.dim() != 0); + } + { + // with storage, offset, sizes, strides + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + auto storage = T.storage(rhs.numel()); + lhs.set_(*storage, rhs.storage_offset(), rhs.sizes(), rhs.strides()); + assert_equal_size_dim(lhs, rhs); + } + } + + // assign_ + { + auto lhs = T.ones(*lhs_it); + auto lhs_save = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + try { + lhs.assign_(rhs); + assert(lhs_save.numel() == rhs.numel()); + // ensure didn't change shape + assert_equal_size_dim(lhs, lhs_save); + } catch (std::runtime_error &e) { + assert(lhs_save.numel() != rhs.numel()); + } + } + } + + // view + { + auto lhs = T.ones(*lhs_it); + auto rhs = T.ones(*rhs_it); + auto rhs_size = *rhs_it; + try { + auto result = lhs.view(rhs_size); + assert(lhs.numel() == rhs.numel()); + assert_equal_size_dim(result, rhs); + } catch (std::runtime_error &e) { + assert(lhs.numel() != rhs.numel()); + } + } + + // expand + { + auto lhs = T.ones(*lhs_it); + auto lhs_size = *lhs_it; + auto rhs = T.ones(*rhs_it); + auto rhs_size = *rhs_it; + bool should_pass = should_expand(lhs_size, rhs_size); + try { + auto result = lhs.expand(rhs_size); + assert(should_pass); + assert_equal_size_dim(result, rhs); + } catch (std::runtime_error &e) { + assert(!should_pass); + } + + // in-place functions (would be good if we can also do a non-broadcasting one, b/c + // broadcasting functions will always end up operating on tensors of same size; + // is there an example of this outside of assign_ ?) + { + bool should_pass_inplace = should_expand(rhs_size, lhs_size); + try { + lhs.add_(rhs); + assert(should_pass_inplace); + assert_equal_size_dim(lhs, T.ones(*lhs_it)); + } catch (std::runtime_error &e) { + assert(!should_pass_inplace); + } + } + } + } + } + + return 0; +} diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 462dde53a0..aca2a93d7c 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -9,4 +9,5 @@ $BUILD_ROOT/src/ATen/test/broadcast_test $BUILD_ROOT/src/ATen/test/wrapdim_test $BUILD_ROOT/src/ATen/test/dlconvertor_test $BUILD_ROOT/src/ATen/test/native_test +$BUILD_ROOT/src/ATen/test/scalar_tensor_test valgrind --suppressions=`dirname $0`/valgrind.sup --error-exitcode=1 $BUILD_ROOT/src/ATen/test/basic -n |