summaryrefslogtreecommitdiff
path: root/compiler/luci-interpreter/src/kernels/Transpose.test.cpp
blob: 87e6e2a00a16184ad7109773ecd2b3ea42303ae2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "kernels/Transpose.h"
#include "kernels/TestUtils.h"

namespace luci_interpreter
{
namespace kernels
{
namespace
{

using namespace testing;

template <typename T>
void Check(std::initializer_list<int32_t> input_shape, std::initializer_list<int32_t> perm_shape,
           std::initializer_list<int32_t> output_shape, std::initializer_list<T> input_data,
           std::initializer_list<int32_t> perm_data, std::initializer_list<T> output_data,
           DataType element_type)
{
  Tensor input_tensor{element_type, input_shape, {}, ""};
  input_tensor.writeData(input_data.begin(), input_data.size() * sizeof(T));

  Tensor perm_tensor{DataType::S32, perm_shape, {}, ""};
  perm_tensor.writeData(perm_data.begin(), perm_data.size() * sizeof(int32_t));
  Tensor output_tensor = makeOutputTensor(element_type);

  Transpose kernel(&input_tensor, &perm_tensor, &output_tensor);
  kernel.configure();
  kernel.execute();

  EXPECT_THAT(extractTensorData<T>(output_tensor), ::testing::ElementsAreArray(output_data));
}

template <typename T> class TransposeTest : public ::testing::Test
{
};

using DataTypes = ::testing::Types<float, uint8_t>;
TYPED_TEST_CASE(TransposeTest, DataTypes);

TYPED_TEST(TransposeTest, Small3D)
{
  Check<TypeParam>(/*input_shape=*/{2, 3, 4}, /*perm_shape=*/{3}, /*output_shape=*/{4, 2, 3},
                   /*input_data=*/{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
                                   12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
                   /*perm_data=*/{2, 0, 1},
                   /*output_data=*/{0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
                                    2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23},
                   getElementType<TypeParam>());
}

TYPED_TEST(TransposeTest, Large4D)
{
  Check<TypeParam>(
      /*input_shape=*/{2, 3, 4, 5}, /*perm_shape=*/{4}, /*output_shape=*/{4, 2, 3, 5},
      /*input_data=*/{0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
                      15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
                      30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
                      45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
                      60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
                      75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
                      90,  91,  92,  93,  94,  95,  96,  97,  98,  99,  100, 101, 102, 103, 104,
                      105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
      /*perm_data=*/{2, 0, 1, 3},
      /*output_data=*/{0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
                       60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
                       5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
                       65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
                       10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
                       70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
                       15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
                       75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119},
      getElementType<TypeParam>());
}

TYPED_TEST(TransposeTest, Large2D)
{
  Check<TypeParam>(
      /*input_shape=*/{10, 12}, /*perm_shape=*/{2}, /*output_shape=*/{12, 10},
      /*input_data=*/{0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,  12,  13,  14,
                      15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
                      30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
                      45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
                      60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
                      75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
                      90,  91,  92,  93,  94,  95,  96,  97,  98,  99,  100, 101, 102, 103, 104,
                      105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119},
      /*perm_data=*/{1, 0},
      /*output_data=*/{0,  12, 24, 36,  48,  60, 72, 84, 96,  108, 1,  13, 25, 37,  49,
                       61, 73, 85, 97,  109, 2,  14, 26, 38,  50,  62, 74, 86, 98,  110,
                       3,  15, 27, 39,  51,  63, 75, 87, 99,  111, 4,  16, 28, 40,  52,
                       64, 76, 88, 100, 112, 5,  17, 29, 41,  53,  65, 77, 89, 101, 113,
                       6,  18, 30, 42,  54,  66, 78, 90, 102, 114, 7,  19, 31, 43,  55,
                       67, 79, 91, 103, 115, 8,  20, 32, 44,  56,  68, 80, 92, 104, 116,
                       9,  21, 33, 45,  57,  69, 81, 93, 105, 117, 10, 22, 34, 46,  58,
                       70, 82, 94, 106, 118, 11, 23, 35, 47,  59,  71, 83, 95, 107, 119},
      getElementType<TypeParam>());
}

} // namespace
} // namespace kernels
} // namespace luci_interpreter