summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSsnL <tongzhou.wang.1994@gmail.com>2019-01-29 12:23:06 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-29 13:06:29 -0800
commitded6fb0293f45b4f6ab70a831e66516266fdc96f (patch)
treecddad340513004e290a989acc9d8c412f085c2b6
parentd79e45bbba250b6e9e1e7854cb2f2ec45f9d9800 (diff)
downloadpytorch-ded6fb0293f45b4f6ab70a831e66516266fdc96f.tar.gz
pytorch-ded6fb0293f45b4f6ab70a831e66516266fdc96f.tar.bz2
pytorch-ded6fb0293f45b4f6ab70a831e66516266fdc96f.zip
Add stack & cat support for CPU Half (#16389)
Summary: Fixes https://github.com/pytorch/pytorch/issues/6968 Needed for #14705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16389 Differential Revision: D13861446 Pulled By: gchanan fbshipit-source-id: 7b8700b95aaf252d9669693dbddccb2302e58409
-rw-r--r--aten/src/ATen/Declarations.cwrap1
-rw-r--r--aten/src/TH/generic/THTensor.cpp125
-rw-r--r--aten/src/TH/generic/THTensor.h4
-rw-r--r--aten/src/TH/generic/THTensorMath.h2
-rw-r--r--aten/src/TH/generic/THTensorMoreMath.cpp123
-rw-r--r--test/common_utils.py36
-rw-r--r--test/test_torch.py97
7 files changed, 201 insertions, 187 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index ae6d4bd4f8..903fbd51af 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -2872,6 +2872,7 @@
name: _th_cat
cname: catArray
variants: [function]
+ cpu_half: True
return: self
arguments:
- arg: THTensor* self
diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp
index e3dfc26ea6..193a7ced1e 100644
--- a/aten/src/TH/generic/THTensor.cpp
+++ b/aten/src/TH/generic/THTensor.cpp
@@ -668,6 +668,131 @@ scalar_t THTensor_(get4d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3));
}
+
+/* Shape manipulation methods */
+void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
+{
+ THTensor* inputs[2];
+ inputs[0] = ta;
+ inputs[1] = tb;
+ THTensor_(catArray)(r_, inputs, 2, dimension);
+}
+
+void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
+inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
+{
+ int first_dims = first->dim();
+ int second_dims = second->dim();
+ THArgCheck(first_dims == second_dims, 0,
+ "Tensors must have same number of dimensions: got %d and %d",
+ first_dims, second_dims);
+ for (int dim = 0; dim < first_dims; dim++) {
+ if (dim == dimension) {
+ continue;
+ }
+ int64_t first_dim_size = first->size(dim);
+ int64_t second_dim_size = second->size(dim);
+ THArgCheck(first_dim_size == second_dim_size, 0,
+ "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
+ dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
+ }
+}
+
+void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
+{
+ // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
+ // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
+ // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
+ // size (i.e. other empty sizes are not skipped).
+ // FIXME: warn if this is the case
+ bool allSkipped= true;
+ int64_t nDims = 0;
+ THTensor *notSkippedTensor; // non-owning reference
+ auto should_skip = [](THTensor *t) { return t->is_empty() && t->dim() == 1; };
+ for (int i = 0; i < numInputs; i++) {
+ if (should_skip(inputs[i])) {
+ continue;
+ }
+ // We've found a non-empty tensor
+ allSkipped = false;
+ notSkippedTensor = inputs[i];
+ nDims = notSkippedTensor->dim();
+ break;
+ }
+ if (allSkipped) {
+ return;
+ }
+
+ // Compute cat_dimension based on the non-empty tensor
+ THArgCheck(dimension < nDims, 4, "invalid dimension %d", dimension);
+ THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
+
+ // Compute size of the result in the cat dimension
+ int64_t cat_dim_size = 0;
+ for (int i = 0; i < numInputs; i++) {
+ THTensor *tensor = inputs[i];
+ if (should_skip(tensor)) {
+ continue;
+ }
+ THTensor_(check_shape_except_dim)(notSkippedTensor, tensor, dimension);
+ cat_dim_size += tensor->size(dimension);
+ }
+
+ // Compute the size of the result
+ std::vector<int64_t> size(nDims);
+ for (int dim = 0; dim < nDims; dim++) {
+ int64_t result_dim_size = notSkippedTensor->size(dim);
+ if (dim == dimension) {
+ result_dim_size = cat_dim_size;
+ }
+ size[dim] = result_dim_size;
+ }
+ THTensor_(resize)(result, size, {});
+
+ // Check contiguity of all inputs and result
+ bool allContiguous = true;
+ for (int i = 0; i < numInputs; i++) {
+ if(!should_skip(inputs[i])) {
+ allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
+ }
+ }
+ allContiguous = allContiguous && THTensor_(isContiguous)(result);
+
+ // First path is for contiguous inputs along dim 0
+ // Second path for non-contiguous
+ int64_t offset;
+ if (dimension == 0 && allContiguous) {
+ scalar_t* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset();
+ offset = 0;
+ for (int j = 0; j < numInputs; j++) {
+ if (!should_skip(inputs[j])) {
+ THTensor* input0 = inputs[j];
+ scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
+ int64_t input0_size = THTensor_(nElement)(input0);
+ // C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this.
+ if (input0_size != 0) {
+ memcpy(result_data + offset, input0_data, input0_size*sizeof(scalar_t));
+ }
+ offset += input0_size;
+ }
+ }
+ } else {
+ offset = 0;
+ for (int j = 0; j < numInputs; j++) {
+ if (!should_skip(inputs[j])) {
+ int64_t dimSize = inputs[j]->size(dimension);
+ THTensor *nt = THTensor_(newWithTensor)(result);
+ THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
+ at::Tensor nt__wrap = THTensor_wrap(nt);
+ at::Tensor inputs_wrap = THTensor_wrap(inputs[j]);
+ at::_copy_same_type_(nt__wrap, inputs_wrap);
+ c10::raw::intrusive_ptr::decref(nt);
+ offset += dimSize;
+ }
+ }
+ }
+}
+
THDescBuff THTensor_(desc)(const THTensor *tensor) {
const int L = TH_DESC_BUFF_LEN;
THDescBuff buf;
diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h
index 6edd40042d..8d0c3184f4 100644
--- a/aten/src/TH/generic/THTensor.h
+++ b/aten/src/TH/generic/THTensor.h
@@ -125,6 +125,10 @@ TH_API scalar_t THTensor_(get2d)(const THTensor *tensor, int64_t x0, int64_t x1)
TH_API scalar_t THTensor_(get3d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2);
TH_API scalar_t THTensor_(get4d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3);
+/* Shape manipulation methods */
+TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
+TH_API void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension);
+
/* Debug methods */
TH_API THDescBuff THTensor_(desc)(const THTensor *tensor);
TH_API THDescBuff THTensor_(sizeDesc)(const THTensor *tensor);
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 244b60c3e7..cea6de0366 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -103,8 +103,6 @@ TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n
TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted);
TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k);
-TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
-TH_API void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension);
TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 27f67985ee..46769b7438 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -1238,129 +1238,6 @@ void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k)
}
}
-void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
-{
- THTensor* inputs[2];
- inputs[0] = ta;
- inputs[1] = tb;
- THTensor_(catArray)(r_, inputs, 2, dimension);
-}
-
-void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
-inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
-{
- int first_dims = first->dim();
- int second_dims = second->dim();
- THArgCheck(first_dims == second_dims, 0,
- "Tensors must have same number of dimensions: got %d and %d",
- first_dims, second_dims);
- for (int dim = 0; dim < first_dims; dim++) {
- if (dim == dimension) {
- continue;
- }
- int64_t first_dim_size = first->size(dim);
- int64_t second_dim_size = second->size(dim);
- THArgCheck(first_dim_size == second_dim_size, 0,
- "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
- dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
- }
-}
-
-void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
-{
- // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
- // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
- // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
- // size (i.e. other empty sizes are not skipped).
- // FIXME: warn if this is the case
- bool allSkipped= true;
- int64_t nDims = 0;
- THTensor *notSkippedTensor; // non-owning reference
- auto should_skip = [](THTensor *t) { return t->is_empty() && t->dim() == 1; };
- for (int i = 0; i < numInputs; i++) {
- if (should_skip(inputs[i])) {
- continue;
- }
- // We've found a non-empty tensor
- allSkipped = false;
- notSkippedTensor = inputs[i];
- nDims = notSkippedTensor->dim();
- break;
- }
- if (allSkipped) {
- return;
- }
-
- // Compute cat_dimension based on the non-empty tensor
- THArgCheck(dimension < nDims, 4, "invalid dimension %d", dimension);
- THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
-
- // Compute size of the result in the cat dimension
- int64_t cat_dim_size = 0;
- for (int i = 0; i < numInputs; i++) {
- THTensor *tensor = inputs[i];
- if (should_skip(tensor)) {
- continue;
- }
- THTensor_(check_shape_except_dim)(notSkippedTensor, tensor, dimension);
- cat_dim_size += tensor->size(dimension);
- }
-
- // Compute the size of the result
- std::vector<int64_t> size(nDims);
- for (int dim = 0; dim < nDims; dim++) {
- int64_t result_dim_size = notSkippedTensor->size(dim);
- if (dim == dimension) {
- result_dim_size = cat_dim_size;
- }
- size[dim] = result_dim_size;
- }
- THTensor_(resize)(result, size, {});
-
- // Check contiguity of all inputs and result
- bool allContiguous = true;
- for (int i = 0; i < numInputs; i++) {
- if(!should_skip(inputs[i])) {
- allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
- }
- }
- allContiguous = allContiguous && THTensor_(isContiguous)(result);
-
- // First path is for contiguous inputs along dim 0
- // Second path for non-contiguous
- int64_t offset;
- if (dimension == 0 && allContiguous) {
- scalar_t* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset();
- offset = 0;
- for (int j = 0; j < numInputs; j++) {
- if (!should_skip(inputs[j])) {
- THTensor* input0 = inputs[j];
- scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
- int64_t input0_size = THTensor_(nElement)(input0);
- // C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this.
- if (input0_size != 0) {
- memcpy(result_data + offset, input0_data, input0_size*sizeof(scalar_t));
- }
- offset += input0_size;
- }
- }
- } else {
- offset = 0;
- for (int j = 0; j < numInputs; j++) {
- if (!should_skip(inputs[j])) {
- int64_t dimSize = inputs[j]->size(dimension);
- THTensor *nt = THTensor_(newWithTensor)(result);
- THTensor_(narrow)(nt, NULL, dimension, offset, dimSize);
- at::Tensor nt__wrap = THTensor_wrap(nt);
- at::Tensor inputs_wrap = THTensor_wrap(inputs[j]);
- at::_copy_same_type_(nt__wrap, inputs_wrap);
- c10::raw::intrusive_ptr::decref(nt);
- offset += dimSize;
- }
- }
- }
-}
-
int THTensor_(equal)(THTensor *ta, THTensor* tb)
{
int equal = 1;
diff --git a/test/common_utils.py b/test/common_utils.py
index ae258ec3ef..5dce3d80db 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -392,22 +392,28 @@ class TestCase(expecttest.TestCase):
def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(a.size(), b.size(), message)
if a.numel() > 0:
- b = b.type_as(a)
- b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
- # check that NaNs are in the same locations
- nan_mask = a != a
- self.assertTrue(torch.equal(nan_mask, b != b), message)
+ if a.device.type == 'cpu' and a.dtype == torch.float16:
+ # CPU half tensors don't have the methods we need below
+ a = a.to(torch.float32)
+ if TEST_WITH_ROCM:
+ # Workaround for bug https://github.com/pytorch/pytorch/issues/16448
+ # TODO: remove after the bug is resolved.
+ b = b.to(a.dtype).to(a.device)
+ else:
+ b = b.to(a)
diff = a - b
- diff[nan_mask] = 0
- # inf check if allow_inf=True
- if allow_inf:
- inf_mask = (a == float("inf")) | (a == float("-inf"))
- self.assertTrue(torch.equal(inf_mask,
- (b == float("inf")) | (b == float("-inf"))),
- message)
- diff[inf_mask] = 0
- # TODO: implement abs on CharTensor
- if diff.is_signed() and 'CharTensor' not in diff.type():
+ if a.is_floating_point():
+ # check that NaNs are in the same locations
+ nan_mask = torch.isnan(a)
+ self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
+ diff[nan_mask] = 0
+ # inf check if allow_inf=True
+ if allow_inf:
+ inf_mask = torch.isinf(a)
+ self.assertTrue(torch.equal(inf_mask, torch.isinf(b)), message)
+ diff[inf_mask] = 0
+ # TODO: implement abs on CharTensor (int8)
+ if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
diff --git a/test/test_torch.py b/test/test_torch.py
index 5bf6a4a93d..e5304ef020 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -4100,27 +4100,28 @@ class _TestTorchMixin(object):
def test_cat(self):
SIZE = 10
- for dim in range(-3, 3):
- pos_dim = dim if dim >= 0 else 3 + dim
- x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim)
- y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim)
- z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim)
+ for dtype in (torch.half, torch.double, torch.int):
+ for dim in range(-3, 3):
+ pos_dim = dim if dim >= 0 else 3 + dim
+ x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
+ y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
+ z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
- res1 = torch.cat((x, y, z), dim)
- self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
- self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
- self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
+ res1 = torch.cat((x, y, z), dim)
+ self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
+ self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
+ self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
- x = torch.randn(20, SIZE, SIZE)
- self.assertEqual(torch.cat(torch.split(x, 7)), x)
- self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
+ x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype)
+ self.assertEqual(torch.cat(torch.split(x, 7)), x)
+ self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
- y = torch.randn(1, SIZE, SIZE)
- z = torch.cat([x, y])
- self.assertEqual(z.size(), (21, SIZE, SIZE))
+ y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype)
+ z = torch.cat([x, y])
+ self.assertEqual(z.size(), (21, SIZE, SIZE))
- self.assertRaises(RuntimeError, lambda: torch.cat([]))
- self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None]))
+ self.assertRaises(RuntimeError, lambda: torch.cat([]))
+ self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None]))
def test_cat_bad_input_sizes(self):
x = torch.randn(2, 1)
@@ -4227,38 +4228,40 @@ class _TestTorchMixin(object):
self.assertEqual(sz, y.size())
def test_stack(self):
- x = torch.rand(2, 3, 4)
- y = torch.rand(2, 3, 4)
- z = torch.rand(2, 3, 4)
- for dim in range(4):
- res = torch.stack((x, y, z), dim)
- res_neg = torch.stack((x, y, z), dim - 4)
- expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
- self.assertEqual(res, res_neg)
- self.assertEqual(res.size(), expected_size)
- self.assertEqual(res.select(dim, 0), x, 0)
- self.assertEqual(res.select(dim, 1), y, 0)
- self.assertEqual(res.select(dim, 2), z, 0)
+ for dtype in (torch.half, torch.double, torch.int):
+ x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ for dim in range(4):
+ res = torch.stack((x, y, z), dim)
+ res_neg = torch.stack((x, y, z), dim - 4)
+ expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
+ self.assertEqual(res, res_neg)
+ self.assertEqual(res.size(), expected_size)
+ self.assertEqual(res.select(dim, 0), x, 0)
+ self.assertEqual(res.select(dim, 1), y, 0)
+ self.assertEqual(res.select(dim, 2), z, 0)
def test_stack_out(self):
- x = torch.rand(2, 3, 4)
- y = torch.rand(2, 3, 4)
- z = torch.rand(2, 3, 4)
- for dim in range(4):
- expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
- res_out = x.new(expected_size)
- res_neg_out = x.new(expected_size)
- res_out_dp = res_out.data_ptr()
- res_out_neg_dp = res_neg_out.data_ptr()
- torch.stack((x, y, z), dim, out=res_out)
- torch.stack((x, y, z), dim - 4, out=res_neg_out)
- self.assertEqual(res_out, res_neg_out)
- self.assertEqual(res_out.size(), expected_size)
- self.assertEqual(res_out_dp, res_out.data_ptr())
- self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
- self.assertEqual(res_out.select(dim, 0), x, 0)
- self.assertEqual(res_out.select(dim, 1), y, 0)
- self.assertEqual(res_out.select(dim, 2), z, 0)
+ for dtype in (torch.half, torch.double, torch.int):
+ x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
+ for dim in range(4):
+ expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
+ res_out = x.new(expected_size)
+ res_neg_out = x.new(expected_size)
+ res_out_dp = res_out.data_ptr()
+ res_out_neg_dp = res_neg_out.data_ptr()
+ torch.stack((x, y, z), dim, out=res_out)
+ torch.stack((x, y, z), dim - 4, out=res_neg_out)
+ self.assertEqual(res_out, res_neg_out)
+ self.assertEqual(res_out.size(), expected_size)
+ self.assertEqual(res_out_dp, res_out.data_ptr())
+ self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
+ self.assertEqual(res_out.select(dim, 0), x, 0)
+ self.assertEqual(res_out.select(dim, 1), y, 0)
+ self.assertEqual(res_out.select(dim, 2), z, 0)
def test_unbind(self):
x = torch.rand(2, 3, 4, 5)