diff options
Diffstat (limited to 'caffe2')
44 files changed, 2684 insertions, 2 deletions
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b634670d27..4a7664588f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -96,6 +96,17 @@ set(Caffe2_MAIN_LIBS) # Compile exposed libraries. add_library(caffe2 ${Caffe2_CPU_SRCS} $<TARGET_OBJECTS:Caffe_PROTO> $<TARGET_OBJECTS:Caffe2_PROTO>) +if(USE_ACL) + if(NOT USE_ARM64) + target_compile_options(caffe2 PRIVATE "-mfpu=neon-fp16") + endif() + + include(CheckCCompilerFlag) + CHECK_C_COMPILER_FLAG(-mfp16-format=ieee CAFFE2_COMPILER_SUPPORTS_FP16_FORMAT) + if(CAFFE2_COMPILER_SUPPORTS_FP16_FORMAT) + target_compile_options(caffe2 PRIVATE "-mfp16-format=ieee") + endif() +endif() target_link_libraries(caffe2 PRIVATE ${Caffe2_DEPENDENCY_LIBS}) target_include_directories(caffe2 INTERFACE $<INSTALL_INTERFACE:include>) target_compile_options(caffe2 INTERFACE "-std=c++11") diff --git a/caffe2/mobile/CMakeLists.txt b/caffe2/mobile/CMakeLists.txt index 08f11f1ef9..de20a27804 100644 --- a/caffe2/mobile/CMakeLists.txt +++ b/caffe2/mobile/CMakeLists.txt @@ -8,4 +8,4 @@ set(Caffe2_CPU_BINARY_SRCS ${Caffe2_CPU_BINARY_SRCS} PARENT_SCOPE) # GPU source, test sources, binary sources set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE) set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE) -set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE) +set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE)
\ No newline at end of file diff --git a/caffe2/mobile/contrib/CMakeLists.txt b/caffe2/mobile/contrib/CMakeLists.txt index 4722166241..29a35812bc 100644 --- a/caffe2/mobile/contrib/CMakeLists.txt +++ b/caffe2/mobile/contrib/CMakeLists.txt @@ -1,5 +1,8 @@ add_subdirectory(ios) add_subdirectory(opengl) +if (USE_ACL) + add_subdirectory(arm-compute) +endif() # Finally pass the src lists back to the parent if (USE_NNAPI) @@ -14,4 +17,4 @@ set(Caffe2_CPU_BINARY_SRCS ${Caffe2_CPU_BINARY_SRCS} PARENT_SCOPE) # GPU source, test sources, binary sources set(Caffe2_GPU_SRCS ${Caffe2_GPU_SRCS} PARENT_SCOPE) set(Caffe2_GPU_TEST_SRCS ${Caffe2_GPU_TEST_SRCS} PARENT_SCOPE) -set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE) +set(Caffe2_GPU_BINARY_SRCS ${Caffe2_GPU_BINARY_SRCS} PARENT_SCOPE)
\ No newline at end of file diff --git a/caffe2/mobile/contrib/arm-compute/CMakeLists.txt b/caffe2/mobile/contrib/arm-compute/CMakeLists.txt new file mode 100644 index 0000000000..f064601978 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(core) +add_subdirectory(operators) +add_subdirectory(test) + +set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
\ No newline at end of file diff --git a/caffe2/mobile/contrib/arm-compute/README.md b/caffe2/mobile/contrib/arm-compute/README.md new file mode 100644 index 0000000000..f128bc7efc --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/README.md @@ -0,0 +1,62 @@ +# Caffe2 - ARM Compute Backend + +## Build + +To build, clone and install scons + +``` +brew install scons +``` + +set ANDROID_NDK to /opt/android_ndk/xxx + +setup toolchain: +arm +``` +rm -rf PATH_TO_TOOLCHAIN +$ANDROID_NDK/build/tools/make_standalone_toolchain.py --arch arm64 --api 21 --install-dir PATH_TO_TOOLCHAIN +``` + +arm64 +``` +rm -rf PATH_TO_TOOLCHAIN +$ANDROID_NDK/build/tools/make_standalone_toolchain.py --arch arm64 --api 21 --install-dir PATH_TO_TOOLCHAIN +``` + +add the toolchain path to .bashrc/.zshrc etc. +e.g. +``` +export PATH=$PATH:PATH_TO_TOOLCHAIN +``` + +use the build\_android.sh: + +for 32bit +``` +./scripts/build_android.sh -DUSE_ACL=ON -DBUILD_TEST=ON +``` + +for 64bit +``` +./scripts/build_android.sh -DUSE_ACL=ON -DBUILD_TEST=ON -DUSE_NNPACK=OFF -DUSE_ARM64=ON +``` + +Before switch between 32 bit and 64 bit, please make sure to delete build\_android folder: +``` +rm -rf build_android +``` +## Test +Plug in an android device, and run a test + +``` +cd build_android +adb push bin/gl_conv_op_test /data/local/tmp && adb shell '/data/local/tmp/gl_conv_op_test' +``` +or use a script to run them all + +In caffe2 top level directory +``` +./caffe2/mobile/contrib/arm-compute/run_tests.sh build_android +``` + +Note that some tests(fully_connected and alignment) have been disabled until the next release of ACL.
\ No newline at end of file diff --git a/caffe2/mobile/contrib/arm-compute/core/CMakeLists.txt b/caffe2/mobile/contrib/arm-compute/core/CMakeLists.txt new file mode 100644 index 0000000000..dbc170e14e --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE tmp *.cc) +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp} PARENT_SCOPE) diff --git a/caffe2/mobile/contrib/arm-compute/core/context.cc b/caffe2/mobile/contrib/arm-compute/core/context.cc new file mode 100644 index 0000000000..72d0a6df2f --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/context.cc @@ -0,0 +1,38 @@ +#include "context.h" + +#include "caffe2/core/allocator.h" +#include "caffe2/core/context.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/tensor.h" + +namespace caffe2 { + +CAFFE_KNOWN_TYPE(GLTensor<GLfloat>); +CAFFE_KNOWN_TYPE(GLTensor<GLhalf>); +CAFFE_KNOWN_TYPE(GLTensor<half>); + +bool GLContext::initialized = false; + +GLContext::GLContext() { + CAFFE_ENFORCE(arm_compute::opengles31_is_available()); + if(!initialized) { + arm_compute::GCScheduler::get().default_init(); + initialized = true; + } +} + +void EventCreateOPENGL(const DeviceOption & /* unused */, + Event * /* unused */) {} +void EventRecordOPENGL(Event * /* unused */, const void * /* unused */, + const char * /* unused */) {} +void EventWaitOPENGLOPENGL(const Event * /* unused */, void * /* unused */) {} +void EventFinishOPENGL(const Event * /* unused */) {} +void EventResetOPENGL(Event * /* unused */) {} + +REGISTER_EVENT_CREATE_FUNCTION(OPENGL, EventCreateOPENGL); +REGISTER_EVENT_RECORD_FUNCTION(OPENGL, EventRecordOPENGL); +REGISTER_EVENT_WAIT_FUNCTION(OPENGL, OPENGL, EventWaitOPENGLOPENGL); +REGISTER_EVENT_FINISH_FUNCTION(OPENGL, EventFinishOPENGL); +REGISTER_EVENT_RESET_FUNCTION(OPENGL, EventResetOPENGL); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/core/context.h b/caffe2/mobile/contrib/arm-compute/core/context.h new file mode 100644 index 0000000000..7e3c936ba8 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/context.h @@ -0,0 +1,330 @@ +#ifndef CAFFE2_OPENGL_CONTEXT_H_ +#define CAFFE2_OPENGL_CONTEXT_H_ + +#ifdef CAFFE2_OPENGL_BACKEND +#error Can only build one OpenGL backend at a time. +#else +#define CAFFE2_OPENGL_BACKEND +#endif + +#include "caffe2/core/allocator.h" +#include "caffe2/core/blob.h" +#include "caffe2/core/context.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/tensor.h" + +#include "arm_compute/core/GLES_COMPUTE/OpenGLES.h" +#include "arm_compute/runtime/GLES_COMPUTE/GCFunctions.h" +#include "arm_compute/runtime/GLES_COMPUTE/GCScheduler.h" +#include "arm_compute/runtime/GLES_COMPUTE/GCTensor.h" + +#include "arm_compute/core/Types.h" +#include "arm_compute/runtime/Allocator.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/MemoryManagerOnDemand.h" +#include "arm_compute/runtime/PoolManager.h" +#include "utils/Utils.h" +#include "include/half/half.hpp" + +namespace caffe2 { + +typedef half_float::half half; +typedef half DataType; + +template <typename T> class GLTensor; + +class GLContext final { +public: + static bool initialized; + explicit GLContext(); + explicit GLContext(const DeviceOption &option) { + DCHECK_EQ(option.device_type(), OPENGL); + GLContext(); + } + ~GLContext() {} + + static void sync() { arm_compute::GCScheduler::get().memory_barrier(); } + + template <typename T> + using deleted_unique_ptr = std::unique_ptr<T, std::function<void(T *)>>; + + template <typename T> + static deleted_unique_ptr<const GLTensor<T>> getGLTensor(const Blob *b) { + if (b->IsType<TensorCPU>()) { + auto &Xcpu = b->Get<TensorCPU>(); + GLTensor<T> *X_raw_ptr; + X_raw_ptr = new GLTensor<T>(); + X_raw_ptr->ResizeLike(Xcpu); + deleted_unique_ptr<const GLTensor<T>> X_unique_ptr( + X_raw_ptr, [](const GLTensor<T> *X) { delete X; }); + return X_unique_ptr; + } + const GLTensor<T> *X_raw_ptr; + X_raw_ptr = &b->Get<GLTensor<T>>(); + deleted_unique_ptr<const GLTensor<T>> X_unique_ptr( + X_raw_ptr, [](const GLTensor<T> *X) { return; }); + return X_unique_ptr; + } + + /* + * Everything below is basically boiler plate for Context classes + */ + static std::pair<void *, MemoryDeleter> New(size_t nbytes) { + return std::pair<void *, MemoryDeleter>(malloc(nbytes), GLContext::Delete); + } + + static void Delete(void *data) { + if (data != nullptr) { + free(data); + } + } + + template <class SrcContext, class DstContext> + inline void CopyBytes(size_t nbytes, const void *src, void *dst) {} + + template <typename T, class SrcContext, class DstContext> + inline void Copy(int n, const T *src, T *dst) { + CopyBytes<SrcContext, DstContext>(n * sizeof(T), + static_cast<const void *>(src), + static_cast<void *>(dst)); + } + + template <class SrcContext, class DstContext> + inline void CopyItems(const TypeMeta &meta, size_t n, const void *src, + void *dst) { + CAFFE_ENFORCE(!meta.copy(), "GLContext requires fundamental types."); + CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst); + } + + void SwitchToDevice(int a, ...) { /* TODO */ + } + void SwitchToDevice() { SwitchToDevice(0); } + + inline void WaitEvent(const Event &ev) { /* TODO */ + } + void FinishDeviceComputation() { /* TODO */ + } + inline void Record(Event *ev, const char *&) const { /* TODO */ + } + static bool IsStreamFree(const DeviceOption& /* unused */, int /* unused */) { + return true; + } + bool HasAsyncPartDefault() const { return false; } + bool SupportsAsyncScheduling() const { return false; } + +}; + +template <typename T> class GLTensor { +private: + bool allocated_ = false; +public: + GLTensor() { tensor_ = make_unique<arm_compute::GCTensor>(); } + ~GLTensor() { tensor_->allocator()->free(); } + + template <typename TensorType> void ResizeLike(TensorType &X) { + tensor_->allocator()->free(); + SetDims(X.dims()); + shape_ = arm_compute::TensorShape(); + for (int i = 0; i < dims_.size(); i++) { + shape_.set(dims_.size() - i - 1, dims_[i]); + } + + tensor_->allocator()->init( + arm_compute::TensorInfo(shape_, 1, arm_compute::DataType::F16)); + } + + template <typename... Ts> void Resize(Ts... dim_source) { + bool size_changed = SetDims(dim_source...); + if (size_changed) { + // TODO: Make it type generic + int64_t new_size = size_ * sizeof(T); + tensor_->allocator()->free(); + for (int i = 0; i < dims_.size(); i++) { + shape_.set(dims_.size() - i - 1, dims_[i]); + } + tensor_->allocator()->init( + arm_compute::TensorInfo(shape_, 1, arm_compute::DataType::F16)); + } + } + + // Allocates and copies data if needed + void lazy_allocate(const Blob *b, bool allocate_tensor, bool try_to_copy_from_cpu) const { + if (try_to_copy_from_cpu) { + // we skip GLTensors, nothing to copy + if (!b->IsType<GLTensor>()) { + // typically only called on the second run + if (allocate_tensor) { + allocate(); + } + fillGLTensor(b); + } + } + } + + void allocate() const { + tensor_->allocator()->allocate(); + } + + void fillGLTensor(const Blob *b) const { + if (b->IsType<TensorCPU>()) { + auto &Xcpu = b->Get<TensorCPU>(); + + T *buffer = map(); + char *byte_buffer = (char *)buffer; + auto info = tensor_->info(); + if (Xcpu.ndim() == 4) { + auto M = Xcpu.dim32(0); + auto C = Xcpu.dim32(1); + auto H = Xcpu.dim32(2); + auto W = Xcpu.dim32(3); + for (auto m = 0; m < M; ++m) { + for (auto c = 0; c < C; ++c) { + for (auto h = 0; h < H; ++h) { + for (auto w = 0; w < W; ++w) { + T *b = (T *)(&byte_buffer[info->offset_element_in_bytes( + arm_compute::Coordinates(w, h, c, m))]); + // require cpu input blob to be float + *b = T(Xcpu.data<float>()[((m * C + c) * H + h) * W + w]); + } + } + } + } + } else if (Xcpu.ndim() == 3) { + auto C = Xcpu.dim32(0); + auto H = Xcpu.dim32(1); + auto W = Xcpu.dim32(2); + for (auto c = 0; c < C; ++c) { + for (auto h = 0; h < H; ++h) { + for (auto w = 0; w < W; ++w) { + T *b = (T *)(&byte_buffer[info->offset_element_in_bytes( + arm_compute::Coordinates(w, h, c))]); + // require cpu input blob to be float + *b = T(Xcpu.data<float>()[(c * H + h) * W + w]); + } + } + } + } else if (Xcpu.ndim() == 2) { + auto H = Xcpu.dim32(0); + auto W = Xcpu.dim32(1); + for (auto h = 0; h < H; ++h) { + for (auto w = 0; w < W; ++w) { + T *b = (T *)(&byte_buffer[info->offset_element_in_bytes( + arm_compute::Coordinates(w, h))]); + // require cpu input blob to be float + *b = T(Xcpu.data<float>()[h * W + w]); + } + } + } else { + auto size = Xcpu.dim32(0); + for (auto i = 0; i < size; ++i) { + T *b = (T *)(&byte_buffer[info->offset_element_in_bytes(arm_compute::Coordinates(i))]); + // require cpu input blob to be float + *b = T(Xcpu.data<float>()[i]); + } + } + unmap(); + } + } + + + const int32_t ndim() const { return dims_.size(); } + + vector<TIndex> dims() const { return dims_; } + + const int32_t dim32(const int index) const { return dims_.at(index); } + + const int32_t size() const { + int32_t s = 1; + for (int i = 0; i < dims_.size(); i++) { + s *= dims_[i]; + } + return s; + } + + arm_compute::GCTensor *get_underlying() const { return tensor_.get(); } + + T *map() const { + GLContext::sync(); + tensor_->map(true); + return reinterpret_cast<T *>(tensor_->buffer()); + } + + void unmap() const { return tensor_->unmap(); } + + void sync() const { + GLContext::sync(); + tensor_->map(); + tensor_->unmap(); + } + +private: + template <typename TI, typename = typename std::enable_if< + std::is_integral<TI>::value>::type> + bool SetDims(const vector<TI> &src) { + auto old_size = size_; + dims_.resize(src.size()); + TIndex new_size = 1; + for (unsigned int i = 0; i < src.size(); ++i) { + new_size *= src[i]; + dims_[i] = src[i]; + } + size_ = new_size; + return size_ != old_size; + } + + bool SetDims() { + auto old_size = size_; + dims_.resize(0); + size_ = 1; + return size_ != old_size; + } + + bool SetDims(const TIndex d0) { + auto old_size = size_; + dims_.resize(1); + dims_[0] = d0; + size_ = d0; + return size_ != old_size; + } + + bool SetDims(const TIndex d0, const TIndex d1) { + auto old_size = size_; + dims_.resize(2); + dims_[0] = d0; + dims_[1] = d1; + size_ = d0 * d1; + return size_ != old_size; + } + + bool SetDims(const TIndex d0, const TIndex d1, const TIndex d2) { + auto old_size = size_; + dims_.resize(3); + dims_[0] = d0; + dims_[1] = d1; + dims_[2] = d2; + size_ = d0 * d1 * d2; + return size_ != old_size; + } + + bool SetDims(const TIndex d0, const TIndex d1, const TIndex d2, + const TIndex d3) { + auto old_size = size_; + dims_.resize(4); + dims_[0] = d0; + dims_[1] = d1; + dims_[2] = d2; + dims_[3] = d3; + size_ = d0 * d1 * d2 * d3; + return size_ != old_size; + } + + vector<TIndex> dims_; + TIndex size_ = -1; + arm_compute::TensorShape shape_; + unique_ptr<arm_compute::GCTensor> tensor_; +}; + + +} // namespace caffe2 + +#endif /* CAFFE2_OPENGL_CONTEXT_H_ */ diff --git a/caffe2/mobile/contrib/arm-compute/core/net_gl.cc b/caffe2/mobile/contrib/arm-compute/core/net_gl.cc new file mode 100644 index 0000000000..43207a0ef0 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/net_gl.cc @@ -0,0 +1,219 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/mobile/contrib/arm-compute/core/net_gl.h" +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/core/net.h" + +#include <set> +#include <unordered_map> +#include <unordered_set> + +#include "caffe2/core/operator.h" +#include "caffe2/core/static_tracepoint.h" +#include "caffe2/core/timer.h" +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/utils/proto_utils.h" + +namespace caffe2 { + +GLNet::GLNet( + const std::shared_ptr<const NetDef>& net_def, + Workspace* ws) + : NetBase(net_def, ws) { + ws_ = ws; + VLOG(1) << "Constructing GLNet " << net_def->name(); + const bool net_def_has_device_option = net_def->has_device_option(); + // Initialize the operators + for (int idx = 0; idx < net_def->op_size(); ++idx) { + const auto& operator_def = net_def->op(idx); + VLOG(1) << "Creating operator " << operator_def.name() << ": " + << operator_def.type(); + output_blobs_.push_back(operator_def.output(0)); + if (operator_def.has_device_option() && operator_def.device_option().device_type() == OPENGL) { + opengl_device_.push_back(true); + } else { + opengl_device_.push_back(false); + } + + std::unique_ptr<OperatorBase> op{nullptr}; + if (!operator_def.has_device_option() && net_def_has_device_option) { + // In the case that the operator def does not specify a device option but + // the net def has a default option, we copy the device option over to the + // operator def. + OperatorDef temp_def(operator_def); + temp_def.mutable_device_option()->CopyFrom(net_def->device_option()); + op = CreateOperator(temp_def, ws, idx); + } else { + op = CreateOperator(operator_def, ws, idx); + op->set_debug_def( + std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))}); + } + operators_.emplace_back(std::move(op)); + } +} + +bool GLNet::Run() { + StartAllObservers(); + if (first_run_) { + first_run_ = false; + for (auto& op: operators_) { + op->Run(); + } + } + VLOG(1) << "Running net " << name_; + for (auto& op : operators_) { + bool res = op->Run(); + if (!res) { + LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); + return false; + } + } + StopAllObservers(); + return true; +} + +bool GLNet::RunAsync() { + return Run(); +} + +namespace { +template <typename A, typename B> +bool PairLargerThan(const std::pair<A, B>& x, const std::pair<A, B>& y) { + return x.second > y.second; +} +} + +vector<float> GLNet::TEST_Benchmark( + const int warmup_runs, + const int main_runs, + const bool run_individual) { + LOG(INFO) << "Starting benchmark."; + LOG(INFO) << "Running warmup runs."; + CAFFE_ENFORCE( + warmup_runs >= 0, + "Number of warm up runs should be non negative, provided ", + warmup_runs, + "."); + for (int i = 0; i < warmup_runs; ++i) { + CAFFE_ENFORCE(Run(), "Warmup run ", i, " has failed."); + } + + auto last_blob = output_blobs_[output_blobs_.size() - 1]; + Blob *gpu_out_blob = ws_->GetBlob(last_blob); + auto &g_ = gpu_out_blob->Get<GLTensor<half>>(); + // Enforce gpu execution + g_.sync(); + + LOG(INFO) << "Main runs."; + CAFFE_ENFORCE( + main_runs >= 0, + "Number of main runs should be non negative, provided ", + main_runs, + "."); + Timer timer; + for (int i = 0; i < main_runs; ++i) { + CAFFE_ENFORCE(Run(), "Main run ", i, " has failed."); + } + g_.sync(); + + auto millis = timer.MilliSeconds(); + LOG(INFO) << "[C2DEBUG] Main run finished. Milliseconds per iter: " + << millis / main_runs + << ". Iters per second: " << 1000.0 * main_runs / millis; + + vector<float> time_per_op(operators_.size(), 0); + vector<uint64_t> flops_per_op(operators_.size(), 0); + CaffeMap<string, float> time_per_op_type; + if (run_individual) { + for (int i = 0; i < main_runs; ++i) { + for (auto& op : operators_) { + op->ResetEvent(); + } + int idx = 0; + for (auto& op : operators_) { + const string& op_type = op->debug_def().type(); + if (i == 0) { // Gather flops on the first run. + auto* schema = OpSchemaRegistry::Schema(op_type); + if (schema && schema->HasCostInferenceFunction()) { + vector<TensorShape> shapes = op->InputTensorShapes(); + flops_per_op[idx] = + schema->InferCost(op->debug_def(), shapes).flops; + } + } + timer.Start(); + CAFFE_ENFORCE( + op->Run(), + "operator ", + op->debug_def().name(), + "(", + op_type, + ") has failed."); + if (opengl_device_[idx]) { + Blob *gpu_out_blob = ws_->GetBlob(output_blobs_[idx]); + auto &g_ = gpu_out_blob->Get<GLTensor<half>>(); + g_.sync(); + } + float spent = timer.MilliSeconds(); + time_per_op[idx] += spent; + time_per_op_type[op_type] += spent; + ++idx; + } + } + + int idx = 0; + for (auto& op : operators_) { + const string& op_type = op->debug_def().type(); + const string& print_name = + (op->debug_def().name().size() + ? op->debug_def().name() + : (op->debug_def().output_size() ? op->debug_def().output(0) + : "NO_OUTPUT")); + std::stringstream flops_str; + if (flops_per_op[idx]) { + flops_str << " (" + << to_string(1.0e-6 * flops_per_op[idx] / time_per_op[idx]) + << " GFLOPS)"; + } + LOG(INFO) << "[C2DEBUG] Operator #" << idx << " (" << print_name << ", " << op_type + << ") " << time_per_op[idx] / main_runs << " ms/iter" + << flops_str.str(); + ++idx; + } + LOG(INFO) << "[C2DEBUG] Time per operator type:"; + // sort by decreasing time spending. + std::vector<std::pair<string, float>> time_per_op_type_vec( + time_per_op_type.begin(), time_per_op_type.end()); + std::sort( + time_per_op_type_vec.begin(), + time_per_op_type_vec.end(), + PairLargerThan<string, float>); + for (const auto& item : time_per_op_type_vec) { + LOG(INFO) << "[C2DEBUG] " << std::setw(15) << std::setfill(' ') << item.second / main_runs + << " " << item.first; + } + } + // We will reuse time_per_op to return the result of BenchmarkNet. + for (int i = 0; i < time_per_op.size(); ++i) { + time_per_op[i] /= main_runs; + } + time_per_op.insert(time_per_op.begin(), millis / main_runs); + return time_per_op; +} + +REGISTER_NET(opengl, GLNet); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/core/net_gl.h b/caffe2/mobile/contrib/arm-compute/core/net_gl.h new file mode 100644 index 0000000000..27654ba875 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/net_gl.h @@ -0,0 +1,81 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CAFFE2_CORE_NET_GL_H_ +#define CAFFE2_CORE_NET_GL_H_ + +#include <vector> + +#include "caffe2/core/common.h" +#include "caffe2/core/logging.h" +#include "caffe2/core/net.h" +#include "caffe2/core/registry.h" +#include "caffe2/core/tensor.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +// This is the very basic structure you need to run a network with +// ARM's compute library +class GLNet : public NetBase { + private: + bool first_run_ = true; + Workspace* ws_; + // record output blob for sync step in operator level benchmarking + std::vector<string> output_blobs_; + // record operator type and only sync after gpu op + std::vector<bool> opengl_device_; + public: + GLNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws); + bool SupportsAsync() override { + return false; + } + + vector<float> TEST_Benchmark( + const int warmup_runs, + const int main_runs, + const bool run_individual) override; + + /* + * This returns a list of pointers to objects stored in unique_ptrs. + * Used by Observers. + * + * Think carefully before using. + */ + vector<OperatorBase*> GetOperators() const override { + vector<OperatorBase*> op_list; + for (auto& op : operators_) { + op_list.push_back(op.get()); + } + return op_list; + } + + protected: + bool Run(); + bool RunAsync(); + bool DoRunAsync() override { + return Run(); + } + + vector<unique_ptr<OperatorBase>> operators_; + + DISABLE_COPY_AND_ASSIGN(GLNet); +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_NET_SIMPLE_H_ diff --git a/caffe2/mobile/contrib/arm-compute/core/operator.cc b/caffe2/mobile/contrib/arm-compute/core/operator.cc new file mode 100644 index 0000000000..bd4337aa85 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/operator.cc @@ -0,0 +1,9 @@ +#include "operator.h" + +namespace caffe2 { + +CAFFE_DEFINE_REGISTRY(GLOperatorRegistry, OperatorBase, const OperatorDef &, + Workspace *); +CAFFE_REGISTER_DEVICE_TYPE(DeviceType::OPENGL, GLOperatorRegistry); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/core/operator.h b/caffe2/mobile/contrib/arm-compute/core/operator.h new file mode 100644 index 0000000000..037173054f --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/core/operator.h @@ -0,0 +1,27 @@ +#ifndef CAFFE2_OPENGL_OPERATOR_H_ +#define CAFFE2_OPENGL_OPERATOR_H_ + +#include "caffe2/core/operator.h" +#include "caffe2/core/registry.h" + +namespace caffe2 { + +CAFFE_DECLARE_REGISTRY(GLOperatorRegistry, OperatorBase, const OperatorDef &, + Workspace *); +#define REGISTER_GL_OPERATOR_CREATOR(key, ...) \ + CAFFE_REGISTER_CREATOR(GLOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_GL_OPERATOR(name, ...) \ + extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ + static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_GL##name() { \ + CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ + } \ + CAFFE_REGISTER_CLASS(GLOperatorRegistry, name, __VA_ARGS__) +#define REGISTER_GL_OPERATOR_STR(str_name, ...) \ + CAFFE_REGISTER_TYPED_CLASS(GLOperatorRegistry, str_name, __VA_ARGS__) + +#define REGISTER_GL_OPERATOR_WITH_ENGINE(name, engine, ...) \ + CAFFE_REGISTER_CLASS(GLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) + +} // namespace caffe2 + +#endif // CAFFE2_OPENGL_OPERATOR_H_ diff --git a/caffe2/mobile/contrib/arm-compute/models/squeezenet_init.pb b/caffe2/mobile/contrib/arm-compute/models/squeezenet_init.pb Binary files differnew file mode 100644 index 0000000000..3d3df32128 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/models/squeezenet_init.pb diff --git a/caffe2/mobile/contrib/arm-compute/models/squeezenet_predict.pb b/caffe2/mobile/contrib/arm-compute/models/squeezenet_predict.pb Binary files differnew file mode 100644 index 0000000000..188c347788 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/models/squeezenet_predict.pb diff --git a/caffe2/mobile/contrib/arm-compute/operators/CMakeLists.txt b/caffe2/mobile/contrib/arm-compute/operators/CMakeLists.txt new file mode 100644 index 0000000000..dbc170e14e --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE tmp *.cc) +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp} PARENT_SCOPE) diff --git a/caffe2/mobile/contrib/arm-compute/operators/activation_ops.cc b/caffe2/mobile/contrib/arm-compute/operators/activation_ops.cc new file mode 100644 index 0000000000..43e0e8fd95 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/activation_ops.cc @@ -0,0 +1,89 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +#include "caffe2/mobile/contrib/arm-compute/operators/activation_ops.h" +#include "caffe2/operators/relu_op.h" + +namespace caffe2 { + +template <typename T> +bool GLReluOp<T>::RunOnDevice() { + + auto *Xblob = OperatorBase::Inputs()[0]; + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + } + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + + if (first_run_) { + first_run_ = false; + if (Y->get_underlying() != X_->get_underlying()) + { + Y->ResizeLike(*X_); + } + relu_layer_.configure( + X_->get_underlying(), Y->get_underlying(), + arm_compute::ActivationLayerInfo( + arm_compute::ActivationLayerInfo::ActivationFunction::RELU)); + + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + // in place activation, do not need to allocate new memory + if (Y->get_underlying() != X_->get_underlying()) + { + Y->allocate(); + } + } + relu_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Relu, GLReluOp<half>); + +template <typename T> +bool GLSigmoidOp<T>::RunOnDevice() { + + auto *Xblob = OperatorBase::Inputs()[0]; + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + } + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + + if (Y->get_underlying() != X_->get_underlying()) + { + Y->ResizeLike(*X_); + } + + sigmoid_layer_.configure( + X_->get_underlying(), Y->get_underlying(), + arm_compute::ActivationLayerInfo( + arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC)); + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + // in place activation, do not need to allocate new memory + if (Y->get_underlying() != X_->get_underlying()) + { + Y->allocate(); + } + } + sigmoid_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Sigmoid, GLSigmoidOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/activation_ops.h b/caffe2/mobile/contrib/arm-compute/operators/activation_ops.h new file mode 100644 index 0000000000..4de6a07cf6 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/activation_ops.h @@ -0,0 +1,38 @@ +#ifndef CAFFE2_OPENGL_OPERATORS_ACTIVATION_OPS_H_ +#define CAFFE2_OPENGL_OPERATORS_ACTIVATION_OPS_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template <typename T> +class GLSigmoidOp final : public Operator<GLContext> { +public: + GLSigmoidOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCActivationLayer sigmoid_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; +}; + +template <typename T> class GLReluOp final : public Operator<GLContext> { +public: + GLReluOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLReluOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCActivationLayer relu_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; + +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPENGL_OPERATORS_ACTIVATION_OPS_H_ diff --git a/caffe2/mobile/contrib/arm-compute/operators/concat_op.cc b/caffe2/mobile/contrib/arm-compute/operators/concat_op.cc new file mode 100644 index 0000000000..b72a3f3186 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/concat_op.cc @@ -0,0 +1,88 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" +#include "caffe2/operators/concat_split_op.h" + +namespace caffe2 { + +template <typename T> class GLConcatOp final : public Operator<GLContext> { +public: + GLConcatOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLConcatOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCDepthConcatenateLayer concat_layer_; + bool first_run_ = true, second_run_ = true; + std::vector<GLContext::deleted_unique_ptr<const GLTensor<T>>> inputs_; + int channelCount_ = 0; +}; + + +template <typename T> +bool GLConcatOp<T>::RunOnDevice() { + + CAFFE_ENFORCE(InputSize() <= 4 && InputSize() >= 2, "Number \ + of input must be between 2 and 4."); + + auto *X0blob = OperatorBase::Inputs()[0]; + auto X0 = GLContext::getGLTensor<T>(X0blob); + if (first_run_) { + inputs_.push_back(std::move(X0)); + } + + int N = inputs_[0]->dim32(0); + int channels = inputs_[0]->dim32(1); + int height = inputs_[0]->dim32(2); + int width = inputs_[0]->dim32(3); + std::vector<const Blob*> inputsBlob; + inputsBlob.push_back(X0blob); + + if (first_run_) { + channelCount_ = channels; + for (int i = 1; i < Inputs().size(); ++i) { + auto *Xblob = OperatorBase::Inputs()[i]; + auto X = GLContext::getGLTensor<T>(Xblob); + CAFFE_ENFORCE_EQ(N, X->dim32(0), X->dim32(0)); + CAFFE_ENFORCE_EQ(height, X->dim32(2), X->dim32(2)); + CAFFE_ENFORCE_EQ(width, X->dim32(3), X->dim32(3)); + channelCount_ += X->dim32(1); + inputs_.push_back(std::move(X)); + } + } + + for (int i = 1; i < Inputs().size(); ++i) { + auto *Xblob = OperatorBase::Inputs()[i]; + inputsBlob.push_back(Xblob); + } + std::vector<int> output_dims = {N, channelCount_, height, width}; + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->Resize(output_dims); + + std::vector<arm_compute::IGCTensor*> inputsGC; + for (int i = 0; i < inputs_.size(); ++i) { + inputsGC.push_back(inputs_[i]->get_underlying()); + } + concat_layer_.configure(inputsGC, Y->get_underlying()); + } else { + for (int i = 0; i < inputs_.size(); ++i) { + auto* X = inputs_[i].get(); + auto* Xblob = inputsBlob[i]; + X->lazy_allocate(Xblob, second_run_, true); + } + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + concat_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Concat, GLConcatOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/conv_op.cc b/caffe2/mobile/contrib/arm-compute/operators/conv_op.cc new file mode 100644 index 0000000000..b1f0953b16 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/conv_op.cc @@ -0,0 +1,105 @@ +#include "arm_compute/graph/Graph.h" +#include "arm_compute/graph/Nodes.h" +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +#include "caffe2/operators/conv_op.h" + +namespace caffe2 { + +template <typename T> +class GLConvOp final : public ConvPoolOpBase<GLContext> { + public: + USE_CONV_POOL_BASE_FUNCTIONS(GLContext); + GLConvOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<GLContext>(operator_def, ws) { + // Since this is the default convolution implementation, we will + // use CAFFE_ENFORCE instead of OPERATOR_NEEDS_FEATURE. + CAFFE_ENFORCE( + group_ == 1 || order_ == StorageOrder::NCHW, + "Group convolution only supports NCHW order right now."); + } + ~GLConvOp() {} + + bool RunOnDevice() override; +private: + arm_compute::GCDirectConvolutionLayer conv_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_, filter_, bias_; +}; + +template <typename T> +bool GLConvOp<T>::RunOnDevice() { + auto *Xblob = OperatorBase::Inputs()[0]; + auto *filterblob = OperatorBase::Inputs()[1]; + auto *biasblob = OperatorBase::Inputs()[2]; + + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + filter_ = GLContext::getGLTensor<T>(filterblob); + bias_ = GLContext::getGLTensor<T>(biasblob); + } + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + + const int N = X_->dim32(0), H = X_->dim32(2), W = X_->dim32(3), C = X_->dim32(1); + + CAFFE_ENFORCE_EQ(kernel_.size(), 2, + "Only 2d convolution is supported with ARM compute backend"); + + CAFFE_ENFORCE(X_->ndim(), filter_->ndim()); + const int M = filter_->dim32(0); + CAFFE_ENFORCE(filter_->dim32(2) == kernel_h()); + CAFFE_ENFORCE(filter_->dim32(3) == kernel_w()); + CAFFE_ENFORCE(filter_->dim32(1) == C); + + if (first_run_) { + first_run_ = false; + + // resize output accordingly + TensorCPU fakeX; + fakeX.Resize(X_->dims()); + TensorCPU fakeY; + ConvPoolOpBase<GLContext>::SetOutputSize(fakeX, &fakeY, filter_->dim32(0)); + Y->ResizeLike(fakeY); + LOG(INFO) << "[C2DEBUG] dims of X " << X_->dims(); + LOG(INFO) << "[C2DEBUG] dims of X(gctensor) " + << X_->get_underlying()->info()->dimension(3) << " " + << X_->get_underlying()->info()->dimension(2) << " " + << X_->get_underlying()->info()->dimension(1) << " " + << X_->get_underlying()->info()->dimension(0) << " " + ; + LOG(INFO) << "[C2DEBUG] dims of Y " << Y->dims(); + LOG(INFO) << "[C2DEBUG] dims of Y(gctensor) " + << Y->get_underlying()->info()->dimension(3) << " " + << Y->get_underlying()->info()->dimension(2) << " " + << Y->get_underlying()->info()->dimension(1) << " " + << Y->get_underlying()->info()->dimension(0) << " " + ; + + conv_.configure( + X_->get_underlying(), filter_->get_underlying(), bias_->get_underlying(), + Y->get_underlying(), + arm_compute::PadStrideInfo(stride_[0], stride_[1], pads_[0], pads_[1])); + + } else { + // Always attempt to copy the CPU to GPU on input + X_->lazy_allocate(Xblob, second_run_, true); + filter_->lazy_allocate(filterblob, second_run_, second_run_); + bias_->lazy_allocate(biasblob, second_run_, second_run_); + if (second_run_) { + second_run_ = false; + if (Y->get_underlying() != X_->get_underlying()) { + Y->allocate(); + } + } + conv_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Conv, GLConvOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/elementwise_sum_op.cc b/caffe2/mobile/contrib/arm-compute/operators/elementwise_sum_op.cc new file mode 100644 index 0000000000..68638026e1 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/elementwise_sum_op.cc @@ -0,0 +1,54 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" +#include "caffe2/operators/utility_ops.h" + +namespace caffe2 { + +template <typename T> class GLSumOp final : public Operator<GLContext> { +public: + GLSumOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLSumOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCArithmeticAddition add_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> A_, B_; +}; + + +template <typename T> +bool GLSumOp<T>::RunOnDevice() { + + auto *Ablob = OperatorBase::Inputs()[0]; + auto *Bblob = OperatorBase::Inputs()[1]; + + if (first_run_) { + A_ = GLContext::getGLTensor<T>(Ablob); + B_ = GLContext::getGLTensor<T>(Bblob); + } + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->ResizeLike(*A_); + add_layer_.configure(A_->get_underlying(), B_->get_underlying(), Y->get_underlying(), arm_compute::ConvertPolicy::SATURATE); + } else { + A_->lazy_allocate(Ablob, second_run_, true); + B_->lazy_allocate(Bblob, second_run_, true); + if (second_run_) { + Y->allocate(); + second_run_ = false; + } + add_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Sum, GLSumOp<DataType>); +REGISTER_GL_OPERATOR(Add, GLSumOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/fully_connected_op.cc b/caffe2/mobile/contrib/arm-compute/operators/fully_connected_op.cc new file mode 100644 index 0000000000..da4d5cc9a8 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/fully_connected_op.cc @@ -0,0 +1,68 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +#include "caffe2/operators/fully_connected_op.h" + +namespace caffe2 { + +template <typename T> class GLFullyConnectedOp final : public Operator<GLContext> { +public: + GLFullyConnectedOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLFullyConnectedOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCFullyConnectedLayer fc_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_, W_, B_; +}; + +template <typename T> +bool GLFullyConnectedOp<T>::RunOnDevice() { + + auto Xblob = OperatorBase::Inputs()[0]; + auto *Wblob = OperatorBase::Inputs()[1]; + auto *Bblob = OperatorBase::Inputs()[2]; + + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + W_ = GLContext::getGLTensor<T>(Wblob); + B_ = GLContext::getGLTensor<T>(Bblob); + } + + auto M = X_->dim32(0); + auto CIn = X_->dim32(1); + auto Height = X_->dim32(2); + auto Width = X_->dim32(3); + auto N = W_->dim32(0); + + CAFFE_ENFORCE_EQ(1, B_->ndim()); + CAFFE_ENFORCE_EQ(N, B_->dim32(0)); + + vector<TIndex> output_dims = {M, N}; + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->Resize(output_dims); + + fc_layer_.configure(X_->get_underlying(), W_->get_underlying(), + B_->get_underlying(), Y->get_underlying(), true, false); + } else { + X_->lazy_allocate(Xblob, second_run_, true); + W_->lazy_allocate(Wblob, second_run_, second_run_); + B_->lazy_allocate(Bblob, second_run_, second_run_); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + fc_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(FC, GLFullyConnectedOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/norm_planar_yuv_op.cc b/caffe2/mobile/contrib/arm-compute/operators/norm_planar_yuv_op.cc new file mode 100644 index 0000000000..f0eee4a259 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/norm_planar_yuv_op.cc @@ -0,0 +1,63 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +namespace caffe2 { + +template <typename T> +class GLNormalizePlanarYUVOp final : public Operator<GLContext> { +public: + GLNormalizePlanarYUVOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLNormalizePlanarYUVOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCNormalizePlanarYUVLayer norm_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_, mean_, sd_; +}; + +template <typename T> bool GLNormalizePlanarYUVOp<T>::RunOnDevice() { + + auto Xblob = OperatorBase::Inputs()[0]; + auto *meanblob = OperatorBase::Inputs()[1]; + auto *sdblob = OperatorBase::Inputs()[2]; + + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + mean_ = GLContext::getGLTensor<T>(meanblob); + sd_ = GLContext::getGLTensor<T>(sdblob); + } + + CAFFE_ENFORCE_EQ(X_->ndim(), 4); + auto N = X_->dim32(0); + auto C = X_->dim32(1); + auto H = X_->dim32(2); + auto W = X_->dim32(3); + + CAFFE_ENFORCE_EQ(C, mean_->dim32(1)); + CAFFE_ENFORCE_EQ(C, sd_->dim32(1)); + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->ResizeLike(*X_); + norm_layer_.configure(X_->get_underlying(), Y->get_underlying(), mean_->get_underlying(), sd_->get_underlying()); + } else { + X_->lazy_allocate(Xblob, second_run_, true); + mean_->lazy_allocate(meanblob, second_run_, second_run_); + sd_->lazy_allocate(sdblob, second_run_, second_run_); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + norm_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(NormalizePlanarYUV, GLNormalizePlanarYUVOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/pool_op.cc b/caffe2/mobile/contrib/arm-compute/operators/pool_op.cc new file mode 100644 index 0000000000..972c857374 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/pool_op.cc @@ -0,0 +1,159 @@ +#include "caffe2/operators/pool_op.h" +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +namespace caffe2 { + +template <typename T> +class GLAveragePoolOp final : public ConvPoolOpBase<GLContext> { + public: + USE_CONV_POOL_BASE_FUNCTIONS(GLContext); + GLAveragePoolOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<GLContext>(operator_def, ws) { + } + ~GLAveragePoolOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; +private: + arm_compute::GCPoolingLayer pooling_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; +}; + +template<typename T> +class GLMaxPoolOp final : public ConvPoolOpBase<GLContext> { + public: + USE_CONV_POOL_BASE_FUNCTIONS(GLContext); + GLMaxPoolOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase<GLContext>(operator_def, ws) { + } + ~GLMaxPoolOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; +private: + arm_compute::GCPoolingLayer pooling_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; +}; + +template <> +bool GLAveragePoolOp<half>::RunOnDeviceWithOrderNCHW() { + + auto *Xblob = OperatorBase::Inputs()[0]; + if (first_run_) { + X_ = GLContext::getGLTensor<half>(Xblob); + } + + int N = X_->dim32(0); + int channels = X_->dim32(1); + int height = X_->dim32(2); + int width = X_->dim32(3); + + GLTensor<half> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<half>>(); + if (first_run_) { + first_run_ = false; + CAFFE_ENFORCE_EQ(kernel_.size(), 2, "ARM OpenGL only supports 2D pooling"); + CAFFE_ENFORCE_EQ(kernel_h(), kernel_w(), + "ARM OpenGL only supports equal kernel size"); + if (global_pooling_) { + vector<TIndex> output_dims = {N, channels, 1, 1}; + Y->Resize(output_dims); + } else { + vector<TIndex> output_dims = {N, channels, 0, 0}; + output_dims[2] = (height + pad_t() + pad_b() - kernel_h()) / stride_h() + 1; + output_dims[3] = (width + pad_l() + pad_r() - kernel_w()) / stride_w() + 1; + Y->Resize(output_dims); + } + if (global_pooling_) { + arm_compute::PoolingLayerInfo info(arm_compute::PoolingType::AVG); + pooling_layer_.configure(X_->get_underlying(), Y->get_underlying(), info); + } else { + arm_compute::PadStrideInfo ps_info(stride_w(), stride_h(), pad_l(), pad_r(), + pad_t(), pad_b(), + arm_compute::DimensionRoundingType::FLOOR); + arm_compute::PoolingLayerInfo info(arm_compute::PoolingType::AVG, kernel_h(), + ps_info); + pooling_layer_.configure(X_->get_underlying(), Y->get_underlying(), info); + } + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + pooling_layer_.run(); + } + + return true; +} + +template <> bool GLMaxPoolOp<half>::RunOnDeviceWithOrderNCHW() { + + auto *Xblob = OperatorBase::Inputs()[0]; + if (first_run_) { + X_ = GLContext::getGLTensor<half>(Xblob); + } + + int N = X_->dim32(0); + int channels = X_->dim32(1); + int height = X_->dim32(2); + int width = X_->dim32(3); + + GLTensor<half> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<half>>(); + + if (first_run_) { + first_run_ = false; + CAFFE_ENFORCE_EQ(kernel_.size(), 2, "ARM OpenGL only supports 2D pooling"); + CAFFE_ENFORCE_EQ(kernel_h(), kernel_w(), + "ARM OpenGL only supports equal kernel size"); + if (global_pooling_) { + vector<TIndex> output_dims = {N, channels, 1, 1}; + Y->Resize(output_dims); + } else { + vector<int> output_dims = {1, 0, 0, 0}; + output_dims[1] = channels; + output_dims[2] = (height + pad_t() + pad_b() - kernel_h()) / stride_h() + 1; + output_dims[3] = (width + pad_l() + pad_r() - kernel_w()) / stride_w() + 1; + Y->Resize(output_dims); + } + if (global_pooling_) { + arm_compute::PoolingLayerInfo info(arm_compute::PoolingType::MAX); + pooling_layer_.configure(X_->get_underlying(), Y->get_underlying(), info); + } else { + arm_compute::PadStrideInfo ps_info(stride_w(), stride_h(), pad_l(), pad_r(), + pad_t(), pad_b(), + arm_compute::DimensionRoundingType::FLOOR); + arm_compute::PoolingLayerInfo info(arm_compute::PoolingType::MAX, kernel_h(), + ps_info); + pooling_layer_.configure(X_->get_underlying(), Y->get_underlying(), info); + } + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + pooling_layer_.run(); + } + + return true; +} + +template <> +bool GLAveragePoolOp<half>::RunOnDeviceWithOrderNHWC() { + return false; +} + +template <> +bool GLMaxPoolOp<half>::RunOnDeviceWithOrderNHWC() { + return false; +} + +REGISTER_GL_OPERATOR(AveragePool, GLAveragePoolOp<DataType>); +REGISTER_GL_OPERATOR(MaxPool, GLMaxPoolOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/reshape_op.cc b/caffe2/mobile/contrib/arm-compute/operators/reshape_op.cc new file mode 100644 index 0000000000..7eb860f7d4 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/reshape_op.cc @@ -0,0 +1,30 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" +#include "caffe2/operators/reshape_op.h" + +namespace caffe2 { + +template <typename T> class GLReshapeOp final : public Operator<GLContext> { +public: + GLReshapeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLReshapeOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +}; + +template <typename T> +bool GLReshapeOp<T>::RunOnDevice() { + auto *Xblob = OperatorBase::Inputs()[0]; + auto X = GLContext::getGLTensor<T>(Xblob); + LOG(INFO) << "[C2DEBUG] X: " << X->dim32(0) << " " << X->dim32(1) << " " << X->dim32(2) << " " << X->dim32(3); + auto arg = OperatorBase::GetRepeatedArgument<int>("shape"); + for (int i = 0; i < arg.size(); ++i) { + LOG(INFO) << "[C2DEBUG] shape: " << arg[i]; + } + return true; +} + +REGISTER_GL_OPERATOR(Reshape, GLReshapeOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/resize_op.cc b/caffe2/mobile/contrib/arm-compute/operators/resize_op.cc new file mode 100644 index 0000000000..ef0c6dd733 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/resize_op.cc @@ -0,0 +1,69 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" +#include "caffe2/operators/resize_op.h" + +namespace caffe2 { + +template<typename T> +class GLResizeNearestOp final : public Operator<GLContext> { +public: + GLResizeNearestOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws), width_scale_(1), height_scale_(1) { + if (HasArgument("width_scale")) { + width_scale_ = static_cast<float>( + OperatorBase::GetSingleArgument<float>("width_scale", 1)); + } + if (HasArgument("height_scale")) { + height_scale_ = static_cast<float>( + OperatorBase::GetSingleArgument<float>("height_scale", 1)); + } + CAFFE_ENFORCE_GT(width_scale_, 0); + CAFFE_ENFORCE_GT(height_scale_, 0); + } + virtual ~GLResizeNearestOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + float width_scale_; + float height_scale_; + arm_compute::GCScale resize_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; +}; + +template <typename T> +bool GLResizeNearestOp<T>::RunOnDevice() { + + auto Xblob = OperatorBase::Inputs()[0]; + + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + } + + auto N = X_->dim32(0); + auto C = X_->dim32(1); + auto H = X_->dim32(2); + auto W = X_->dim32(3); + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + vector<TIndex> output_dims = {N, C, H * height_scale_, W * width_scale_}; + Y->Resize(output_dims); + first_run_ = false; + resize_layer_.configure(X_->get_underlying(), Y->get_underlying(), arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR, arm_compute::BorderMode::UNDEFINED); + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + resize_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(ResizeNearest, GLResizeNearestOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/softmax_op.cc b/caffe2/mobile/contrib/arm-compute/operators/softmax_op.cc new file mode 100644 index 0000000000..cb7891f24c --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/softmax_op.cc @@ -0,0 +1,49 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +#include "caffe2/operators/softmax_op.h" + +namespace caffe2 { + +template <typename T> class GLSoftmaxOp final : public Operator<GLContext> { +public: + GLSoftmaxOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws) {} + virtual ~GLSoftmaxOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; +private: + arm_compute::GCSoftmaxLayer softmax_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_; +}; + +template <typename T> +bool GLSoftmaxOp<T>::RunOnDevice() { + + auto *Xblob = OperatorBase::Inputs()[0]; + if (first_run_) { + X_ = GLContext::getGLTensor<T>(Xblob); + } + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->ResizeLike(*X_); + softmax_layer_.configure(X_->get_underlying(), Y->get_underlying()); + } else { + X_->lazy_allocate(Xblob, second_run_, true); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + softmax_layer_.run(); + } + + return true; +} + +REGISTER_GL_OPERATOR(Softmax, GLSoftmaxOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/operators/spatial_batch_norm_op.cc b/caffe2/mobile/contrib/arm-compute/operators/spatial_batch_norm_op.cc new file mode 100644 index 0000000000..495f2a2eb6 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/operators/spatial_batch_norm_op.cc @@ -0,0 +1,85 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/core/operator.h" + +#include "caffe2/operators/spatial_batch_norm_op.h" + +namespace caffe2 { + +template <typename T> class GLSpatialBNOp final : public Operator<GLContext> { +public: + GLSpatialBNOp(const OperatorDef &operator_def, Workspace *ws) + : Operator<GLContext>(operator_def, ws), + is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)), + epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)), + momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9)), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument<string>("order", "NCHW"))) { } + virtual ~GLSpatialBNOp() noexcept {} + USE_OPERATOR_FUNCTIONS(GLContext); + bool RunOnDevice() override; + protected: + bool is_test_; + double epsilon_; + double momentum_; + StorageOrder order_; + INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR); + OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR); +private: + arm_compute::GCBatchNormalizationLayer bn_layer_; + bool first_run_ = true, second_run_ = true; + GLContext::deleted_unique_ptr<const GLTensor<T>> X_, mean_, var_, bias_, scale_; +}; + +template <typename T> +bool GLSpatialBNOp<T>::RunOnDevice() { + auto *XBlob = OperatorBase::Inputs()[0]; + auto *scaleBlob = OperatorBase::Inputs()[SCALE]; + auto *biasBlob = OperatorBase::Inputs()[BIAS]; + auto *meanBlob = OperatorBase::Inputs()[EST_MEAN]; + auto *varBlob = OperatorBase::Inputs()[EST_VAR]; + + if (first_run_) { + X_ = GLContext::getGLTensor<T>(XBlob); + scale_ = GLContext::getGLTensor<T>(scaleBlob); + bias_ = GLContext::getGLTensor<T>(biasBlob); + mean_ = GLContext::getGLTensor<T>(meanBlob); + var_ = GLContext::getGLTensor<T>(varBlob); + } + + auto C = X_->dim32(1); + CAFFE_ENFORCE_EQ(scale_->ndim(), 1); + CAFFE_ENFORCE_EQ(bias_->ndim(), 1); + CAFFE_ENFORCE_EQ(mean_->ndim(), 1); + CAFFE_ENFORCE_EQ(var_->ndim(), 1); + + CAFFE_ENFORCE_EQ(scale_->dim32(0), C); + CAFFE_ENFORCE_EQ(bias_->dim32(0), C); + CAFFE_ENFORCE_EQ(mean_->dim32(0), C); + CAFFE_ENFORCE_EQ(var_->dim32(0), C); + + GLTensor<T> *Y = + OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>(); + if (first_run_) { + first_run_ = false; + Y->ResizeLike(*X_); + bn_layer_.configure(X_->get_underlying(), Y->get_underlying(), + mean_->get_underlying(), var_->get_underlying(), + bias_->get_underlying(), scale_->get_underlying(), epsilon_); + } else { + X_->lazy_allocate(XBlob, second_run_, true); + scale_->lazy_allocate(scaleBlob, second_run_, second_run_); + bias_->lazy_allocate(biasBlob, second_run_, second_run_); + mean_->lazy_allocate(meanBlob, second_run_, second_run_); + var_->lazy_allocate(varBlob, second_run_, second_run_); + if (second_run_) { + second_run_ = false; + Y->allocate(); + } + bn_layer_.run(); + } + return true; +} + +REGISTER_GL_OPERATOR(SpatialBN, GLSpatialBNOp<DataType>); + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/run_tests.sh b/caffe2/mobile/contrib/arm-compute/run_tests.sh new file mode 100755 index 0000000000..c08eece2f4 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/run_tests.sh @@ -0,0 +1,22 @@ +set -vex + +if [ -z "$CAFFE2_BINARY_DIR" ] ; then + if [ -z "$1" ] ; then + CAFFE2_BINARY_DIR=. + else + CAFFE2_BINARY_DIR=$1 + fi +fi + +files=($(find "$CAFFE2_BINARY_DIR" -type f -name "*_test")) +for test_binary in "${files[@]}"; +do + test_binary_base=$(basename $test_binary) + if [[ $test_binary_base == gl* ]];then + echo Running $test_binary_base + adb push $test_binary "/data/local/tmp/$test_binary_base" + adb shell "GLOG_logtostderr=1 /data/local/tmp/$test_binary_base" + fi +done + +echo All tests passed. diff --git a/caffe2/mobile/contrib/arm-compute/test/CMakeLists.txt b/caffe2/mobile/contrib/arm-compute/test/CMakeLists.txt new file mode 100644 index 0000000000..480846c965 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB tmp *_test.cc) +set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp} PARENT_SCOPE) diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_activation_ops_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_activation_ops_test.cc new file mode 100644 index 0000000000..7b1d261a75 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_activation_ops_test.cc @@ -0,0 +1,70 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, Sigmoid) { + Workspace ws; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, 4, 4, 4}); + + NetDef cpu_net; + { + AddOp(&cpu_net, "Sigmoid", {"cpu_X"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Sigmoid", {"cpu_X"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + compareNetResult(ws, cpu_net, gpu_net); + +} + +TEST(OPENGLOperatorTest, ReLU) { + Workspace ws; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, 4, 4, 4}); + + NetDef cpu_net; + { + AddOp(&cpu_net, "Relu", {"cpu_X"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Relu", {"cpu_X"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); +} + +TEST(OPENGLOperatorTest, SigmoidTwice) { + Workspace ws; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, 4, 4, 4}); + + NetDef cpu_net; + { + AddOp(&cpu_net, "Sigmoid", {"cpu_X"}, {"ref_Y1"}); + AddOp(&cpu_net, "Sigmoid", {"ref_Y1"}, {"ref_Y2"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Sigmoid", {"cpu_X"}, {"gpu_Y1"}); + MAKE_OPENGL_OPERATOR(def); + } + { + OperatorDef* def = AddOp(&gpu_net, "Sigmoid", {"gpu_Y1"}, {"gpu_Y2"}); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net, "ref_Y2", "gpu_Y2"); +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_alignment_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_alignment_test.cc new file mode 100644 index 0000000000..9b74fdc871 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_alignment_test.cc @@ -0,0 +1,197 @@ +#include "gl_operator_test.h" +#include "caffe2/core/timer.h" + +namespace caffe2 { + +constexpr float tol = 5.0e-2; + +// {MaxPool, Relu, Add} followed by pad 1 conv +TEST(OPENGLOperatorTest, ConvMaxPoolConv) { + + Workspace ws; + auto channel_in = 16; + auto channel_out = 16; + auto spatial = 32; + auto kern = 3; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, channel_in, spatial, spatial}, 1337); + PopulateCPUBlob(&ws, true, "W", {channel_out, channel_in, kern, kern}, 1337); + PopulateCPUBlob(&ws, false, "b", {channel_out}, 0); + PopulateCPUBlob(&ws, true, "W2", {channel_out, channel_in, kern, kern}); + PopulateCPUBlob(&ws, true, "b2", {channel_out}); + +#define ADD_CONV_ARGS \ + { \ + ADD_ARG((*def), "kernel", i, kern); \ + ADD_ARG((*def), "stride", i, 1); \ + ADD_ARG((*def), "pad", i, 1); \ + ADD_ARG((*def), "order", s, "NCHW"); \ + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"cpu_X", "W", "b"}, {"ref_Y"}); + def->set_name("cpu_conv"); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&cpu_net, "MaxPool", {"ref_Y"}, {"ref_maxpool"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride_w", i, 2); + ADD_ARG((*def), "stride_h", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + } + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"ref_maxpool", "W2", "b2"}, {"ref_Y2"}); + ADD_CONV_ARGS; + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"cpu_X", "W", "b"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&gpu_net, "MaxPool", {"gpu_Y"}, {"gpu_maxpool"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride_w", i, 2); + ADD_ARG((*def), "stride_h", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + MAKE_OPENGL_OPERATOR(def); + } + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"gpu_maxpool", "W2", "b2"}, {"gpu_Y2"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + +#undef ADD_CONV_ARGS + + // will work after next release of ACL + // compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y2", "gpu_Y2", tol); +} + +TEST(OPENGLOperatorTest, ConvReluConv) { + + Workspace ws; + auto channel_in = 16; + auto channel_out = 16; + auto spatial = 32; + auto kern = 3; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, channel_in, spatial, spatial}, 1337); + PopulateCPUBlob(&ws, true, "W", {channel_out, channel_in, kern, kern}, 1337); + PopulateCPUBlob(&ws, false, "b", {channel_out}, 0); + PopulateCPUBlob(&ws, true, "W2", {channel_out, channel_in, kern, kern}); + PopulateCPUBlob(&ws, true, "b2", {channel_out}); + +#define ADD_CONV_ARGS \ + { \ + ADD_ARG((*def), "kernel", i, kern); \ + ADD_ARG((*def), "stride", i, 1); \ + ADD_ARG((*def), "pad", i, 1); \ + ADD_ARG((*def), "order", s, "NCHW"); \ + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"cpu_X", "W", "b"}, {"ref_Y"}); + def->set_name("cpu_conv"); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&cpu_net, "Relu", {"ref_Y"}, {"ref_relu"}); + } + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"ref_relu", "W2", "b2"}, {"ref_Y2"}); + ADD_CONV_ARGS; + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"cpu_X", "W", "b"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&gpu_net, "Relu", {"gpu_Y"}, {"gpu_relu"}); + MAKE_OPENGL_OPERATOR(def); + } + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"gpu_relu", "W2", "b2"}, {"gpu_Y2"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + +#undef ADD_CONV_ARGS + + // will work after next release of ACL + // compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y2", "gpu_Y2", tol); + +} + +TEST(OPENGLOperatorTest, ConvAddConv) { + + Workspace ws; + auto channel_in = 16; + auto channel_out = 16; + auto spatial = 32; // --> 2x2 w no padding, all values 9 + auto kern = 3; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, channel_in, spatial, spatial}, 1337); + PopulateCPUBlob(&ws, true, "W", {channel_out, channel_in, kern, kern}, 1337); + PopulateCPUBlob(&ws, false, "b", {channel_out}, 0); + PopulateCPUBlob(&ws, true, "W2", {channel_out, channel_in, kern, kern}); + PopulateCPUBlob(&ws, true, "b2", {channel_out}); + PopulateCPUBlob(&ws, true, "cpu_Y", {1, channel_in, spatial, spatial}, 1337); + +#define ADD_CONV_ARGS \ + { \ + ADD_ARG((*def), "kernel", i, kern); \ + ADD_ARG((*def), "stride", i, 1); \ + ADD_ARG((*def), "pad", i, 1); \ + ADD_ARG((*def), "order", s, "NCHW"); \ + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"cpu_X", "W", "b"}, {"ref_Y"}); + def->set_name("cpu_conv"); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&cpu_net, "Add", {"ref_Y", "cpu_Y"}, {"ref_add"}); + } + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"ref_add", "W2", "b2"}, {"ref_Y2"}); + ADD_CONV_ARGS; + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"cpu_X", "W", "b"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&gpu_net, "Add", {"gpu_Y", "cpu_Y"}, {"gpu_add"}); + MAKE_OPENGL_OPERATOR(def); + } + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"gpu_add", "W2", "b2"}, {"gpu_Y2"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } +#undef ADD_CONV_ARGS + + // will work after next release of ACL + // compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y2", "gpu_Y2", tol); + +} +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_concat_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_concat_op_test.cc new file mode 100644 index 0000000000..5676521ab6 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_concat_op_test.cc @@ -0,0 +1,45 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, Concat) { + + for (auto Cs: std::vector<std::vector<int>>{ + {4, 4}, + {4, 4, 4}, + {6, 6, 6}, + {16, 8, 4}, + {32, 8, 16, 4}, + }) { + Workspace ws; + int batchSize = 1; + int H = 8; + int W = 8; + for (int i = 0; i < Cs.size(); ++i) { + PopulateCPUBlob(&ws, true, std::string("cpu_X") + caffe2::to_string(i), {batchSize, Cs[i], H, W}); + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Concat", {}, {"ref_Y", "cpu_dummy"}); + for (int i = 0; i < Cs.size(); ++i ) { + def->add_input(std::string("cpu_X") + caffe2::to_string(i)); + } + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Concat", {}, {"gpu_Y", "gpu_dummy"}); + MAKE_OPENGL_OPERATOR(def); + for (int i = 0; i < Cs.size(); ++i ) { + def->add_input(std::string("cpu_X") + caffe2::to_string(i)); + } + } + + compareNetResult(ws, cpu_net, gpu_net); + + } +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_context_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_context_test.cc new file mode 100644 index 0000000000..71d1136821 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_context_test.cc @@ -0,0 +1,11 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include <gtest/gtest.h> + +namespace caffe2 { + +TEST(OPENGLContextTest, Initialization) { + auto gc = new GLContext(); + delete gc; +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_conv_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_conv_op_test.cc new file mode 100644 index 0000000000..bb8f397873 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_conv_op_test.cc @@ -0,0 +1,162 @@ +#include "gl_operator_test.h" +#include "caffe2/core/timer.h" + +namespace caffe2 { + +constexpr float tol = 3.0e-2; + +TEST(OPENGLOperatorTest, Conv) { + + Workspace ws; + auto channel_in = 16; + auto channel_out = 16; + auto spatial = 16; // --> 2x2 w no padding, all values 9 + auto kern = 3; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, channel_in, spatial, spatial}, 1337); + PopulateCPUBlob(&ws, true, "W", {channel_out, channel_in, kern, kern}, 1337); + PopulateCPUBlob(&ws, false, "b", {channel_out}, 0); + +#define ADD_CONV_ARGS \ + { \ + ADD_ARG((*def), "kernel", i, kern); \ + ADD_ARG((*def), "stride", i, 1); \ + ADD_ARG((*def), "pad", i, 0); \ + ADD_ARG((*def), "order", s, "NCHW"); \ + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"cpu_X", "W", "b"}, {"ref_Y"}); + def->set_name("cpu_conv"); + ADD_CONV_ARGS; + } + ws.RunNetOnce(cpu_net); + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + + OperatorDef* def = AddOp(&gpu_net, "Conv", {"cpu_X", "W", "b"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + +#undef ADD_CONV_ARGS + + compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y", "gpu_Y", tol); + +} + +TEST(OPENGLOperatorTest, ConvReluConv) { + + Workspace ws; + auto channel_in = 16; + auto channel_out = 16; + auto spatial = 32; // --> 2x2 w no padding, all values 9 + auto kern = 3; + + PopulateCPUBlob(&ws, true, "cpu_X", {1, channel_in, spatial, spatial}, 1337); + PopulateCPUBlob(&ws, true, "W", {channel_out, channel_in, kern, kern}, 1337); + PopulateCPUBlob(&ws, false, "b", {channel_out}, 0); + PopulateCPUBlob(&ws, true, "W2", {channel_out, channel_in, kern, kern}); + PopulateCPUBlob(&ws, true, "b2", {channel_out}); + +#define ADD_CONV_ARGS \ + { \ + ADD_ARG((*def), "kernel", i, kern); \ + ADD_ARG((*def), "stride", i, 1); \ + ADD_ARG((*def), "pad", i, 0); \ + ADD_ARG((*def), "order", s, "NCHW"); \ + } + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"cpu_X", "W", "b"}, {"ref_Y"}); + def->set_name("cpu_conv"); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&cpu_net, "Relu", {"ref_Y"}, {"ref_relu"}); + } + { + OperatorDef* def = AddOp(&cpu_net, "Conv", {"ref_relu", "W2", "b2"}, {"ref_Y2"}); + ADD_CONV_ARGS; + } + + ws.RunNetOnce(cpu_net); + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"cpu_X", "W", "b"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + { + OperatorDef* def = AddOp(&gpu_net, "Relu", {"gpu_Y"}, {"gpu_relu"}); + MAKE_OPENGL_OPERATOR(def); + } + { + OperatorDef* def = AddOp(&gpu_net, "Conv", {"gpu_relu", "W2", "b2"}, {"gpu_Y2"}); + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS; + } + +#undef ADD_CONV_ARGS + + compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y2", "gpu_Y2", tol); + +} + +TEST(OPENGLOperatorTest, ConvBenchmark) { + + Workspace ws; + auto channel_in = 4; + auto channel_out = 4; + auto spatial = 10; + auto kern = 3; + long long iters = 2; + + PopulateCPUBlob(&ws, false, "cpu_X", {1, channel_in, spatial, spatial}, 1, 0, 0.1); + +#define ADD_CONV_ARGS(_def) \ + { \ + ADD_ARG((*_def), "kernel", i, kern); \ + ADD_ARG((*_def), "stride", i, 1); \ + ADD_ARG((*_def), "pad", i, 0); \ + ADD_ARG((*_def), "order", s, "NCHW"); \ + } + + NetDef gpu_net; + NetDef cpu_net; + gpu_net.set_type("opengl"); + + std::string prev_out = "cpu_X"; + for (auto i = 0; i < iters; ++i) { + std::string weightName = "W" + to_string(i); + std::string biasName = "b" + to_string(i); + std::string output = "conv" + to_string(i); + PopulateCPUBlob(&ws, false, weightName, {channel_out, channel_in, kern, kern}, 1); + PopulateCPUBlob(&ws, false, biasName, {channel_out}, 0); + OperatorDef* def = AddOp(&gpu_net, "Conv", {prev_out, weightName, biasName}, {output}); + if (i == 0) { + OperatorDef* def2 = AddOp(&cpu_net, "Conv", {prev_out, weightName, biasName}, {"cpu" + output}); + ADD_CONV_ARGS(def2); + } else { + OperatorDef* def2 = AddOp(&cpu_net, "Conv", {"cpu" + prev_out, weightName, biasName}, {"cpu" + output}); + ADD_CONV_ARGS(def2); + } + prev_out = output; + MAKE_OPENGL_OPERATOR(def); + ADD_CONV_ARGS(def); + } + +#undef ADD_CONV_ARGS + + compareNetResult4D(ws, cpu_net, gpu_net, "cpu" + prev_out, prev_out, tol); + +} + +} // namespace caffe2 + diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_elementwise_sum_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_elementwise_sum_op_test.cc new file mode 100644 index 0000000000..a732237538 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_elementwise_sum_op_test.cc @@ -0,0 +1,27 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, Sum) { + Workspace ws; + int N = 28; + int D = 128; + PopulateCPUBlob(&ws, true, "cpu_X", {N, D}, 1); + PopulateCPUBlob(&ws, true, "cpu_Y", {N, D}, 1); + + NetDef cpu_net; + { + AddOp(&cpu_net, "Sum", {"cpu_X", "cpu_Y"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Sum", {"cpu_X", "cpu_Y"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_fully_connected_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_fully_connected_op_test.cc new file mode 100644 index 0000000000..0880b2a574 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_fully_connected_op_test.cc @@ -0,0 +1,36 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, FC) { + + Workspace ws; + int batchSize = 1; + int CIn = 4; + int H = 8; + int W = 8; + int COut = 16; + + PopulateCPUBlob(&ws, true, "cpu_X", {batchSize, CIn, H, W}); + PopulateCPUBlob(&ws, true, "cpu_W", {COut, CIn * H * W}); + PopulateCPUBlob(&ws, true, "cpu_B", {COut}); + + constexpr float tol = 0.2; + + NetDef cpu_net; + { + AddOp(&cpu_net, "FC", {"cpu_X", "cpu_W", "cpu_B"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "FC", {"cpu_X", "cpu_W", "cpu_B"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + + // will work after the next release of ACL + // compareNetResult(ws, cpu_net, gpu_net, "ref_Y", "gpu_Y", tol, true); +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_model_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_model_test.cc new file mode 100644 index 0000000000..0bb976ad65 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_model_test.cc @@ -0,0 +1,11 @@ +#include "caffe2/mobile/contrib/arm-compute/test/gl_model_test.h" + +namespace caffe2 { + +// The last softmax op didn't pass because of the dimension mismatch, and we are not likely to hit it in other models, but the implementation should be correct +// TEST(OPENGLModelTest, SqueezenetV11) { +// std::string parent_path = "/data/local/tmp/"; +// benchmarkModel(parent_path + "squeezenet_init.pb", parent_path + "squeezenet_predict.pb", "data", {1, 3, 224, 224}, "squeezenet_v11"); +// } + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_model_test.h b/caffe2/mobile/contrib/arm-compute/test/gl_model_test.h new file mode 100644 index 0000000000..55baf24328 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_model_test.h @@ -0,0 +1,63 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include "caffe2/mobile/contrib/arm-compute/test/gl_operator_test.h" +#include <gtest/gtest.h> + +#include "caffe2/core/operator.h" +#include "caffe2/core/workspace.h" + +CAFFE2_DEFINE_int(warmup, 10, "The number of iterations to warm up."); +CAFFE2_DEFINE_int(iter, 100, "The number of iterations to run."); +CAFFE2_DEFINE_bool( + run_individual, + false, + "Whether to benchmark individual operators."); + + +constexpr float tol = 0.03; +namespace caffe2 { + void benchmarkModel(std::string init_net_pb, std::string predict_net_pb, std::string input_name, std::vector<int> input_dims, std::string net_name="benchmark_net", std::vector<std::string> cpu_ops = {}) { + unique_ptr<caffe2::Workspace> ws(new caffe2::Workspace()); + NetDef init_net_def; + CAFFE_ENFORCE(ReadProtoFromFile(init_net_pb, &init_net_def)); + CAFFE_ENFORCE(ws->RunNetOnce(init_net_def)); + NetDef predict_net_def, predict_net_def_gpu; + CAFFE_ENFORCE(ReadProtoFromFile(predict_net_pb, &predict_net_def)); + PopulateCPUBlob(ws.get(), true, input_name, input_dims); + predict_net_def.clear_external_output(); + + predict_net_def_gpu.CopyFrom(predict_net_def); + predict_net_def_gpu.set_type("opengl"); + + for (auto i = 0; i < predict_net_def_gpu.op().size(); ++i) { + auto op = predict_net_def_gpu.mutable_op(i); + if (std::find(cpu_ops.begin(), cpu_ops.end(), op->type()) == cpu_ops.end()) { + op->mutable_device_option()->set_device_type(OPENGL); + } + } + // change the name of last op + auto index = predict_net_def_gpu.op().size() - 1; + auto last_blob = predict_net_def_gpu.op()[index].output()[0]; + auto op = predict_net_def_gpu.mutable_op(index); + auto output = op->mutable_output(0); + *output = last_blob + "_gpu"; + + compareNetResult4D(*ws, predict_net_def, predict_net_def_gpu, last_blob, last_blob + "_gpu"); + + NetBase* net = ws->CreateNet(predict_net_def); + LOG(INFO) << "[C2DEBUG] Benchmarking OpenGL Net"; + net->TEST_Benchmark(caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual); + // Test CPU + for (auto i = 0; i < predict_net_def.op().size(); ++i) { + auto op = predict_net_def.mutable_op(i); + if (std::find(cpu_ops.begin(), cpu_ops.end(), op->type()) == cpu_ops.end()) { + op->mutable_device_option()->set_device_type(CPU); + } + } + predict_net_def.set_type("simple"); + predict_net_def.set_name("cpu_net"); + net = ws->CreateNet(predict_net_def); + LOG(INFO) << "[C2DEBUG] Benchmarking CPU Net"; + net->TEST_Benchmark(caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual); + + } +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_norm_planar_yuv_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_norm_planar_yuv_op_test.cc new file mode 100644 index 0000000000..9720dfd06c --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_norm_planar_yuv_op_test.cc @@ -0,0 +1,33 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +constexpr float tol = 5.0e-2; + +TEST(OPENGLOperatorTest, NormPlanarYUV) { + + Workspace ws; + int batchSize = 1; + int channels = 8; + + PopulateCPUBlob(&ws, true, "cpu_X", {batchSize, channels, 8, 13}); + + PopulateCPUBlob(&ws, true, "cpu_mean", {1, channels}); + PopulateCPUBlob(&ws, true, "cpu_stddev", {1, channels}, 1, 0.5); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "NormalizePlanarYUV", {"cpu_X", "cpu_mean", "cpu_stddev"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "NormalizePlanarYUV", {"cpu_X", "cpu_mean", "cpu_stddev"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult4D(ws, cpu_net, gpu_net); +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_operator_test.h b/caffe2/mobile/contrib/arm-compute/test/gl_operator_test.h new file mode 100644 index 0000000000..d47f713160 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_operator_test.h @@ -0,0 +1,129 @@ +#include "caffe2/mobile/contrib/arm-compute/core/context.h" +#include <gtest/gtest.h> + +#include "caffe2/core/graph.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/workspace.h" + +namespace caffe2 { + +#define DECLARE_OPENGL_OPERATOR(_name) \ + OperatorDef _name; \ + _name.mutable_device_option()->set_device_type(OPENGL); + +#define MAKE_OPENGL_OPERATOR(_op) \ + _op->mutable_device_option()->set_device_type(OPENGL); + +#define ADD_ARG(_op, _name, _type, _val) \ + { \ + Argument *arg = _op.add_arg(); \ + arg->set_name(_name); \ + arg->set_##_type(_val); \ + } + +// Use value 1337 to generate a blob that is deterministic +// and unique at each value (for debugging purposes) +template<typename T = float> +void PopulateCPUBlob(Workspace *ws, bool random, std::string name, + std::vector<int> dims, int val = 1, int dist_shift = 0, float variance = 1) { + Blob *blob = ws->CreateBlob(name); + auto *tensor = blob->GetMutable<TensorCPU>(); + tensor->Resize(dims); + T *t_data = tensor->mutable_data<T>(); + std::random_device rd; + std::mt19937 e2(rd()); + std::normal_distribution<> dist(0 + dist_shift, variance + dist_shift); + for (int i = 0; i < tensor->size(); ++i) { + t_data[i] = T(random ? dist(e2) : (val == 1337 ? i : val)); + } +} + +template<typename T = half> +void compareNetResult(Workspace& ws, + NetDef& cpu_net, NetDef& gpu_net, + string cpu_blob="ref_Y", + string gpu_blob="gpu_Y", + double tol=0.01, + bool relative=false) { + ws.RunNetOnce(cpu_net); + ws.RunNetOnce(gpu_net); + + Blob *cpu_out = ws.GetBlob(cpu_blob); + Blob *gpu_out = ws.GetBlob(gpu_blob); + EXPECT_NE(nullptr, cpu_out); + EXPECT_NE(nullptr, gpu_out); + + auto &g_ = gpu_out->Get<GLTensor<T>>(); + TensorCPU g; + g.Resize(g_.dims()); + T *buffer = g_.map(); + + for (auto i = 0; i < g.size(); ++i) { + auto tmp = buffer[i]; + g.mutable_data<float>()[i] = tmp; + } + g_.unmap(); + + auto &t = cpu_out->Get<TensorCPU>(); + EXPECT_EQ(g.size(), t.size()); + + for (auto i = 0; i < g.size(); ++i) { + if (relative) { + EXPECT_NEAR(g.data<float>()[i], t.data<float>()[i], tol + tol * std::abs(t.data<float>()[i])) << "at index " << i; + } else{ + EXPECT_NEAR(g.data<float>()[i], t.data<float>()[i], tol) + << "at index " << i; + } + } +} + +template<typename T = half> +void compareNetResult4D(Workspace& ws, + NetDef& cpu_net, NetDef& gpu_net, + string cpu_blob="ref_Y", + string gpu_blob="gpu_Y", + double tol=0.05) { + ws.RunNetOnce(cpu_net); + ws.RunNetOnce(gpu_net); + + Blob *cpu_out = ws.GetBlob(cpu_blob); + Blob *gpu_out = ws.GetBlob(gpu_blob); + auto &g_ = gpu_out->Get<GLTensor<T>>(); + + EXPECT_NE(nullptr, cpu_out); + EXPECT_NE(nullptr, gpu_out); + + TensorCPU g; + auto &t = cpu_out->Get<TensorCPU>(); + g.Resize(g_.dims()); + T *buffer = g_.map(); + char *byte_buffer = (char *)buffer; + auto info = g_.get_underlying()->info(); + + CAFFE_ENFORCE(byte_buffer != NULL); + auto C = t.dim32(1); + auto H = t.dim32(2); + auto W = t.dim32(3); + int diff_num = 0; +#define get_elem(_a, _b, _c) \ + (half *)&byte_buffer[info->offset_element_in_bytes( \ + arm_compute::Coordinates(_a, _b, _c))] + for (auto c = 0; c < C; ++c) { + for (auto h = 0; h < H; ++h) { + for (auto w = 0; w < W; ++w) { + auto t_elem = t.data<float>()[(c * H + h) * W + w]; + auto g_elem = get_elem(w, h, c); + + if (!isnan(t_elem) && (std::abs(t_elem - float(*g_elem)) > tol + tol * std::abs(t_elem))) { + diff_num++; + } + CHECK(diff_num <= 0.03 * C*H*W); + } + } + } +#undef get_elem + g_.unmap(); +} + + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_pool_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_pool_op_test.cc new file mode 100644 index 0000000000..26ca2ba524 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_pool_op_test.cc @@ -0,0 +1,89 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, AveragePool) { + Workspace ws; + PopulateCPUBlob(&ws, true, "cpu_X", {1, 1, 8, 8}); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "AveragePool", {"cpu_X"}, {"ref_Y"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "AveragePool", {"cpu_X"}, {"gpu_Y"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); + +} + +TEST(OPENGLOperatorTest, MaxPool) { + Workspace ws; + PopulateCPUBlob(&ws, true, "cpu_X", {1, 1, 8, 8}); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "MaxPool", {"cpu_X"}, {"ref_Y"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "MaxPool", {"cpu_X"}, {"gpu_Y"}); + ADD_ARG((*def), "kernel", i, 2); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 2); + ADD_ARG((*def), "order", s, "NCHW"); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); + +} + +TEST(OPENGLOperatorTest, AverageGlobalPool) { + Workspace ws; + PopulateCPUBlob(&ws, true, "cpu_X", {1, 1, 8, 8}); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "AveragePool", {"cpu_X"}, {"ref_Y"}); + ADD_ARG((*def), "global_pooling", i, 1); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 1); + ADD_ARG((*def), "order", s, "NCHW"); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "AveragePool", {"cpu_X"}, {"gpu_Y"}); + ADD_ARG((*def), "global_pooling", i, 1); + ADD_ARG((*def), "pad", i, 0); + ADD_ARG((*def), "stride", i, 1); + ADD_ARG((*def), "order", s, "NCHW"); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); + +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_resize_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_resize_op_test.cc new file mode 100644 index 0000000000..7b3c7a9154 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_resize_op_test.cc @@ -0,0 +1,35 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, ResizeNearest) { + + Workspace ws; + float height_scale = 2; + float width_scale = 2; + int N = 1; + int CIn = 7; + + PopulateCPUBlob(&ws, true, "cpu_X", {N, CIn, 37, 89}); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "ResizeNearest", {"cpu_X"}, {"ref_Y"}); + ADD_ARG((*def), "height_scale", f, height_scale); + ADD_ARG((*def), "width_scale", f, width_scale); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "ResizeNearest", {"cpu_X"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_ARG((*def), "height_scale", f, height_scale); + ADD_ARG((*def), "width_scale", f, width_scale); + } + + compareNetResult4D(ws, cpu_net, gpu_net); + +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_softmax_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_softmax_op_test.cc new file mode 100644 index 0000000000..28b834eed1 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_softmax_op_test.cc @@ -0,0 +1,28 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, Softmax) { + + Workspace ws; + int N = 1; + int D = 128; + PopulateCPUBlob(&ws, true, "cpu_X", {N, D}, 1); + + NetDef cpu_net; + { + AddOp(&cpu_net, "Softmax", {"cpu_X"}, {"ref_Y"}); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "Softmax", {"cpu_X"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + } + + compareNetResult(ws, cpu_net, gpu_net); + +} + +} // namespace caffe2 diff --git a/caffe2/mobile/contrib/arm-compute/test/gl_spatial_batch_norm_op_test.cc b/caffe2/mobile/contrib/arm-compute/test/gl_spatial_batch_norm_op_test.cc new file mode 100644 index 0000000000..38fa2b85a5 --- /dev/null +++ b/caffe2/mobile/contrib/arm-compute/test/gl_spatial_batch_norm_op_test.cc @@ -0,0 +1,35 @@ +#include "gl_operator_test.h" + +namespace caffe2 { + +TEST(OPENGLOperatorTest, SpatialBN) { + + Workspace ws; + int batchSize = 1; + int channels = 8; + + PopulateCPUBlob(&ws, true, "cpu_X", {3, channels, 8, 13}); + PopulateCPUBlob(&ws, true, "cpu_scale", {channels}); + PopulateCPUBlob(&ws, true, "cpu_bias", {channels}); + PopulateCPUBlob(&ws, true, "cpu_mean", {channels}); + PopulateCPUBlob(&ws, true, "cpu_var", {channels}, 1, 0.5); + + NetDef cpu_net; + { + OperatorDef* def = AddOp(&cpu_net, "SpatialBN", {"cpu_X", "cpu_scale", "cpu_bias", "cpu_mean", "cpu_var"}, {"ref_Y"}); + ADD_ARG((*def), OpSchema::Arg_IsTest, i, 1); + } + + NetDef gpu_net; + gpu_net.set_type("opengl"); + { + OperatorDef* def = AddOp(&gpu_net, "SpatialBN", {"cpu_X", "cpu_scale", "cpu_bias", "cpu_mean", "cpu_var"}, {"gpu_Y"}); + MAKE_OPENGL_OPERATOR(def); + ADD_ARG((*def), OpSchema::Arg_IsTest, i, 1); + } + + compareNetResult4D(ws, cpu_net, gpu_net, "ref_Y", "gpu_Y", 0.01); + +} + +} // namespace caffe2 |