summaryrefslogtreecommitdiff
path: root/runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc
blob: 089c783c1ac9f69c79be14d28fc9f3a16ac01390 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "internal/layers/SimpleEmbeddingLookup.h"

#include <arm_compute/runtime/CL/CLScheduler.h>

void SimpleEmbeddingLookup::configure(::arm_compute::ITensor *lookups,
                                      ::arm_compute::ITensor *values,
                                      ::arm_compute::ITensor *output)
{
  // Assume that verification of operands are already done at Planner::visit()
  _lookups = lookups;
  _values = values;
  _output = output;
}

void SimpleEmbeddingLookup::run()
{
  if (::internal::arm_compute::isGpuMode())
  {
    auto &q = ::arm_compute::CLScheduler::get().queue();

    CAST_CL(_lookups)->map(q);
    CAST_CL(_values)->map(q);
    CAST_CL(_output)->map(q);
  }

  // type of elements of lookups is always integer
  const int32_t *lookups_buf = reinterpret_cast<int32_t *>(_lookups->buffer());
  const auto values_buf = _values->buffer();
  auto output_buf = _output->buffer();

  const auto lookups_info = _lookups->info();
  const auto values_info = _values->info();
  const auto output_info = _output->info();

  // TODO Refactor below duplicated code!
  const auto values_rank = values_info->num_dimensions();
  switch (values_rank)
  {
    case 2:
      // (H,W) in nnapi -> (W,H) in acl
      {
        const size_t row_size = values_info->dimension(1);
        const size_t row_bytes = values_info->total_size() / row_size;
        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
        {
          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
            throw std::runtime_error("Embedding Lookup: index out of bounds.");

          size_t idx = lookups_buf[i];
          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, idx});
          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, i});

          unsigned char *sink_addr = output_buf + row_offset_by_i;
          unsigned char *source_addr = values_buf + row_offset_by_idx;
          memcpy(sink_addr, source_addr, row_bytes);
        }
      }
      break;
    case 3:
      // (B,H,W) in nnapi -> (W,H,B) in acl
      {
        const size_t row_size = values_info->dimension(2);
        const size_t row_bytes = values_info->total_size() / row_size;
        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
        {
          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
            throw std::runtime_error("Embedding Lookup: index out of bounds.");

          size_t idx = lookups_buf[i];
          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, idx});
          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, i});

          unsigned char *sink_addr = output_buf + row_offset_by_i;
          unsigned char *source_addr = values_buf + row_offset_by_idx;
          memcpy(sink_addr, source_addr, row_bytes);
        }
      }
      break;
    case 4:
      // (N,H,W,C) in nnapi -> (N,C,H,W) in acl
      {
        const size_t row_size = values_info->dimension(3);
        const size_t row_bytes = values_info->total_size() / row_size;
        for (size_t i = 0; i < lookups_info->dimension(0); ++i)
        {
          if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
            throw std::runtime_error("Embedding Lookup: index out of bounds.");

          size_t idx = lookups_buf[i];
          size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, 0, idx});
          size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, 0, i});

          unsigned char *sink_addr = output_buf + row_offset_by_i;
          unsigned char *source_addr = values_buf + row_offset_by_idx;
          memcpy(sink_addr, source_addr, row_bytes);
        }
      }
      break;
    case 1:
      // In this case, shape of values actually is matrix but the height(row size) is 1 in acl. If
      // row size is 1, this op is not needed and it means this situtation could be wrong.
      throw std::runtime_error("Wrong usage of EmbeddingLookup op!");
    default:
      throw std::runtime_error("Not supported rank!");
  }

  if (::internal::arm_compute::isGpuMode())
  {
    auto &q = ::arm_compute::CLScheduler::get().queue();

    CAST_CL(_lookups)->unmap(q);
    CAST_CL(_values)->unmap(q);
    CAST_CL(_output)->unmap(q);
  }
}