summaryrefslogtreecommitdiff
path: root/runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc')
-rw-r--r--runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc b/runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc
new file mode 100644
index 000000000..089c783c1
--- /dev/null
+++ b/runtimes/pure_arm_compute/src/internal/layers/SimpleEmbeddingLookup.cc
@@ -0,0 +1,115 @@
+#include "internal/layers/SimpleEmbeddingLookup.h"
+
+#include <arm_compute/runtime/CL/CLScheduler.h>
+
+void SimpleEmbeddingLookup::configure(::arm_compute::ITensor *lookups,
+ ::arm_compute::ITensor *values,
+ ::arm_compute::ITensor *output)
+{
+ // Assume that verification of operands are already done at Planner::visit()
+ _lookups = lookups;
+ _values = values;
+ _output = output;
+}
+
+void SimpleEmbeddingLookup::run()
+{
+ if (::internal::arm_compute::isGpuMode())
+ {
+ auto &q = ::arm_compute::CLScheduler::get().queue();
+
+ CAST_CL(_lookups)->map(q);
+ CAST_CL(_values)->map(q);
+ CAST_CL(_output)->map(q);
+ }
+
+ // type of elements of lookups is always integer
+ const int32_t *lookups_buf = reinterpret_cast<int32_t *>(_lookups->buffer());
+ const auto values_buf = _values->buffer();
+ auto output_buf = _output->buffer();
+
+ const auto lookups_info = _lookups->info();
+ const auto values_info = _values->info();
+ const auto output_info = _output->info();
+
+ // TODO Refactor below duplicated code!
+ const auto values_rank = values_info->num_dimensions();
+ switch (values_rank)
+ {
+ case 2:
+ // (H,W) in nnapi -> (W,H) in acl
+ {
+ const size_t row_size = values_info->dimension(1);
+ const size_t row_bytes = values_info->total_size() / row_size;
+ for (size_t i = 0; i < lookups_info->dimension(0); ++i)
+ {
+ if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
+ throw std::runtime_error("Embedding Lookup: index out of bounds.");
+
+ size_t idx = lookups_buf[i];
+ size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, idx});
+ size_t row_offset_by_i = output_info->offset_element_in_bytes({0, i});
+
+ unsigned char *sink_addr = output_buf + row_offset_by_i;
+ unsigned char *source_addr = values_buf + row_offset_by_idx;
+ memcpy(sink_addr, source_addr, row_bytes);
+ }
+ }
+ break;
+ case 3:
+ // (B,H,W) in nnapi -> (W,H,B) in acl
+ {
+ const size_t row_size = values_info->dimension(2);
+ const size_t row_bytes = values_info->total_size() / row_size;
+ for (size_t i = 0; i < lookups_info->dimension(0); ++i)
+ {
+ if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
+ throw std::runtime_error("Embedding Lookup: index out of bounds.");
+
+ size_t idx = lookups_buf[i];
+ size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, idx});
+ size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, i});
+
+ unsigned char *sink_addr = output_buf + row_offset_by_i;
+ unsigned char *source_addr = values_buf + row_offset_by_idx;
+ memcpy(sink_addr, source_addr, row_bytes);
+ }
+ }
+ break;
+ case 4:
+ // (N,H,W,C) in nnapi -> (N,C,H,W) in acl
+ {
+ const size_t row_size = values_info->dimension(3);
+ const size_t row_bytes = values_info->total_size() / row_size;
+ for (size_t i = 0; i < lookups_info->dimension(0); ++i)
+ {
+ if (lookups_buf[i] < 0 || lookups_buf[i] >= row_size)
+ throw std::runtime_error("Embedding Lookup: index out of bounds.");
+
+ size_t idx = lookups_buf[i];
+ size_t row_offset_by_idx = values_info->offset_element_in_bytes({0, 0, 0, idx});
+ size_t row_offset_by_i = output_info->offset_element_in_bytes({0, 0, 0, i});
+
+ unsigned char *sink_addr = output_buf + row_offset_by_i;
+ unsigned char *source_addr = values_buf + row_offset_by_idx;
+ memcpy(sink_addr, source_addr, row_bytes);
+ }
+ }
+ break;
+ case 1:
+ // In this case, shape of values actually is matrix but the height(row size) is 1 in acl. If
+ // row size is 1, this op is not needed and it means this situtation could be wrong.
+ throw std::runtime_error("Wrong usage of EmbeddingLookup op!");
+ default:
+ throw std::runtime_error("Not supported rank!");
+ }
+
+ if (::internal::arm_compute::isGpuMode())
+ {
+ auto &q = ::arm_compute::CLScheduler::get().queue();
+
+ CAST_CL(_lookups)->unmap(q);
+ CAST_CL(_values)->unmap(q);
+ CAST_CL(_output)->unmap(q);
+ }
+}