diff options
-rw-r--r-- | compiler/loco/src/Service/GraphTestcase.h | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 1f42effad..0c8fd45be 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -32,6 +32,7 @@ enum class GraphCode FeatureCodec, AvgPool2D, DepthwiseConv2D, + TransposedConv2D, MaxPool2D, TensorBroadcast, TensorConcat, @@ -299,6 +300,45 @@ private: std::unique_ptr<loco::Graph> _graph; }; +template <> class GraphTestcase<GraphCode::TransposedConv2D> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Prepare permutations + Permutation<Domain::Feature> feature_perm = make_NHWC_perm<Domain::Feature>(); + Permutation<Domain::Filter> filter_perm = make_HWCN_perm<Domain::Filter>(); + + // Build graph + _graph = make_graph(); + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(feature_perm)->node(); + const_node = graph_builder->push<ConstGenLayer>()->node(); + filter_encode_node = graph_builder->push<FilterEncodeLayer>()->perm(filter_perm)->node(); + tr_conv2d_node = graph_builder->push<TransposedConv2DLayer>()->node(); + decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(feature_perm)->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::ConstGen *const_node = nullptr; + loco::FilterEncode *filter_encode_node = nullptr; + loco::TransposedConv2D *tr_conv2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + template <> class GraphTestcase<GraphCode::MaxPool2D> final { public: |