summaryrefslogtreecommitdiff
path: root/runtimes/libs/misc/include/misc/tensor/IndexEnumerator.h
blob: 6ce3add7743b9075867d68bc6fe897a76da2eb1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/**
 * @file IndexEnumerator.h
 * @ingroup COM_AI_RUNTIME
 * @brief This file contains nnfw::misc::tensor::IndexEnumerator class
 */

#ifndef __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__
#define __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__

#include "misc/tensor/Shape.h"
#include "misc/tensor/Index.h"

namespace nnfw
{
namespace misc
{
namespace tensor
{
/**
 * @brief Class to enumerate index of a tensor
 *
 */
class IndexEnumerator
{
public:
  /**
   * @brief Construct a new @c IndexEnumerator object
   * @param[in] shape   Shape of tensor of which index will be enumerate
   */
  explicit IndexEnumerator(const Shape &shape) : _shape(shape), _cursor(0), _index(shape.rank())
  {
    const uint32_t rank = _shape.rank();

    for (uint32_t axis = 0; axis < rank; ++axis)
    {
      _index.at(axis) = 0;
    }

    for (_cursor = 0; _cursor < rank; ++_cursor)
    {
      if (_index.at(_cursor) < _shape.dim(_cursor))
      {
        break;
      }
    }
  }

public:
  /**
   * @brief Prevent constructing @c IndexEnumerator object by using R-value reference
   */
  IndexEnumerator(IndexEnumerator &&) = delete;
  /**
   * @brief Prevent copy constructor
   */
  IndexEnumerator(const IndexEnumerator &) = delete;

public:
  /**
   * @brief Check if more enumeration is available
   * @return @c true if more @c advance() is available, otherwise @c false
   */
  bool valid(void) const { return _cursor < _shape.rank(); }

public:
  /**
   * @brief Get the current index to enumerate
   * @return Current index
   */
  const Index &curr(void) const { return _index; }

public:
  /**
   * @brief Advance index by +1
   */
  void advance(void)
  {
    const uint32_t rank = _shape.rank();

    // Find axis to be updated
    while ((_cursor < rank) && !(_index.at(_cursor) + 1 < _shape.dim(_cursor)))
    {
      ++_cursor;
    }

    if (_cursor == rank)
    {
      return;
    }

    // Update index
    _index.at(_cursor) += 1;

    for (uint32_t axis = 0; axis < _cursor; ++axis)
    {
      _index.at(axis) = 0;
    }

    // Update cursor
    _cursor = 0;
  }

public:
  const Shape _shape; //!< Shape to enumerate

private:
  uint32_t _cursor;
  Index _index;
};

} // namespace tensor
} // namespace misc
} // namespace nnfw

#endif // __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__