summaryrefslogtreecommitdiff
path: root/compiler/enco/frontend/caffe/src/Layer
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/enco/frontend/caffe/src/Layer')
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/BatchNorm.cpp254
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/BatchNorm.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Concatenation.cpp138
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Concatenation.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Convolution.cpp197
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Convolution.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Eltwise.cpp134
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Eltwise.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Input.cpp60
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Input.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Pooling.cpp138
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Pooling.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/ReLU.cpp83
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/ReLU.h35
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Scale.cpp160
-rw-r--r--compiler/enco/frontend/caffe/src/Layer/Scale.h35
16 files changed, 1444 insertions, 0 deletions
diff --git a/compiler/enco/frontend/caffe/src/Layer/BatchNorm.cpp b/compiler/enco/frontend/caffe/src/Layer/BatchNorm.cpp
new file mode 100644
index 000000000..ff1e86570
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/BatchNorm.cpp
@@ -0,0 +1,254 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BatchNorm.h"
+#include "IRBuilder.h"
+
+#include <morph/caffe.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+using tensor::num_elements;
+
+namespace caffeimport
+{
+
+void BatchNormBuilder::build(const ::caffe::LayerParameter &layer,
+ GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Data *data = context->data();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+ WeightContext &weight_ctx = context->weight_ctx();
+
+ assert(layer.bottom().size() == 1);
+ assert(layer.top().size() == 1);
+
+ assert(layer.has_batch_norm_param());
+ const auto &param = layer.batch_norm_param();
+
+ // TODO Support training case
+ assert(param.use_global_stats() == true);
+
+ // Create an object for an input feature map
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ auto ifm_bag = bag_ctx.at(ifm_name);
+ auto ifm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ifm_obj->bag(ifm_bag);
+ ifm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ifm_shape)));
+
+ // Create an object for an output feature map
+ const auto ofm_name = layer.top(0);
+ const auto ofm_shape = ifm_shape;
+ auto ofm_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto ofm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ofm_obj->bag(ofm_bag);
+ ofm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Create an object for the scaled mean estimates data
+ auto mean_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto mean_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ mean_obj->bag(mean_bag);
+ mean_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Create an object for the scaled variance estimates data
+ auto variance_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto variance_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ variance_obj->bag(variance_bag);
+ variance_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ if (param.use_global_stats())
+ {
+ // Use the stored mean/variance estimates.
+ assert(weight_ctx.blob_count(layer.name()) == 3);
+
+ // Create an object for scale factor data
+ auto factor_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto factor_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ factor_obj->bag(factor_bag);
+ factor_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Fill "scale factor" data
+ {
+ data->f32()->allocate(factor_bag);
+
+ auto dst = data->f32()->weight(factor_bag);
+ // Calculate scale factor
+ auto blob = weight_ctx.blob_get(layer.name(), 2);
+ const auto scale_factor = blob->data(0) == 0 ? 0.f : 1 / blob->data(0);
+
+ for (uint32_t ch = 0; ch < factor_obj->shape().depth(); ++ch)
+ {
+ dst[ch] = scale_factor;
+ }
+ }
+
+ // Create an object for saved mean data
+ auto saved_mean_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto saved_mean_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ saved_mean_obj->bag(saved_mean_bag);
+ saved_mean_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Fill "saved mean estimates" data
+ {
+ data->f32()->allocate(saved_mean_bag);
+
+ auto dst = data->f32()->weight(saved_mean_bag);
+ auto blob = weight_ctx.blob_get(layer.name(), 0);
+
+ for (uint32_t ch = 0; ch < saved_mean_obj->shape().depth(); ++ch)
+ {
+ dst[ch] = blob->data(ch);
+ }
+ }
+
+ // Multiply scale factor to mean data
+ {
+ auto mul_op = op_builder(module).load(factor_obj).load(saved_mean_obj).mul().pop();
+ auto mul_ins = instr_builder(module).eval(mean_obj, mul_op);
+
+ blk->instr()->append(mul_ins);
+ }
+
+ // Create an object for saved variance data
+ auto saved_variance_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto saved_variance_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ saved_variance_obj->bag(saved_variance_bag);
+ saved_variance_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Fill "saved variance estimates" data
+ {
+ data->f32()->allocate(saved_variance_bag);
+
+ auto dst = data->f32()->weight(saved_variance_bag);
+ auto blob = weight_ctx.blob_get(layer.name(), 1);
+
+ for (uint32_t ch = 0; ch < saved_variance_obj->shape().depth(); ++ch)
+ {
+ dst[ch] = blob->data(ch);
+ }
+ }
+
+ // Multiply scale factor to variance data
+ {
+ auto mul_op = op_builder(module).load(factor_obj).load(saved_variance_obj).mul().pop();
+ auto mul_ins = instr_builder(module).eval(variance_obj, mul_op);
+
+ blk->instr()->append(mul_ins);
+ }
+ }
+ else
+ {
+ // TODO use_global_stats() == false case
+ }
+
+ // Create an object for subtraction
+ auto sub_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto sub_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ sub_obj->bag(sub_bag);
+ sub_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Subtract mean
+ {
+ auto sub_op = op_builder(module).load(mean_obj).load(ifm_obj).sub().pop();
+ auto sub_ins = instr_builder(module).eval(sub_obj, sub_op);
+
+ blk->instr()->append(sub_ins);
+ }
+
+ // Create an object for normalize variance data
+ auto norm_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto norm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ norm_obj->bag(norm_bag);
+ norm_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Normalize variance
+ {
+ // Create an object for epsilon data
+ auto eps_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto eps_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ eps_obj->bag(eps_bag);
+ eps_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Fill "epsilon" data
+ {
+ data->f32()->allocate(eps_bag);
+
+ auto dst = data->f32()->weight(eps_bag);
+ auto eps = param.eps();
+
+ for (uint32_t ch = 0; ch < eps_obj->shape().depth(); ++ch)
+ {
+ dst[ch] = eps;
+ }
+ }
+
+ // Create a temp object
+ auto temp_bag = module->entity()->bag()->create(ofm_shape.dim(1));
+ auto temp_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ temp_obj->bag(temp_bag);
+ temp_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ // Add epsilon to variance
+ {
+ auto add_op = op_builder(module).load(variance_obj).load(eps_obj).add().pop();
+ auto add_ins = instr_builder(module).eval(temp_obj, add_op);
+
+ blk->instr()->append(add_ins);
+ }
+
+ // Sqrt variance
+ {
+ auto load = op_builder(module).load(temp_obj).pop();
+ auto sqrt_op = module->entity()->op()->create<coco::Sqrt>();
+ sqrt_op->arg(load);
+ auto sqrt_ins = instr_builder(module).eval(norm_obj, sqrt_op);
+
+ blk->instr()->append(sqrt_ins);
+ }
+ }
+
+ // Replicate variance to input size
+ {
+ auto div_op = op_builder(module).load(norm_obj).load(sub_obj).div().pop();
+ auto div_ins = instr_builder(module).eval(ofm_obj, div_op);
+
+ blk->instr()->append(div_ins);
+ }
+
+ // Update bag and shape context
+ bag_ctx[ofm_name] = ofm_bag;
+ shape_ctx[ofm_name] = ofm_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/BatchNorm.h b/compiler/enco/frontend/caffe/src/Layer/BatchNorm.h
new file mode 100644
index 000000000..613b6687e
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/BatchNorm.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __BATCHNORM_BUILDER_H__
+#define __BATCHNORM_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class BatchNormBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __BATCHNORM_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Concatenation.cpp b/compiler/enco/frontend/caffe/src/Layer/Concatenation.cpp
new file mode 100644
index 000000000..f05f5908a
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Concatenation.cpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Concatenation.h"
+#include "IRBuilder.h"
+
+#include <coco/IR/FeatureLayouts.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+namespace caffeimport
+{
+
+void ConcatBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+
+ assert(layer.bottom().size() > 0);
+ assert(layer.top().size() == 1);
+
+ // Assume default concat axis
+ // - Please refer to http://caffe.berkeleyvision.org/tutorial/layers/concat.html for details
+ // TODO Get concat axis from concat param
+ assert(!layer.has_concat_param());
+ const uint32_t concat_axis = 1;
+
+ // Construct a vector of input objects
+ std::vector<coco::FeatureObject *> input_objects;
+
+ for (const auto &input_name : layer.bottom())
+ {
+ const auto input_shape = as_feature_shape(shape_ctx.at(input_name));
+
+ auto input_bag = bag_ctx.at(input_name);
+ auto input_feature = module->entity()->object()->create<coco::FeatureObject>();
+
+ input_feature->bag(input_bag);
+ input_feature->layout(coco::FeatureLayouts::BCHW::create(input_shape));
+
+ input_objects.emplace_back(input_feature);
+ }
+
+ coco::FeatureObject *last_feature = input_objects.at(0);
+
+ assert(last_feature != nullptr);
+ assert(last_feature->bag() != nullptr);
+
+ // Update coco IR
+ //
+ // Given a sequence of input features %in[0] / %in[1] / ... / %in[N]
+ // the below code constructs a sequence of eval instructions
+ // - Load is omitted for simplicity
+ //
+ // %out[0] = eval(ConcatF(%in[0], %in[1]))
+ // %out[1] = eval(ConcatF(%out[0], %in[2]))
+ // ...
+ // %out[N - 1] = eval(ConcatF(%out[N - 2], %in[N]))
+ //
+ for (uint32_t n = 1; n < input_objects.size(); ++n)
+ {
+ auto const left_feature = last_feature;
+ auto const left_shape = left_feature->layout()->shape();
+
+ auto right_feature = input_objects.at(n);
+ auto right_shape = right_feature->layout()->shape();
+
+ // Batch is not supported, yet
+ assert(left_feature->layout()->batch() == 1);
+ assert(right_feature->layout()->batch() == 1);
+
+ // Height and Width SHOULD BE IDENTICAL for depth concat
+ assert(left_shape.height() == right_shape.height());
+ assert(left_shape.width() == right_shape.width());
+
+ const uint32_t C = left_shape.depth() + right_shape.depth();
+ const uint32_t H = left_shape.height();
+ const uint32_t W = left_shape.width();
+
+ const nncc::core::ADT::feature::Shape out_shape{C, H, W};
+
+ auto out_bag = module->entity()->bag()->create(num_elements(out_shape));
+ auto out_feature = module->entity()->object()->create<coco::FeatureObject>();
+
+ out_feature->bag(out_bag);
+ out_feature->layout(coco::FeatureLayouts::BCHW::create(out_shape));
+
+ auto left_load = op_builder(module).load(left_feature).pop();
+ auto right_load = op_builder(module).load(right_feature).pop();
+
+ auto concat_f = module->entity()->op()->create<coco::ConcatF>();
+
+ concat_f->axis(coco::ConcatF::Axis::Depth);
+ concat_f->left(left_load);
+ concat_f->right(right_load);
+
+ auto eval = instr_builder(module).eval(out_feature, concat_f);
+
+ // Append the constructed Shuffle instruction
+ blk->instr()->append(eval);
+
+ // Update 'last_feature'
+ last_feature = out_feature;
+ }
+
+ assert(last_feature != nullptr);
+ assert(last_feature->bag() != nullptr);
+
+ // Update bag and shape context
+ auto const out_name = layer.top(0);
+ auto const out_shape = as_tensor_shape(last_feature->layout()->shape());
+ auto const out_bag = last_feature->bag();
+
+ bag_ctx[out_name] = out_bag;
+ shape_ctx[out_name] = out_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Concatenation.h b/compiler/enco/frontend/caffe/src/Layer/Concatenation.h
new file mode 100644
index 000000000..85e04000d
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Concatenation.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONCAT_BUILDER_H__
+#define __CONCAT_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class ConcatBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __CONCAT_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Convolution.cpp b/compiler/enco/frontend/caffe/src/Layer/Convolution.cpp
new file mode 100644
index 000000000..9fb096d49
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Convolution.cpp
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Convolution.h"
+#include "ConvolutionSpec.h"
+#include "Convert.h"
+#include "IRBuilder.h"
+
+#include <nncc/core/ADT/kernel/Overlay.h>
+#include <nncc/core/ADT/kernel/NCHWLayout.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+using tensor::num_elements;
+
+namespace caffeimport
+{
+
+void ConvolutionBuilder::build(const ::caffe::LayerParameter &layer,
+ GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Data *data = context->data();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+ WeightContext &weight_ctx = context->weight_ctx();
+
+ assert(layer.bottom().size() == 1);
+ assert(layer.top().size() == 1);
+
+ assert(layer.has_convolution_param());
+ const auto &param = layer.convolution_param();
+
+ ConvolutionSpec spec{param};
+ {
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ spec.ifm_shape(ifm_shape);
+ }
+
+ // NOTE The current implementation focuses on 2D convolution
+ // TODO Support general ND convolution
+ assert(spec.num_batch_axes() == 1);
+ assert(spec.num_spatial_axes() == 2);
+
+ // Create an object for an input feature map
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ auto ifm_bag = bag_ctx.at(ifm_name);
+ auto ifm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ifm_obj->bag(ifm_bag);
+ ifm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ifm_shape)));
+
+ // Create an object for an output feature map
+ const auto ofm_name = layer.top(0);
+ const auto ofm_shape = spec.ofm_shape();
+ auto ofm_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto ofm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ofm_obj->bag(ofm_bag);
+ ofm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Create an object for kernel
+ using namespace coco::KernelLayouts;
+
+ const auto ker_shape = spec.ker_shape();
+ auto ker_bag = module->entity()->bag()->create(num_elements(ker_shape));
+ auto ker_obj = module->entity()->object()->create<coco::KernelObject>();
+
+ ker_obj->bag(ker_bag);
+ ker_obj->layout(NCHW::create(as_kernel_shape(ker_shape)));
+
+ // Create a kernel overlay for the kernel object
+ data->f32()->allocate(ker_bag);
+
+ // Initialize the kernel overlay
+ assert(weight_ctx.blob_count(layer.name()) >= 1);
+ auto ker_blob = weight_ctx.blob_get(layer.name(), 0);
+
+ assert(ker_shape == caffeimport::as_tensor_shape(ker_blob));
+
+ auto ker_dst = data->f32()->access(ker_obj);
+ auto ker_src = kernel::OverlayFactory<float, kernel::NCHWLayout>::make(
+ ker_obj->shape(), ker_blob->mutable_data()->begin());
+
+ for (uint32_t n = 0; n < ker_obj->shape().count(); ++n)
+ {
+ for (uint32_t ch = 0; ch < ker_obj->shape().depth(); ++ch)
+ {
+ for (uint32_t row = 0; row < ker_obj->shape().height(); ++row)
+ {
+ for (uint32_t col = 0; col < ker_obj->shape().width(); ++col)
+ {
+ ker_dst->at(n, ch, row, col) = ker_src.at(n, ch, row, col);
+ }
+ }
+ }
+ }
+
+ // Create a Load op
+ auto load = op_builder(module).load(ifm_obj).pop();
+
+ // Create a Conv2D op
+ auto op = module->entity()->op()->create<coco::Conv2D>();
+
+ op->group(spec.group());
+
+ op->ker(ker_obj);
+ op->stride()->vertical(spec.stride(0));
+ op->stride()->horizontal(spec.stride(1));
+
+ op->pad()->top(spec.pad(0));
+ op->pad()->bottom(spec.pad(0));
+ op->pad()->left(spec.pad(1));
+ op->pad()->right(spec.pad(1));
+
+ op->arg(load);
+
+ // Create an Eval instruction
+ auto ins = instr_builder(module).eval(ofm_obj, op);
+
+ // Append the instruction to the block
+ blk->instr()->append(ins);
+
+ //
+ // coco IR allows Conv2D fused with Add, but the current implementation of enco backend
+ // is unable to process such a tree.
+ //
+ // As a workaround, caffe frontend constructs a instruction for Conv2D and Add.
+ //
+ if (param.bias_term())
+ {
+ assert(weight_ctx.blob_count(layer.name()) >= 2);
+
+ // Create Bag & Object
+ auto bias_bag = module->entity()->bag()->create(ker_shape.dim(0));
+ auto bias_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ bias_obj->bag(bias_bag);
+ bias_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(ofm_shape)));
+
+ auto added_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto added_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ added_obj->bag(added_bag);
+ added_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Create Op
+ auto bias_add = op_builder(module).load(bias_obj).load(ofm_obj).add().pop();
+
+ // Create Instr
+ auto bias_add_ins = instr_builder(module).eval(added_obj, bias_add);
+
+ // Append the instruction
+ blk->instr()->append(bias_add_ins);
+
+ // Fill bias data
+ data->f32()->allocate(bias_bag);
+
+ auto bias_span = data->f32()->weight(bias_bag);
+ auto bias_blob = weight_ctx.blob_get(layer.name(), 1);
+
+ for (uint32_t ch = 0; ch < ker_obj->shape().count(); ++ch)
+ {
+ bias_span[ch] = bias_blob->data(ch);
+ }
+
+ // Update output
+ ofm_bag = added_bag;
+ }
+
+ // Update bag and shape context
+ bag_ctx[ofm_name] = ofm_bag;
+ shape_ctx[ofm_name] = ofm_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Convolution.h b/compiler/enco/frontend/caffe/src/Layer/Convolution.h
new file mode 100644
index 000000000..a944f12a3
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Convolution.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONVOLUTION_BUILDER_H__
+#define __CONVOLUTION_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class ConvolutionBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __CONVOLUTION_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Eltwise.cpp b/compiler/enco/frontend/caffe/src/Layer/Eltwise.cpp
new file mode 100644
index 000000000..6a5d4f196
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Eltwise.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Eltwise.h"
+#include "IRBuilder.h"
+
+#include <coco/IR/FeatureLayouts.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+#include <functional>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+namespace caffeimport
+{
+
+void EltwiseBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+
+ using coco::FeatureLayouts::BCHW;
+
+ assert(layer.bottom().size() > 1);
+ assert(layer.top().size() == 1);
+
+ assert(layer.has_eltwise_param());
+ const auto &param = layer.eltwise_param();
+
+ using ::caffe::EltwiseParameter_EltwiseOp;
+ using ::caffe::EltwiseParameter_EltwiseOp_SUM;
+ using ::caffe::EltwiseParameter_EltwiseOp_PROD;
+
+ using Reducer = std::function<coco::Op *(coco::Op * lhs, coco::Op * rhs)>;
+ using ReducerRegistry = std::map<EltwiseParameter_EltwiseOp, Reducer>;
+
+ ReducerRegistry registry;
+
+ // MAX are not supported, yet
+ registry[EltwiseParameter_EltwiseOp_SUM] = [](coco::Op *lhs, coco::Op *rhs) -> coco::Op * {
+ if (lhs == nullptr)
+ {
+ assert(rhs != nullptr);
+ return rhs;
+ }
+
+ assert(lhs != nullptr && rhs != nullptr);
+ assert(lhs->module() == rhs->module());
+ assert(lhs->module() != nullptr);
+
+ auto m = lhs->module();
+ return op_builder(m).push(rhs).push(lhs).add().pop();
+ };
+
+ registry[EltwiseParameter_EltwiseOp_PROD] = [](coco::Op *lhs, coco::Op *rhs) -> coco::Op * {
+ if (lhs == nullptr)
+ {
+ assert(rhs != nullptr);
+ return rhs;
+ }
+
+ assert(lhs != nullptr && rhs != nullptr);
+ assert(lhs->module() == rhs->module());
+ assert(lhs->module() != nullptr);
+
+ auto m = lhs->module();
+ return op_builder(m).push(rhs).push(lhs).mul().pop();
+ };
+
+ // coeff is not supported, yet
+ assert(!param.coeff().size());
+
+ // Decide appropriate reduce function
+ auto reduce = registry.at(param.operation());
+
+ coco::Op *op = nullptr;
+
+ for (const auto &ifm_name : layer.bottom())
+ {
+ auto ifm_shape = shape_ctx.at(ifm_name);
+
+ // NOTE The current implementation does not work in general
+ auto ifm_bag = bag_ctx.at(ifm_name);
+ auto ifm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ifm_obj->bag(ifm_bag);
+ ifm_obj->layout(BCHW::create(as_feature_shape(ifm_shape)));
+
+ auto load = op_builder(module).load(ifm_obj).pop();
+
+ op = reduce(op, load);
+ }
+
+ assert(op != nullptr);
+
+ const auto ofm_name = layer.top(0);
+ const auto ofm_shape = shape_ctx.at(layer.bottom(0));
+
+ auto ofm_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto ofm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ofm_obj->bag(ofm_bag);
+ ofm_obj->layout(BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Create "Eval" instruction
+ auto eval = instr_builder(module).eval(ofm_obj, op);
+
+ // Append the instruction to the block
+ blk->instr()->append(eval);
+
+ // Update bag and shape context
+ bag_ctx[ofm_name] = ofm_bag;
+ shape_ctx[ofm_name] = ofm_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Eltwise.h b/compiler/enco/frontend/caffe/src/Layer/Eltwise.h
new file mode 100644
index 000000000..e717077ec
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Eltwise.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ELTWISE_BUILDER_H__
+#define __ELTWISE_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class EltwiseBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __ELTWISE_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Input.cpp b/compiler/enco/frontend/caffe/src/Layer/Input.cpp
new file mode 100644
index 000000000..39e44fa31
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Input.cpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Input.h"
+#include "Convert.h"
+
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+
+using tensor::num_elements;
+using tensor::LexicalLayout;
+
+namespace caffeimport
+{
+
+void InputBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+
+ assert(layer.has_input_param());
+ const auto &param = layer.input_param();
+
+ for (uint32_t n = 0; n < layer.top_size(); ++n)
+ {
+ const auto &name = layer.top(n);
+ const auto shape = as_tensor_shape(param.shape(n));
+
+ auto bag = module->entity()->bag()->create(num_elements(shape));
+ auto input = module->entity()->input()->create(shape);
+
+ input->bag(bag);
+ input->name(name);
+ input->reorder<LexicalLayout>();
+
+ module->input()->insert(input);
+
+ bag_ctx[name] = bag;
+ shape_ctx[name] = shape;
+ }
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Input.h b/compiler/enco/frontend/caffe/src/Layer/Input.h
new file mode 100644
index 000000000..2f464748d
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Input.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __INPUT_BUILDER_H__
+#define __INPUT_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class InputBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __INPUT_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Pooling.cpp b/compiler/enco/frontend/caffe/src/Layer/Pooling.cpp
new file mode 100644
index 000000000..36220d841
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Pooling.cpp
@@ -0,0 +1,138 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Pooling.h"
+#include "PoolingSpec.h"
+#include "IRBuilder.h"
+
+#include <coco/IR/FeatureLayouts.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+#include <functional>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+namespace caffeimport
+{
+
+void PoolingBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+
+ assert(layer.bottom().size() == 1);
+ assert(layer.top().size() == 1);
+
+ assert(layer.has_pooling_param());
+ const auto &param = layer.pooling_param();
+
+ PoolingSpec spec{param};
+ {
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ spec.ifm_shape(ifm_shape);
+ }
+
+ // Create an object for an input feature map
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ auto ifm_bag = bag_ctx.at(ifm_name);
+ auto ifm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ifm_obj->bag(ifm_bag);
+ ifm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ifm_shape)));
+
+ // Create an object for an output feature map
+ const auto ofm_name = layer.top(0);
+ const auto ofm_shape = spec.ofm_shape();
+ auto ofm_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto ofm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ofm_obj->bag(ofm_bag);
+ ofm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ using PoolingOpBuilder = std::function<coco::Op *(coco::Module * m, const PoolingSpec &spec)>;
+
+ std::map<PoolingMethod, PoolingOpBuilder> builders;
+
+ // MaxPool2D op builder
+ builders[PoolingMethod::Max] = [ifm_obj](coco::Module *module, const PoolingSpec &spec) {
+ auto load = op_builder(module).load(ifm_obj).pop();
+
+ auto op = module->entity()->op()->create<coco::MaxPool2D>();
+
+ op->arg(load);
+
+ op->window()->height(spec.window_height());
+ op->window()->width(spec.window_width());
+
+ op->stride()->vertical(spec.vertical_stride());
+ op->stride()->horizontal(spec.horizontal_stride());
+
+ op->pad()->top(spec.vertical_pad());
+ op->pad()->bottom(spec.vertical_pad());
+ op->pad()->left(spec.horizontal_pad());
+ op->pad()->right(spec.horizontal_pad());
+
+ return op;
+ };
+
+ // AvgPool2D op builder
+ builders[PoolingMethod::Avg] = [ifm_obj](coco::Module *module, const PoolingSpec &spec) {
+ auto load = op_builder(module).load(ifm_obj).pop();
+
+ auto op = module->entity()->op()->create<coco::AvgPool2D>();
+
+ op->arg(load);
+
+ // NOTE Caffe use static divisor on average pooling
+ op->divisor(coco::AvgPool2D::Divisor::Static);
+
+ op->window()->height(spec.window_height());
+ op->window()->width(spec.window_width());
+
+ op->stride()->vertical(spec.vertical_stride());
+ op->stride()->horizontal(spec.horizontal_stride());
+
+ op->pad()->top(spec.vertical_pad());
+ op->pad()->bottom(spec.vertical_pad());
+ op->pad()->left(spec.horizontal_pad());
+ op->pad()->right(spec.horizontal_pad());
+
+ return op;
+ };
+
+ // Create a pooling op
+ auto builder = builders.at(spec.method());
+ auto op = builder(module, spec);
+
+ // Create a UnitF instruction
+ auto ins = instr_builder(module).eval(ofm_obj, op);
+
+ // Append the instruction to the block
+ blk->instr()->append(ins);
+
+ // Update bag and shape context
+ bag_ctx[ofm_name] = ofm_bag;
+ shape_ctx[ofm_name] = ofm_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Pooling.h b/compiler/enco/frontend/caffe/src/Layer/Pooling.h
new file mode 100644
index 000000000..e72fd7aef
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Pooling.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __POOLING_BUILDER_H__
+#define __POOLING_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class PoolingBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __POOLING_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/ReLU.cpp b/compiler/enco/frontend/caffe/src/Layer/ReLU.cpp
new file mode 100644
index 000000000..61e206dc2
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/ReLU.cpp
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ReLU.h"
+#include "IRBuilder.h"
+
+#include <coco/IR/FeatureLayouts.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+namespace caffeimport
+{
+
+void ReLUBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+
+ assert(layer.bottom().size() == 1);
+ assert(layer.top().size() == 1);
+
+ // PReLU is not supported, yet
+ // TODO Support PReLU
+ assert(!layer.has_relu_param());
+
+ // NOTE The current implementation treats ReLU as Feature op
+ // TODO Support ReLU over general tensor
+ const auto ifm_name = layer.bottom(0);
+ const auto ifm_shape = shape_ctx.at(ifm_name);
+ auto ifm_bag = bag_ctx.at(ifm_name);
+ auto ifm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ifm_obj->bag(ifm_bag);
+ ifm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ifm_shape)));
+
+ const auto ofm_name = layer.top(0);
+ const auto ofm_shape = ifm_shape;
+ auto ofm_bag = module->entity()->bag()->create(num_elements(ofm_shape));
+ auto ofm_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ ofm_obj->bag(ofm_bag);
+ ofm_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(ofm_shape)));
+
+ // Create a Load Op
+ auto load = op_builder(module).load(ifm_obj).pop();
+
+ // Create a ReLU op
+ auto op = module->entity()->op()->create<coco::ReLU>();
+
+ op->arg(load);
+
+ // Create a Eval instruction
+ auto ins = instr_builder(module).eval(ofm_obj, op);
+
+ // Append the instruction to the block
+ blk->instr()->append(ins);
+
+ // Update bag and shape context
+ bag_ctx[ofm_name] = ofm_bag;
+ shape_ctx[ofm_name] = ofm_shape;
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/ReLU.h b/compiler/enco/frontend/caffe/src/Layer/ReLU.h
new file mode 100644
index 000000000..94836fd8e
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/ReLU.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __RELU_BUILDER_H__
+#define __RELU_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class ReLUBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __RELU_BUILDER_H__
diff --git a/compiler/enco/frontend/caffe/src/Layer/Scale.cpp b/compiler/enco/frontend/caffe/src/Layer/Scale.cpp
new file mode 100644
index 000000000..b9925978c
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Scale.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Scale.h"
+#include "IRBuilder.h"
+
+#include <coco/IR/FeatureLayouts.h>
+
+#include <morph/caffe.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+using namespace morph::caffe;
+
+namespace caffeimport
+{
+
+void ScaleBuilder::build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const
+{
+ coco::Module *module = context->module();
+ coco::Data *data = context->data();
+ coco::Block *blk = context->block();
+ std::map<std::string, tensor::Shape> &shape_ctx = context->shape_ctx();
+ std::map<std::string, coco::Bag *> &bag_ctx = context->bag_ctx();
+ WeightContext &weight_ctx = context->weight_ctx();
+
+ // TODO Support Scale layer with 2 bottoms
+ assert(layer.bottom().size() == 1);
+ assert(layer.top().size() == 1);
+
+ assert(layer.has_scale_param());
+ const auto &param = layer.scale_param();
+
+ assert(param.axis() == 1);
+ assert(!param.has_num_axes());
+
+ assert(weight_ctx.blob_count(layer.name()) >= 1);
+
+ // NOTE The shape of "Scale" output is same as that of its input
+ // NOTE The current implementation assumes that input/output is of feature type
+ // TODO Support generic tensor arguments
+ auto shape = shape_ctx.at(layer.bottom(0));
+
+ coco::Bag *last_bag = bag_ctx.at(layer.bottom(0));
+
+ // Create channel-wise multiplication
+ {
+ auto in_bag = last_bag;
+ auto in_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ in_obj->bag(in_bag);
+ in_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(shape)));
+
+ auto factor_bag = module->entity()->bag()->create(num_elements(shape));
+ auto factor_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ factor_obj->bag(factor_bag);
+ factor_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(shape)));
+
+ auto out_bag = module->entity()->bag()->create(num_elements(shape));
+ auto out_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ out_obj->bag(out_bag);
+ out_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(shape)));
+
+ auto mul_op = op_builder(module).load(factor_obj).load(in_obj).mul().pop();
+ auto mul_ins = instr_builder(module).eval(out_obj, mul_op);
+
+ blk->instr()->append(mul_ins);
+
+ // Fill "factor" data
+ {
+ data->f32()->allocate(factor_bag);
+
+ auto span = data->f32()->weight(factor_bag);
+ auto blob = weight_ctx.blob_get(layer.name(), 0);
+
+ for (uint32_t ch = 0; ch < factor_obj->shape().depth(); ++ch)
+ {
+ span[ch] = blob->data(ch);
+ }
+ }
+
+ // Update "last_bag"
+ last_bag = out_bag;
+ }
+
+ assert(last_bag != nullptr);
+
+ // Create bias addition (as channel-wise addition)
+ if (param.bias_term())
+ {
+ assert(weight_ctx.blob_count(layer.name()) >= 2);
+
+ auto in_bag = last_bag; /* Use the output of the last computation as an input */
+ auto in_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ in_obj->bag(in_bag);
+ in_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(shape)));
+
+ auto bias_bag = module->entity()->bag()->create(num_elements(shape));
+ auto bias_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ bias_obj->bag(bias_bag);
+ bias_obj->layout(coco::FeatureLayouts::BC::create(as_feature_shape(shape)));
+
+ auto out_bag = module->entity()->bag()->create(num_elements(shape));
+ auto out_obj = module->entity()->object()->create<coco::FeatureObject>();
+
+ out_obj->bag(out_bag);
+ out_obj->layout(coco::FeatureLayouts::BCHW::create(as_feature_shape(shape)));
+
+ auto add_op = op_builder(module).load(bias_obj).load(in_obj).add().pop();
+ auto add_ins = instr_builder(module).eval(out_obj, add_op);
+
+ blk->instr()->append(add_ins);
+
+ // Fill bias data
+ {
+ data->f32()->allocate(bias_bag);
+
+ auto bias_span = data->f32()->weight(bias_bag);
+ auto bias_blob = weight_ctx.blob_get(layer.name(), 1);
+
+ for (uint32_t ch = 0; ch < bias_obj->shape().depth(); ++ch)
+ {
+ bias_span[ch] = bias_blob->data(ch);
+ }
+ }
+
+ // Update "last_bag"
+ last_bag = out_bag;
+ }
+
+ // Update bag and shape context
+ {
+ const auto &out_name = layer.top(0);
+ const auto &out_bag = last_bag;
+ const auto &out_shape = shape;
+
+ bag_ctx[out_name] = out_bag;
+ shape_ctx[out_name] = out_shape;
+ }
+}
+
+} // namespace caffeimport
diff --git a/compiler/enco/frontend/caffe/src/Layer/Scale.h b/compiler/enco/frontend/caffe/src/Layer/Scale.h
new file mode 100644
index 000000000..491cc31cf
--- /dev/null
+++ b/compiler/enco/frontend/caffe/src/Layer/Scale.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __SCALE_BUILDER_H__
+#define __SCALE_BUILDER_H__
+
+#include "GraphBuilder.h"
+
+#include "Context.h"
+
+namespace caffeimport
+{
+
+class ScaleBuilder final : public GraphBuilder
+{
+public:
+ void build(const ::caffe::LayerParameter &layer, GraphBuilderContext *context) const override;
+};
+
+} // namespace caffeimport
+
+#endif // __SCALE_BUILDER_H__