summaryrefslogtreecommitdiff
path: root/compiler/moco-tf
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco-tf')
-rw-r--r--compiler/moco-tf/CMakeLists.txt26
-rw-r--r--compiler/moco-tf/include/moco/tf/Frontend.h66
-rw-r--r--compiler/moco-tf/include/moco/tf/Names.h93
-rw-r--r--compiler/moco-tf/proto/CMakeLists.txt22
-rw-r--r--compiler/moco-tf/requires.cmake5
-rw-r--r--compiler/moco-tf/src/Annotations/ConcatData.h44
-rw-r--r--compiler/moco-tf/src/Annotations/PadData.h51
-rw-r--r--compiler/moco-tf/src/Annotations/PaddingData.h49
-rw-r--r--compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp264
-rw-r--r--compiler/moco-tf/src/Annotations/ShapeInferenceData.h75
-rw-r--r--compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp174
-rw-r--r--compiler/moco-tf/src/Annotations/StrideData.h48
-rw-r--r--compiler/moco-tf/src/Annotations/WindowData.h46
-rw-r--r--compiler/moco-tf/src/CanonicalEltwiseInputConnector.cpp1
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp24
-rw-r--r--compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp116
-rw-r--r--compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp86
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp59
-rw-r--r--compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp371
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp110
-rw-r--r--compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp105
-rw-r--r--compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp29
-rw-r--r--compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp114
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h47
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp100
-rw-r--r--compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp102
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp47
-rw-r--r--compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp62
-rw-r--r--compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp37
-rw-r--r--compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp115
-rw-r--r--compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h42
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp36
-rw-r--r--compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp74
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp27
-rw-r--r--compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h7
-rw-r--r--compiler/moco-tf/src/Canonicalizer.cpp23
-rw-r--r--compiler/moco-tf/src/CodecHelper.h74
-rw-r--r--compiler/moco-tf/src/Dialect/TFDialect.cpp (renamed from compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp)8
-rw-r--r--compiler/moco-tf/src/Dialect/TFDialect.h46
-rw-r--r--compiler/moco-tf/src/Dialect/TFDialect.test.cpp29
-rw-r--r--compiler/moco-tf/src/Dialect/TFNode.cpp (renamed from compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp)14
-rw-r--r--compiler/moco-tf/src/Dialect/TFNode.h23
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodeDecl.h96
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodeImpl.h67
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodeVisitor.forward.h33
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodeVisitor.h82
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodes.h46
-rw-r--r--compiler/moco-tf/src/Dialect/TFNodes.lst34
-rw-r--r--compiler/moco-tf/src/Dialect/TFOpcode.h38
-rw-r--r--compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp58
-rw-r--r--compiler/moco-tf/src/Dialect/TFShapeInferenceRule.h39
-rw-r--r--compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp101
-rw-r--r--compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h40
-rw-r--r--compiler/moco-tf/src/Dialect/VariadicArityNode.h80
-rw-r--r--compiler/moco-tf/src/Dialect/VariadicArityNode.test.cpp55
-rw-r--r--compiler/moco-tf/src/Frontend.cpp184
-rw-r--r--compiler/moco-tf/src/Frontend.test.cpp7
-rw-r--r--compiler/moco-tf/src/GraphBuilder.h43
-rw-r--r--compiler/moco-tf/src/GraphBuilderContext.cpp82
-rw-r--r--compiler/moco-tf/src/GraphBuilderContext.h147
-rw-r--r--compiler/moco-tf/src/GraphBuilderContext.test.cpp75
-rw-r--r--compiler/moco-tf/src/GraphBuilderRegistry.h102
-rw-r--r--compiler/moco-tf/src/IR/TFAdd.h59
-rw-r--r--compiler/moco-tf/src/IR/TFAdd.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFAvgPool.h104
-rw-r--r--compiler/moco-tf/src/IR/TFAvgPool.test.cpp35
-rw-r--r--compiler/moco-tf/src/IR/TFBiasAdd.h71
-rw-r--r--compiler/moco-tf/src/IR/TFBiasAdd.test.cpp33
-rw-r--r--compiler/moco-tf/src/IR/TFConcatV2.h94
-rw-r--r--compiler/moco-tf/src/IR/TFConcatV2.test.cpp35
-rw-r--r--compiler/moco-tf/src/IR/TFConst.cpp66
-rw-r--r--compiler/moco-tf/src/IR/TFConst.h86
-rw-r--r--compiler/moco-tf/src/IR/TFConst.test.cpp65
-rw-r--r--compiler/moco-tf/src/IR/TFConv2D.h58
-rw-r--r--compiler/moco-tf/src/IR/TFConv2D.test.cpp35
-rw-r--r--compiler/moco-tf/src/IR/TFConv2DBackpropInput.h105
-rw-r--r--compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.h65
-rw-r--r--compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.test.cpp35
-rw-r--r--compiler/moco-tf/src/IR/TFFusedBatchNorm.h58
-rw-r--r--compiler/moco-tf/src/IR/TFFusedBatchNorm.test.cpp36
-rw-r--r--compiler/moco-tf/src/IR/TFIdentity.h55
-rw-r--r--compiler/moco-tf/src/IR/TFIdentity.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFMaxPool.h104
-rw-r--r--compiler/moco-tf/src/IR/TFMaxPool.test.cpp35
-rw-r--r--compiler/moco-tf/src/IR/TFMean.h52
-rw-r--r--compiler/moco-tf/src/IR/TFMean.test.cpp33
-rw-r--r--compiler/moco-tf/src/IR/TFMul.h59
-rw-r--r--compiler/moco-tf/src/IR/TFMul.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFRealDiv.h59
-rw-r--r--compiler/moco-tf/src/IR/TFRealDiv.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFRelu.h40
-rw-r--r--compiler/moco-tf/src/IR/TFRelu.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFRelu6.h40
-rw-r--r--compiler/moco-tf/src/IR/TFRelu6.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFReshape.h57
-rw-r--r--compiler/moco-tf/src/IR/TFReshape.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFRsqrt.h55
-rw-r--r--compiler/moco-tf/src/IR/TFRsqrt.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFShape.h63
-rw-r--r--compiler/moco-tf/src/IR/TFShape.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFSoftmax.h40
-rw-r--r--compiler/moco-tf/src/IR/TFSoftmax.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFSqrt.h55
-rw-r--r--compiler/moco-tf/src/IR/TFSqrt.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFSquaredDifference.h59
-rw-r--r--compiler/moco-tf/src/IR/TFSquaredDifference.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFSqueeze.h74
-rw-r--r--compiler/moco-tf/src/IR/TFSqueeze.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFStopGradient.h55
-rw-r--r--compiler/moco-tf/src/IR/TFStopGradient.test.cpp31
-rw-r--r--compiler/moco-tf/src/IR/TFSub.h59
-rw-r--r--compiler/moco-tf/src/IR/TFSub.test.cpp32
-rw-r--r--compiler/moco-tf/src/IR/TFTanh.h40
-rw-r--r--compiler/moco-tf/src/IR/TFTanh.test.cpp31
-rw-r--r--compiler/moco-tf/src/ImportTarget.h26
-rw-r--r--compiler/moco-tf/src/Importer.cpp290
-rw-r--r--compiler/moco-tf/src/Importer.h (renamed from compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h)32
-rw-r--r--compiler/moco-tf/src/Importer.test.cpp148
-rw-r--r--compiler/moco-tf/src/Knob.lst23
-rw-r--r--compiler/moco-tf/src/LogHelper.cpp16
-rw-r--r--compiler/moco-tf/src/LogHelper.h15
-rw-r--r--compiler/moco-tf/src/Op/Add.cpp107
-rw-r--r--compiler/moco-tf/src/Op/Add.test.cpp136
-rw-r--r--compiler/moco-tf/src/Op/AvgPool.cpp325
-rw-r--r--compiler/moco-tf/src/Op/AvgPool.h52
-rw-r--r--compiler/moco-tf/src/Op/AvgPool.test.cpp211
-rw-r--r--compiler/moco-tf/src/Op/BiasAdd.cpp240
-rw-r--r--compiler/moco-tf/src/Op/BiasAdd.h52
-rw-r--r--compiler/moco-tf/src/Op/BiasAdd.test.cpp301
-rw-r--r--compiler/moco-tf/src/Op/COpCall.cpp19
-rw-r--r--compiler/moco-tf/src/Op/COpCall.h5
-rw-r--r--compiler/moco-tf/src/Op/COpCall.test.cpp14
-rw-r--r--compiler/moco-tf/src/Op/Concat.cpp276
-rw-r--r--compiler/moco-tf/src/Op/Concat.h52
-rw-r--r--compiler/moco-tf/src/Op/Concat.test.cpp449
-rw-r--r--compiler/moco-tf/src/Op/Const.cpp359
-rw-r--r--compiler/moco-tf/src/Op/Const.h52
-rw-r--r--compiler/moco-tf/src/Op/Const.test.cpp464
-rw-r--r--compiler/moco-tf/src/Op/Conv2D.cpp322
-rw-r--r--compiler/moco-tf/src/Op/Conv2D.h52
-rw-r--r--compiler/moco-tf/src/Op/Conv2D.test.cpp513
-rw-r--r--compiler/moco-tf/src/Op/DepthwiseConv2dNative.cpp155
-rw-r--r--compiler/moco-tf/src/Op/DepthwiseConv2dNative.test.cpp219
-rw-r--r--compiler/moco-tf/src/Op/FusedBatchNorm.cpp121
-rw-r--r--compiler/moco-tf/src/Op/FusedBatchNorm.test.cpp223
-rw-r--r--compiler/moco-tf/src/Op/Identity.cpp185
-rw-r--r--compiler/moco-tf/src/Op/Identity.h52
-rw-r--r--compiler/moco-tf/src/Op/MaxPool.cpp297
-rw-r--r--compiler/moco-tf/src/Op/MaxPool.h52
-rw-r--r--compiler/moco-tf/src/Op/MaxPool.test.cpp299
-rw-r--r--compiler/moco-tf/src/Op/Mul.cpp107
-rw-r--r--compiler/moco-tf/src/Op/Mul.test.cpp136
-rw-r--r--compiler/moco-tf/src/Op/Placeholder.cpp100
-rw-r--r--compiler/moco-tf/src/Op/Placeholder.test.cpp88
-rw-r--r--compiler/moco-tf/src/Op/RealDiv.cpp109
-rw-r--r--compiler/moco-tf/src/Op/RealDiv.test.cpp136
-rw-r--r--compiler/moco-tf/src/Op/Relu.cpp159
-rw-r--r--compiler/moco-tf/src/Op/Relu.h51
-rw-r--r--compiler/moco-tf/src/Op/Relu.test.cpp133
-rw-r--r--compiler/moco-tf/src/Op/Relu6.cpp149
-rw-r--r--compiler/moco-tf/src/Op/Relu6.h53
-rw-r--r--compiler/moco-tf/src/Op/Relu6.test.cpp133
-rw-r--r--compiler/moco-tf/src/Op/Reshape.cpp119
-rw-r--r--compiler/moco-tf/src/Op/Reshape.test.cpp108
-rw-r--r--compiler/moco-tf/src/Op/Rsqrt.cpp103
-rw-r--r--compiler/moco-tf/src/Op/Rsqrt.test.cpp103
-rw-r--r--compiler/moco-tf/src/Op/Shape.cpp118
-rw-r--r--compiler/moco-tf/src/Op/Shape.test.cpp94
-rw-r--r--compiler/moco-tf/src/Op/Softmax.cpp104
-rw-r--r--compiler/moco-tf/src/Op/Softmax.test.cpp94
-rw-r--r--compiler/moco-tf/src/Op/Sqrt.cpp102
-rw-r--r--compiler/moco-tf/src/Op/Sqrt.test.cpp103
-rw-r--r--compiler/moco-tf/src/Op/SquaredDifference.cpp114
-rw-r--r--compiler/moco-tf/src/Op/SquaredDifference.test.cpp136
-rw-r--r--compiler/moco-tf/src/Op/Squeeze.cpp121
-rw-r--r--compiler/moco-tf/src/Op/Squeeze.test.cpp162
-rw-r--r--compiler/moco-tf/src/Op/StopGradient.cpp105
-rw-r--r--compiler/moco-tf/src/Op/StopGradient.test.cpp100
-rw-r--r--compiler/moco-tf/src/Op/Sub.cpp107
-rw-r--r--compiler/moco-tf/src/Op/Sub.test.cpp136
-rw-r--r--compiler/moco-tf/src/Op/Tanh.cpp102
-rw-r--r--compiler/moco-tf/src/Op/Tanh.test.cpp103
-rw-r--r--compiler/moco-tf/src/Phase.cpp107
-rw-r--r--compiler/moco-tf/src/Phase.h78
-rw-r--r--compiler/moco-tf/src/SimpleNodeTransform.h64
-rw-r--r--compiler/moco-tf/src/SimpleNodeTransform.test.cpp56
-rw-r--r--compiler/moco-tf/src/TFEltwiseBinaryCanonicalzeHelper.h60
-rw-r--r--compiler/moco-tf/src/TFFormattedGraph.cpp93
-rw-r--r--compiler/moco-tf/src/TFOptimizer.cpp22
-rw-r--r--compiler/moco-tf/src/TFReduceCanonicalzeHelper.h118
-rw-r--r--compiler/moco-tf/src/TestHelper.cpp47
-rw-r--r--compiler/moco-tf/src/TestHelper.h39
-rw-r--r--compiler/moco-tf/src/TestHelper.test.cpp121
-rw-r--r--compiler/moco-tf/src/Transforms.h8
-rw-r--r--compiler/moco-tf/src/Transforms/ClearAnnotTransform.cpp63
-rw-r--r--compiler/moco-tf/src/Transforms/ClearAnnotTransform.h (renamed from compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h)17
-rw-r--r--compiler/moco-tf/src/Transforms/ClearAnnotTransform.test.cpp29
-rw-r--r--compiler/moco-tf/src/Transforms/FixShapeTransform.cpp1539
-rw-r--r--compiler/moco-tf/src/Transforms/FixShapeTransform.h (renamed from compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h)19
-rw-r--r--compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp227
-rw-r--r--compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp547
-rw-r--r--compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.h44
-rw-r--r--compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.cpp67
-rw-r--r--compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.h50
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp126
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveConstantShape.h (renamed from compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h)17
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp259
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.h44
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp232
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.cpp157
-rw-r--r--compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.h (renamed from compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h)20
-rw-r--r--compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp7
-rw-r--r--compiler/moco-tf/src/Transforms/TypeInferencePass.cpp7
241 files changed, 19355 insertions, 1895 deletions
diff --git a/compiler/moco-tf/CMakeLists.txt b/compiler/moco-tf/CMakeLists.txt
index 5516388a4..9ca5777ad 100644
--- a/compiler/moco-tf/CMakeLists.txt
+++ b/compiler/moco-tf/CMakeLists.txt
@@ -1,6 +1,16 @@
-if(NOT TARGET mio_tf)
+nncc_find_package(Protobuf QUIET)
+# TensorFlowSource package is used to use ~.proto files
+nncc_find_package(TensorFlowSource EXACT 1.12 QUIET)
+
+if(NOT Protobuf_FOUND)
+ return()
+endif(NOT Protobuf_FOUND)
+
+if(NOT TensorFlowSource_FOUND)
return()
-endif(NOT TARGET mio_tf)
+endif(NOT TensorFlowSource_FOUND)
+
+add_subdirectory(proto)
file(GLOB_RECURSE SOURCES "src/*.cpp")
file(GLOB_RECURSE TESTS "src/*.test.cpp")
@@ -9,17 +19,13 @@ list(REMOVE_ITEM SOURCES ${TESTS})
add_library(moco_tf_frontend SHARED ${SOURCES})
target_include_directories(moco_tf_frontend PRIVATE src)
target_include_directories(moco_tf_frontend PUBLIC include)
+target_link_libraries(moco_tf_frontend PUBLIC moco_tf_proto)
target_link_libraries(moco_tf_frontend PUBLIC loco)
-target_link_libraries(moco_tf_frontend PUBLIC moco_lang)
-target_link_libraries(moco_tf_frontend PUBLIC moco_import)
-target_link_libraries(moco_tf_frontend PUBLIC moco_pass)
-target_link_libraries(moco_tf_frontend PUBLIC mio_tf)
-target_link_libraries(moco_tf_frontend PRIVATE moco_service)
-target_link_libraries(moco_tf_frontend PRIVATE moco_support)
target_link_libraries(moco_tf_frontend PRIVATE bino)
target_link_libraries(moco_tf_frontend PRIVATE fipe)
target_link_libraries(moco_tf_frontend PRIVATE locop)
target_link_libraries(moco_tf_frontend PRIVATE stdex)
+target_link_libraries(moco_tf_frontend PRIVATE cwrap)
target_link_libraries(moco_tf_frontend PRIVATE moco_log)
target_link_libraries(moco_tf_frontend PRIVATE pepper_str)
target_link_libraries(moco_tf_frontend PRIVATE pepper_strcast)
@@ -27,14 +33,12 @@ target_link_libraries(moco_tf_frontend PRIVATE locomotiv)
target_link_libraries(moco_tf_frontend PRIVATE plier_tf)
target_link_libraries(moco_tf_frontend PRIVATE locoex_customop)
target_link_libraries(moco_tf_frontend PRIVATE logo)
-target_link_libraries(moco_tf_frontend PRIVATE oops)
-install(TARGETS moco_tf_frontend DESTINATION lib)
if(NOT ENABLE_TEST)
return()
endif(NOT ENABLE_TEST)
-nnas_find_package(GTest REQUIRED)
+nncc_find_package(GTest REQUIRED)
add_executable(moco_tf_frontend_test ${TESTS})
target_include_directories(moco_tf_frontend_test PRIVATE src)
diff --git a/compiler/moco-tf/include/moco/tf/Frontend.h b/compiler/moco-tf/include/moco/tf/Frontend.h
index 6914fdd38..4507a76d4 100644
--- a/compiler/moco-tf/include/moco/tf/Frontend.h
+++ b/compiler/moco-tf/include/moco/tf/Frontend.h
@@ -17,17 +17,72 @@
#ifndef __MOCO_TENSORFLOW_FRONTEND_H__
#define __MOCO_TENSORFLOW_FRONTEND_H__
-#include <moco/Import/ModelSignature.h>
+#include <moco/tf/Names.h>
#include <loco.h>
+#include <angkor/TensorShape.h>
#include <tensorflow/core/framework/graph.pb.h>
+#include <istream>
+#include <memory>
+#include <string>
+#include <vector>
+
namespace moco
{
namespace tf
{
+using TensorShape = angkor::TensorShape;
+
+/**
+ * @brief Class to store information to run a model. Normally this info comes from users
+ * via CLI params or configuration file.
+ */
+struct ModelSignature
+{
+public:
+ void add_input(const TensorName &input) { _inputs.push_back(input); }
+ void add_input(const TensorName &&input) { _inputs.push_back(input); }
+ void add_output(const TensorName &output) { _outputs.push_back(output); }
+ void add_output(const TensorName &&output) { _outputs.push_back(output); }
+
+ const std::vector<TensorName> &inputs() const { return _inputs; }
+ const std::vector<TensorName> &outputs() const { return _outputs; }
+
+ /**
+ * @brief Adds customop op type (not name of node) provided from user
+ */
+ void add_customop(const std::string &op);
+ const std::vector<std::string> &customops() const { return _customops; }
+
+ /**
+ * @brief Adds node name and its shape provided from user
+ */
+ void shape(const std::string &node_name, const TensorShape &shape);
+ const TensorShape *shape(const std::string &node_name) const;
+
+ /**
+ * @brief Adds node name and its dtype provided from user
+ */
+ void dtype(const std::string &node_name, loco::DataType dtype);
+ loco::DataType dtype(const std::string &node_name) const;
+
+private:
+ std::vector<TensorName> _inputs; // graph inputs
+ std::vector<TensorName> _outputs; // graph outputs
+
+ // For custom op types passed from user (e.g., via CLI)
+ std::vector<std::string> _customops;
+
+ // For and node names and shapes passed from user (e.g., via CLI)
+ std::map<std::string, TensorShape> _shapes;
+
+ // For and node names and dtype passed from user (e.g., via CLI)
+ std::map<std::string, loco::DataType> _dtypes;
+};
+
class Frontend
{
public:
@@ -48,6 +103,15 @@ private:
std::unique_ptr<loco::Graph> import(const ModelSignature &, tensorflow::GraphDef &) const;
};
+/**
+ * @brief This will do internal memory cleanup of the graph returned from
+ * Frontend::load() method.
+ *
+ * @note Calling this can be omitted as all allocations will be freed at
+ * termination but memory usage can be unnecessary higher.
+ */
+void cleanup(loco::Graph *graph);
+
} // namespace tf
} // namespace moco
diff --git a/compiler/moco-tf/include/moco/tf/Names.h b/compiler/moco-tf/include/moco/tf/Names.h
new file mode 100644
index 000000000..fc612d27d
--- /dev/null
+++ b/compiler/moco-tf/include/moco/tf/Names.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_NAMES_H__
+#define __MOCO_TF_NAMES_H__
+
+#include <string>
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+struct TensorName final
+{
+public:
+ /**
+ * @brief Constructor
+ *
+ * @note If tensor_name does not have ":index", this constructor adds ":0" by default
+ */
+ explicit TensorName(const std::string &tensor_name)
+ {
+ if (tensor_name.find(":") != std::string::npos) // tensor_name is a form of letter:0
+ {
+ _name.assign(tensor_name);
+ }
+ else
+ {
+ _name.assign(tensor_name + ":0"); // if it does not have ":index", adds ":0" by default
+ }
+ }
+
+ explicit TensorName(const std::string &node_name, const int tensor_index)
+ {
+ if (node_name.find(":") != std::string::npos) // tensor_name is already a form of name:0
+ {
+ throw std::runtime_error("Node name has already tensor index:" + node_name);
+ }
+ else
+ {
+ _name.assign(node_name + ":" + std::to_string(tensor_index));
+ }
+ }
+
+ const std::string &name() const { return _name; }
+
+ /**
+ * @brief Returns node name from tensor name by removing, e.g., ":0"
+ */
+ const std::string nodeName() const
+ {
+ auto index = _name.find(":");
+
+ if (index != std::string::npos)
+ return _name.substr(0, index);
+ else
+ throw std::runtime_error{"Tensor name should be a name:number format: " + _name};
+ };
+
+private:
+ std::string _name;
+};
+
+/**
+ * @brief To use TensorName as a key in std::map, this struct defines how to compare two TensorNames
+ */
+struct TensorNameCompare
+{
+ bool operator()(const TensorName &lhs, const TensorName &rhs) const
+ {
+ return lhs.name() < rhs.name();
+ }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_NAMES_H__
diff --git a/compiler/moco-tf/proto/CMakeLists.txt b/compiler/moco-tf/proto/CMakeLists.txt
new file mode 100644
index 000000000..2053a6893
--- /dev/null
+++ b/compiler/moco-tf/proto/CMakeLists.txt
@@ -0,0 +1,22 @@
+# Minimal Protocol Buffer specification for GraphDef file (.pb) encoding/decoding
+unset(PROTO_FILES)
+list(APPEND PROTO_FILES tensorflow/core/framework/versions.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/resource_handle.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/types.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/tensor.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/tensor_shape.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/attr_value.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/op_def.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/node_def.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/function.proto)
+list(APPEND PROTO_FILES tensorflow/core/framework/graph.proto)
+
+Protobuf_Generate(GRAPHDEF_PROTO
+ "${CMAKE_CURRENT_BINARY_DIR}/generated"
+ "${TensorFlowSource_DIR}"
+ ${PROTO_FILES})
+
+add_library(moco_tf_proto STATIC ${GRAPHDEF_PROTO_SOURCES})
+set_target_properties(moco_tf_proto PROPERTIES POSITION_INDEPENDENT_CODE ON)
+target_include_directories(moco_tf_proto PUBLIC ${GRAPHDEF_PROTO_INCLUDE_DIRS})
+target_link_libraries(moco_tf_proto PUBLIC libprotobuf)
diff --git a/compiler/moco-tf/requires.cmake b/compiler/moco-tf/requires.cmake
index 3e0fabee9..10e4774e7 100644
--- a/compiler/moco-tf/requires.cmake
+++ b/compiler/moco-tf/requires.cmake
@@ -1,14 +1,11 @@
require("fipe")
require("loco")
-require("moco")
require("locop")
+require("cwrap")
require("stdex")
require("moco-log")
require("pepper-strcast")
require("locomotiv")
-require("mio-tf")
require("plier-tf")
require("locoex-customop")
require("logo")
-require("oops")
-require("bino")
diff --git a/compiler/moco-tf/src/Annotations/ConcatData.h b/compiler/moco-tf/src/Annotations/ConcatData.h
new file mode 100644
index 000000000..4c8e5fa5e
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/ConcatData.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_CONCAT_DATA_H__
+#define __MOCO_TF_CONCAT_DATA_H__
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief ConcatData holds temporary axis attribute of Concat while building the graph
+*/
+class ConcatData : public loco::NodeAnnotation
+{
+public:
+ ConcatData(int32_t axis) : _axis(axis) {}
+
+ int32_t axis(void) const { return _axis; }
+
+private:
+ int32_t _axis;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_CONCAT_DATA_H__
diff --git a/compiler/moco-tf/src/Annotations/PadData.h b/compiler/moco-tf/src/Annotations/PadData.h
new file mode 100644
index 000000000..887a1c503
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/PadData.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PAD_DATA_H__
+#define __MOCO_TF_PAD_DATA_H__
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief PadData holds temporary 'pad' attribute of TFConv2D
+ *
+ * @note This holds the same pad attribute that exist in Canonical Conv2D
+ * to simplify Canonicalizing step of TFConv2D to Conv2D conversion.
+ * Values of 'pad' will be calculated in FixPaddingTransformation.
+ * PadData holds Padding2D where PaddingData holds 'padding' as a string.
+ */
+class PadData : public loco::NodeAnnotation
+{
+public:
+ PadData() = default;
+
+public:
+ const loco::Padding2D *pad(void) const { return &_pad; }
+ loco::Padding2D *pad(void) { return &_pad; }
+
+private:
+ loco::Padding2D _pad;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PAD_DATA_H__
diff --git a/compiler/moco-tf/src/Annotations/PaddingData.h b/compiler/moco-tf/src/Annotations/PaddingData.h
new file mode 100644
index 000000000..e875cca7d
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/PaddingData.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PADDING_DATA_H__
+#define __MOCO_TF_PADDING_DATA_H__
+
+#include <loco.h>
+
+#include <string>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief PaddingData holds temporary padding attribute
+ *
+ * @note Related nodes are AvgPool2D, MaxPool2D, Conv2D and maybe others
+ * PaddingData holds 'padding' as a string where PadData holds Padding2D
+ */
+class PaddingData : public loco::NodeAnnotation
+{
+public:
+ PaddingData(const std::string &padding) : _padding(padding) {}
+
+ const std::string &padding(void) const { return _padding; }
+
+private:
+ std::string _padding;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PADDING_DATA_H__
diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp b/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp
new file mode 100644
index 000000000..e6ffa98ae
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.cpp
@@ -0,0 +1,264 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferenceData.h"
+
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+loco::TensorShape ShapeInferenceData::tensor_shape(void) const
+{
+ assert(_domain == loco::Domain::Tensor);
+
+ loco::TensorShape shape;
+
+ shape.rank(rank());
+ for (uint32_t r = 0; r < rank(); ++r)
+ {
+ if (dim(r).known())
+ shape.dim(r) = dim(r).value();
+ else
+ shape.dim(r).unset();
+ }
+
+ return shape;
+}
+
+loco::FeatureShape ShapeInferenceData::feature_shape(void) const
+{
+ assert(_domain == loco::Domain::Feature);
+
+ loco::FeatureShape shape;
+
+ if (rank() != 4)
+ throw std::runtime_error("Feature should be rank 4");
+
+ shape.count() = dim(0);
+ shape.height() = dim(1);
+ shape.width() = dim(2);
+ shape.depth() = dim(3);
+
+ return shape;
+}
+
+loco::FilterShape ShapeInferenceData::filter_shape(void) const
+{
+ assert(_domain == loco::Domain::Filter);
+
+ loco::FilterShape shape;
+
+ if (rank() != 4)
+ throw std::runtime_error("Filter should be rank 4");
+
+ shape.count() = dim(0);
+ shape.height() = dim(1);
+ shape.width() = dim(2);
+ shape.depth() = dim(3);
+
+ return shape;
+}
+
+loco::DepthwiseFilterShape ShapeInferenceData::depthwisefilter_shape(void) const
+{
+ assert(_domain == loco::Domain::DepthwiseFilter);
+
+ loco::DepthwiseFilterShape shape;
+
+ if (rank() != 4)
+ throw std::runtime_error("DepthwiseFilter should be rank 4");
+
+ shape.height() = dim(0);
+ shape.width() = dim(1);
+ shape.depth() = dim(2);
+ shape.multiplier() = dim(3);
+
+ return shape;
+}
+
+loco::BiasShape ShapeInferenceData::bias_shape(void) const
+{
+ assert(_domain == loco::Domain::Bias);
+
+ loco::BiasShape shape;
+
+ // Note: this may change when loco::BiasShape becomes available
+ shape.length() = dim(0).value();
+
+ return shape;
+}
+
+void ShapeInferenceData::tensor_shape(const loco::TensorShape &shape)
+{
+ _domain = loco::Domain::Tensor;
+
+ rank(shape.rank());
+ for (uint32_t r = 0; r < shape.rank(); ++r)
+ {
+ if (shape.dim(r).known())
+ dim(r) = shape.dim(r).value();
+ else
+ dim(r).unset();
+ }
+}
+
+void ShapeInferenceData::feature_shape(const loco::FeatureShape &shape)
+{
+ _domain = loco::Domain::Feature;
+
+ rank(4);
+ dim(0) = shape.count();
+ dim(1) = shape.height();
+ dim(2) = shape.width();
+ dim(3) = shape.depth();
+}
+
+void ShapeInferenceData::filter_shape(const loco::FilterShape &shape)
+{
+ _domain = loco::Domain::Filter;
+
+ rank(4);
+ dim(0) = shape.count();
+ dim(1) = shape.height();
+ dim(2) = shape.width();
+ dim(3) = shape.depth();
+}
+
+void ShapeInferenceData::depthwisefilter_shape(const loco::DepthwiseFilterShape &shape)
+{
+ _domain = loco::Domain::DepthwiseFilter;
+
+ rank(4);
+ dim(0) = shape.height();
+ dim(1) = shape.width();
+ dim(2) = shape.depth();
+ dim(3) = shape.multiplier();
+}
+
+void ShapeInferenceData::bias_shape(const loco::BiasShape &shape)
+{
+ _domain = loco::Domain::Bias;
+
+ // Note: this may change when loco::BiasShape becomes available
+ rank(1);
+ dim(0) = shape.length();
+}
+
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &feature_shape,
+ const TFDataLayout &data_layout)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ if (data_layout == "NHWC")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.height();
+ tensor_shape.dim(2) = feature_shape.width();
+ tensor_shape.dim(3) = feature_shape.depth();
+ }
+ else if (data_layout == "NCHW")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.depth();
+ tensor_shape.dim(2) = feature_shape.height();
+ tensor_shape.dim(3) = feature_shape.width();
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ throw std::runtime_error("as_tensor_shape: only supports NHWC or NCHW");
+ }
+
+ shapedata.tensor_shape(tensor_shape);
+}
+
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+ const TFDataLayout &data_layout)
+{
+ if (shapedata.domain() == loco::Domain::Feature)
+ return shapedata.feature_shape();
+
+ loco::FeatureShape feature_shape;
+
+ // only convert from tensor to feature
+ if (shapedata.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("as_feature_shape: domain is not tensor");
+ }
+ if (shapedata.rank() != 4)
+ {
+ throw std::runtime_error("as_feature_shape: rank is not 4");
+ }
+
+ // TODO support for other data_layout if needed
+ if (data_layout != "NHWC" && data_layout != "NCHW")
+ {
+ throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+ }
+
+ if (data_layout == "NHWC")
+ {
+ feature_shape.count() = shapedata.dim(0);
+ feature_shape.height() = shapedata.dim(1);
+ feature_shape.width() = shapedata.dim(2);
+ feature_shape.depth() = shapedata.dim(3);
+ }
+ else
+ {
+ feature_shape.count() = shapedata.dim(0);
+ feature_shape.depth() = shapedata.dim(1);
+ feature_shape.height() = shapedata.dim(2);
+ feature_shape.width() = shapedata.dim(3);
+ }
+
+ return feature_shape;
+}
+
+bool operator==(const ShapeInferenceData &lhs, const ShapeInferenceData &rhs)
+{
+ if (lhs.domain() != rhs.domain())
+ return false;
+
+ switch (lhs.domain())
+ {
+ case loco::Domain::Tensor:
+ {
+ auto lhs_t = lhs.tensor_shape();
+ auto rhs_t = rhs.tensor_shape();
+ if (lhs_t.rank() != rhs.rank())
+ return false;
+ for (uint32_t axis = 0; axis < lhs_t.rank(); ++axis)
+ {
+ if (!(lhs_t.dim(axis) == rhs_t.dim(axis)))
+ return false;
+ }
+ return true;
+ }
+ // TODO Support other domains
+ // case loco::Domain::Feature:
+ // case loco::Domain::Filter:
+ // case loco::Domain::Bias:
+ default:
+ throw std::runtime_error("Not supported domain for ShapeInferenceData equality");
+ }
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.h b/compiler/moco-tf/src/Annotations/ShapeInferenceData.h
new file mode 100644
index 000000000..d48699356
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_SHAPEINFERENCE_DATA_H__
+#define __MOCO_TF_SHAPEINFERENCE_DATA_H__
+
+#include <loco.h>
+#include "loco/IR/BiasShape.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note Below alias may be introduced as separate class
+using TFDataLayout = std::string;
+
+/**
+ * @brief ShapeInferenceData provides shape inference data tracking from the start(input)
+ *
+ * @note For Feature and Filter, NHWC is used for shape layout
+ */
+class ShapeInferenceData : public loco::NodeAnnotation,
+ public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+public:
+ ~ShapeInferenceData(){};
+
+public:
+ loco::Domain domain(void) const { return _domain; }
+
+ loco::TensorShape tensor_shape(void) const;
+ loco::FeatureShape feature_shape(void) const;
+ loco::FilterShape filter_shape(void) const;
+ loco::DepthwiseFilterShape depthwisefilter_shape(void) const;
+ loco::BiasShape bias_shape(void) const;
+
+ void tensor_shape(const loco::TensorShape &shape);
+ void feature_shape(const loco::FeatureShape &shape);
+ void filter_shape(const loco::FilterShape &shape);
+ void depthwisefilter_shape(const loco::DepthwiseFilterShape &shape);
+ void bias_shape(const loco::BiasShape &shape);
+
+private:
+ // TODO set default as Unknown, setting Tensor is to minimize change
+ loco::Domain _domain{loco::Domain::Tensor};
+};
+
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &shape,
+ const TFDataLayout &data_layout);
+
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+ const TFDataLayout &data_layout);
+
+bool operator==(const ShapeInferenceData &lhs, const ShapeInferenceData &rhs);
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_SHAPEINFERENCE_DATA_H__
diff --git a/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp b/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp
new file mode 100644
index 000000000..8b8de535e
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/ShapeInferenceData.test.cpp
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInferenceData.h"
+
+#include <gtest/gtest.h>
+
+TEST(TensorFlowImport, shapeinferencedata_tensor_get)
+{
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.rank(4);
+ shapedata.dim(0) = 1;
+ shapedata.dim(1) = 2;
+ shapedata.dim(2) = 3;
+ shapedata.dim(3) = 4;
+
+ loco::TensorShape tensor = shapedata.tensor_shape();
+
+ ASSERT_EQ(tensor.rank(), 4);
+ ASSERT_EQ(tensor.dim(0), 1);
+ ASSERT_EQ(tensor.dim(1), 2);
+ ASSERT_EQ(tensor.dim(2), 3);
+ ASSERT_EQ(tensor.dim(3), 4);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_feature)
+{
+ loco::FeatureShape feature_s;
+
+ feature_s.count() = 1;
+ feature_s.height() = 2;
+ feature_s.width() = 3;
+ feature_s.depth() = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.feature_shape(feature_s);
+
+ loco::FeatureShape feature_g = shapedata.feature_shape();
+
+ ASSERT_EQ(feature_g.count(), 1);
+ ASSERT_EQ(feature_g.height(), 2);
+ ASSERT_EQ(feature_g.width(), 3);
+ ASSERT_EQ(feature_g.depth(), 4);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_filter)
+{
+ loco::FilterShape filter_s;
+
+ filter_s.count() = 1;
+ filter_s.height() = 2;
+ filter_s.width() = 3;
+ filter_s.depth() = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.filter_shape(filter_s);
+
+ ASSERT_EQ(shapedata.domain(), loco::Domain::Filter);
+
+ loco::FilterShape filter_g = shapedata.filter_shape();
+
+ ASSERT_EQ(filter_g.count(), 1);
+ ASSERT_EQ(filter_g.height(), 2);
+ ASSERT_EQ(filter_g.width(), 3);
+ ASSERT_EQ(filter_g.depth(), 4);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_bias)
+{
+ // Note: this may change when loco::BiasShape becomes available
+
+ loco::BiasShape bias_s;
+
+ bias_s.length() = 3;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.bias_shape(bias_s);
+
+ loco::BiasShape bias_g = shapedata.bias_shape();
+
+ ASSERT_EQ(bias_g.length(), 3);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_as_tensor_set)
+{
+ loco::FeatureShape feature_s;
+
+ feature_s.count() = 1;
+ feature_s.height() = 2;
+ feature_s.width() = 3;
+ feature_s.depth() = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ as_tensor_shape(shapedata, feature_s, "NHWC");
+
+ loco::TensorShape tensor_g;
+
+ tensor_g = shapedata.tensor_shape();
+
+ ASSERT_EQ(tensor_g.rank(), 4);
+ ASSERT_EQ(tensor_g.dim(0), 1);
+ ASSERT_EQ(tensor_g.dim(1), 2);
+ ASSERT_EQ(tensor_g.dim(2), 3);
+ ASSERT_EQ(tensor_g.dim(3), 4);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_as_feature)
+{
+ loco::TensorShape tensor_s;
+
+ tensor_s.rank(4);
+ tensor_s.dim(0) = 1;
+ tensor_s.dim(1) = 2;
+ tensor_s.dim(2) = 3;
+ tensor_s.dim(3) = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.tensor_shape(tensor_s);
+
+ loco::FeatureShape feature_g = as_feature_shape(shapedata, "NHWC");
+
+ ASSERT_EQ(feature_g.count(), 1);
+ ASSERT_EQ(feature_g.height(), 2);
+ ASSERT_EQ(feature_g.width(), 3);
+ ASSERT_EQ(feature_g.depth(), 4);
+}
+
+TEST(TensorFlowImport, shapeinferencedata_equality_tensor)
+{
+ moco::tf::ShapeInferenceData left;
+ moco::tf::ShapeInferenceData right, wrong1, wrong2;
+
+ left.rank(2);
+ left.dim(0) = 1;
+ left.dim(1) = 2;
+ ASSERT_EQ(left.domain(), loco::Domain::Tensor);
+
+ right.rank(2);
+ right.dim(0) = 1;
+ right.dim(1) = 2;
+ ASSERT_TRUE(left == right);
+
+ wrong1.rank(1);
+ wrong1.dim(0) = 1;
+ ASSERT_FALSE(left == wrong1);
+
+ loco::FeatureShape wrong2_f;
+ wrong2_f.count() = 1;
+ wrong2_f.depth() = 1;
+ wrong2_f.height() = 1;
+ wrong2_f.width() = 1;
+
+ wrong2.feature_shape(wrong2_f);
+ ASSERT_FALSE(left == wrong2);
+}
diff --git a/compiler/moco-tf/src/Annotations/StrideData.h b/compiler/moco-tf/src/Annotations/StrideData.h
new file mode 100644
index 000000000..fb9a4b304
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/StrideData.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_STRIDE_DATA_H__
+#define __MOCO_TF_STRIDE_DATA_H__
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief StrideData holds temporary 'stride' attribute of TFConv2D, same thing
+ * like the one in Conv2D used for Canonicalizing TFConv2D to Conv2D.
+ * 'stride' will be calculated in FixShapeTransformation as for now.
+ */
+class StrideData : public loco::NodeAnnotation
+{
+public:
+ StrideData() = default;
+
+public:
+ const loco::Stride<2> *stride(void) const { return &_stride; }
+ loco::Stride<2> *stride(void) { return &_stride; }
+
+private:
+ loco::Stride<2> _stride;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_STRIDE_DATA_H__
diff --git a/compiler/moco-tf/src/Annotations/WindowData.h b/compiler/moco-tf/src/Annotations/WindowData.h
new file mode 100644
index 000000000..8bd962578
--- /dev/null
+++ b/compiler/moco-tf/src/Annotations/WindowData.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_WINDOW_DATA_H__
+#define __MOCO_TF_WINDOW_DATA_H__
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief WindowData holds temporary 'window' attribute of AvgPool2D, MaxPool2D
+ */
+class WindowData : public loco::NodeAnnotation
+{
+public:
+ WindowData() = default;
+
+public:
+ const loco::Window<2> *window(void) const { return &_window; }
+ loco::Window<2> *window(void) { return &_window; }
+
+private:
+ loco::Window<2> _window;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_WINDOW_DATA_H__
diff --git a/compiler/moco-tf/src/CanonicalEltwiseInputConnector.cpp b/compiler/moco-tf/src/CanonicalEltwiseInputConnector.cpp
index adeae39de..7142336bf 100644
--- a/compiler/moco-tf/src/CanonicalEltwiseInputConnector.cpp
+++ b/compiler/moco-tf/src/CanonicalEltwiseInputConnector.cpp
@@ -37,7 +37,6 @@ template <typename NodeTy> void InputConnector<NodeTy>::operator()(const NodePai
INSTANTIATE(EltwiseAdd);
INSTANTIATE(EltwiseSub);
-INSTANTIATE(EltwiseMax);
INSTANTIATE(EltwiseMul);
INSTANTIATE(EltwiseDiv);
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
index 8028a870c..ef82f3dab 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.cpp
@@ -16,8 +16,8 @@
#include "AddCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -26,9 +26,25 @@ namespace moco
namespace tf
{
-bool AddCanonicalizer::transform(TFAdd *node) const
+bool AddCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFAdd *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
index 53ba9ed58..07b8a72de 100644
--- a/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AddCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_ADD_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFAdd to Canonical EltwiseAdd
*/
-class AddCanonicalizer : public SimpleNodeTransform<TFAdd>
+class AddCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "AddCanonicalizer"; }
public:
- bool transform(TFAdd *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
index e07a4f64f..66a71089e 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp
@@ -16,19 +16,71 @@
#include "AvgPoolCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/WindowData.h"
-#include "CodecHelper.h"
-
-#include <loco/IR/NodeShape.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
-bool canonicalize_avgpool2d(loco::Graph *graph, moco::TFAvgPool *node)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node)
{
LOGGER(l);
@@ -61,24 +113,30 @@ bool canonicalize_avgpool2d(loco::Graph *graph, moco::TFAvgPool *node)
avgPool2d_node->convention(loco::AvgPool2D::Convention::Valid);
- auto value_shape = moco::node_shape(node->value());
- assert(value_shape.domain() != loco::Domain::Unknown);
+ // paddata to pad
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(node->ksize(), node->data_layout());
+ avgPool2d_node->pad()->top(pad_data->pad()->top());
+ avgPool2d_node->pad()->bottom(pad_data->pad()->bottom());
+ avgPool2d_node->pad()->left(pad_data->pad()->left());
+ avgPool2d_node->pad()->right(pad_data->pad()->right());
- moco::Padding2DInference infer_padding2d;
+ // windowdata to window (ksize to window)
+ auto window_data = node->annot<moco::tf::WindowData>();
+ assert(window_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ auto window = avgPool2d_node->window();
+ window->vertical(window_data->window()->vertical());
+ window->horizontal(window_data->window()->horizontal());
- auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // stridedata to stride (strides to stride)
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *avgPool2d_node->pad() = infer_padding2d(input_plane_shape);
- *avgPool2d_node->stride() = node_stride;
- *avgPool2d_node->window() = node_window;
+ auto stride = avgPool2d_node->stride();
+ stride->vertical(stride_data->stride()->vertical());
+ stride->horizontal(stride_data->stride()->horizontal());
INFO(l) << "Canonicalize TFAvgPool pad = T " << avgPool2d_node->pad()->top() << ", L "
<< avgPool2d_node->pad()->left() << ", B " << avgPool2d_node->pad()->bottom() << ", R "
@@ -105,9 +163,25 @@ namespace moco
namespace tf
{
-bool AvgPoolCanonicalizer::transform(TFAvgPool *node) const
+bool AvgPoolCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_avgpool2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFAvgPool *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_avgpool2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
index e9c56c868..7d7e6a80b 100644
--- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_AVGPOOL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFAvgPool to Canonical AvgPool2D
*/
-class AvgPoolCanonicalizer : public SimpleNodeTransform<moco::TFAvgPool>
+class AvgPoolCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "AvgPoolCanonicalizer"; }
public:
- bool transform(TFAvgPool *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
index a5568ce1a..37b660e4a 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.cpp
@@ -16,9 +16,12 @@
#include "BiasAddCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
#include <plier/tf/Convert.h>
@@ -26,7 +29,7 @@ namespace
{
using plier::tf::DataLayout;
-bool canonicalize_biasadd(loco::Graph *graph, moco::TFBiasAdd *node)
+bool canonicalize_biasadd(loco::Graph *graph, moco::tf::TFBiasAdd *node)
{
LOGGER(l);
@@ -100,9 +103,25 @@ namespace moco
namespace tf
{
-bool BiasAddCanonicalizer::transform(TFBiasAdd *node) const
+bool BiasAddCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_biasadd(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_biasadd = dynamic_cast<moco::tf::TFBiasAdd *>(node);
+ if (tf_biasadd != nullptr)
+ {
+ if (canonicalize_biasadd(graph, tf_biasadd))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
index ff4032ca9..a30894708 100644
--- a/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/BiasAddCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_BIASADD_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFBiasAdd to Canonical BiasAdd
*/
-class BiasAddCanonicalizer final : public SimpleNodeTransform<moco::TFBiasAdd>
+class BiasAddCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "BiasAddCanonicalizer"; }
public:
- bool transform(TFBiasAdd *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
index b59a3f3d7..e3939adb9 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.cpp
@@ -15,39 +15,27 @@
*/
#include "ConcatV2Canonicalizer.h"
+
#include "LogHelper.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ConcatData.h"
+#include "Annotations/ShapeInferenceData.h"
-#include <moco/Log.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <loco/Service/ShapeInference.h>
+#include <moco/Log.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
namespace
{
using namespace moco::tf;
-bool scalar_value(moco::TFConst *node, int32_t &ret)
-{
- auto nodeshape = node_shape(node);
- if (!(node->dtype() == loco::DataType::S32))
- return false;
-
- auto tensor_shape = nodeshape.as<loco::TensorShape>();
- if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
- return false;
-
- ret = node->at<loco::DataType::S32>(0);
-
- return true;
-}
-
-bool canonicalize_concat(loco::Graph *graph, moco::TFConcatV2 *node)
+bool canonicalize_concat(loco::Graph *graph, moco::tf::TFConcatV2 *node)
{
LOGGER(l);
@@ -83,43 +71,19 @@ bool canonicalize_concat(loco::Graph *graph, moco::TFConcatV2 *node)
const int num_values = node->num_values();
assert(num_values >= 2);
- // get axis absolute value
- auto value_a = node->values(0);
- if (!loco::shape_known(value_a))
- return false;
+ // get axis value
+ auto concat_data = node->annot<ConcatData>();
+ assert(concat_data != nullptr);
+ auto axis_value = concat_data->axis();
- uint32_t node_rank = 0;
- {
- auto value_a_shape = moco::node_shape(value_a);
- assert(value_a_shape.domain() == loco::Domain::Tensor);
-
- auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
- node_rank = value_a_tensor_shape.rank();
- }
+ auto shapedata = node->annot<ShapeInferenceData>();
+ auto node_rank = shapedata->rank();
- int32_t axis_value = 0;
- {
- // axis should be TFConst
- auto axis_node = node->axis();
- auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
- if (tfconst == nullptr)
- {
- // TODO Check this: this error can be from TFOptimizatier.
- throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
- }
- auto result = scalar_value(tfconst, axis_value);
- if (!result)
- {
- // TODO Check this: this error can be from TFOptimizatier.
- throw oops::UserExn("ConcatV2 node has invalid input for axis", node->name());
- }
- }
uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)node_rank + axis_value;
INFO(l) << "canonicalize_concat axis(" << axis_absolute << "), value(" << axis_value << "), rank("
<< node_rank << ")";
- // Convert series of TensorConcat if num_values > 2
auto concat_node = graph->nodes()->create<loco::TensorConcat>();
concat_node->lhs(node->values(0));
concat_node->rhs(node->values(1));
@@ -151,9 +115,25 @@ namespace moco
namespace tf
{
-bool ConcatV2Canonicalizer::transform(TFConcatV2 *node) const
+bool ConcatV2Canonicalizer::run(loco::Graph *graph)
{
- return canonicalize_concat(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFConcatV2 *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_concat(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
index e6b471b89..4448ddb16 100644
--- a/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConcatV2Canonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONCATV2_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConcatV2 to Canonical TensorConcat
*/
-class ConcatV2Canonicalizer : public SimpleNodeTransform<moco::TFConcatV2>
+class ConcatV2Canonicalizer : public Transform
{
public:
const char *name(void) const final { return "ConcatV2Canonicalizer"; }
public:
- bool transform(moco::TFConcatV2 *node) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
index 60629cd5a..dea97f94a 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.cpp
@@ -16,17 +16,18 @@
#include "ConstCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
-#include <oops/UserExn.h>
-
namespace
{
-bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
+bool canonicalize_const(loco::Graph *graph, moco::tf::TFConst *node)
{
LOGGER(l);
@@ -54,27 +55,13 @@ bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
const_node->dtype(dtype);
auto rank = node->rank();
-
- if (rank == 0)
- {
- // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
- // into a rank-1 tensor of shape [1].
- //
- // TODO Revise this implementation later
- const_node->rank(1);
- const_node->dim(0) = 1;
- }
- else
+ const_node->rank(rank);
+ for (uint32_t r = 0; r < rank; ++r)
{
- const_node->rank(rank);
-
- for (uint32_t r = 0; r < rank; ++r)
- {
- if (node->dim(r).known())
- const_node->dim(r) = node->dim(r);
- else
- const_node->dim(r).unset();
- }
+ if (node->dim(r).known())
+ const_node->dim(r) = node->dim(r);
+ else
+ const_node->dim(r).unset();
}
switch (dtype)
@@ -100,7 +87,7 @@ bool canonicalize_const(loco::Graph *graph, moco::TFConst *node)
break;
}
default:
- throw oops::UserExn("Const has unsupported data type", node->name());
+ throw std::runtime_error("NYI for this DataType");
}
// update graph
@@ -118,9 +105,25 @@ namespace moco
namespace tf
{
-bool ConstCanonicalizer::transform(TFConst *node) const
+bool ConstCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_const(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_const = dynamic_cast<moco::tf::TFConst *>(node);
+ if (tf_const != nullptr)
+ {
+ if (canonicalize_const(graph, tf_const))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
index 1b0b2b867..53f3ca8e3 100644
--- a/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ConstCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONST_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConst to Canonical ConstGen
*/
-class ConstCanonicalizer : public SimpleNodeTransform<moco::TFConst>
+class ConstCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ConstCanonicalizer"; }
public:
- bool transform(moco::TFConst *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
deleted file mode 100644
index d3cbd4ab3..000000000
--- a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.cpp
+++ /dev/null
@@ -1,371 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "Conv2DBackpropInputCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "CodecHelper.h"
-
-#include <loco/IR/Stride.h>
-#include <loco/IR/Padding2D.h>
-#include <loco/Service/ShapeInference.h>
-
-#include <oops/UserExn.h>
-
-namespace
-{
-using plier::tf::DataLayout;
-
-void set_filter_enc(loco::FilterEncode *filter_enc)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
-
- // In TensorFlow, Conv2dBackpropInput's filter is a 4-D tensor of following shape:
- // [filter_height, filter_width, out_channels, in_channels] or HWOI or HWNC (in/out in loco sense)
- enc->perm()->axis(loco::FilterAxis::Height) = 0;
- enc->perm()->axis(loco::FilterAxis::Width) = 1;
- enc->perm()->axis(loco::FilterAxis::Count) = 2;
- enc->perm()->axis(loco::FilterAxis::Depth) = 3;
-
- filter_enc->encoder(std::move(enc));
-}
-
-} // namespace
-
-namespace
-{
-
-bool stride_2d_from_4d(loco::Stride<2> &ret, const std::vector<int64_t> &strides_4d,
- const DataLayout data_layout)
-{
- if (!(strides_4d.size() == 4))
- return false;
-
- switch (data_layout)
- {
- case DataLayout::NHWC:
- ret.vertical(strides_4d.at(1));
- ret.horizontal(strides_4d.at(2));
- break;
- case DataLayout::NCHW:
- ret.vertical(strides_4d.at(2));
- ret.horizontal(strides_4d.at(3));
- break;
- default:
- return false;
- }
- return true;
-}
-
-struct PlaneShape
-{
- loco::Dimension vertical;
- loco::Dimension horizontal;
-};
-
-class Padding2DInference final
-{
-public:
- Padding2DInference(const moco::TFNode *node) { _node = node; }
-
-public:
- loco::Padding2D operator()(void);
-
-public:
- PlaneShape &input() { return _input; }
- PlaneShape &output() { return _output; }
- loco::Stride<2> &stride() { return _stride; }
- loco::Window<2> &window() { return _window; }
- moco::TFPadding &padding() { return _padding; }
-
-private:
- /// @brief Check whether ingredients set by non-default values
- bool ready()
- {
- if (not input().vertical.known())
- return false;
- if (not input().horizontal.known())
- return false;
- if (not output().vertical.known())
- return false;
- if (not output().horizontal.known())
- return false;
- if (stride().vertical() == 0)
- return false;
- if (stride().horizontal() == 0)
- return false;
- if (window().vertical() == 0)
- return false;
- if (window().horizontal() == 0)
- return false;
- if (padding().empty())
- return false;
-
- return true;
- }
-
- inline uint32_t tight_output_for_valid_padding(uint32_t input, uint32_t stride, uint32_t filter)
- {
- return stride * (input - 1) + filter;
- }
-
- /**
- * @note For Conv2DBackpropInput SAME padding, TensorFlow requires this condition to hold
- *
- * Reference: `::tensorflow::GetWindowedOutputSizeVerboseV2()` from TensorFlow project
- */
- inline bool same_padding_applicable(uint32_t input, uint32_t output, uint32_t stride)
- {
- // Here 'input' and 'output' means Conv2DBackpropInput's actual node input and output.
- // Then these three conditions are equivalent:
- //
- // input == floor((output + stride - 1) / stride)
- // input == ceil(output / stride)
- // (stride * (input - 1) < output) and (output <= stride * input)
- return (stride * (input - 1) < output) and (output <= stride * input);
- }
-
- inline uint32_t padding_needed(uint32_t input, uint32_t output, uint32_t stride, uint32_t filter)
- {
- return stride * (input - 1) + filter - output;
- }
-
-private:
- const moco::TFNode *_node;
- PlaneShape _input;
- PlaneShape _output;
- loco::Stride<2> _stride;
- loco::Window<2> _window;
- moco::TFPadding _padding;
-};
-
-loco::Padding2D Padding2DInference::operator()(void)
-{
- assert(ready());
-
- if (padding() == "VALID")
- {
- // In case of VALID padding, TensorFlow accepts any size same or larger than
- // 'tight fit' output. When output size (set by 'input sizes' node input) is
- // larger than tight fit, extra spaces filled with zero.
- auto tight_output_vertical = tight_output_for_valid_padding(
- input().vertical.value(), stride().vertical(), window().vertical());
- auto tight_output_horizontal = tight_output_for_valid_padding(
- input().horizontal.value(), stride().horizontal(), window().horizontal());
-
- if (output().vertical.value() < tight_output_vertical or
- output().horizontal.value() < tight_output_horizontal)
- throw oops::UserExn("input_sizes is too small", _node->name());
-
- // Currently, only accept tight fit.
- // TODO Support non-tight case by adding zero padding operation
- assert(output().vertical.value() == tight_output_vertical);
- assert(output().horizontal.value() == tight_output_horizontal);
-
- return loco::Padding2D(0, 0, 0, 0);
- }
-
- if (padding() == "SAME")
- {
- // This condition is required by TensorFlow
- if (not same_padding_applicable(input().vertical.value(), output().vertical.value(),
- stride().vertical()) or
- not same_padding_applicable(input().horizontal.value(), output().horizontal.value(),
- stride().horizontal()))
- throw oops::UserExn("Size mismatch for SAME padding", _node->name());
-
- auto whole_pad_vertical = padding_needed(input().vertical.value(), output().vertical.value(),
- stride().vertical(), window().vertical());
- auto whole_pad_horizontal =
- padding_needed(input().horizontal.value(), output().horizontal.value(),
- stride().horizontal(), window().horizontal());
-
- loco::Padding2D res;
-
- res.top(whole_pad_vertical / 2);
- res.bottom(whole_pad_vertical - res.top());
- res.left(whole_pad_horizontal / 2);
- res.right(whole_pad_horizontal - res.left());
-
- return res;
- }
-
- throw oops::UserExn("Usupported padding " + padding(), _node->name());
-}
-
-/**
- * @param[out] ret PlaneShape extracted from 'node' with given 'data_layout'
- * @param[in] node
- * @param[in] data_layout
- *
- * @return true on success
- */
-bool set_plane_shape(PlaneShape &ret, const loco::Node *node, const DataLayout data_layout)
-{
- auto tensor_shape = loco::shape_get(node).as<loco::TensorShape>();
- if (!(tensor_shape.rank() == 4))
- return false;
-
- switch (data_layout)
- {
- case DataLayout::NHWC:
- ret.vertical = tensor_shape.dim(1).value();
- ret.horizontal = tensor_shape.dim(2).value();
- break;
- case DataLayout::NCHW:
- ret.vertical = tensor_shape.dim(2).value();
- ret.horizontal = tensor_shape.dim(3).value();
- break;
- default:
- return false;
- }
-
- return true;
-}
-
-/**
- * @param[out] ret 2D Window extracted from HW** filter node
- * @param[in] filter_node
- *
- * @return true on success
- */
-bool set_window(loco::Window<2> &ret, const loco::Node *filter_node)
-{
- auto tensor_shape = loco::shape_get(filter_node).as<loco::TensorShape>();
- assert(tensor_shape.rank() == 4);
-
- ret.vertical(tensor_shape.dim(0).value());
- ret.horizontal(tensor_shape.dim(1).value());
-
- return true;
-}
-
-} // namespace
-
-namespace
-{
-
-bool canonicalize_conv2d_backprop_input(loco::Graph *graph,
- moco::TFConv2DBackpropInput *conv2d_backprop)
-{
- /**
- * @note This will replace TFConv2DBackpropInput node with canonical
- * FeatureEncode + FilterEncode + TransposedConv2D + FeatureDecode
- *
- * Before
- * input_sizes ----
- * \
- * filter -------- TFConv2DBackpropInput --- output(s)
- * /
- * out_backprop ---
- *
- * After
- * input_sizes ----
- * \
- * filter -------- TFConv2DBackpropInput ---
- * /
- * out_backprop ---
- *
- * filter ------ FilterEncode ------ TransposedConv2D --- FeatureDecode --- output(s)
- * (as ker) /
- * out_backprop --- FeatureEncode ---
- * (as ifm)
- */
-
- if (!loco::shape_known(conv2d_backprop->out_backprop()))
- return false;
- if (!loco::shape_known(conv2d_backprop))
- return false;
- if (!loco::shape_known(conv2d_backprop->filter()))
- return false;
-
- auto data_layout = plier::tf::as_data_layout(conv2d_backprop->data_layout());
-
- // Nodes to replace
- auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
- auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
- auto tr_conv2d = graph->nodes()->create<loco::TransposedConv2D>();
- auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
-
- set_feature_enc(feature_enc, data_layout);
- set_filter_enc(filter_enc);
- set_feature_dec(feature_dec, data_layout);
-
- // Attributes for new TransposedConv2D
- loco::Stride<2> stride;
- loco::Padding2D pad;
-
- // Get attributes
- {
- if (!stride_2d_from_4d(stride, conv2d_backprop->strides(), data_layout))
- throw oops::UserExn("Unsupported strides", conv2d_backprop->name());
-
- Padding2DInference infer_pad(conv2d_backprop);
-
- if (!set_plane_shape(infer_pad.input(), conv2d_backprop->out_backprop(), data_layout))
- throw oops::UserExn("Unsupported out_backprop data_format", conv2d_backprop->name());
- if (!set_plane_shape(infer_pad.output(), conv2d_backprop, data_layout))
- throw oops::UserExn("Unsupported data_format", conv2d_backprop->name());
- if (!set_window(infer_pad.window(), conv2d_backprop->filter()))
- throw oops::UserExn("Unsupported filter shape", conv2d_backprop->name());
- infer_pad.stride() = stride;
- infer_pad.padding() = conv2d_backprop->padding();
-
- // Run padding infer_pad
- pad = infer_pad();
- }
-
- // Set attributes
- tr_conv2d->pad()->top(pad.top());
- tr_conv2d->pad()->bottom(pad.bottom());
- tr_conv2d->pad()->left(pad.left());
- tr_conv2d->pad()->right(pad.right());
-
- tr_conv2d->stride()->vertical(stride.vertical());
- tr_conv2d->stride()->horizontal(stride.horizontal());
-
- // Update graph
- auto input_node = conv2d_backprop->out_backprop();
- auto filter_node = conv2d_backprop->filter();
-
- // Update connections
- feature_enc->input(input_node);
- filter_enc->input(filter_node);
- tr_conv2d->ifm(feature_enc);
- tr_conv2d->ker(filter_enc);
- feature_dec->input(tr_conv2d);
-
- // Replace old conv2d_backprop
- replace(conv2d_backprop).with(feature_dec);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool Conv2DBackpropInputCanonicalizer::transform(TFConv2DBackpropInput *node) const
-{
- return canonicalize_conv2d_backprop_input(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
index a955793a8..f34339d0f 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp
@@ -16,18 +16,46 @@
#include "Conv2DCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
-void set_filter_enc(loco::FilterEncode *filter_enc)
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_filter_enc(loco::FilterEncode *filter_enc, DataLayout data_layout)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
@@ -41,7 +69,29 @@ void set_filter_enc(loco::FilterEncode *filter_enc)
filter_enc->encoder(std::move(enc));
}
-bool canonicalize_conv2d(loco::Graph *graph, moco::TFConv2D *node)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node)
{
LOGGER(l);
@@ -75,29 +125,23 @@ bool canonicalize_conv2d(loco::Graph *graph, moco::TFConv2D *node)
auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
set_feature_enc(feature_enc, data_layout);
- set_filter_enc(filter_enc);
+ set_filter_enc(filter_enc, data_layout);
set_feature_dec(feature_dec, data_layout);
- auto input_shape = moco::node_shape(node->input());
- assert(input_shape.domain() != loco::Domain::Unknown);
-
- auto ker_shape = moco::node_shape(node->filter());
- auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
-
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
-
- moco::Padding2DInference infer_padding2d;
+ // Set Conv2D attributes from TFConv2D
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ conv2d->pad()->top(pad_data->pad()->top());
+ conv2d->pad()->bottom(pad_data->pad()->bottom());
+ conv2d->pad()->left(pad_data->pad()->left());
+ conv2d->pad()->right(pad_data->pad()->right());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *conv2d->pad() = infer_padding2d(input_plane_shape);
- *conv2d->stride() = node_stride;
+ conv2d->stride()->vertical(stride_data->stride()->vertical());
+ conv2d->stride()->horizontal(stride_data->stride()->horizontal());
// update graph
auto node_A = node->input();
@@ -123,9 +167,25 @@ namespace moco
namespace tf
{
-bool Conv2DCanonicalizer::transform(TFConv2D *node) const
+bool Conv2DCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_conv2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFConv2D *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_conv2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
index ea39667f3..6be264f90 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Conv2DCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_CONV2D_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFConv2D to Canonical Conv2D
*/
-class Conv2DCanonicalizer : public SimpleNodeTransform<TFConv2D>
+class Conv2DCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "Conv2DCanonicalizer"; }
public:
- bool transform(TFConv2D *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
index 50dddf637..ee63efa2f 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.cpp
@@ -16,18 +16,47 @@
#include "DepthwiseConv2dNativeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/StrideData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
using plier::tf::DataLayout;
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::DepthwiseFilter>>();
@@ -42,7 +71,29 @@ void set_filter_enc(loco::DepthwiseFilterEncode *filter_enc)
filter_enc->encoder(std::move(enc));
}
-bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::TFDepthwiseConv2dNative *node)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::tf::TFDepthwiseConv2dNative *node)
{
LOGGER(l);
@@ -83,24 +134,20 @@ bool canonicalize_depthwiseconv2dnative(loco::Graph *graph, moco::TFDepthwiseCon
set_filter_enc(filter_enc);
set_feature_dec(feature_dec, data_layout);
- // Calculate Pad and Stride from inference
- auto input_shape = moco::node_shape(node->input());
- auto ker_shape = moco::node_shape(node->filter());
- auto ker_tensor_shape = ker_shape.as<loco::TensorShape>();
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
-
- moco::Padding2DInference infer_padding2d;
+ // Set DetphwiseConv2D attributes from TFDepthwiseConv2dNative
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ depthwiseconv2d->pad()->top(pad_data->pad()->top());
+ depthwiseconv2d->pad()->bottom(pad_data->pad()->bottom());
+ depthwiseconv2d->pad()->left(pad_data->pad()->left());
+ depthwiseconv2d->pad()->right(pad_data->pad()->right());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *depthwiseconv2d->pad() = infer_padding2d(input_plane_shape);
- *depthwiseconv2d->stride() = node_stride;
+ depthwiseconv2d->stride()->vertical(stride_data->stride()->vertical());
+ depthwiseconv2d->stride()->horizontal(stride_data->stride()->horizontal());
// update graph
auto node_A = node->input();
@@ -128,9 +175,25 @@ namespace moco
namespace tf
{
-bool DepthwiseConv2dNativeCanonicalizer::transform(TFDepthwiseConv2dNative *node) const
+bool DepthwiseConv2dNativeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_depthwiseconv2dnative(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_depthwiseconv2dnative(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
index 704e1ade9..9bb8c5ad8 100644
--- a/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/DepthwiseConv2dNativeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_DEPTHWISE_CONV2D_NATIVE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
namespace moco
{
@@ -30,13 +27,13 @@ namespace tf
/**
* @brief Convert TFDepthwiseConv2dNative to Canonical DepthwiseConv2D
*/
-class DepthwiseConv2dNativeCanonicalizer : public SimpleNodeTransform<moco::TFDepthwiseConv2dNative>
+class DepthwiseConv2dNativeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "DepthwiseConv2dNativeCanonicalizer"; }
public:
- bool transform(moco::TFDepthwiseConv2dNative *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
index 3b680cf04..c4d5d8063 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.cpp
@@ -18,15 +18,18 @@
#include "Convert.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/Log.h>
namespace
{
-bool canonicalize_identity(loco::Graph *graph, moco::TFIdentity *node)
+bool canonicalize_identity(loco::Graph *graph, moco::tf::TFIdentity *node)
{
LOGGER(l);
@@ -69,9 +72,25 @@ namespace moco
namespace tf
{
-bool IdentityCanonicalizer::transform(TFIdentity *node) const
+bool IdentityCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_identity(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_identity = dynamic_cast<moco::tf::TFIdentity *>(node);
+ if (tf_identity != nullptr)
+ {
+ if (canonicalize_identity(graph, tf_identity))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
index 59b2894c5..81aee178a 100644
--- a/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/IdentityCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_IDENTITY_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFIdentity to Canonical Forward
*/
-class IdentityCanonicalizer : public SimpleNodeTransform<moco::TFIdentity>
+class IdentityCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "IdentityCanonicalizer"; }
public:
- bool transform(moco::TFIdentity *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
index 06a605717..c46fbd208 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.cpp
@@ -16,17 +16,70 @@
#include "MaxPoolCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/PadData.h"
+#include "Annotations/StrideData.h"
+#include "Annotations/WindowData.h"
-#include "CodecHelper.h"
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
namespace
{
-bool canonicalize_maxpool2d(loco::Graph *graph, moco::TFMaxPool *node)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+
+ feature_dec->decoder(std::move(dec));
+}
+
+bool canonicalize_maxpool2d(loco::Graph *graph, moco::tf::TFMaxPool *node)
{
LOGGER(l);
@@ -58,31 +111,36 @@ bool canonicalize_maxpool2d(loco::Graph *graph, moco::TFMaxPool *node)
set_feature_dec(feature_dec, data_layout);
// paddata to pad
- auto input_shape = moco::node_shape(node->input());
- assert(input_shape.domain() != loco::Domain::Unknown);
+ auto pad_data = node->annot<moco::tf::PadData>();
+ assert(pad_data != nullptr);
- auto node_stride = moco::stride_of(node->strides(), node->data_layout());
- auto node_window = moco::window_of(node->ksize(), node->data_layout());
+ maxPool2d_node->pad()->top(pad_data->pad()->top());
+ maxPool2d_node->pad()->bottom(pad_data->pad()->bottom());
+ maxPool2d_node->pad()->left(pad_data->pad()->left());
+ maxPool2d_node->pad()->right(pad_data->pad()->right());
- moco::Padding2DInference infer_padding2d;
+ // windowdata to window (ksize to window)
+ auto window_data = node->annot<moco::tf::WindowData>();
+ assert(window_data != nullptr);
- infer_padding2d.padding(node->padding());
- infer_padding2d.stride(node_stride);
- infer_padding2d.window(node_window);
+ auto window = maxPool2d_node->window();
+ window->vertical(window_data->window()->vertical());
+ window->horizontal(window_data->window()->horizontal());
- auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
- auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // stridedata to stride (strides to stride)
+ auto stride_data = node->annot<moco::tf::StrideData>();
+ assert(stride_data != nullptr);
- *maxPool2d_node->pad() = infer_padding2d(input_plane_shape);
- *maxPool2d_node->stride() = node_stride;
- *maxPool2d_node->window() = node_window;
+ auto stride = maxPool2d_node->stride();
+ stride->vertical(stride_data->stride()->vertical());
+ stride->horizontal(stride_data->stride()->horizontal());
INFO(l) << "Canonicalize TFMaxPool pad = T " << maxPool2d_node->pad()->top() << ", L "
<< maxPool2d_node->pad()->left() << ", B " << maxPool2d_node->pad()->bottom() << ", R "
<< maxPool2d_node->pad()->right() << std::endl;
// update graph
- auto node_A = node->input();
+ auto node_A = node->value();
// update connections
feature_enc->input(node_A);
@@ -102,9 +160,25 @@ namespace moco
namespace tf
{
-bool MaxPoolCanonicalizer::transform(TFMaxPool *node) const
+bool MaxPoolCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_maxpool2d(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFMaxPool *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_maxpool2d(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
index c58ade528..a486c4caa 100644
--- a/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MaxPoolCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_MAXPOOL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFMaxPool to Canonical MaxPool2D
*/
-class MaxPoolCanonicalizer : public SimpleNodeTransform<moco::TFMaxPool>
+class MaxPoolCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "MaxPoolCanonicalizer"; }
public:
- bool transform(moco::TFMaxPool *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
deleted file mode 100644
index baff4d7ad..000000000
--- a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_MAXIMUM_CANONICALIZER_H__
-#define __MOCO_TF_MAXIMUM_CANONICALIZER_H__
-
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Convert TFMaximum to Canonical EltwiseMax
- */
-class MaximumCanonicalizer : public SimpleNodeTransform<moco::TFMaximum>
-{
-public:
- const char *name(void) const final { return "MaximumCanonicalizer"; }
-
-public:
- bool transform(moco::TFMaximum *) const final;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_MAXIMUM_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
index d02f71361..78d0ebc48 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "MulCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool MulCanonicalizer::transform(moco::TFMul *node) const
+bool MulCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFMul *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
index 480eec700..680f4c315 100644
--- a/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/MulCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_MUL_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFMul to Canonical EltwiseMul
*/
-class MulCanonicalizer : public SimpleNodeTransform<moco::TFMul>
+class MulCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "MulCanonicalizer"; }
public:
- bool transform(moco::TFMul *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
deleted file mode 100644
index 36136aed4..000000000
--- a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.cpp
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "PadCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "loco/Service/TypeInference.h"
-
-#include <stdex/Memory.h>
-
-namespace
-{
-
-bool canonicalize_pad(loco::Graph *graph, moco::TFPad *node)
-{
- /**
- * @note This will replace TFPad node with Canonical TensorConstantPad
- *
- * Before
- * input --- TFPad -- C
- * paddings --/
- * After
- * paddings ------- TFPad --
- * /
- * input ----------- TensorConstantPad -- C
- * ConstGen --------/
- * Where
- * input : input of TFPad
- * paddings : paddings of TFPad. it becomes TensorConstantPad's attribute.
- * C : a node that uses TFPad as an input. TFPad is disconnected from C.
- * ConstGen : constant value of Pad. TFPad has zero value by default.
- */
-
- auto pad_node = graph->nodes()->create<loco::TensorConstantPad>();
-
- auto constant_node = graph->nodes()->create<loco::ConstGen>();
-
- auto input_node = node->input();
- // TODO: support other dtype.
- assert(loco::dtype_get(input_node) == loco::DataType::FLOAT32);
- constant_node->dtype(loco::DataType::FLOAT32);
- // TODO: constant node changes to scalar when it is implemented.
- constant_node->shape({1});
- constant_node->size<loco::DataType::FLOAT32>(1);
- constant_node->at<loco::DataType::FLOAT32>(0) = 0.0f;
-
- auto const_paddings_node = loco::must_cast<loco::ConstGen *>(node->paddings());
- // TODO: support S64 type.
- assert(const_paddings_node->dtype() == loco::DataType::S32);
- assert(const_paddings_node->rank() == 2);
- assert(const_paddings_node->dim(1).value() == 2);
-
- auto padding = pad_node->padding();
- uint32_t padding_rank = const_paddings_node->dim(0).value();
- padding->rank(padding_rank);
-
- for (uint32_t i = 0; i < padding_rank; i++)
- {
- padding->front(i) = const_paddings_node->at<loco::DataType::S32>(i << 1);
- padding->back(i) = const_paddings_node->at<loco::DataType::S32>((i << 1) + 1);
- }
-
- // update connections
- pad_node->input(input_node);
- pad_node->constant(constant_node);
-
- // replace node
- replace(node).with(pad_node);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool PadCanonicalizer::transform(TFPad *node) const
-{
- return canonicalize_pad(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
deleted file mode 100644
index f568e909f..000000000
--- a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.cpp
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "PlaceholderCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include <moco/Names.h>
-#include <moco/Log.h>
-
-namespace
-{
-
-bool canonicalize_placeholder(loco::Graph *graph, moco::TFPlaceholder *node)
-{
- LOGGER(l);
-
- /**
- * @note This will replace TFPlaceholder node with Canonical Pull
- *
- * Before
- * TFPlaceholder -- C
- *
- * After
- * TFPlaceholder -
- * Pull -- C
- *
- * Where
- * C : a node that uses TFPlaceholder as an input
- * TFPlaceholder is disconnected from other nodes
- */
-
- INFO(l) << "PlaceholderCanonicalizer begin";
-
- auto pull_node = graph->nodes()->create<loco::Pull>();
-
- // copy properties
- auto dtype = node->dtype();
- pull_node->dtype(dtype);
-
- auto rank = node->rank();
-
- if (rank == 0)
- {
- // This routine implements a workaround that converts a scalar constant (rank-0 tensor)
- // into a rank-1 tensor of shape [1].
- //
- // TODO Revise this implementation later
- pull_node->rank(1);
- pull_node->dim(0) = 1;
- }
- else
- {
- pull_node->rank(rank);
-
- for (uint32_t r = 0; r < rank; ++r)
- {
- if (node->dim(r).known())
- pull_node->dim(r) = node->dim(r);
- else
- pull_node->dim(r).unset();
- }
- }
-
- // set loco::Pull GraphInputIndex
- pull_node->index(moco::index(node));
-
- // update graph
- replace(node).with(pull_node);
-
- INFO(l) << "PlaceholderCanonicalizer done";
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool PlaceholderCanonicalizer::transform(TFPlaceholder *node) const
-{
- return canonicalize_placeholder(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
index a448d85fa..9ad15150a 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "RealDivCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool RealDivCanonicalizer::transform(moco::TFRealDiv *node) const
+bool RealDivCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRealDiv *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
index 76e1bd377..8e6953396 100644
--- a/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RealDivCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_REALDIV_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRealDiv to Canonical EltwiseDiv
*/
-class RealDivCanonicalizer : public SimpleNodeTransform<moco::TFRealDiv>
+class RealDivCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "RealDivCanonicalizer"; }
public:
- bool transform(moco::TFRealDiv *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
index c53a880a8..07657244b 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.cpp
@@ -16,14 +16,17 @@
#include "Relu6Canonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu6(loco::Graph *graph, moco::TFRelu6 *node)
+bool canonicalize_relu6(loco::Graph *graph, moco::tf::TFRelu6 *node)
{
/**
* @note This will replace TFRelu6 node with Canonical ReLU6
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool Relu6Canonicalizer::transform(TFRelu6 *node) const
+bool Relu6Canonicalizer::run(loco::Graph *graph)
{
- return canonicalize_relu6(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRelu6 *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_relu6(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
index d8ad5db8e..aa1580f28 100644
--- a/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/Relu6Canonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RELU6_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRelu6 to Canonical ReLU6
*/
-class Relu6Canonicalizer : public SimpleNodeTransform<moco::TFRelu6>
+class Relu6Canonicalizer : public Transform
{
public:
const char *name(void) const final { return "Relu6Canonicalizer"; }
public:
- bool transform(moco::TFRelu6 *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
index 7965dc931..20cd0bab9 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "ReluCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_relu(loco::Graph *graph, moco::TFRelu *node)
+bool canonicalize_relu(loco::Graph *graph, moco::tf::TFRelu *node)
{
/**
* @note This will replace TFRelu node with Canonical ReLU
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool ReluCanonicalizer::transform(TFRelu *node) const
+bool ReluCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_relu(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRelu *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_relu(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
index e27abe158..97adba308 100644
--- a/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReluCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RELU_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRelu to Canonical ReLU
*/
-class ReluCanonicalizer : public SimpleNodeTransform<moco::TFRelu>
+class ReluCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ReluCanonicalizer"; }
public:
- bool transform(moco::TFRelu *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
index b944568e0..3771d549a 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.cpp
@@ -16,11 +16,11 @@
#include "ReshapeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
#include <plier/tf/Convert.h>
-#include <oops/UserExn.h>
#include <cassert>
@@ -31,7 +31,7 @@ using plier::tf::DataLayout;
/**
* @brief Check whether given 'new shape' arg is a fixed shape input for Reshape
*
- * ConstNode can be moco::TFConst or loco::ConstGen
+ * ConstNode can be moco::tf::TFConst or loco::ConstGen
*/
template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_input)
{
@@ -54,16 +54,13 @@ template <typename ConstNode> bool is_fixed_shape_input(ConstNode *const_shape_i
// has wildcard dimension, i.e. dynamic reshape
return false;
}
- if (!(shape_dim >= 1))
- {
- throw oops::UserExn("New shape of Reshape has invalid dimension");
- }
+ assert(shape_dim >= 1 && "Unknown behavior: New shape of Reshape has invalid dimension");
}
return true;
}
/// @note Currently only supports to canonicalize Fixed Reshape
-bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
+bool canonicalize_reshape(loco::Graph *graph, moco::tf::TFReshape *node)
{
LOGGER(l);
INFO(l) << "TFNodeCanonicalize TFReshape begin";
@@ -102,17 +99,14 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
// Supports 2 cases for Reshape's shape input:
// TF-dialect TFConst or Canonical ConstGen
loco::Node *shape_input = node->shape();
- auto tfconst_shape_input = dynamic_cast<moco::TFConst *>(shape_input);
+ auto tfconst_shape_input = dynamic_cast<moco::tf::TFConst *>(shape_input);
auto constgen_shape_input = dynamic_cast<loco::ConstGen *>(shape_input);
if (tfconst_shape_input)
{
// Only support fixed reshape
// TODO support dynamic reshape
- if (!(is_fixed_shape_input(tfconst_shape_input)))
- {
- throw oops::UserExn("Supports only fixed reshape", node->name());
- }
+ assert(is_fixed_shape_input(tfconst_shape_input));
auto rank = tfconst_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -124,10 +118,7 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
else if (constgen_shape_input)
{
// ditto
- if (!(is_fixed_shape_input(constgen_shape_input)))
- {
- throw oops::UserExn("Supports only fixed reshape", node->name());
- }
+ assert(is_fixed_shape_input(constgen_shape_input));
auto rank = constgen_shape_input->dim(0).value();
fixed_reshape->rank(rank);
@@ -139,7 +130,7 @@ bool canonicalize_reshape(loco::Graph *graph, moco::TFReshape *node)
else
{
// TODO support dynamic reshape from not const node
- throw oops::UserExn("Supports only const node as input shape", node->name());
+ throw std::runtime_error("ReshapeCanonicalizer: only support const node as input shape");
}
// replace
@@ -160,9 +151,25 @@ namespace moco
namespace tf
{
-bool ReshapeCanonicalizer::transform(TFReshape *node) const
+bool ReshapeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_reshape(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_reshape = dynamic_cast<moco::tf::TFReshape *>(node);
+ if (tf_reshape != nullptr)
+ {
+ if (canonicalize_reshape(graph, tf_reshape))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
index 1a792024e..c9deee7a4 100644
--- a/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/ReshapeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RESHAPE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFReshape to Canonical Reshape
*/
-class ReshapeCanonicalizer : public SimpleNodeTransform<moco::TFReshape>
+class ReshapeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "ReshapeCanonicalizer"; }
public:
- bool transform(moco::TFReshape *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
index c31dbf6d6..b4fbcac3c 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.cpp
@@ -16,25 +16,29 @@
#include "RsqrtCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
#include <loco/Service/TypeInference.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
namespace
{
template <typename T>
-bool prepare_const_gen(loco::ConstGen *const_node, const loco::TensorShape &tensorshape, T value);
+void prepare_const_gen(loco::ConstGen *const_node, const moco::tf::ShapeInferenceData *shapedata,
+ T value);
template <>
-bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShape &tensorshape,
- float value)
+void prepare_const_gen<float>(loco::ConstGen *const_node,
+ const moco::tf::ShapeInferenceData *shapedata, float value)
{
LOGGER(l);
@@ -43,18 +47,18 @@ bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShap
auto dtype = loco::DataType::FLOAT32;
const_node->dtype(dtype);
- auto rank = tensorshape.rank();
+ auto rank = shapedata->rank();
const_node->rank(rank);
for (uint32_t r = 0; r < rank; ++r)
{
- if (tensorshape.dim(r).known())
- const_node->dim(r) = tensorshape.dim(r);
+ if (shapedata->dim(r).known())
+ const_node->dim(r) = shapedata->dim(r);
else
- return false;
+ throw std::runtime_error("Cannot handle unknown shape");
- assert(tensorshape.dim(r).value() > 0);
+ assert(shapedata->dim(r).value() > 0);
- const_num_elements *= tensorshape.dim(r).value();
+ const_num_elements *= shapedata->dim(r).value();
}
INFO(l) << "prepare_const_gen : Elements = " << const_num_elements;
@@ -64,11 +68,9 @@ bool prepare_const_gen<float>(loco::ConstGen *const_node, const loco::TensorShap
{
const_node->at<loco::DataType::FLOAT32>(i) = value;
}
-
- return true;
}
-bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
+bool canonicalize_rsqrt(loco::Graph *graph, moco::tf::TFRsqrt *node)
{
/**
* @note This will replace TFRsqrt node with Canonical EltwiseSqrt + EltwiseRealDiv
@@ -89,14 +91,13 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
* TFRsqrt is converted to 1 / EltwiseSqrt
*/
- auto nodeshape = moco::node_shape(node);
- if (nodeshape.domain() == loco::Domain::Unknown)
+ auto rsqrt_shapedata = node->annot<moco::tf::ShapeInferenceData>();
+ if (rsqrt_shapedata == nullptr)
{
// We need this shape information
assert(false); // this shouldn't happen, let's add an alarm
return false;
}
- auto tensorshape = nodeshape.as<loco::TensorShape>();
if (!loco::dtype_known(node))
{
@@ -113,12 +114,11 @@ bool canonicalize_rsqrt(loco::Graph *graph, moco::TFRsqrt *node)
switch (dtype)
{
case loco::DataType::FLOAT32:
- if (!prepare_const_gen<float>(const_node, tensorshape, 1.0f))
- throw oops::UserExn("Cannot handle unknown shape", node->name());
+ prepare_const_gen<float>(const_node, rsqrt_shapedata, 1.0f);
break;
default:
- throw oops::UserExn("Unsupported data type", node->name());
+ throw std::runtime_error("NYI for this DataType");
}
auto node_A = node->x();
@@ -141,9 +141,25 @@ namespace moco
namespace tf
{
-bool RsqrtCanonicalizer::transform(TFRsqrt *node) const
+bool RsqrtCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_rsqrt(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFRsqrt *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_rsqrt(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
index 7fd4ff697..a58c0adcb 100644
--- a/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/RsqrtCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_RSQRT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFRsqrt to Canonical EltwiseDiv + EltwiseSqrt
*/
-class RsqrtCanonicalizer : public SimpleNodeTransform<moco::TFRsqrt>
+class RsqrtCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "RsqrtCanonicalizer"; }
public:
- bool transform(moco::TFRsqrt *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
index 98af7b693..3b5043fa7 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.cpp
@@ -16,15 +16,19 @@
#include "SoftmaxCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
+bool canonicalize_softmax(loco::Graph *graph, moco::tf::TFSoftmax *node)
{
LOGGER(l);
@@ -42,11 +46,12 @@ bool canonicalize_softmax(loco::Graph *graph, moco::TFSoftmax *node)
* In ---- TensorSoftmax ----- Out(s)
*/
- auto nodeshape = moco::node_shape(node);
+ auto softmax_shape = node->annot<moco::tf::ShapeInferenceData>();
+
// Canonicalization into TensorSoftmax is valid when softmax has shape info
- assert(nodeshape.domain() != loco::Domain::Unknown);
+ assert(softmax_shape);
- auto softmax_tensor_shape = nodeshape.as<loco::TensorShape>();
+ auto softmax_tensor_shape = softmax_shape->tensor_shape();
// Create loco node to replace
auto softmax = graph->nodes()->create<loco::TensorSoftmax>();
@@ -69,9 +74,25 @@ namespace moco
namespace tf
{
-bool SoftmaxCanonicalizer::transform(TFSoftmax *node) const
+bool SoftmaxCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_softmax(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_softmax = dynamic_cast<moco::tf::TFSoftmax *>(node);
+ if (tf_softmax != nullptr)
+ {
+ if (canonicalize_softmax(graph, tf_softmax))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
index ebaf04cfe..6debf4194 100644
--- a/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SoftmaxCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SOFTMAx_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFSoftmax into canonical Softmax node
*/
-class SoftmaxCanonicalizer : public SimpleNodeTransform<moco::TFSoftmax>
+class SoftmaxCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SoftmaxCanonicalizer"; }
public:
- bool transform(moco::TFSoftmax *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
index 89b9b8a44..347265121 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.cpp
@@ -16,12 +16,15 @@
#include "SqrtCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
namespace
{
-bool canonicalize_sqrt(loco::Graph *graph, moco::TFSqrt *node)
+bool canonicalize_sqrt(loco::Graph *graph, moco::tf::TFSqrt *node)
{
/**
* @note This will replace TFSqrt node with Canonical EltwiseSqrt
@@ -59,9 +62,25 @@ namespace moco
namespace tf
{
-bool SqrtCanonicalizer::transform(TFSqrt *node) const
+bool SqrtCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_sqrt(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSqrt *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_sqrt(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
index 3f7ffead8..b4e6da09a 100644
--- a/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqrtCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SQRT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFsqrt to Canonical EltwiseSqrt
*/
-class SqrtCanonicalizer : public SimpleNodeTransform<moco::TFSqrt>
+class SqrtCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SqrtCanonicalizer"; }
public:
- bool transform(moco::TFSqrt *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
new file mode 100644
index 000000000..4eb7a7217
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SquaredDifferenceCanonicalizer.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
+
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool canonicalize_sqdiff(loco::Graph *graph, moco::tf::TFSquaredDifference *node)
+{
+ /**
+ * @note This will replace TFSquaredDifference node with Canonical EltwiseSub and EltwiseMul
+ *
+ * Before
+ * A --- TFSquaredDifference -- C
+ * B --/
+ * After
+ * A --- TFSquaredDifference --
+ * B --/
+ * A --- EltwiseSub == EltwiseMul -- C
+ * B --/
+ * Where
+ * A : x of TFSquaredDifference
+ * B : y of TFSquaredDifference
+ * C : a node that uses TFSquaredDifference as an input
+ * TFSquaredDifference is disconnected from C
+ * A and B are drawn multiple times to simplify the diagram
+ */
+
+ auto node_A = node->x();
+ auto node_B = node->y();
+
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
+
+ if (!(x_shape == y_shape))
+ {
+ // TODO support broadcast
+ return false;
+ }
+
+ auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
+ auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
+
+ // update connections
+ sub_node->lhs(node_A);
+ sub_node->rhs(node_B);
+ mul_node->lhs(sub_node);
+ mul_node->rhs(sub_node);
+
+ // replace node
+ replace(node).with(mul_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool SquaredDifferenceCanonicalizer::run(loco::Graph *graph)
+{
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSquaredDifference *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_sqdiff(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h
new file mode 100644
index 000000000..afd65be32
--- /dev/null
+++ b/compiler/moco-tf/src/Canonicalization/SquaredDifferenceCanonicalizer.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
+#define __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
+
+#include "Transform.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Convert TFSquaredDifference to Canonical EltwiseSub and EltwiseMul
+ */
+class SquaredDifferenceCanonicalizer final : public Transform
+{
+public:
+ const char *name(void) const final { return "SquaredDifferenceCanonicalizer"; }
+
+public:
+ bool run(loco::Graph *graph) final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_SQUAREDDIFFERENCE_CANONICALIZER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
index f5b991206..a3fcc3b47 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.cpp
@@ -16,15 +16,19 @@
#include "SqueezeCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/Support/TFShapeInferenceHelper.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
+bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::tf::TFSqueeze *node)
{
LOGGER(l);
@@ -42,12 +46,12 @@ bool canonicalize_squeeze_to_reshape(loco::Graph *graph, moco::TFSqueeze *node)
* In ---- FixedReshape ----- Out(s)
*/
- auto nodeshape = moco::node_shape(node);
+ auto squeeze_shape = node->annot<moco::tf::ShapeInferenceData>();
// canonicalize into FixedReshape is valid when squeeze has shape info
// TODO Support general Squeeze case
- assert(nodeshape.domain() != loco::Domain::Unknown);
+ assert(squeeze_shape);
- auto squeeze_tensor_shape = nodeshape.as<loco::TensorShape>();
+ auto squeeze_tensor_shape = squeeze_shape->tensor_shape();
// Create loco node to replace
auto reshape = graph->nodes()->create<loco::FixedReshape>();
@@ -77,9 +81,25 @@ namespace moco
namespace tf
{
-bool SqueezeCanonicalizer::transform(TFSqueeze *node) const
+bool SqueezeCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_squeeze_to_reshape(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_squeeze = dynamic_cast<moco::tf::TFSqueeze *>(node);
+ if (tf_squeeze != nullptr)
+ {
+ if (canonicalize_squeeze_to_reshape(graph, tf_squeeze))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
index 28a1442bd..dc5b2d7b1 100644
--- a/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SqueezeCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SQUEEZE_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -34,13 +31,13 @@ namespace tf
*
* @note There is no canonical Squeeze node
*/
-class SqueezeCanonicalizer : public SimpleNodeTransform<moco::TFSqueeze>
+class SqueezeCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SqueezeCanonicalizer"; }
public:
- bool transform(moco::TFSqueeze *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
index 574fa3993..a52af05a5 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "StopGradientCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
namespace
{
-bool canonicalize_stopgradient(loco::Graph *graph, moco::TFStopGradient *node)
+bool canonicalize_stopgradient(loco::Graph *graph, moco::tf::TFStopGradient *node)
{
LOGGER(l);
@@ -62,9 +65,25 @@ namespace moco
namespace tf
{
-bool StopGradientCanonicalizer::transform(TFStopGradient *node) const
+bool StopGradientCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_stopgradient(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_stopgradient = dynamic_cast<moco::tf::TFStopGradient *>(node);
+ if (tf_stopgradient != nullptr)
+ {
+ if (canonicalize_stopgradient(graph, tf_stopgradient))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
index 6a17728a6..a23a801f0 100644
--- a/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/StopGradientCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_STOPGRADIENT_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Canonicalize TF-dialect TFStopGradient into canonical Forward node
*/
-class StopGradientCanonicalizer : public SimpleNodeTransform<moco::TFStopGradient>
+class StopGradientCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "StopGradientCanonicalizer"; }
public:
- bool transform(moco::TFStopGradient *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
index c518b7d64..21f4210eb 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.cpp
@@ -16,7 +16,8 @@
#include "SubCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "TFEltwiseBinaryCanonicalzeHelper.h"
@@ -25,9 +26,25 @@ namespace moco
namespace tf
{
-bool SubCanonicalizer::transform(moco::TFSub *node) const
+bool SubCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_eltwise_binary_node(node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFSub *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_eltwise_binary_node(tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
index f715cc86c..4ab470685 100644
--- a/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/SubCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_SUB_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFSub to Canonical EltwiseSub
*/
-class SubCanonicalizer : public SimpleNodeTransform<moco::TFSub>
+class SubCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "SubCanonicalizer"; }
public:
- bool transform(moco::TFSub *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
deleted file mode 100644
index 081e0e5f9..000000000
--- a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.cpp
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TFPushCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include <stdex/Memory.h>
-
-namespace
-{
-
-bool canonicalize_push(loco::Graph *graph, moco::TFPush *node)
-{
- /**
- * @note This will replace TFRelu node with Canonical ReLU
- *
- * Before
- * A --- TFPush
- * After
- * +- TFPush
- * |
- * A -+- Push
- *
- * Where
- * A : from of TFPush
- * TFPush will have no GraphOutputIndex
- * Push will have GraphOutputIndex that from TFPush
- */
-
- auto push_node = graph->nodes()->create<loco::Push>();
-
- auto node_A = node->from();
-
- // update connections
- push_node->from(node_A);
-
- // update output index
- push_node->index(node->index());
- node->index_reset();
-
- // replace node
- replace(node).with(push_node);
-
- return true;
-}
-
-} // namespace
-
-namespace moco
-{
-namespace tf
-{
-
-bool TFPushCanonicalizer::transform(TFPush *node) const
-{
- return canonicalize_push(node->graph(), node);
-}
-
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
index 3f48a50fc..9b7b073e1 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.cpp
@@ -16,14 +16,17 @@
#include "TanhCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+#include "Dialect/TFNodeVisitor.h"
+#include "Dialect/TFNodeImpl.h"
#include <stdex/Memory.h>
namespace
{
-bool canonicalize_tanh(loco::Graph *graph, moco::TFTanh *node)
+bool canonicalize_tanh(loco::Graph *graph, moco::tf::TFTanh *node)
{
/**
* @note This will replace TFTanh node with Canonical Tanh
@@ -61,9 +64,25 @@ namespace moco
namespace tf
{
-bool TanhCanonicalizer::transform(TFTanh *node) const
+bool TanhCanonicalizer::run(loco::Graph *graph)
{
- return canonicalize_tanh(node->graph(), node);
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+ bool changed = false;
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFTanh *>(node);
+ if (tf_node != nullptr)
+ {
+ if (canonicalize_tanh(graph, tf_node))
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
index af5e79fb5..cf566a4d4 100644
--- a/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
+++ b/compiler/moco-tf/src/Canonicalization/TanhCanonicalizer.h
@@ -18,9 +18,6 @@
#define __MOCO_TF_TANH_CANONICALIZER_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -32,13 +29,13 @@ namespace tf
/**
* @brief Convert TFTanh to Canonical Tanh
*/
-class TanhCanonicalizer : public SimpleNodeTransform<moco::TFTanh>
+class TanhCanonicalizer : public Transform
{
public:
const char *name(void) const final { return "TanhCanonicalizer"; }
public:
- bool transform(moco::TFTanh *) const override;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
diff --git a/compiler/moco-tf/src/Canonicalizer.cpp b/compiler/moco-tf/src/Canonicalizer.cpp
index 04bc7c57a..c705d686b 100644
--- a/compiler/moco-tf/src/Canonicalizer.cpp
+++ b/compiler/moco-tf/src/Canonicalizer.cpp
@@ -27,16 +27,11 @@
#include "Canonicalization/BiasAddCanonicalizer.h"
#include "Canonicalization/ConcatV2Canonicalizer.h"
#include "Canonicalization/ConstCanonicalizer.h"
-#include "Canonicalization/Conv2DBackpropInputCanonicalizer.h"
#include "Canonicalization/Conv2DCanonicalizer.h"
#include "Canonicalization/DepthwiseConv2dNativeCanonicalizer.h"
#include "Canonicalization/IdentityCanonicalizer.h"
-#include "Canonicalization/MaximumCanonicalizer.h"
#include "Canonicalization/MaxPoolCanonicalizer.h"
-#include "Canonicalization/MeanCanonicalizer.h"
#include "Canonicalization/MulCanonicalizer.h"
-#include "Canonicalization/PadCanonicalizer.h"
-#include "Canonicalization/PlaceholderCanonicalizer.h"
#include "Canonicalization/RealDivCanonicalizer.h"
#include "Canonicalization/ReluCanonicalizer.h"
#include "Canonicalization/Relu6Canonicalizer.h"
@@ -44,15 +39,14 @@
#include "Canonicalization/RsqrtCanonicalizer.h"
#include "Canonicalization/SoftmaxCanonicalizer.h"
#include "Canonicalization/SqrtCanonicalizer.h"
+#include "Canonicalization/SquaredDifferenceCanonicalizer.h"
#include "Canonicalization/SqueezeCanonicalizer.h"
#include "Canonicalization/StopGradientCanonicalizer.h"
#include "Canonicalization/SubCanonicalizer.h"
#include "Canonicalization/TanhCanonicalizer.h"
-// For virtual nodes
-#include "Canonicalization/TFPushCanonicalizer.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include <logo/Phase.h>
@@ -71,7 +65,7 @@ bool has_tf_nodes(loco::Graph *g)
auto active_nodes = loco::active_nodes(loco::output_nodes(g));
for (auto node : active_nodes)
{
- if (node->dialect() == moco::TFDialect::get())
+ if (node->dialect() == moco::tf::TFDialect::get())
{
return true;
}
@@ -102,17 +96,12 @@ void Canonicalizer::canonicalize(loco::Graph *g) const
phase.emplace_back(stdex::make_unique<ConcatV2Canonicalizer>());
if (moco::tf::get<moco::tf::Knob::CanonicalizeConst>())
phase.emplace_back(stdex::make_unique<ConstCanonicalizer>());
- phase.emplace_back(stdex::make_unique<Conv2DBackpropInputCanonicalizer>());
if (moco::tf::get<moco::tf::Knob::CanonicalizeConv2D>())
phase.emplace_back(stdex::make_unique<Conv2DCanonicalizer>());
phase.emplace_back(stdex::make_unique<DepthwiseConv2dNativeCanonicalizer>());
phase.emplace_back(stdex::make_unique<IdentityCanonicalizer>());
- phase.emplace_back(stdex::make_unique<MaximumCanonicalizer>());
phase.emplace_back(stdex::make_unique<MaxPoolCanonicalizer>());
- phase.emplace_back(stdex::make_unique<MeanCanonicalizer>());
phase.emplace_back(stdex::make_unique<MulCanonicalizer>());
- phase.emplace_back(stdex::make_unique<PadCanonicalizer>());
- phase.emplace_back(stdex::make_unique<PlaceholderCanonicalizer>());
phase.emplace_back(stdex::make_unique<RealDivCanonicalizer>());
phase.emplace_back(stdex::make_unique<ReluCanonicalizer>());
phase.emplace_back(stdex::make_unique<Relu6Canonicalizer>());
@@ -120,13 +109,11 @@ void Canonicalizer::canonicalize(loco::Graph *g) const
phase.emplace_back(stdex::make_unique<RsqrtCanonicalizer>());
phase.emplace_back(stdex::make_unique<SoftmaxCanonicalizer>());
phase.emplace_back(stdex::make_unique<SqrtCanonicalizer>());
- // NOTE SquaredDifference is handled in ResolveSquaredDifference
+ phase.emplace_back(stdex::make_unique<SquaredDifferenceCanonicalizer>());
phase.emplace_back(stdex::make_unique<SqueezeCanonicalizer>());
phase.emplace_back(stdex::make_unique<StopGradientCanonicalizer>());
phase.emplace_back(stdex::make_unique<SubCanonicalizer>());
phase.emplace_back(stdex::make_unique<TanhCanonicalizer>());
- // For virtual nodes
- phase.emplace_back(stdex::make_unique<TFPushCanonicalizer>());
/* TRANSFORM DECLARATION END */
ProgressReporter prog(g, logo::PhaseStrategy::Restart);
diff --git a/compiler/moco-tf/src/CodecHelper.h b/compiler/moco-tf/src/CodecHelper.h
deleted file mode 100644
index 85e4e2164..000000000
--- a/compiler/moco-tf/src/CodecHelper.h
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __CODEC_HELPER_H__
-#define __CODEC_HELPER_H__
-
-#include <plier/tf/Convert.h>
-#include <stdex/Memory.h>
-
-namespace
-{
-
-using plier::tf::DataLayout;
-
-void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
-{
- auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Height) = 1;
- enc->perm()->axis(loco::FeatureAxis::Width) = 2;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- enc->perm()->axis(loco::FeatureAxis::Count) = 0;
- enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
- enc->perm()->axis(loco::FeatureAxis::Height) = 2;
- enc->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_enc->encoder(std::move(enc));
-}
-
-void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
-{
- auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
-
- if (data_layout == DataLayout::NHWC)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Height) = 1;
- dec->perm()->axis(loco::FeatureAxis::Width) = 2;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
- }
- else if (data_layout == DataLayout::NCHW)
- {
- dec->perm()->axis(loco::FeatureAxis::Count) = 0;
- dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
- dec->perm()->axis(loco::FeatureAxis::Height) = 2;
- dec->perm()->axis(loco::FeatureAxis::Width) = 3;
- }
-
- feature_dec->decoder(std::move(dec));
-}
-
-} // namespace
-
-#endif // __CODEC_HELPER_H__
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp b/compiler/moco-tf/src/Dialect/TFDialect.cpp
index 69eaf7900..730224753 100644
--- a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Dialect/TFDialect.cpp
@@ -14,17 +14,17 @@
* limitations under the License.
*/
-#include "MeanCanonicalizer.h"
-#include "TFReduceCanonicalzeHelper.h"
+#include "TFDialect.h"
namespace moco
{
namespace tf
{
-bool MeanCanonicalizer::transform(moco::TFMean *node) const
+loco::Dialect *TFDialect::get(void)
{
- return canonicalize_reduce_node(node);
+ static TFDialect d;
+ return &d;
}
} // namespace tf
diff --git a/compiler/moco-tf/src/Dialect/TFDialect.h b/compiler/moco-tf/src/Dialect/TFDialect.h
new file mode 100644
index 000000000..9074e18c7
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFDialect.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFDIALECT_H__
+#define __MOCO_TF_DIALECT_TFDIALECT_H__
+
+#include <loco/IR/Dialect.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief A singleton for TensorFlow Dialect
+ */
+class TFDialect final : public loco::Dialect
+{
+private:
+ TFDialect() = default;
+
+public:
+ TFDialect(const TFDialect &) = delete;
+ TFDialect(TFDialect &&) = delete;
+
+public:
+ static loco::Dialect *get(void);
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFDIALECT_H__
diff --git a/compiler/moco-tf/src/Dialect/TFDialect.test.cpp b/compiler/moco-tf/src/Dialect/TFDialect.test.cpp
new file mode 100644
index 000000000..f89eaaf96
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFDialect.test.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFDialectTest, get)
+{
+ auto d = moco::tf::TFDialect::get();
+
+ // get() SHOULD return a valid(non-null) pointer
+ ASSERT_NE(d, nullptr);
+ // The return value SHOULD be stable across multiple invocations
+ ASSERT_EQ(d, moco::tf::TFDialect::get());
+}
diff --git a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp b/compiler/moco-tf/src/Dialect/TFNode.cpp
index 92634d01f..e9fc3149c 100644
--- a/compiler/moco-tf/src/Canonicalization/MaximumCanonicalizer.cpp
+++ b/compiler/moco-tf/src/Dialect/TFNode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,21 +14,15 @@
* limitations under the License.
*/
-#include "MaximumCanonicalizer.h"
-
-#include <moco/IR/TFDialect.h>
-
-#include "TFEltwiseBinaryCanonicalzeHelper.h"
+#include "TFNode.h"
+#include "TFDialect.h"
namespace moco
{
namespace tf
{
-bool MaximumCanonicalizer::transform(moco::TFMaximum *node) const
-{
- return canonicalize_eltwise_binary_node(node);
-}
+const loco::Dialect *TFNode::dialect(void) const { return TFDialect::get(); }
} // namespace tf
} // namespace moco
diff --git a/compiler/moco-tf/src/Dialect/TFNode.h b/compiler/moco-tf/src/Dialect/TFNode.h
new file mode 100644
index 000000000..3cd12af23
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNode.h
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODE_H__
+#define __MOCO_TF_DIALECT_TFNODE_H__
+
+#include "TFNodeDecl.h"
+#include "TFNodeImpl.h"
+
+#endif // __MOCO_TF_DIALECT_TFNODE_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodeDecl.h b/compiler/moco-tf/src/Dialect/TFNodeDecl.h
new file mode 100644
index 000000000..922165b01
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodeDecl.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODE_DECL_H__
+#define __MOCO_TF_DIALECT_TFNODE_DECL_H__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/Dialect.h>
+
+#include "TFOpcode.h"
+#include "TFNodeVisitor.forward.h"
+
+#include <array>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note Below alias may be introduced as separate class
+using TFDataLayout = std::string;
+using TFPadding = std::string;
+
+struct TFNode : public loco::Node
+{
+ virtual ~TFNode() = default;
+
+ const loco::Dialect *dialect(void) const final;
+ virtual TFOpcode opcode(void) const = 0;
+
+ template <typename T> T accept(TFNodeVisitorBase<T> *) const;
+ template <typename T> T accept(TFNodeMutableVisitorBase<T> *);
+};
+
+template <TFOpcode Code> struct TFNodeImpl : public TFNode
+{
+ virtual ~TFNodeImpl() = default;
+
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); }
+ TFOpcode opcode(void) const final { return Code; }
+};
+
+/**
+ * @brief Nodes with the fixed number of inputs
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+ FixedArityNode()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::unique_ptr<loco::Use>{new loco::Use{this}};
+ }
+ }
+
+ virtual ~FixedArityNode() = default;
+
+public:
+ unsigned arity(void) const final { return N; }
+
+ loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+ std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFNODE_DECL_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodeImpl.h b/compiler/moco-tf/src/Dialect/TFNodeImpl.h
new file mode 100644
index 000000000..39e8830b4
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodeImpl.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODE_IMPL_H__
+#define __MOCO_TF_DIALECT_TFNODE_IMPL_H__
+
+#include "TFNodes.h"
+#include "TFNodeVisitor.h"
+
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+template <typename T> T TFNode::accept(TFNodeVisitorBase<T> *v) const
+{
+ switch (this->opcode())
+ {
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+ case TFOpcode::OPCODE: \
+ return v->visit(dynamic_cast<const CLASS *>(this));
+
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+ default:
+ break;
+ }
+
+ throw std::runtime_error{"NYI"};
+}
+
+template <typename T> T TFNode::accept(TFNodeMutableVisitorBase<T> *v)
+{
+ switch (this->opcode())
+ {
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+ case TFOpcode::OPCODE: \
+ return v->visit(dynamic_cast<CLASS *>(this));
+
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+ default:
+ break;
+ }
+
+ throw std::runtime_error{"NYI"};
+}
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFNODE_IMPL_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodeVisitor.forward.h b/compiler/moco-tf/src/Dialect/TFNodeVisitor.forward.h
new file mode 100644
index 000000000..513c0ae1e
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodeVisitor.forward.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODE_VISITOR_FORWARD_H__
+#define __MOCO_TF_DIALECT_TFNODE_VISITOR_FORWARD_H__
+
+namespace moco
+{
+namespace tf
+{
+
+// NOTE These forward declarations SHOULD BE aligned with Node delcarations in
+// "TFNodeVisitor.h"
+template <typename T> struct TFNodeVisitorBase;
+template <typename T> struct TFNodeMutableVisitorBase;
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFNODE_VISITOR_FORWARD_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodeVisitor.h b/compiler/moco-tf/src/Dialect/TFNodeVisitor.h
new file mode 100644
index 000000000..aff9bca4e
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodeVisitor.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODE_VISITOR_H__
+#define __MOCO_TF_DIALECT_TFNODE_VISITOR_H__
+
+#include "TFNodes.h"
+
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * DO NOT use this class. Use TFNodeVisitor instead.
+ */
+template <typename T> struct TFNodeVisitorBase
+{
+ virtual ~TFNodeVisitorBase() = default;
+
+#define TENSORFLOW_NODE(OPCODE, CLASS) virtual T visit(const CLASS *) = 0;
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+};
+
+template <typename T> struct TFNodeVisitor : public TFNodeVisitorBase<T>
+{
+ virtual ~TFNodeVisitor() = default;
+
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+ virtual T visit(const CLASS *node) { return visit(static_cast<const TFNode *>(node)); }
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+
+ /// @brief Default fallback
+ virtual T visit(const TFNode *) { throw std::runtime_error{"Not implemented, yet"}; }
+};
+
+/**
+ * DO NOT use this class. Use TFNodeMutableVisitor instead.
+ */
+template <typename T> struct TFNodeMutableVisitorBase
+{
+ virtual ~TFNodeMutableVisitorBase() = default;
+
+#define TENSORFLOW_NODE(OPCODE, CLASS) virtual T visit(CLASS *) = 0;
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+};
+
+template <typename T> struct TFNodeMutableVisitor : public TFNodeMutableVisitorBase<T>
+{
+ virtual ~TFNodeMutableVisitor() = default;
+
+#define TENSORFLOW_NODE(OPCODE, CLASS) \
+ virtual T visit(CLASS *node) { return visit(static_cast<TFNode *>(node)); }
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+
+ /// @brief Default fallback
+ virtual T visit(TFNode *) { throw std::runtime_error{"Not implemented, yet"}; }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFNODE_VISITOR_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodes.h b/compiler/moco-tf/src/Dialect/TFNodes.h
new file mode 100644
index 000000000..c7c63f0b6
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodes.h
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFNODES_H__
+#define __MOCO_TF_DIALECT_TFNODES_H__
+
+#include "IR/TFAdd.h"
+#include "IR/TFAvgPool.h"
+#include "IR/TFBiasAdd.h"
+#include "IR/TFConcatV2.h"
+#include "IR/TFConst.h"
+#include "IR/TFConv2D.h"
+#include "IR/TFDepthwiseConv2dNative.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "IR/TFIdentity.h"
+#include "IR/TFMaxPool.h"
+#include "IR/TFMean.h"
+#include "IR/TFMul.h"
+#include "IR/TFRealDiv.h"
+#include "IR/TFRelu.h"
+#include "IR/TFRelu6.h"
+#include "IR/TFReshape.h"
+#include "IR/TFRsqrt.h"
+#include "IR/TFShape.h"
+#include "IR/TFSoftmax.h"
+#include "IR/TFSqrt.h"
+#include "IR/TFSquaredDifference.h"
+#include "IR/TFSqueeze.h"
+#include "IR/TFStopGradient.h"
+#include "IR/TFSub.h"
+#include "IR/TFTanh.h"
+
+#endif // __MOCO_TF_DIALECT_TFNODES_H__
diff --git a/compiler/moco-tf/src/Dialect/TFNodes.lst b/compiler/moco-tf/src/Dialect/TFNodes.lst
new file mode 100644
index 000000000..20730bb69
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFNodes.lst
@@ -0,0 +1,34 @@
+#ifndef TENSORFLOW_NODE
+#error "Define TENSORFLOW_NODE"
+#endif // TENSORFLOW_NODE
+
+//
+// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
+//
+
+// TENSORFLOW_NODE(OPCODE, CLASS)
+TENSORFLOW_NODE(Add, TFAdd)
+TENSORFLOW_NODE(AvgPool, TFAvgPool)
+TENSORFLOW_NODE(BiasAdd, TFBiasAdd)
+TENSORFLOW_NODE(ConcatV2, TFConcatV2)
+TENSORFLOW_NODE(Const, TFConst)
+TENSORFLOW_NODE(Conv2D, TFConv2D)
+TENSORFLOW_NODE(DepthwiseConv2dNative, TFDepthwiseConv2dNative)
+TENSORFLOW_NODE(FusedBatchNorm, TFFusedBatchNorm)
+TENSORFLOW_NODE(Identity, TFIdentity)
+TENSORFLOW_NODE(MaxPool, TFMaxPool)
+TENSORFLOW_NODE(Mean, TFMean)
+TENSORFLOW_NODE(Mul, TFMul)
+TENSORFLOW_NODE(RealDiv, TFRealDiv)
+TENSORFLOW_NODE(Relu, TFRelu)
+TENSORFLOW_NODE(Relu6, TFRelu6)
+TENSORFLOW_NODE(Reshape, TFReshape)
+TENSORFLOW_NODE(Rsqrt, TFRsqrt)
+TENSORFLOW_NODE(Shape, TFShape)
+TENSORFLOW_NODE(Softmax, TFSoftmax)
+TENSORFLOW_NODE(Sqrt, TFSqrt)
+TENSORFLOW_NODE(SquaredDifference, TFSquaredDifference)
+TENSORFLOW_NODE(Squeeze, TFSqueeze)
+TENSORFLOW_NODE(StopGradient, TFStopGradient)
+TENSORFLOW_NODE(Sub, TFSub)
+TENSORFLOW_NODE(Tanh, TFTanh)
diff --git a/compiler/moco-tf/src/Dialect/TFOpcode.h b/compiler/moco-tf/src/Dialect/TFOpcode.h
new file mode 100644
index 000000000..13e9ca119
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFOpcode.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_DIALECT_TFOPCODE_H__
+#define __MOCO_TF_DIALECT_TFOPCODE_H__
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief TensorFlow Node Opcode
+ */
+enum class TFOpcode
+{
+#define TENSORFLOW_NODE(OPCODE, CLASS) OPCODE,
+#include "TFNodes.lst"
+#undef TENSORFLOW_NODE
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_DIALECT_TFOPCODE_H__
diff --git a/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp
new file mode 100644
index 000000000..b25ad0c17
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.cpp
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFShapeInferenceRule.h"
+
+#include "TFDialect.h"
+#include "TFNode.h"
+
+#include "Annotations/ShapeInferenceData.h"
+
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+bool TFShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ // handle only TensorFlow dialect
+ return TFDialect::get() == d;
+}
+
+bool TFShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+ assert(node->dialect() == TFDialect::get());
+ assert(dynamic_cast<const TFNode *>(node) != nullptr);
+
+ if (auto shapedata = node->annot<ShapeInferenceData>())
+ {
+ assert(shapedata->domain() == loco::Domain::Tensor);
+
+ shape.set(shapedata->tensor_shape());
+
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.h b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.h
new file mode 100644
index 000000000..84b3b47c6
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFShapeInferenceRule.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_SHAPE_INFERENCE_RULE_H__
+#define __MOCO_TF_SHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInferenceRule.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Shape inference rule for TensorFlow dialect
+ */
+struct TFShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::NodeShape &) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_SHAPE_INFERENCE_RULE_H__
diff --git a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp
new file mode 100644
index 000000000..8525768db
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFTypeInferenceRule.h"
+
+#include "TFDialect.h"
+#include "TFNodeVisitor.h"
+#include "TFNodes.h"
+
+#include "TFNodeImpl.h"
+
+#include <cassert>
+
+namespace
+{
+
+using namespace moco::tf;
+
+struct TypeForwardAlgorithm final : public moco::tf::TFNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->values(0)); }
+
+ loco::DataType visit(const TFConst *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFConv2D *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFDepthwiseConv2dNative *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFFusedBatchNorm *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFIdentity *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFMean *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFRelu6 *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFReshape *node) { return dtype_get(node->tensor()); }
+ loco::DataType visit(const TFRsqrt *node) { return dtype_get(node->x()); }
+
+ loco::DataType visit(const TFShape *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFSoftmax *node) { return dtype_get(node->logits()); }
+ loco::DataType visit(const TFSqrt *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFSquaredDifference *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFSqueeze *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFStopGradient *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFSub *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFTanh *node) { return dtype_get(node->x()); }
+};
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool TFTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ // This rule recognizes only "TFDialect" dialect!
+ return TFDialect::get() == d;
+}
+
+bool TFTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == TFDialect::get());
+
+ TypeForwardAlgorithm alg;
+
+// clang-format off
+#define TENSORFLOW_NODE(OPCODE,CLASS) \
+ if (dynamic_cast<const moco::tf::CLASS *>(node)) \
+ { \
+ auto tfnode = dynamic_cast<const moco::tf::CLASS *>(node); \
+ dtype = tfnode->accept(&alg); \
+ assert(dtype != loco::DataType::Unknown); \
+ return true; \
+ }
+#include "Dialect/TFNodes.lst"
+#undef TENSORFLOW_NODE
+ // clang-format on
+
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h
new file mode 100644
index 000000000..3e6a64712
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/TFTypeInferenceRule.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_TYPE_INFERENCE_RULE_H__
+#define __MOCO_TF_TYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Type Inference Rule for TFDialect
+ */
+struct TFTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_TYPE_INFERENCE_RULE_H__
diff --git a/compiler/moco-tf/src/Dialect/VariadicArityNode.h b/compiler/moco-tf/src/Dialect/VariadicArityNode.h
new file mode 100644
index 000000000..407b8314c
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/VariadicArityNode.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TF_DIALECT_VARIADIC_ARITY_NODE_H__
+#define __TF_DIALECT_VARIADIC_ARITY_NODE_H__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/Use.h>
+
+#include <vector>
+#include <memory>
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Nodes with the variadic inputs
+ */
+template <typename Base> class VariadicArityNode : public Base
+{
+public:
+ VariadicArityNode(uint32_t arity)
+ {
+ for (uint32_t n = 0; n < arity; ++n)
+ {
+ _args.emplace_back(std::move(std::unique_ptr<loco::Use>{new loco::Use{this}}));
+ }
+ };
+
+ virtual ~VariadicArityNode() = default;
+
+public:
+ uint32_t arity(void) const final { return _args.size(); }
+
+ loco::Node *arg(uint32_t n) const final
+ {
+ assert(n < _args.size());
+ return _args.at(n)->node();
+ }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < _args.size(); ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(uint32_t n) const
+ {
+ assert(n < _args.size());
+ return _args.at(n).get();
+ }
+
+private:
+ std::vector<std::unique_ptr<loco::Use>> _args;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __TF_DIALECT_VARIADIC_ARITY_NODE_H__
diff --git a/compiler/moco-tf/src/Dialect/VariadicArityNode.test.cpp b/compiler/moco-tf/src/Dialect/VariadicArityNode.test.cpp
new file mode 100644
index 000000000..0b2d8795e
--- /dev/null
+++ b/compiler/moco-tf/src/Dialect/VariadicArityNode.test.cpp
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "VariadicArityNode.h"
+
+#include <loco/IR/Nodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class ArbitraryInputNode : public VariadicArityNode<loco::Node>
+{
+public:
+ ArbitraryInputNode(uint32_t arity) : VariadicArityNode<loco::Node>(arity) {}
+
+ void input(uint32_t idx, loco::Node *node) { at(idx)->node(node); }
+ loco::Node *input(uint32_t idx) const { return at(idx)->node(); }
+
+ const loco::Dialect *dialect(void) const { return nullptr; } // this won't be called for testing
+ uint32_t opnum(void) const { return -1; } // this won't be called for testing
+};
+
+} // namespace
+
+TEST(CustomOpTest, VariadicArityNode_arity_n)
+{
+ loco::ConstGen cg0, cg1, cg2;
+
+ ArbitraryInputNode a_node(3);
+ a_node.input(0, &cg0);
+ a_node.input(1, &cg1);
+ a_node.input(2, &cg2);
+
+ ASSERT_EQ(a_node.arity(), 3);
+ ASSERT_EQ(a_node.input(0), &cg0);
+ ASSERT_EQ(a_node.input(1), &cg1);
+ ASSERT_EQ(a_node.input(2), &cg2);
+}
diff --git a/compiler/moco-tf/src/Frontend.cpp b/compiler/moco-tf/src/Frontend.cpp
index a17d5dd0e..e76580785 100644
--- a/compiler/moco-tf/src/Frontend.cpp
+++ b/compiler/moco-tf/src/Frontend.cpp
@@ -15,13 +15,10 @@
*/
#include <moco/tf/Frontend.h>
-#include <moco/Importer.h>
-#include <moco/IR/TFNode.h>
#include <moco/Log.h>
-#include <moco/Import/GraphBuilderRegistry.h>
-
#include "Canonicalizer.h"
+#include "Importer.h"
#include "Optimizer.h"
#include "TFOptimizer.h"
@@ -31,8 +28,8 @@
#include <loco/Service/ShapeInference.h>
+#include <cwrap/Fildes.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
@@ -40,7 +37,6 @@
#include <iostream>
#include <sstream>
-#include <fstream>
#include <stdexcept>
#include <fcntl.h>
@@ -49,6 +45,13 @@
namespace
{
+bool load_text(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+ google::protobuf::io::FileInputStream fis(fildes.get());
+
+ return google::protobuf::TextFormat::Parse(&fis, &graph_def);
+}
+
bool load_text(std::istream *stream, tensorflow::GraphDef &graph_def)
{
google::protobuf::io::IstreamInputStream iis(stream);
@@ -56,6 +59,14 @@ bool load_text(std::istream *stream, tensorflow::GraphDef &graph_def)
return google::protobuf::TextFormat::Parse(&iis, &graph_def);
}
+bool load_binary(const cwrap::Fildes &fildes, tensorflow::GraphDef &graph_def)
+{
+ google::protobuf::io::FileInputStream fis(fildes.get());
+ google::protobuf::io::CodedInputStream cis(&fis);
+
+ return graph_def.ParseFromCodedStream(&cis);
+}
+
bool load_binary(std::istream *stream, tensorflow::GraphDef &graph_def)
{
google::protobuf::io::IstreamInputStream iis(stream);
@@ -64,94 +75,42 @@ bool load_binary(std::istream *stream, tensorflow::GraphDef &graph_def)
return graph_def.ParseFromCodedStream(&cis);
}
-void load_tf(std::istream *stream, moco::tf::Frontend::FileType type,
+void load_tf(const std::string &path, moco::tf::Frontend::FileType type,
tensorflow::GraphDef &graph_def)
{
- bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(stream, graph_def)
- : load_binary(stream, graph_def);
- if (!result)
+ cwrap::Fildes fildes{open(path.c_str(), O_RDONLY)};
+
+ if (fildes.get() < 0)
{
- throw oops::UserExn("Failed to parse prototxt from stream");
+ throw std::runtime_error{"Error: " + path + " not found"};
}
-}
-// If Placeholder has no shape attribute, set unknown_rank property to true.
-void set_unknown_rank(tensorflow::GraphDef &tf_graph_def)
-{
- for (auto &n : *tf_graph_def.mutable_node())
+ bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(fildes, graph_def)
+ : load_binary(fildes, graph_def);
+ if (!result)
{
- if (n.op().compare("Placeholder"))
- continue;
-
- auto iter = n.attr().find("shape");
- if (iter == n.attr().end())
- {
- tensorflow::AttrValue attr;
- attr.mutable_shape()->set_unknown_rank(true);
- n.mutable_attr()->insert({"shape", attr});
- }
+ throw std::runtime_error{"Error: Failed to parse prototxt " + path};
}
}
-/**
- * @brief Set input shape according to signature if node has unknown shape in GraphDef.
- *
- * @note If shape you provided is wrong or not enough, it returns false.
- */
-bool set_input_shape(const moco::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def)
+void load_tf(std::istream *stream, moco::tf::Frontend::FileType type,
+ tensorflow::GraphDef &graph_def)
{
- for (auto &n : *tf_graph_def.mutable_node())
+ bool result = (type == moco::tf::Frontend::FileType::Text) ? load_text(stream, graph_def)
+ : load_binary(stream, graph_def);
+ if (!result)
{
- if (n.op().compare("Placeholder"))
- continue;
-
- auto node_shape = n.mutable_attr()->at("shape").mutable_shape();
- auto sig_shape = signature.shape(n.name() + ":0");
-
- if (node_shape->unknown_rank() || !node_shape->dim_size())
- {
- // If shape in GraphDef is unknown, user must provide the shape info.
- if (sig_shape == nullptr)
- return false;
- node_shape->clear_unknown_rank();
- for (uint32_t i = 0; i < sig_shape->rank(); i++)
- node_shape->add_dim()->set_size(-1);
- }
-
- for (uint32_t d = 0; d < node_shape->dim_size(); d++)
- {
- if (node_shape->mutable_dim(d)->size() == -1)
- {
- if (sig_shape == nullptr)
- return false;
- node_shape->mutable_dim(d)->set_size(sig_shape->dim(d));
- }
- else
- {
- // If User provide shape info though it already exists in GraphDef, make sure it matches
- // the shape of GraphDef.
- if (sig_shape && node_shape->dim(d).size() != sig_shape->dim(d))
- return false;
- }
- }
+ throw std::runtime_error{"Error: Failed to parse prototxt from stream"};
}
- return true;
-}
-
-void transform_tf(const moco::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def)
-{
- set_unknown_rank(tf_graph_def);
- if (!set_input_shape(signature, tf_graph_def))
- oops::UserExn("Info you provided may be wrong or not enough. Please check the info file.");
}
/**
* @brief Returns GraphBuilderRegistry that looks up default registry and additions
* such as custom op
*/
-moco::GraphBuilderRegistry make_graph_builder_registry(const moco::ModelSignature &sig)
+moco::tf::GraphBuilderRegistry make_graph_builder_registry(const moco::tf::ModelSignature &sig)
{
- moco::GraphBuilderRegistry registry{&moco::GraphBuilderRegistry::get()};
+ moco::tf::GraphBuilderRegistry registry{&moco::tf::GraphBuilderRegistry::get()};
// build a COpCallGraphBuilder per custom op type
for (const auto &custom_op : sig.customops())
@@ -185,6 +144,49 @@ namespace moco
namespace tf
{
+void ModelSignature::add_customop(const std::string &op)
+{
+ if (std::find(_customops.begin(), _customops.end(), op) == _customops.end())
+ _customops.emplace_back(op);
+ else
+ throw std::runtime_error{"Duplicated custom op: " + op};
+}
+
+void ModelSignature::shape(const std::string &node_name,
+ const nncc::core::ADT::tensor::Shape &shape)
+{
+ if (_shapes.find(node_name) != _shapes.end())
+ throw std::runtime_error{"Duplicated node name: " + node_name};
+
+ _shapes[node_name] = shape;
+}
+
+const nncc::core::ADT::tensor::Shape *ModelSignature::shape(const std::string &node_name) const
+{
+ auto res = _shapes.find(node_name);
+ if (res == _shapes.end())
+ return nullptr;
+ else
+ return &res->second;
+}
+
+void ModelSignature::dtype(const std::string &node_name, loco::DataType dtype)
+{
+ if (_dtypes.find(node_name) != _dtypes.end())
+ throw std::runtime_error{"Duplicated node name: " + node_name};
+
+ _dtypes[node_name] = dtype;
+}
+
+loco::DataType ModelSignature::dtype(const std::string &node_name) const
+{
+ auto res = _dtypes.find(node_name);
+ if (res == _dtypes.end())
+ return loco::DataType::Unknown;
+ else
+ return res->second;
+}
+
Frontend::Frontend()
{
// DO NOTHING
@@ -193,9 +195,13 @@ Frontend::Frontend()
std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, const char *modelfile,
FileType type) const
{
- // Using c++ standard library, rather than file descriptor, makes these lines portable
- std::ifstream ifs{modelfile, std::ios::in | std::ios::binary};
- return load(signature, &ifs, type);
+ tensorflow::GraphDef tf_graph_def;
+
+ load_tf(modelfile, type, tf_graph_def);
+
+ auto graph = import(signature, tf_graph_def);
+
+ return std::move(graph);
}
std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, std::istream *stream,
@@ -205,13 +211,25 @@ std::unique_ptr<loco::Graph> Frontend::load(const ModelSignature &signature, std
load_tf(stream, type, tf_graph_def);
- transform_tf(signature, tf_graph_def);
-
auto graph = import(signature, tf_graph_def);
return std::move(graph);
}
+void cleanup(loco::Graph *graph)
+{
+ std::vector<std::unique_ptr<moco::tf::Transform>> finalize;
+
+ finalize.emplace_back(stdex::make_unique<moco::tf::ClearAnnotTransform>());
+ // TODO add more cleanup transformations
+
+ // Run finalize to cleanup temporary annotations
+ for (auto &tr : finalize)
+ {
+ tr->run(graph);
+ }
+}
+
std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature,
tensorflow::GraphDef &tf_graph_def) const
{
@@ -241,7 +259,7 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature,
for (uint32_t n = 0; n < graph->inputs()->size(); ++n)
{
auto input = graph->inputs()->at(n);
- auto input_node = moco::placeholder_node(graph.get(), n);
+ auto input_node = loco::pull_node(graph.get(), n);
assert(input_node != nullptr);
input->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(input_node)));
}
@@ -249,9 +267,9 @@ std::unique_ptr<loco::Graph> Frontend::import(const ModelSignature &signature,
for (uint32_t n = 0; n < graph->outputs()->size(); ++n)
{
auto output = graph->outputs()->at(n);
- auto output_node = moco::push_node(graph.get(), n);
+ auto output_node = loco::push_node(graph.get(), n);
assert(output_node != nullptr);
- output->shape(stdex::make_unique<loco::TensorShape>(::tensor_shape(output_node)));
+ output->shape(stdex::make_unique<loco::TensorShape>(tensor_shape(output_node)));
}
// Convert graph to hold only Canonical dialect
diff --git a/compiler/moco-tf/src/Frontend.test.cpp b/compiler/moco-tf/src/Frontend.test.cpp
index c665bd9e3..57a9eb7f7 100644
--- a/compiler/moco-tf/src/Frontend.test.cpp
+++ b/compiler/moco-tf/src/Frontend.test.cpp
@@ -60,11 +60,10 @@ node {
TEST(FrontendTests, testcase_000)
{
moco::tf::Frontend frontend;
- moco::ModelSignature signature;
+ moco::tf::ModelSignature signature;
- signature.add_input(moco::TensorName("Placeholder", 0));
- signature.shape("Placeholder:0", angkor::TensorShape{4});
- signature.add_output(moco::TensorName("Identity", 0));
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("Identity", 0));
std::stringstream ss{pbtxt_000};
diff --git a/compiler/moco-tf/src/GraphBuilder.h b/compiler/moco-tf/src/GraphBuilder.h
new file mode 100644
index 000000000..b18bd2716
--- /dev/null
+++ b/compiler/moco-tf/src/GraphBuilder.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_BUILDER_H__
+#define __GRAPH_BUILDER_H__
+
+#include "GraphBuilderContext.h"
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Interface of convert TF NodeDef to loco::Node (e.g., Conv2DGraphBuilder)
+ */
+class GraphBuilder
+{
+public:
+ virtual bool validate(const tensorflow::NodeDef &) const = 0;
+ virtual void build(const tensorflow::NodeDef &, GraphBuilderContext *) const = 0;
+ virtual ~GraphBuilder() {}
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __GRAPH_BUILDER_H__
diff --git a/compiler/moco-tf/src/GraphBuilderContext.cpp b/compiler/moco-tf/src/GraphBuilderContext.cpp
new file mode 100644
index 000000000..04fb8cd88
--- /dev/null
+++ b/compiler/moco-tf/src/GraphBuilderContext.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilderContext.h"
+
+#include <stdexcept>
+#include <string>
+
+namespace moco
+{
+namespace tf
+{
+
+void NodeDefTable::enroll(const std::string &node_name, const tensorflow::NodeDef *node)
+{
+ MapNameNode_t::iterator iter = _table.find(node_name);
+
+ if (iter != _table.end())
+ {
+ throw std::runtime_error{"Error: Duplicate node name in TensorFlow GraphDef: " + node_name};
+ }
+
+ _table[node_name] = node;
+}
+
+const tensorflow::NodeDef *NodeDefTable::node(const std::string &node_name) const
+{
+ MapNameNode_t::const_iterator iter = _table.find(node_name);
+
+ if (iter == _table.end())
+ {
+ throw std::runtime_error{"Error: Cannot find node with name in TensorFlow GraphDef: " +
+ node_name};
+ }
+
+ return iter->second;
+}
+
+void SymbolTable::enroll(const TensorName &tensor_name, loco::Node *node)
+{
+ MapNameNode_t::iterator iter = _table.find(tensor_name);
+
+ if (iter != _table.end())
+ {
+ throw std::runtime_error{"Error: Duplicate node name in Graph: " + tensor_name.name()};
+ }
+
+ _table[tensor_name] = node;
+}
+
+loco::Node *SymbolTable::node(const TensorName &tensor_name) const
+{
+ MapNameNode_t::const_iterator iter = _table.find(tensor_name);
+
+ if (iter == _table.end())
+ {
+ throw std::runtime_error{"Error: Cannot find node with name in Graph: " + tensor_name.name()};
+ }
+
+ return iter->second;
+}
+
+void UpdateQueue::enroll(std::unique_ptr<GraphUpdate> &&update)
+{
+ _queue.push_back(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/GraphBuilderContext.h b/compiler/moco-tf/src/GraphBuilderContext.h
new file mode 100644
index 000000000..ca474823c
--- /dev/null
+++ b/compiler/moco-tf/src/GraphBuilderContext.h
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPHBUILDERCONTEXT_H__
+#define __GRAPHBUILDERCONTEXT_H__
+
+#include <moco/tf/Names.h>
+
+#include <loco.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Class to store and query tensorflow::NodeDef* with string name key
+ */
+class NodeDefTable
+{
+public:
+ /**
+ * @brief Registers a name with corresponding tensorflow::NodeDef*
+ */
+ void enroll(const std::string &node_name, const tensorflow::NodeDef *node);
+ /**
+ * @brief Queries enrolled(registered) with name and return node if found
+ * Will throw runtime_error if not found
+ */
+ const tensorflow::NodeDef *node(const std::string &node_name) const;
+
+private:
+ using MapNameNode_t = std::map<std::string, const tensorflow::NodeDef *>;
+
+ MapNameNode_t _table;
+};
+
+/**
+ * @brief Class to store and query loco::Node* with string name key
+ */
+class SymbolTable
+{
+public:
+ /**
+ * @brief Registers a name with corresponding loco::Node *
+ */
+ void enroll(const TensorName &tensor_name, loco::Node *node);
+ /**
+ * @brief Queries enrolled(registered) with name and return node if found
+ * Will throw runtime_error if not found
+ */
+ loco::Node *node(const TensorName &tensor_name) const;
+
+private:
+ using MapNameNode_t = std::map<TensorName, loco::Node *, TensorNameCompare>;
+
+ MapNameNode_t _table;
+};
+
+/**
+ * @brief Interface to connect the graph
+ */
+class GraphUpdate
+{
+public:
+ virtual ~GraphUpdate() = default;
+
+public:
+ /**
+ * @brief Do the graph input connections using the SymbolTable
+ */
+ virtual void input(const SymbolTable *) const = 0;
+};
+
+/**
+ * @brief Class to store GraphUpdate objects
+ */
+class UpdateQueue final
+{
+public:
+ /**
+ * @brief Registers GraphUpdate objects
+ */
+ void enroll(std::unique_ptr<GraphUpdate> &&update);
+
+public:
+ using Queue = std::vector<std::unique_ptr<GraphUpdate>>;
+
+ const Queue &queue() const { return _queue; }
+
+private:
+ Queue _queue;
+};
+
+/**
+ * @brief Class to store context to build loco graph IR from TensorFlow
+ */
+class GraphBuilderContext
+{
+public:
+ GraphBuilderContext(loco::Graph *g, NodeDefTable *nodedef, SymbolTable *tensor_names,
+ UpdateQueue *updates)
+ : _g(g), _nodedef(nodedef), _tensor_names(tensor_names), _updates(updates)
+ {
+ // DO NOTHING
+ }
+
+ GraphBuilderContext(const GraphBuilderContext &) = delete;
+ GraphBuilderContext(GraphBuilderContext &&) = delete;
+
+public:
+ loco::Graph *graph() { return _g; }
+ NodeDefTable *nodedef() { return _nodedef; }
+ SymbolTable *tensor_names() { return _tensor_names; }
+ UpdateQueue *updates() { return _updates; }
+
+private:
+ loco::Graph *_g;
+ NodeDefTable *_nodedef;
+ SymbolTable *_tensor_names;
+ UpdateQueue *_updates;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __GRAPHBUILDERCONTEXT_H__
diff --git a/compiler/moco-tf/src/GraphBuilderContext.test.cpp b/compiler/moco-tf/src/GraphBuilderContext.test.cpp
new file mode 100644
index 000000000..03993b281
--- /dev/null
+++ b/compiler/moco-tf/src/GraphBuilderContext.test.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilderContext.h"
+#include <moco/tf/Names.h>
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(GraphBuilderContext, ctor)
+{
+ auto graph = loco::make_graph();
+ moco::tf::NodeDefTable nodedef;
+ moco::tf::SymbolTable nodes;
+ moco::tf::UpdateQueue updates;
+
+ moco::tf::GraphBuilderContext context(graph.get(), &nodedef, &nodes, &updates);
+
+ ASSERT_EQ(context.graph(), graph.get());
+ ASSERT_EQ(context.nodedef(), &nodedef);
+ ASSERT_EQ(context.tensor_names(), &nodes);
+ ASSERT_EQ(context.updates(), &updates);
+}
+
+TEST(SymbolTable, node_name)
+{
+ moco::tf::SymbolTable table;
+ loco::Pull pull_node;
+ moco::tf::TensorName name("input", 0);
+ moco::tf::TensorName invalid("invalid", 0);
+
+ table.enroll(name, &pull_node);
+ ASSERT_EQ(table.node(name), &pull_node);
+ // duplicate name should throw
+ EXPECT_THROW(table.enroll(name, &pull_node), std::runtime_error);
+ // unregistered name should throw
+ EXPECT_THROW(table.node(invalid), std::runtime_error);
+}
+
+namespace
+{
+
+class TestGraphUpdate final : public moco::tf::GraphUpdate
+{
+public:
+ void input(const moco::tf::SymbolTable *) const override;
+};
+
+void TestGraphUpdate::input(const moco::tf::SymbolTable *) const {}
+
+} // namespace
+
+TEST(GraphUpdateQueue, queue)
+{
+ std::unique_ptr<TestGraphUpdate> update(new TestGraphUpdate());
+ moco::tf::UpdateQueue updates;
+
+ updates.enroll(std::move(update));
+ auto &queue = updates.queue();
+ ASSERT_EQ(queue.size(), 1);
+}
diff --git a/compiler/moco-tf/src/GraphBuilderRegistry.h b/compiler/moco-tf/src/GraphBuilderRegistry.h
new file mode 100644
index 000000000..f902ec228
--- /dev/null
+++ b/compiler/moco-tf/src/GraphBuilderRegistry.h
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_BUILDER_REGISTRY_H__
+#define __GRAPH_BUILDER_REGISTRY_H__
+
+#include "GraphBuilder.h"
+
+#include <map>
+#include <memory>
+#include <string>
+
+namespace moco
+{
+namespace tf
+{
+
+struct GraphBuilderSource
+{
+ virtual ~GraphBuilderSource() = default;
+
+ /**
+ * @brief Returns registered GraphBuilder pointer for operator (nullptr if not present)
+ */
+ virtual const GraphBuilder *lookup(const std::string &op) const = 0;
+};
+
+/**
+ * @brief Class to return graph builder for TF nodes
+ */
+class GraphBuilderRegistry final : public GraphBuilderSource
+{
+public:
+ GraphBuilderRegistry() = default;
+
+public:
+ GraphBuilderRegistry(const GraphBuilderSource *parent) : _parent{parent}
+ {
+ // DO NOTHING
+ }
+
+public:
+ /**
+ * @brief Returns registered GraphBuilder pointer for operator or
+ * nullptr if not registered
+ */
+ const GraphBuilder *lookup(const std::string &op) const final
+ {
+ if (_builder_map.find(op) == _builder_map.end())
+ return (_parent == nullptr) ? nullptr : _parent->lookup(op);
+
+ return _builder_map.at(op).get();
+ }
+
+ static GraphBuilderRegistry &get()
+ {
+ static GraphBuilderRegistry me;
+ return me;
+ }
+
+public:
+ void add(const std::string op, std::unique_ptr<GraphBuilder> &&builder)
+ {
+ _builder_map[op] = std::move(builder);
+ }
+
+private:
+ const GraphBuilderSource *_parent = nullptr;
+
+private:
+ std::map<const std::string, std::unique_ptr<GraphBuilder>> _builder_map;
+};
+
+} // namespace tf
+} // namespace mono
+
+#include <stdex/Memory.h>
+
+#define REGISTER_OP_BUILDER(NAME, BUILDER) \
+ namespace \
+ { \
+ __attribute__((constructor)) void reg_op(void) \
+ { \
+ std::unique_ptr<moco::tf::BUILDER> builder = stdex::make_unique<moco::tf::BUILDER>(); \
+ moco::tf::GraphBuilderRegistry::get().add(#NAME, std::move(builder)); \
+ } \
+ }
+
+#endif // __GRAPH_BUILDER_REGISTRY_H__
diff --git a/compiler/moco-tf/src/IR/TFAdd.h b/compiler/moco-tf/src/IR/TFAdd.h
new file mode 100644
index 000000000..d2489fb5c
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFAdd.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFADD_H__
+#define __MOCO_TF_IR_TFADD_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFAdd corresponds to the following GraphDef
+/*
+node {
+ name: "add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Add>>
+{
+public:
+ TFAdd() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+
+ Node *y(void) const { return at(1)->node(); }
+ void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFADD_H__
diff --git a/compiler/moco-tf/src/IR/TFAdd.test.cpp b/compiler/moco-tf/src/IR/TFAdd.test.cpp
new file mode 100644
index 000000000..3134f8610
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFAdd.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFAdd.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFAddTest, constructor)
+{
+ moco::tf::TFAdd add_node;
+
+ ASSERT_EQ(add_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(add_node.opcode(), moco::tf::TFOpcode::Add);
+
+ ASSERT_EQ(add_node.x(), nullptr);
+ ASSERT_EQ(add_node.y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFAvgPool.h b/compiler/moco-tf/src/IR/TFAvgPool.h
new file mode 100644
index 000000000..93a72bb30
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFAvgPool.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFAVGPOOL_H__
+#define __MOCO_TF_IR_TFAVGPOOL_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFAvgPool corresponds to the following GraphDef
+/*
+node {
+ name: "avgpool"
+ op: "AvgPool"
+ input: "placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "ksize"
+ value {
+ list {
+ i: 1 i: 3 i: 3 i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "SAME"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1 i: 1 i: 1 i: 1
+ }
+ }
+ }
+}
+*/
+
+class TFAvgPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::AvgPool>>
+{
+public:
+ TFAvgPool() = default;
+
+public:
+ Node *value(void) const { return at(0)->node(); }
+ void value(Node *node) { return at(0)->node(node); }
+
+public:
+ const TFDataLayout &data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+ const TFPadding &padding(void) const { return _padding; }
+ void padding(const TFPadding &padding) { _padding = padding; }
+
+ const std::vector<int64_t> &ksize(void) const { return _ksize; }
+ void ksize(const std::vector<int64_t> &ksize) { _ksize = ksize; }
+
+ const std::vector<int64_t> &strides(void) const { return _strides; }
+ void strides(const std::vector<int64_t> &strides) { _strides = strides; }
+
+private:
+ TFDataLayout _data_layout;
+ TFPadding _padding;
+ std::vector<int64_t> _ksize;
+ std::vector<int64_t> _strides;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFAVGPOOL_H__
diff --git a/compiler/moco-tf/src/IR/TFAvgPool.test.cpp b/compiler/moco-tf/src/IR/TFAvgPool.test.cpp
new file mode 100644
index 000000000..1e17122fd
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFAvgPool.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFAvgPool.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFAvgPoolTest, constructor)
+{
+ moco::tf::TFAvgPool avgpool;
+
+ ASSERT_EQ(avgpool.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(avgpool.opcode(), moco::tf::TFOpcode::AvgPool);
+
+ ASSERT_EQ(avgpool.value(), nullptr);
+ ASSERT_EQ(avgpool.data_layout(), "");
+ ASSERT_EQ(avgpool.padding(), "");
+ ASSERT_EQ(avgpool.ksize(), std::vector<int64_t>({}));
+ ASSERT_EQ(avgpool.strides(), std::vector<int64_t>({}));
+}
diff --git a/compiler/moco-tf/src/IR/TFBiasAdd.h b/compiler/moco-tf/src/IR/TFBiasAdd.h
new file mode 100644
index 000000000..468e02dad
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFBiasAdd.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFBIASADD_H__
+#define __MOCO_TF_IR_TFBIASADD_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFBiasAdd corresponds to the following GraphDef
+/*
+node {
+ name: "bias_add_01"
+ op: "BiasAdd"
+ input: "input_01"
+ input: "bias_add_01/bias"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+}
+*/
+
+class TFBiasAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::BiasAdd>>
+{
+public:
+ TFBiasAdd() = default;
+
+public:
+ Node *value(void) const { return at(0)->node(); }
+ void value(Node *node) { return at(0)->node(node); }
+
+ Node *bias(void) const { return at(1)->node(); }
+ void bias(Node *node) { return at(1)->node(node); }
+
+ const TFDataLayout data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+private:
+ TFDataLayout _data_layout;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFBIASADD_H__
diff --git a/compiler/moco-tf/src/IR/TFBiasAdd.test.cpp b/compiler/moco-tf/src/IR/TFBiasAdd.test.cpp
new file mode 100644
index 000000000..8fc1cdd6e
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFBiasAdd.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFBiasAdd.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFBiasAddTest, constructor)
+{
+ moco::tf::TFBiasAdd bias_add;
+
+ ASSERT_EQ(bias_add.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(bias_add.opcode(), moco::tf::TFOpcode::BiasAdd);
+
+ ASSERT_EQ(bias_add.value(), nullptr);
+ ASSERT_EQ(bias_add.bias(), nullptr);
+ ASSERT_EQ(bias_add.data_layout(), "");
+}
diff --git a/compiler/moco-tf/src/IR/TFConcatV2.h b/compiler/moco-tf/src/IR/TFConcatV2.h
new file mode 100644
index 000000000..1db44cdb0
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConcatV2.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONCATV2_H__
+#define __MOCO_TF_IR_TFCONCATV2_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include "Dialect/VariadicArityNode.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFConcatV2 corresponds to the following GraphDef
+/*
+node {
+ name: "Concat"
+ op: "ConcatV2"
+ input: "Input01"
+ input: "Input02"
+ input: "Axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+*/
+
+class TFConcatV2 final : public VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>
+{
+public:
+ TFConcatV2(uint32_t arity) : VariadicArityNode<TFNodeImpl<TFOpcode::ConcatV2>>(arity + 1)
+ {
+ // we add +1 for axis of VariadicArityNode ctor
+ // at least one value is required
+ assert(arity >= 1);
+ }
+
+public:
+ uint32_t num_values(void) const
+ {
+ // last one is for axis
+ return arity() - 1;
+ }
+
+public:
+ Node *values(uint32_t index) const
+ {
+ assert(index < num_values());
+ return at(index)->node();
+ }
+ void values(uint32_t index, Node *node)
+ {
+ assert(index < num_values());
+ at(index)->node(node);
+ }
+
+ Node *axis(void) const { return at(num_values())->node(); }
+ void axis(Node *node) { at(num_values())->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONCATV2_H__
diff --git a/compiler/moco-tf/src/IR/TFConcatV2.test.cpp b/compiler/moco-tf/src/IR/TFConcatV2.test.cpp
new file mode 100644
index 000000000..89eac1bce
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConcatV2.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConcatV2.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFConcatV2Test, constructor)
+{
+ moco::tf::TFConcatV2 concatv2_node(3); // num of values
+
+ ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2);
+
+ ASSERT_EQ(concatv2_node.num_values(), 3);
+ ASSERT_EQ(concatv2_node.values(0), nullptr);
+ ASSERT_EQ(concatv2_node.values(1), nullptr);
+ ASSERT_EQ(concatv2_node.values(2), nullptr);
+ ASSERT_EQ(concatv2_node.axis(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFConst.cpp b/compiler/moco-tf/src/IR/TFConst.cpp
new file mode 100644
index 000000000..e59e6644a
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConst.cpp
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConst.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace tf
+{
+
+template <loco::DataType DT> uint32_t TFConst::size(void) const
+{
+ assert(dtype() == DT);
+ assert(_data.size() % sizeof(typename loco::DataTypeImpl<DT>::Type) == 0);
+ return _data.size() / sizeof(typename loco::DataTypeImpl<DT>::Type);
+}
+
+template <loco::DataType DT> void TFConst::size(uint32_t l)
+{
+ assert(dtype() == DT);
+ _data.resize(l * sizeof(typename loco::DataTypeImpl<DT>::Type));
+}
+
+template <loco::DataType DT>
+const typename loco::DataTypeImpl<DT>::Type &TFConst::at(uint32_t n) const
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<const typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &TFConst::at(uint32_t n)
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<typename loco::DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+#define INSTANTIATE(DT) \
+ template uint32_t TFConst::size<DT>(void) const; \
+ template void TFConst::size<DT>(uint32_t); \
+ template const typename loco::DataTypeImpl<DT>::Type &TFConst::at<DT>(uint32_t) const; \
+ template typename loco::DataTypeImpl<DT>::Type &TFConst::at<DT>(uint32_t);
+
+INSTANTIATE(loco::DataType::S32);
+INSTANTIATE(loco::DataType::FLOAT32);
+
+#undef INSTANTIATE
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/IR/TFConst.h b/compiler/moco-tf/src/IR/TFConst.h
new file mode 100644
index 000000000..b63d37db7
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConst.h
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONSTANT_H__
+#define __MOCO_TF_IR_TFCONSTANT_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <loco/IR/DataTypeTraits.h>
+#include <loco/IR/NodeMixins.h>
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFConst corresponds to the following GraphDef
+/*
+node {
+ name: "val"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 3 }
+ dim { size: 4 }
+ dim { size: 4 }
+ }
+ float_val: 2.1
+ }
+ }
+ }
+}
+*/
+
+/**
+ * @brief IR for tf.constant
+ *
+ * @note Implementation for this class came from Canonical ConstGen
+ * Read comments in loco::ConstGen for details
+ */
+class TFConst final : public FixedArityNode<0, TFNodeImpl<TFOpcode::Const>>,
+ public loco::NodeMixin<loco::NodeTrait::DataType>,
+ public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+public:
+ TFConst() = default;
+
+public:
+ template <loco::DataType DT> uint32_t size(void) const;
+ template <loco::DataType DT> void size(uint32_t size);
+
+ template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
+ template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n);
+
+private:
+ std::vector<uint8_t> _data;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONSTANT_H__
diff --git a/compiler/moco-tf/src/IR/TFConst.test.cpp b/compiler/moco-tf/src/IR/TFConst.test.cpp
new file mode 100644
index 000000000..6963122d8
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConst.test.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConst.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFConstantTest, constructor)
+{
+ moco::tf::TFConst constant;
+
+ ASSERT_EQ(constant.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(constant.opcode(), moco::tf::TFOpcode::Const);
+
+ ASSERT_EQ(constant.dtype(), loco::DataType::Unknown);
+ ASSERT_EQ(constant.rank(), 0);
+
+ constant.dtype(loco::DataType::FLOAT32);
+ ASSERT_EQ(constant.dtype(), loco::DataType::FLOAT32);
+
+ constant.rank(2);
+ ASSERT_EQ(constant.rank(), 2);
+
+ constant.dim(0) = 2;
+ constant.dim(1) = 3;
+
+ ASSERT_TRUE(constant.dim(0).known());
+ ASSERT_TRUE(constant.dim(1).known());
+
+ ASSERT_EQ(constant.dim(0), 2);
+ ASSERT_EQ(constant.dim(1), 3);
+
+ constant.size<loco::DataType::FLOAT32>(6);
+
+ ASSERT_EQ(constant.size<loco::DataType::FLOAT32>(), 6);
+
+ constant.at<loco::DataType::FLOAT32>(0) = 0.0f; // Set 0,0
+ constant.at<loco::DataType::FLOAT32>(1) = 1.0f; // Set 0,1
+ constant.at<loco::DataType::FLOAT32>(2) = 2.0f; // Set 0,2
+ constant.at<loco::DataType::FLOAT32>(3) = 3.0f; // Set 1,0
+ constant.at<loco::DataType::FLOAT32>(4) = 4.0f; // Set 1,1
+ constant.at<loco::DataType::FLOAT32>(5) = 5.0f; // Set 1,2
+
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(0), 0.0f);
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(1), 1.0f);
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(2), 2.0f);
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(3), 3.0f);
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(4), 4.0f);
+ ASSERT_EQ(constant.at<loco::DataType::FLOAT32>(5), 5.0f);
+}
diff --git a/compiler/moco-tf/src/IR/TFConv2D.h b/compiler/moco-tf/src/IR/TFConv2D.h
new file mode 100644
index 000000000..f9a9a1218
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConv2D.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONV2D_H__
+#define __MOCO_TF_IR_TFCONV2D_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+class TFConv2D final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Conv2D>>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(Node *node) { at(1)->node(node); }
+
+public:
+ const TFPadding &padding(void) const { return _padding; }
+ void padding(const TFPadding &padding) { _padding = padding; }
+
+ const TFDataLayout &data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+ const std::vector<int64_t> &strides(void) const { return _strides; }
+ void strides(const std::vector<int64_t> &strides) { _strides = strides; }
+
+private:
+ TFPadding _padding;
+ TFDataLayout _data_layout;
+ std::vector<int64_t> _strides;
+ // TODO Support "Dilation"
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONV2D_H__
diff --git a/compiler/moco-tf/src/IR/TFConv2D.test.cpp b/compiler/moco-tf/src/IR/TFConv2D.test.cpp
new file mode 100644
index 000000000..e1fa4ed6d
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConv2D.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConv2D.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFConv2DTest, constructor)
+{
+ moco::tf::TFConv2D conv2d_node;
+
+ ASSERT_EQ(conv2d_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(conv2d_node.opcode(), moco::tf::TFOpcode::Conv2D);
+
+ ASSERT_EQ(conv2d_node.input(), nullptr);
+ ASSERT_EQ(conv2d_node.filter(), nullptr);
+ ASSERT_EQ(conv2d_node.padding(), "");
+ ASSERT_EQ(conv2d_node.data_layout(), "");
+ ASSERT_EQ(conv2d_node.strides().size(), 0);
+}
diff --git a/compiler/moco-tf/src/IR/TFConv2DBackpropInput.h b/compiler/moco-tf/src/IR/TFConv2DBackpropInput.h
new file mode 100644
index 000000000..af78d6abd
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFConv2DBackpropInput.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONV2DBACKPROPINPUT_H__
+#define __MOCO_TF_IR_TFCONV2DBACKPROPINPUT_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFConv2DBackpropInput corresponds to the following GraphDef
+/*
+node {
+ name: "conv2d_backprop_input"
+ op: "Conv2DBackpropInput"
+ input: "input_sizes"
+ input: "filter"
+ input: "out_backprop"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NHWC" }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list { i: 1 i: 1 i: 1 i: 1 }
+ }
+ }
+ attr {
+ key: "padding"
+ value { s: "SAME" }
+ }
+ attr {
+ key: "strides"
+ value {
+ list { i: 1 i: 2 i: 2 i: 1 }
+ }
+ }
+}
+*/
+
+/**
+ * @note For Tensorflow Conv2DBackpropInput, 'input' refers actual output of the
+ * node, and 'input' refers actual input. The reasone of this is, as name
+ * suggests, because it is inspired from backpropagation of convolution.
+ * For example, 'out_backprop' of Conv2DBackpropInput is its actual input
+ * feature map, and 'input_sizes' means desired output node's size.
+ * Note that this convention is against loco canonical's convention.
+ */
+class TFConv2DBackpropInput final
+ : public FixedArityNode<3, TFNodeImpl<TFOpcode::Conv2DBackpropInput>>
+{
+public:
+ loco::Node *input_sizes(void) const { return at(0)->node(); }
+ void input_sizes(Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(Node *node) { at(1)->node(node); }
+
+ loco::Node *out_backprop(void) const { return at(2)->node(); }
+ void out_backprop(Node *node) { at(2)->node(node); }
+
+public:
+ const TFPadding &padding(void) const { return _padding; }
+ void padding(const TFPadding &padding) { _padding = padding; }
+
+ const TFDataLayout &data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+ const std::vector<int64_t> &strides(void) const { return _strides; }
+ void strides(const std::vector<int64_t> &strides) { _strides = strides; }
+
+private:
+ TFPadding _padding;
+ TFDataLayout _data_layout;
+ std::vector<int64_t> _strides;
+ // TODO Support "Dilation"
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONV2DBACKPROPINPUT_H__
diff --git a/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.h b/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.h
new file mode 100644
index 000000000..9ffc79281
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFDEPTHWISECONV2DNATIVE_H__
+#define __MOCO_TF_IR_TFDEPTHWISECONV2DNATIVE_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include "Convert.h"
+
+#include <loco/IR/Stride.h>
+#include <loco/IR/Padding2D.h>
+
+#include <string>
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+class TFDepthwiseConv2dNative final
+ : public FixedArityNode<2, TFNodeImpl<TFOpcode::DepthwiseConv2dNative>>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ loco::Node *filter(void) const { return at(1)->node(); }
+ void filter(Node *node) { at(1)->node(node); }
+
+public:
+ const TFPadding &padding(void) const { return _padding; }
+ void padding(const TFPadding &padding) { _padding = padding; }
+
+ const TFDataLayout &data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+ const std::vector<int64_t> &strides(void) const { return _strides; }
+ void strides(const std::vector<int64_t> &strides) { _strides = strides; }
+
+private:
+ TFPadding _padding;
+ TFDataLayout _data_layout;
+ std::vector<int64_t> _strides;
+ // TODO Support "Dilation"
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFDEPTHWISECONV2DNATIVE_H__
diff --git a/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.test.cpp b/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.test.cpp
new file mode 100644
index 000000000..086145635
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFDepthwiseConv2dNative.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFDepthwiseConv2dNativeTest, constructor)
+{
+ moco::tf::TFDepthwiseConv2dNative depthwiseConv2dnative_node;
+
+ ASSERT_EQ(depthwiseConv2dnative_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(depthwiseConv2dnative_node.opcode(), moco::tf::TFOpcode::DepthwiseConv2dNative);
+
+ ASSERT_EQ(depthwiseConv2dnative_node.input(), nullptr);
+ ASSERT_EQ(depthwiseConv2dnative_node.filter(), nullptr);
+ ASSERT_EQ(depthwiseConv2dnative_node.padding(), "");
+ ASSERT_EQ(depthwiseConv2dnative_node.data_layout(), "");
+ ASSERT_EQ(depthwiseConv2dnative_node.strides().size(), 0);
+}
diff --git a/compiler/moco-tf/src/IR/TFFusedBatchNorm.h b/compiler/moco-tf/src/IR/TFFusedBatchNorm.h
new file mode 100644
index 000000000..297f439a1
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFFusedBatchNorm.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFFUSEDBATCHNORM_H__
+#define __MOCO_TF_IR_TFFUSEDBATCHNORM_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+class TFFusedBatchNorm final : public FixedArityNode<5, TFNodeImpl<TFOpcode::FusedBatchNorm>>
+{
+public:
+ TFFusedBatchNorm() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ Node *gamma(void) const { return at(1)->node(); }
+ void gamma(Node *node) { at(1)->node(node); }
+
+ Node *beta(void) const { return at(2)->node(); }
+ void beta(Node *node) { at(2)->node(node); }
+
+ Node *mean(void) const { return at(3)->node(); }
+ void mean(Node *node) { at(3)->node(node); }
+
+ Node *variance(void) const { return at(4)->node(); }
+ void variance(Node *node) { at(4)->node(node); }
+
+ float epsilon(void) const { return _epsilon; }
+ void epsilon(float epsilon) { _epsilon = epsilon; }
+
+private:
+ float _epsilon = 0.001f;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFFUSEDBATCHNORM_H__
diff --git a/compiler/moco-tf/src/IR/TFFusedBatchNorm.test.cpp b/compiler/moco-tf/src/IR/TFFusedBatchNorm.test.cpp
new file mode 100644
index 000000000..38db8cf33
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFFusedBatchNorm.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFFusedBatchNormTest, constructor)
+{
+ moco::tf::TFFusedBatchNorm fbn_node;
+
+ ASSERT_EQ(fbn_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(fbn_node.opcode(), moco::tf::TFOpcode::FusedBatchNorm);
+
+ ASSERT_EQ(fbn_node.input(), nullptr);
+ ASSERT_EQ(fbn_node.gamma(), nullptr);
+ ASSERT_EQ(fbn_node.beta(), nullptr);
+ ASSERT_EQ(fbn_node.mean(), nullptr);
+ ASSERT_EQ(fbn_node.variance(), nullptr);
+ ASSERT_NE(fbn_node.epsilon(), 0.0f);
+}
diff --git a/compiler/moco-tf/src/IR/TFIdentity.h b/compiler/moco-tf/src/IR/TFIdentity.h
new file mode 100644
index 000000000..9eeab8d11
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFIdentity.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFIDENTITY_H__
+#define __MOCO_TF_IR_TFIDENTITY_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFIdentity corresponds to the following GraphDef
+/*
+node {
+ name: "identity"
+ op: "Identity"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFIdentity final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Identity>>
+{
+public:
+ TFIdentity() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFIDENTITY_H__
diff --git a/compiler/moco-tf/src/IR/TFIdentity.test.cpp b/compiler/moco-tf/src/IR/TFIdentity.test.cpp
new file mode 100644
index 000000000..4ea3e7acf
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFIdentity.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFIdentity.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFIdentituTest, constructor)
+{
+ moco::tf::TFIdentity identity_node;
+
+ ASSERT_EQ(identity_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(identity_node.opcode(), moco::tf::TFOpcode::Identity);
+
+ ASSERT_EQ(identity_node.input(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFMaxPool.h b/compiler/moco-tf/src/IR/TFMaxPool.h
new file mode 100644
index 000000000..14dae7009
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMaxPool.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFMAXPOOL_H__
+#define __MOCO_TF_IR_TFMAXPOOL_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFMaxPool corresponds to the following GraphDef
+/*
+node {
+ name: "maxpool2d"
+ op: "MaxPool"
+ input: "placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "ksize"
+ value {
+ list {
+ i: 1 i: 2 i: 2 i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1 i: 1 i: 1 i: 1
+ }
+ }
+ }
+}
+*/
+
+class TFMaxPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::MaxPool>>
+{
+public:
+ TFMaxPool() = default;
+
+public:
+ Node *value(void) const { return at(0)->node(); }
+ void value(Node *node) { return at(0)->node(node); }
+
+public:
+ const TFDataLayout &data_layout(void) const { return _data_layout; }
+ void data_layout(const TFDataLayout &data_layout) { _data_layout = data_layout; }
+
+ const TFPadding &padding(void) const { return _padding; }
+ void padding(const TFPadding &padding) { _padding = padding; }
+
+ const std::vector<int64_t> &ksize(void) const { return _ksize; }
+ void ksize(const std::vector<int64_t> &ksize) { _ksize = ksize; }
+
+ const std::vector<int64_t> &strides(void) const { return _strides; }
+ void strides(const std::vector<int64_t> &strides) { _strides = strides; }
+
+private:
+ TFDataLayout _data_layout;
+ TFPadding _padding;
+ std::vector<int64_t> _ksize;
+ std::vector<int64_t> _strides;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFMAXPOOL_H__
diff --git a/compiler/moco-tf/src/IR/TFMaxPool.test.cpp b/compiler/moco-tf/src/IR/TFMaxPool.test.cpp
new file mode 100644
index 000000000..b86e21eab
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMaxPool.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMaxPool.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFMaxPoolTest, constructor)
+{
+ moco::tf::TFMaxPool maxpool;
+
+ ASSERT_EQ(maxpool.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(maxpool.opcode(), moco::tf::TFOpcode::MaxPool);
+
+ ASSERT_EQ(maxpool.value(), nullptr);
+ ASSERT_EQ(maxpool.data_layout(), "");
+ ASSERT_EQ(maxpool.padding(), "");
+ ASSERT_EQ(maxpool.ksize(), std::vector<int64_t>({}));
+ ASSERT_EQ(maxpool.strides(), std::vector<int64_t>({}));
+}
diff --git a/compiler/moco-tf/src/IR/TFMean.h b/compiler/moco-tf/src/IR/TFMean.h
new file mode 100644
index 000000000..508887bd0
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMean.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFMEAN_H__
+#define __MOCO_TF_IR_TFMEAN_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+class TFMean final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Mean>>
+{
+public:
+ TFMean() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ Node *reduction_indices(void) const { return at(1)->node(); }
+ void reduction_indices(Node *node) { at(1)->node(node); }
+
+public:
+ bool keep_dims(void) const { return _keep_dims; }
+ void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
+
+private:
+ bool _keep_dims = false;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFMEAN_H__
diff --git a/compiler/moco-tf/src/IR/TFMean.test.cpp b/compiler/moco-tf/src/IR/TFMean.test.cpp
new file mode 100644
index 000000000..3c580c08e
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMean.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMean.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFMeanTest, constructor)
+{
+ moco::tf::TFMean mean_node;
+
+ ASSERT_EQ(mean_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(mean_node.opcode(), moco::tf::TFOpcode::Mean);
+
+ ASSERT_EQ(mean_node.input(), nullptr);
+ ASSERT_EQ(mean_node.reduction_indices(), nullptr);
+ ASSERT_EQ(mean_node.keep_dims(), false);
+}
diff --git a/compiler/moco-tf/src/IR/TFMul.h b/compiler/moco-tf/src/IR/TFMul.h
new file mode 100644
index 000000000..95826f05a
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMul.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFMUL_H__
+#define __MOCO_TF_IR_TFMUL_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFMul corresponds to the following GraphDef
+/*
+node {
+ name: "mul"
+ op: "Mul"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFMul final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
+{
+public:
+ TFMul() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+
+ Node *y(void) const { return at(1)->node(); }
+ void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFMUL_H__
diff --git a/compiler/moco-tf/src/IR/TFMul.test.cpp b/compiler/moco-tf/src/IR/TFMul.test.cpp
new file mode 100644
index 000000000..cc7c5880b
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFMul.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMul.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFMulTest, constructor)
+{
+ moco::tf::TFMul mul_node;
+
+ ASSERT_EQ(mul_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(mul_node.opcode(), moco::tf::TFOpcode::Mul);
+
+ ASSERT_EQ(mul_node.x(), nullptr);
+ ASSERT_EQ(mul_node.y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFRealDiv.h b/compiler/moco-tf/src/IR/TFRealDiv.h
new file mode 100644
index 000000000..8ef37861a
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRealDiv.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFREALDIV_H__
+#define __MOCO_TF_IR_TFREALDIV_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFRealDiv corresponds to the following GraphDef
+/*
+node {
+ name: "div"
+ op: "RealDiv"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFRealDiv final : public FixedArityNode<2, TFNodeImpl<TFOpcode::RealDiv>>
+{
+public:
+ TFRealDiv() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+
+ Node *y(void) const { return at(1)->node(); }
+ void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFREALDIV_H__
diff --git a/compiler/moco-tf/src/IR/TFRealDiv.test.cpp b/compiler/moco-tf/src/IR/TFRealDiv.test.cpp
new file mode 100644
index 000000000..1c7029f87
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRealDiv.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRealDiv.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFRealDivTest, constructor)
+{
+ moco::tf::TFRealDiv div_node;
+
+ ASSERT_EQ(div_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(div_node.opcode(), moco::tf::TFOpcode::RealDiv);
+
+ ASSERT_EQ(div_node.x(), nullptr);
+ ASSERT_EQ(div_node.y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFRelu.h b/compiler/moco-tf/src/IR/TFRelu.h
new file mode 100644
index 000000000..7df958b11
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRelu.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFRELU_H__
+#define __MOCO_TF_IR_TFRELU_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+class TFRelu final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu>>
+{
+public:
+ TFRelu() = default;
+
+public:
+ Node *features(void) const { return at(0)->node(); }
+ void features(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFRELU_H__
diff --git a/compiler/moco-tf/src/IR/TFRelu.test.cpp b/compiler/moco-tf/src/IR/TFRelu.test.cpp
new file mode 100644
index 000000000..776207966
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRelu.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRelu.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFReluTest, constructor)
+{
+ moco::tf::TFRelu relu_node;
+
+ ASSERT_EQ(relu_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(relu_node.opcode(), moco::tf::TFOpcode::Relu);
+
+ ASSERT_EQ(relu_node.features(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFRelu6.h b/compiler/moco-tf/src/IR/TFRelu6.h
new file mode 100644
index 000000000..eba83a9f7
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRelu6.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFRELU6_H__
+#define __MOCO_TF_IR_TFRELU6_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+class TFRelu6 final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu6>>
+{
+public:
+ TFRelu6() = default;
+
+public:
+ Node *features(void) const { return at(0)->node(); }
+ void features(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFRELU6_H__
diff --git a/compiler/moco-tf/src/IR/TFRelu6.test.cpp b/compiler/moco-tf/src/IR/TFRelu6.test.cpp
new file mode 100644
index 000000000..d342ccd5d
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRelu6.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRelu6.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFRelu6Test, constructor)
+{
+ moco::tf::TFRelu6 relu6_node;
+
+ ASSERT_EQ(relu6_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(relu6_node.opcode(), moco::tf::TFOpcode::Relu6);
+
+ ASSERT_EQ(relu6_node.features(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFReshape.h b/compiler/moco-tf/src/IR/TFReshape.h
new file mode 100644
index 000000000..4359a49b5
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFReshape.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFRESHAPE_H__
+#define __MOCO_TF_IR_TFRESHAPE_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFReshape corresponds to the following GraphDef
+/*
+node {
+ name: "reshape"
+ op: "Reshape"
+ input: "tensor"
+ input: "shape"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+}
+*/
+
+class TFReshape final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Reshape>>
+{
+public:
+ TFReshape() = default;
+
+public:
+ Node *tensor(void) const { return at(0)->node(); }
+ void tensor(Node *node) { at(0)->node(node); }
+
+ Node *shape(void) const { return at(1)->node(); }
+ void shape(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFRESHAPE_H__
diff --git a/compiler/moco-tf/src/IR/TFReshape.test.cpp b/compiler/moco-tf/src/IR/TFReshape.test.cpp
new file mode 100644
index 000000000..39d77e4b1
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFReshape.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFReshape.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFReshapeTest, constructor)
+{
+ moco::tf::TFReshape reshape_node;
+
+ ASSERT_EQ(reshape_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(reshape_node.opcode(), moco::tf::TFOpcode::Reshape);
+
+ ASSERT_EQ(reshape_node.tensor(), nullptr);
+ ASSERT_EQ(reshape_node.shape(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFRsqrt.h b/compiler/moco-tf/src/IR/TFRsqrt.h
new file mode 100644
index 000000000..f371e39ab
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRsqrt.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFRSQRT_H__
+#define __MOCO_TF_IR_TFRSQRT_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFRsqrt corresponds to the following GraphDef
+/*
+node {
+ name: "Rsqrt"
+ op: "Rsqrt"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFRsqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Rsqrt>>
+{
+public:
+ TFRsqrt() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFRSQRT_H__
diff --git a/compiler/moco-tf/src/IR/TFRsqrt.test.cpp b/compiler/moco-tf/src/IR/TFRsqrt.test.cpp
new file mode 100644
index 000000000..7f92704ba
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFRsqrt.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRsqrt.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFRsqrtTest, constructor)
+{
+ moco::tf::TFRsqrt rsqrt_node;
+
+ ASSERT_EQ(rsqrt_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(rsqrt_node.opcode(), moco::tf::TFOpcode::Rsqrt);
+
+ ASSERT_EQ(rsqrt_node.x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFShape.h b/compiler/moco-tf/src/IR/TFShape.h
new file mode 100644
index 000000000..d50cabf79
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFShape.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSHAPE_H__
+#define __MOCO_TF_IR_TFSHAPE_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <loco/IR/NodeMixins.h>
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFShape corresponds to the following GraphDef
+/*
+node {
+ name: "Shape"
+ op: "Shape"
+ input: "some_input"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "out_type"
+ value { type: DT_INT32 }
+ }
+}
+*/
+
+/// @note Mixed in dtype() is for 'out_type' attribute
+class TFShape final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Shape>>,
+ public loco::NodeMixin<loco::NodeTrait::DataType>
+{
+public:
+ TFShape() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSHAPE_H__
diff --git a/compiler/moco-tf/src/IR/TFShape.test.cpp b/compiler/moco-tf/src/IR/TFShape.test.cpp
new file mode 100644
index 000000000..6c68888cc
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFShape.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFShape.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFShapeTest, constructor)
+{
+ moco::tf::TFShape shape_node;
+
+ ASSERT_EQ(shape_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(shape_node.opcode(), moco::tf::TFOpcode::Shape);
+
+ ASSERT_EQ(shape_node.input(), nullptr);
+ ASSERT_EQ(shape_node.dtype(), loco::DataType::Unknown);
+}
diff --git a/compiler/moco-tf/src/IR/TFSoftmax.h b/compiler/moco-tf/src/IR/TFSoftmax.h
new file mode 100644
index 000000000..22b7b9eca
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSoftmax.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSOFTMAX_H__
+#define __MOCO_TF_IR_TFSOFTMAX_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+class TFSoftmax final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Softmax>>
+{
+public:
+ TFSoftmax() = default;
+
+public:
+ Node *logits(void) const { return at(0)->node(); }
+ void logits(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSOFTMAX_H__
diff --git a/compiler/moco-tf/src/IR/TFSoftmax.test.cpp b/compiler/moco-tf/src/IR/TFSoftmax.test.cpp
new file mode 100644
index 000000000..99c7cbc3c
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSoftmax.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSoftmax.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFSoftmaxTest, constructor)
+{
+ moco::tf::TFSoftmax softmax_node;
+
+ ASSERT_EQ(softmax_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(softmax_node.opcode(), moco::tf::TFOpcode::Softmax);
+
+ ASSERT_EQ(softmax_node.logits(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFSqrt.h b/compiler/moco-tf/src/IR/TFSqrt.h
new file mode 100644
index 000000000..fda032e2d
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSqrt.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSQRT_H__
+#define __MOCO_TF_IR_TFSQRT_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFSqrt corresponds to the following GraphDef
+/*
+node {
+ name: "Sqrt"
+ op: "Sqrt"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFSqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Sqrt>>
+{
+public:
+ TFSqrt() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSQRT_H__
diff --git a/compiler/moco-tf/src/IR/TFSqrt.test.cpp b/compiler/moco-tf/src/IR/TFSqrt.test.cpp
new file mode 100644
index 000000000..9048d5729
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSqrt.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSqrt.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFSqrtTest, constructor)
+{
+ moco::tf::TFSqrt sqrt_node;
+
+ ASSERT_EQ(sqrt_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(sqrt_node.opcode(), moco::tf::TFOpcode::Sqrt);
+
+ ASSERT_EQ(sqrt_node.x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFSquaredDifference.h b/compiler/moco-tf/src/IR/TFSquaredDifference.h
new file mode 100644
index 000000000..83ecdb86b
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSquaredDifference.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSQUAREDDIFFERENCE_H__
+#define __MOCO_TF_IR_TFSQUAREDDIFFERENCE_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFSquaredDifference corresponds to the following GraphDef
+/*
+node {
+ name: "SquaredDifference"
+ op: "SquaredDifference"
+ input: "input_x"
+ input: "input_y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFSquaredDifference final : public FixedArityNode<2, TFNodeImpl<TFOpcode::SquaredDifference>>
+{
+public:
+ TFSquaredDifference() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+
+ Node *y(void) const { return at(1)->node(); }
+ void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSQUAREDDIFFERENCE_H__
diff --git a/compiler/moco-tf/src/IR/TFSquaredDifference.test.cpp b/compiler/moco-tf/src/IR/TFSquaredDifference.test.cpp
new file mode 100644
index 000000000..f83d28caf
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSquaredDifference.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSquaredDifference.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFSquaredDifferenceTest, constructor)
+{
+ moco::tf::TFSquaredDifference sd_node;
+
+ ASSERT_EQ(sd_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(sd_node.opcode(), moco::tf::TFOpcode::SquaredDifference);
+
+ ASSERT_EQ(sd_node.x(), nullptr);
+ ASSERT_EQ(sd_node.y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFSqueeze.h b/compiler/moco-tf/src/IR/TFSqueeze.h
new file mode 100644
index 000000000..e98644101
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSqueeze.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSQUEEZE_H__
+#define __MOCO_TF_IR_TFSQUEEZE_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFSqueeze corresponds to the following GraphDef
+/*
+node {
+ name: "squeeze"
+ op: "Squeeze"
+ input: "x"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "squeeze_dims"
+ value {
+ list {
+ i: a
+ i: b
+ ..
+ }
+ }
+ }
+}
+*/
+
+class TFSqueeze final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Squeeze>>
+{
+public:
+ TFSqueeze() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+public:
+ const std::vector<int64_t> &squeeze_dims(void) const { return _squeeze_dims; }
+ void squeeze_dims(const std::vector<int64_t> &squeeze_dims) { _squeeze_dims = squeeze_dims; }
+
+private:
+ std::vector<int64_t> _squeeze_dims;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSQUEEZE_H__
diff --git a/compiler/moco-tf/src/IR/TFSqueeze.test.cpp b/compiler/moco-tf/src/IR/TFSqueeze.test.cpp
new file mode 100644
index 000000000..1ab219b2f
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSqueeze.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSqueeze.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFSqueezeTest, constructor)
+{
+ moco::tf::TFSqueeze squeeze_node;
+
+ ASSERT_EQ(squeeze_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(squeeze_node.opcode(), moco::tf::TFOpcode::Squeeze);
+
+ ASSERT_EQ(squeeze_node.input(), nullptr);
+ ASSERT_EQ(squeeze_node.squeeze_dims().size(), 0);
+}
diff --git a/compiler/moco-tf/src/IR/TFStopGradient.h b/compiler/moco-tf/src/IR/TFStopGradient.h
new file mode 100644
index 000000000..4b8f1b843
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFStopGradient.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSTOPGRADIENT_H__
+#define __MOCO_TF_IR_TFSTOPGRADIENT_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFStopGradient corresponds to the following GraphDef
+/*
+node {
+ name: "StopGradient"
+ op: "StopGradient"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFStopGradient final : public FixedArityNode<1, TFNodeImpl<TFOpcode::StopGradient>>
+{
+public:
+ TFStopGradient() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSTOPGRADIENT_H__
diff --git a/compiler/moco-tf/src/IR/TFStopGradient.test.cpp b/compiler/moco-tf/src/IR/TFStopGradient.test.cpp
new file mode 100644
index 000000000..dafd1a5e7
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFStopGradient.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFStopGradient.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFStopGradientTest, constructor)
+{
+ moco::tf::TFStopGradient node;
+
+ ASSERT_EQ(node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(node.opcode(), moco::tf::TFOpcode::StopGradient);
+
+ ASSERT_EQ(node.input(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFSub.h b/compiler/moco-tf/src/IR/TFSub.h
new file mode 100644
index 000000000..5f4e48b63
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSub.h
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFSUB_H__
+#define __MOCO_TF_IR_TFSUB_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFSub corresponds to the following GraphDef
+/*
+node {
+ name: "sub"
+ op: "Sub"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+*/
+
+class TFSub final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Sub>>
+{
+public:
+ TFSub() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+
+ Node *y(void) const { return at(1)->node(); }
+ void y(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFSUB_H__
diff --git a/compiler/moco-tf/src/IR/TFSub.test.cpp b/compiler/moco-tf/src/IR/TFSub.test.cpp
new file mode 100644
index 000000000..79f746681
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFSub.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSub.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFSubTest, constructor)
+{
+ moco::tf::TFSub sub_node;
+
+ ASSERT_EQ(sub_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(sub_node.opcode(), moco::tf::TFOpcode::Sub);
+
+ ASSERT_EQ(sub_node.x(), nullptr);
+ ASSERT_EQ(sub_node.y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/IR/TFTanh.h b/compiler/moco-tf/src/IR/TFTanh.h
new file mode 100644
index 000000000..c85663e69
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFTanh.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFTANH_H__
+#define __MOCO_TF_IR_TFTANH_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+class TFTanh final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Tanh>>
+{
+public:
+ TFTanh() = default;
+
+public:
+ Node *x(void) const { return at(0)->node(); }
+ void x(Node *node) { at(0)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFTANH_H__
diff --git a/compiler/moco-tf/src/IR/TFTanh.test.cpp b/compiler/moco-tf/src/IR/TFTanh.test.cpp
new file mode 100644
index 000000000..0ff1af6a4
--- /dev/null
+++ b/compiler/moco-tf/src/IR/TFTanh.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFTanh.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFTanhTest, constructor)
+{
+ moco::tf::TFTanh tanh_node;
+
+ ASSERT_EQ(tanh_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(tanh_node.opcode(), moco::tf::TFOpcode::Tanh);
+
+ ASSERT_EQ(tanh_node.x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/ImportTarget.h b/compiler/moco-tf/src/ImportTarget.h
new file mode 100644
index 000000000..cd169f53b
--- /dev/null
+++ b/compiler/moco-tf/src/ImportTarget.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __IMPORT_TARGET_H__
+#define __IMPORT_TARGET_H__
+
+enum class ImportTarget
+{
+ Canonical, // Emit Canonical Dialect
+ TensorFlow, // Emit T/F Dialect
+};
+
+#endif // __IMPORT_TARGET_H__
diff --git a/compiler/moco-tf/src/Importer.cpp b/compiler/moco-tf/src/Importer.cpp
new file mode 100644
index 000000000..7899a4dcf
--- /dev/null
+++ b/compiler/moco-tf/src/Importer.cpp
@@ -0,0 +1,290 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Importer.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "GraphBuilderRegistry.h"
+#include "Transforms.h"
+#include "ProgressReporter.h"
+
+#include "Annotations/ShapeInferenceData.h"
+
+#include <moco/Log.h>
+
+#include <loco/IR/Verifier.h>
+#include <locop/FormattedGraph.h>
+#include <stdex/Memory.h>
+
+#include <logo/Phase.h>
+
+#include <cassert>
+#include <sstream>
+#include <stdexcept>
+
+namespace
+{
+
+void convert_graph(const moco::tf::GraphBuilderSource &source,
+ const moco::tf::ModelSignature &signature, tensorflow::GraphDef &tf_graph_def,
+ loco::Graph *graph)
+{
+ auto nodedef = stdex::make_unique<moco::tf::NodeDefTable>();
+ auto tensor_names = stdex::make_unique<moco::tf::SymbolTable>();
+ auto updates = stdex::make_unique<moco::tf::UpdateQueue>();
+
+ moco::tf::GraphBuilderContext gb_context(graph, nodedef.get(), tensor_names.get(), updates.get());
+
+ // Building a loco graph
+ // 1. Convert all the nodes to loco::Node
+ // 2. Connect inputs: set all node input(from a string) to actual node object
+ // 3. Set graph input
+ // 4. Create loco::Push node and set input and set graph output
+
+ /**
+ * @brief Prepare tensorflow::NodeDef search table from name
+ */
+ for (const auto &n : tf_graph_def.node())
+ {
+ nodedef->enroll(n.name(), &n);
+ }
+
+ /**
+ * @brief 1. Convert all the nodes to loco::Node
+ *
+ * @note In each build for a TF node, four things happen
+ * 1) create corresponding loco::Node(s)
+ * 2) read and set the attributes to created loco::Node(s)
+ * 3) register name-loco::Node(last one of Nodes) that will be used as the output
+ * 4) queue a task to set the input of the loco::Node(first one of the Nodes)
+ * this is done only for required nodes depending on the operator
+ *
+ * @example Placeholder("in") - Identity("out")
+ * %1 = Pull --> 0x1001 (loco::Node* object address)
+ * (symboltable: register %1, after the registeration table will contain as below;
+ * "in" : 0x1001
+ * )
+ * (queue: this will be empty as Pull does not queue a task to set input;
+ * )
+ *
+ * %2 = Forward --> 0x1002
+ * (symboltable: register %2 and table will look like below;
+ * "in" : 0x1001
+ * "out" : 0x1002
+ * )
+ * (queue: Forward will queue a task with input "in";
+ * 0x1002: {"in"}
+ * )
+ */
+ for (const auto &n : tf_graph_def.node())
+ {
+ if (const auto *graph_builder = source.lookup(n.op()))
+ {
+ if (!graph_builder->validate(n))
+ {
+ throw std::runtime_error{"Invalid operator: " + n.op()};
+ }
+
+ graph_builder->build(n, &gb_context);
+ }
+ else
+ {
+ throw std::runtime_error{"Not supported: " + n.op()};
+ }
+ }
+
+ /**
+ * @brief 2. Connect inputs: Iterate updates and call each update input method
+ *
+ * @note Continue from above example graph, connecting inputs is done in following steps
+ * a) iterate queue
+ * b) call the input method for each update
+ * c) each update has the loco::Node *node and names of the input to connect
+ * node = 0x1002 and names = {"in"}
+ * d) from symbol table, "in" will return 0x1001
+ * e) set input of 0x1002 with 0x1001
+ */
+ for (auto &update : updates->queue())
+ {
+ update->input(tensor_names.get());
+ }
+
+ /**
+ * @brief 3. Set graph input
+ */
+ for (auto input : signature.inputs())
+ {
+ auto node = tensor_names->node(input);
+ assert(node != nullptr);
+
+ auto graph_input = graph->inputs()->create();
+
+ loco::Pull *pull_node = dynamic_cast<loco::Pull *>(node);
+ assert(pull_node != nullptr);
+
+ graph_input->name(input.nodeName());
+ // This implementation works as "PlaceholderGraphBuilder in Op/PlaceholderGraphBuilder.cpp"
+ // accepts only TF_FLOAT32 as of now.
+ //
+ // TODO Support other types
+ graph_input->dtype(loco::DataType::FLOAT32);
+ loco::link(graph_input, pull_node);
+ }
+
+ /**
+ * @brief 4. Create loco::Push node and set graph input and output
+ */
+ for (auto output : signature.outputs())
+ {
+ auto output_node = tensor_names->node(output);
+ assert(output_node);
+
+ // create loco::Push for output of graph
+ auto push_node = graph->nodes()->create<loco::Push>();
+ push_node->from(output_node); // set input of Push to output node
+
+ // set the graph output name and node object
+ auto graph_output = graph->outputs()->create();
+ graph_output->name(output.nodeName());
+ // TODO Support other types
+ graph_output->dtype(loco::DataType::FLOAT32);
+ loco::link(graph_output, push_node);
+ }
+
+ // validate graph
+ assert(loco::valid(graph));
+}
+
+void dump_shapeinferencedata(loco::Node *node, const std::string &name)
+{
+ LOGGER(node_shapeinferencedata);
+
+ const moco::tf::ShapeInferenceData *shapedata = node->annot<moco::tf::ShapeInferenceData>();
+ if (shapedata == nullptr)
+ {
+ INFO(node_shapeinferencedata) << "ShapeInferenceData is null for " << name << ":" << node
+ << std::endl;
+ }
+ else
+ {
+ std::stringstream ss;
+
+ ss << "ShapeInferenceData for " << name << ":" << node;
+ // clang-format off
+ switch (shapedata->domain())
+ {
+ case loco::Domain::Tensor: ss << " (Tensor)"; break;
+ case loco::Domain::Feature: ss << " (Feature)"; break;
+ case loco::Domain::Filter: ss << " (Filter)"; break;
+ case loco::Domain::Bias: ss << " (Bias)"; break;
+ default: assert(false && "Unknown Domain"); break;
+ }
+ // clang-format on
+ ss << " rank(" << shapedata->rank() << ") [";
+ for (uint32_t index = 0; index < shapedata->rank(); ++index)
+ {
+ if (index)
+ ss << ",";
+ if (shapedata->dim(index).known())
+ ss << shapedata->dim(index).value();
+ else
+ ss << "?";
+ }
+ ss << "]";
+
+ INFO(node_shapeinferencedata) << ss.str() << std::endl;
+ }
+}
+
+void transform_graph(loco::Graph *graph)
+{
+ LOGGER(transform_graph);
+
+ std::vector<std::unique_ptr<moco::tf::Transform>> prepare;
+ logo::Phase transforms;
+
+ // Transforms that run only once for preparation and finalization
+ {
+ // TODO add one time preparation when needed
+ }
+
+ // Transforms that run multiple times until there is no transform occured
+ {
+ transforms.emplace_back(stdex::make_unique<moco::tf::FixShapeTransform>());
+ // TODO add more TensorFlow related transformations
+ }
+
+ // Run preparation
+ for (auto &tr : prepare)
+ {
+ tr->run(graph);
+ }
+
+ moco::tf::ProgressReporter prog(graph, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> runner{graph};
+
+ runner.attach(&prog);
+ runner.run(transforms);
+
+ // TODO would be better to run this code only when log is enabled
+ {
+ for (uint32_t i = 0; i < graph->outputs()->size(); ++i)
+ {
+ loco::Node *node = loco::push_node(graph, i);
+ std::string name = "Output(" + std::to_string(i) + ")";
+ dump_shapeinferencedata(node, name);
+ }
+ }
+
+ // validate graph
+ assert(loco::valid(graph));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+Importer::Importer()
+{
+ // DO NOTHING
+}
+
+std::unique_ptr<loco::Graph> Importer::import(const ModelSignature &signature,
+ tensorflow::GraphDef &tf_graph_def) const
+{
+ auto graph = loco::make_graph();
+
+ const GraphBuilderSource *source_ptr = &moco::tf::GraphBuilderRegistry::get();
+
+ if (_source != nullptr)
+ {
+ // Use user-defined GraphBuilderSource
+ source_ptr = _source;
+ }
+
+ convert_graph(*source_ptr, signature, tf_graph_def, graph.get());
+
+ transform_graph(graph.get());
+
+ return std::move(graph);
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h b/compiler/moco-tf/src/Importer.h
index bc37bb9cb..e5faafd62 100644
--- a/compiler/moco-tf/src/Canonicalization/Conv2DBackpropInputCanonicalizer.h
+++ b/compiler/moco-tf/src/Importer.h
@@ -14,32 +14,44 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
-#define __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
+#ifndef __IMPORT_H__
+#define __IMPORT_H__
-#include "Transform.h"
-#include "SimpleNodeTransform.h"
+#include <moco/tf/Frontend.h>
+#include <moco/tf/Names.h>
-#include <moco/IR/TFNodes.h>
+#include "GraphBuilderRegistry.h"
#include <loco.h>
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <memory>
+
namespace moco
{
namespace tf
{
-/// @brief Convert TFConv2DBackpropInput to Canonical TransposedConv2D
-class Conv2DBackpropInputCanonicalizer : public SimpleNodeTransform<moco::TFConv2DBackpropInput>
+class Importer final
{
public:
- const char *name(void) const final { return "Conv2DBackpropInputCanonicalizer"; }
+ Importer();
public:
- bool transform(moco::TFConv2DBackpropInput *) const final;
+ explicit Importer(const GraphBuilderSource *source) : _source{source}
+ {
+ // DO NOTHING
+ }
+
+public:
+ std::unique_ptr<loco::Graph> import(const ModelSignature &, tensorflow::GraphDef &) const;
+
+private:
+ const GraphBuilderSource *_source = nullptr;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_CONV2DBACKPROPINPUT_CANONICALIZER_H__
+#endif // __IMPORT_H__
diff --git a/compiler/moco-tf/src/Importer.test.cpp b/compiler/moco-tf/src/Importer.test.cpp
new file mode 100644
index 000000000..770984b0f
--- /dev/null
+++ b/compiler/moco-tf/src/Importer.test.cpp
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Importer.h"
+
+#include "TestHelper.h"
+
+#include "IR/TFIdentity.h"
+#include "Op/Identity.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+TEST(TensorFlowImport, Dummy) { moco::tf::Importer import; }
+
+namespace
+{
+
+// clang-format off
+const char *basic_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+}
+node {
+ name: "output/identity"
+ op: "Identity"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, load_model_withio)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("output/identity", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+ using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - import reads Pull
+ // - import reads Forward
+ // - attribute values should match
+
+ auto pull = find_first_node_bytype<loco::Pull>(graph.get());
+ ASSERT_NE(pull, nullptr);
+ auto forward = find_first_node_bytype<loco::Forward>(graph.get());
+ ASSERT_NE(forward, nullptr);
+
+ ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32);
+ ASSERT_EQ(pull->rank(), 4);
+ loco::Dimension dim1 = 1;
+ loco::Dimension dim2 = 2;
+ ASSERT_EQ(pull->dim(0).value(), dim1.value());
+ ASSERT_EQ(pull->dim(1).value(), dim2.value());
+ ASSERT_EQ(pull->dim(2).value(), dim1.value());
+ ASSERT_EQ(pull->dim(3).value(), dim2.value());
+}
+
+TEST(TensorFlowImport, load_model_withio_tf)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("output/identity", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
+
+ using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ // TODO add Placeholder
+ r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - import reads Placeholder
+ // - import reads Identity
+ // - attribute values should match
+
+ auto tfidentity = find_first_node_bytype<moco::tf::TFIdentity>(graph.get());
+ ASSERT_NE(tfidentity, nullptr);
+ ASSERT_NE(tfidentity->input(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Knob.lst b/compiler/moco-tf/src/Knob.lst
index b88e064c7..fd8b4cd8a 100644
--- a/compiler/moco-tf/src/Knob.lst
+++ b/compiler/moco-tf/src/Knob.lst
@@ -4,19 +4,24 @@
// KNOB_BOOL(NAME, DEFAULT_VALUE, DESCRIPTION)
+// Imports
+KNOB_BOOL(ImportAsTFAvgPool, true, Import AvgPool2D node as TFAvgPool node)
+KNOB_BOOL(ImportAsTFBiasAdd, true, Import BiasAdd node as TFBiasAdd node)
+KNOB_BOOL(ImportAsTFConcatV2, true, Import ConcatV2 node as TFConcatV2 node)
+KNOB_BOOL(ImportAsTFConst, true, Import Const node as TFConst node)
+KNOB_BOOL(ImportAsTFConv2D, true, Import Conv2D node as TFConv2D node)
+KNOB_BOOL(ImportAsTFIdentity, true, Import Identity node as TFIdentity node)
+KNOB_BOOL(ImportAsTFMaxPool, true, Import MaxPool node as TFMaxPool node)
+KNOB_BOOL(ImportAsTFRelu, true, Import Relu node as TFRelu node)
+KNOB_BOOL(ImportAsTFRelu6, true, Import Relu6 node as TFRelu6 node)
+
// TensorFlow dialect transforms
KNOB_BOOL(FuseBinaryIntoPreceding, true, Fuse Binary node to preceding node)
KNOB_BOOL(ResolveFusedBatchNorm, true, Enable ResolveFusedBatchNorm transform)
KNOB_BOOL(ResolveConstantShape, true, Replace determined TFShape to TFConst)
KNOB_BOOL(ResolveReshapeWildcardDim, true, Resolve wildcard dimension in TFReshape node)
-KNOB_BOOL(ResolveSquaredDifference, true, Resolve SquaredDifference node)
+KNOB_BOOL(ResolveSquaredDifference, false, Resolve SquaredDifference node)
KNOB_BOOL(RemoveTFIdentityNode, true, Enable RemoveTFIdentityNode optimization)
-KNOB_BOOL(SqueezeReduceNode, true, Insert TFSqueeze if ReduceNode do not keep dimensions)
-// Constant folding
-KNOB_BOOL(ConstantFoldAdd, false, Constant fold for Add node)
-KNOB_BOOL(ConstantFoldMul, false, Constant fold for Mul node)
-KNOB_BOOL(ConstantFoldPack, false, Constant fold for Pack node)
-KNOB_BOOL(ConstantFoldStridedSlice, false, Constant fold for StridedSlice node)
// Canonicalization
KNOB_BOOL(CanonicalizeBiasAdd, true, Enable Canonicalize for BiasAdd node)
@@ -32,8 +37,8 @@ KNOB_BOOL(ReorderDecodeReLU, true, Reorder FeatureDecode-ReLU)
KNOB_BOOL(ReorderDecodeTensorBiasAdd, true, Reorder FeatureDecode-TensorBiasAdd)
// END
KNOB_BOOL(SimplifyDomainConversion, true, Enable SimplifyDomainConversion optimization)
-KNOB_BOOL(ResolveDuplicateReshape, true, Resolve duplicated Reshape nodes)
-KNOB_BOOL(ResolveRedundantReshape, true, Resolve redundant Reshape node)
+KNOB_BOOL(ResolveDuplicateReshape, false, Resolve duplicated Reshape nodes)
+KNOB_BOOL(ResolveRedundantReshape, false, Resolve redundant Reshape node)
// Graph transformations
KNOB_BOOL(RemoveDeadNode, true, Enable RemoveDeadNode optimization)
diff --git a/compiler/moco-tf/src/LogHelper.cpp b/compiler/moco-tf/src/LogHelper.cpp
index 92ff75569..1a38eb7c3 100644
--- a/compiler/moco-tf/src/LogHelper.cpp
+++ b/compiler/moco-tf/src/LogHelper.cpp
@@ -56,6 +56,22 @@ std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad)
} // namespace loco
+namespace moco
+{
+namespace tf
+{
+
+std::ostream &operator<<(std::ostream &os, const moco::tf::PadData &pad_data)
+{
+ os << "[TLBR " << pad_data.pad()->top() << "," << pad_data.pad()->left() << ","
+ << pad_data.pad()->bottom() << "," << pad_data.pad()->right() << "]";
+
+ return os;
+}
+
+} // namespace tf
+} // namespace moco
+
std::ostream &operator<<(std::ostream &os, const std::vector<int64_t> &vi64)
{
for (auto vi : vi64)
diff --git a/compiler/moco-tf/src/LogHelper.h b/compiler/moco-tf/src/LogHelper.h
index 4e3cb5dac..fc60f9fef 100644
--- a/compiler/moco-tf/src/LogHelper.h
+++ b/compiler/moco-tf/src/LogHelper.h
@@ -23,6 +23,8 @@
#include <loco/IR/FilterShape.h>
#include <loco/IR/TensorShape.h>
+#include "Annotations/PadData.h"
+
#include <sstream>
#include <vector>
@@ -51,6 +53,19 @@ std::ostream &operator<<(std::ostream &os, const loco::Padding2D &pad);
} // namespace loco
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief dump moco::tf::PadData
+ */
+std::ostream &operator<<(std::ostream &os, const moco::tf::PadData &pad_data);
+
+} // namespace tf
+} // namespace moco
+
/**
* @brief dump std::vector<int64_t> values to stream
*/
diff --git a/compiler/moco-tf/src/Op/Add.cpp b/compiler/moco-tf/src/Op/Add.cpp
new file mode 100644
index 000000000..b957cf4dc
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Add.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFAdd.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Add node
+ */
+class TFAddGraphUpdate final : public GraphUpdate
+{
+public:
+ TFAddGraphUpdate(TFAdd *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFAdd *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFAddGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 2);
+
+ _node->x(tensor_names->node(_names[0]));
+ _node->y(tensor_names->node(_names[1]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Add node
+ */
+class AddGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool AddGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ return true;
+}
+
+void AddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Add node
+ auto tf_add = graph->nodes()->create<TFAdd>();
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_add);
+
+ std::vector<TensorName> add_input_names;
+ add_input_names.push_back(TensorName(node.input(0))); // x
+ add_input_names.push_back(TensorName(node.input(1))); // y
+
+ auto tf_add_update = stdex::make_unique<TFAddGraphUpdate>(tf_add, add_input_names);
+ updates->enroll(std::move(tf_add_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Add, AddGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Add.test.cpp b/compiler/moco-tf/src/Op/Add.test.cpp
new file mode 100644
index 000000000..dc53f37b2
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Add.test.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFAdd.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *add_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input_01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "input_02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "ADD_01"
+ op: "Add"
+ input: "input_01"
+ input: "input_02"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_add_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("ADD_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(add_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFAdd node should exist
+ // - both inputs x() and y() should not be null
+
+ auto add_node = moco::tf::test::find_first_node_bytype<moco::tf::TFAdd>(graph.get());
+
+ ASSERT_NE(add_node, nullptr);
+ ASSERT_NE(add_node->x(), nullptr);
+ ASSERT_NE(add_node->y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/AvgPool.cpp b/compiler/moco-tf/src/Op/AvgPool.cpp
new file mode 100644
index 000000000..dda7fce15
--- /dev/null
+++ b/compiler/moco-tf/src/Op/AvgPool.cpp
@@ -0,0 +1,325 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "AvgPool.h"
+
+#include "Convert.h"
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFAvgPool.h"
+
+#include "Annotations/PaddingData.h"
+
+#include <moco/tf/Names.h>
+
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <cassert>
+#include <stdexcept>
+
+using namespace plier::tf;
+
+namespace
+{
+
+using namespace moco::tf;
+
+class AvgPoolGraphUpdate final : public GraphUpdate
+{
+public:
+ AvgPoolGraphUpdate(loco::FeatureEncode *node, const TensorName &name)
+ : _encode_node(node), _input_name(name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::FeatureEncode *_encode_node;
+ const TensorName _input_name;
+};
+
+class TFAvgPoolGraphUpdate final : public GraphUpdate
+{
+public:
+ TFAvgPoolGraphUpdate(moco::tf::TFAvgPool *node, const TensorName &name)
+ : _avgpool_node(node), _value_name(name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFAvgPool *_avgpool_node;
+ const TensorName _value_name;
+};
+
+void AvgPoolGraphUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *input_node = node_table->node(_input_name);
+ _encode_node->input(input_node);
+}
+
+void TFAvgPoolGraphUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *value_node = node_table->node(_value_name);
+ _avgpool_node->value(value_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for AvgPool node
+ */
+class AvgPoolGraphBuilder final : public AvgPoolGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool AvgPoolGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ // note: even though "data_format" is not entered when a model is written,
+ // TF seems to generate "data_format" field into a pb file
+ if (!plier::tf::has_attrs(node, {"T", "data_format", "ksize", "padding", "strides"}))
+ return false;
+
+ if (node.input_size() != 1)
+ return false;
+
+ return true;
+}
+
+void AvgPoolGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFAvgPool>())
+ {
+ AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ AvgPoolGraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void AvgPoolGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ using plier::tf::DataLayout;
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // create loco nodes
+ auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
+ auto avgPool2d_node = graph->nodes()->create<loco::AvgPool2D>();
+ auto decode_node = graph->nodes()->create<loco::FeatureDecode>();
+
+ // name of loco nodes
+ ::std::string avgPool2d_name = node.name();
+
+ // tensorflow padding convention is valid
+ avgPool2d_node->convention(loco::AvgPool2D::Convention::Valid);
+
+ // tensorflow data_format, e.g., NHWC, NCHW, etc.
+ auto data_layout = plier::tf::get_data_layout(node, "data_format");
+ if (!(data_layout == DataLayout::NHWC || data_layout == DataLayout::NCHW))
+ {
+ throw std::runtime_error("Not supported data layout at AvgPoolGraphBuilder");
+ }
+
+ // FeatureEncode
+ {
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ assert(false);
+
+ encode_node->encoder(std::move(enc));
+ }
+
+ // AvgPool
+ {
+ // let's convert attrs:
+ // TensorFlow attr : T, data_format, ksize, padding, strides
+ // to loco attr: not defined, TBD, window, annot, stride
+
+ // tf ksize -> loco window
+ auto tf_ksize = plier::tf::get_list_attr(node, "ksize");
+ auto window = avgPool2d_node->window();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ window->vertical(tf_ksize.i(1));
+ window->horizontal(tf_ksize.i(2));
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ window->vertical(tf_ksize.i(2));
+ window->horizontal(tf_ksize.i(3));
+ }
+ else
+ assert(false);
+
+ // tf strides -> loco stride
+ auto tf_strides = plier::tf::get_list_attr(node, "strides");
+ auto stride = avgPool2d_node->stride();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ stride->vertical(tf_strides.i(1));
+ stride->horizontal(tf_strides.i(2));
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ stride->vertical(tf_strides.i(2));
+ stride->horizontal(tf_strides.i(3));
+ }
+ else
+ assert(false);
+
+ // tf paddings -> PaddingData annotation
+ auto tf_padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
+ auto padding_data = stdex::make_unique<PaddingData>(tf_padding);
+ avgPool2d_node->annot(std::move(padding_data));
+ }
+
+ // FeatureDecode
+ {
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ assert(false);
+
+ decode_node->decoder(std::move(dec));
+ }
+
+ // link nodes
+ avgPool2d_node->ifm(encode_node);
+ decode_node->input(avgPool2d_node);
+
+ // To set the input node of encode_node with avgPool2d_name
+ TensorName output_name(avgPool2d_name, 0);
+ tensor_names->enroll(output_name, decode_node);
+
+ // Record ifm inputs to featureEncode_node
+ auto update = stdex::make_unique<AvgPoolGraphUpdate>(encode_node, TensorName(node.input(0)));
+
+ updates->enroll(std::move(update));
+}
+
+void AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ ::std::string avgPool2d_name = node.name();
+
+ // tensorflow data_format: one of NHWC or NCHW.
+ auto data_layout = get_string_attr(node, "data_format");
+ auto avgPool_node = graph->nodes()->create<moco::tf::TFAvgPool>();
+ avgPool_node->data_layout(data_layout);
+
+ // padding
+ auto padding = moco::str_toupper(get_string_attr(node, "padding"));
+ avgPool_node->padding(padding);
+
+ // ksize
+ auto tf_ksize = get_list_attr(node, "ksize");
+ auto ksize = as_int64_list(tf_ksize);
+ if (ksize.size() != 4)
+ {
+ // TODO support ksize length for 1 and 2
+ throw std::runtime_error("AvgPool only supports ksize length 4");
+ }
+ avgPool_node->ksize(ksize);
+
+ // strides
+ auto tf_strides = get_list_attr(node, "strides");
+ auto strides = as_int64_list(tf_strides);
+ if (strides.size() != 4)
+ {
+ // TODO support strides length for 1 and 2
+ throw std::runtime_error("AvgPool only supports strides length 4");
+ }
+ avgPool_node->strides(strides);
+
+ // To set the input node of encode_node with avgPool2d_name
+ TensorName output_name(avgPool2d_name, 0);
+ tensor_names->enroll(output_name, avgPool_node);
+
+ // Record ifm inputs to featureEncode_node
+ auto update = stdex::make_unique<TFAvgPoolGraphUpdate>(avgPool_node, TensorName(node.input(0)));
+
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(AvgPool, AvgPoolGraphBuilder)
+
+// TODO Consider a case when TF AvgPool is for 3D.
+// AvgPool works for 2D and other Dimensions, such as 3D
+// So, in future, some other GraphBuilder decide if AvgPoolGraphBuilder is used or
+// other GraphBuilder is used for TF AvgPool
diff --git a/compiler/moco-tf/src/Op/AvgPool.h b/compiler/moco-tf/src/Op/AvgPool.h
new file mode 100644
index 000000000..ec9075a81
--- /dev/null
+++ b/compiler/moco-tf/src/Op/AvgPool.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_AVG_POOL_H__
+#define __OP_AVG_POOL_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct AvgPoolGraphBuilderBase : public GraphBuilder
+{
+ virtual ~AvgPoolGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class AvgPoolGraphBuilderImpl;
+
+template <>
+struct AvgPoolGraphBuilderImpl<ImportTarget::Canonical> final : public AvgPoolGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow> final : public AvgPoolGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_AVG_POOL2D_H__
diff --git a/compiler/moco-tf/src/Op/AvgPool.test.cpp b/compiler/moco-tf/src/Op/AvgPool.test.cpp
new file mode 100644
index 000000000..d5fb2082c
--- /dev/null
+++ b/compiler/moco-tf/src/Op/AvgPool.test.cpp
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "AvgPool.h"
+
+#include "IR/TFAvgPool.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <loco/IR/TensorShape.h>
+#include <loco/IR/FeatureShape.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *avgpool_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "avgpool"
+ op: "AvgPool"
+ input: "const/float"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "ksize"
+ value {
+ list {
+ i: 1
+ i: 2
+ i: 3
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 3
+ i: 2
+ i: 1
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, AvgPool_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("avgpool", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(avgpool_01_pbtxtdata, graph_def));
+
+ // Test "AvgPoolGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // what to test:
+ // - there should exist AvgPool2D
+ // - input node should be FeatureEncode
+ // - following node should be FeatureDecode
+ // - stride values should match
+ // - window values should match
+
+ using AvgPoolGraphBuilder = AvgPoolGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("AvgPool", stdex::make_unique<AvgPoolGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ loco::AvgPool2D *avgpool2d_node =
+ moco::tf::test::find_first_node_bytype<loco::AvgPool2D>(graph.get());
+ ASSERT_NE(avgpool2d_node, nullptr);
+
+ loco::Node *previous_node = avgpool2d_node->ifm();
+ auto following_nodes = loco::succs(avgpool2d_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+
+ loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(previous_node);
+ loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(following_node);
+
+ ASSERT_NE(enc_node, nullptr);
+ ASSERT_NE(dec_node, nullptr);
+
+ // attrs inside AvgPool2D
+ auto avgpool2d = avgpool2d_node; // TODO remove this new variable
+ // convention
+ ASSERT_EQ(avgpool2d->convention(), loco::AvgPool2D::Convention::Valid);
+
+ // stride
+ ASSERT_EQ(avgpool2d->stride()->vertical(), 3);
+ ASSERT_EQ(avgpool2d->stride()->horizontal(), 2);
+
+ // window
+ ASSERT_EQ(avgpool2d->window()->vertical(), 2);
+ ASSERT_EQ(avgpool2d->window()->horizontal(), 3);
+ }
+
+ // Test "AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ // what to test:
+ // - there should exist TFAvgPool
+ // - attributes value should match
+
+ using AvgPoolGraphBuilder = AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("AvgPool", stdex::make_unique<AvgPoolGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ moco::tf::TFAvgPool *avgpool_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFAvgPool>(graph.get());
+ ASSERT_NE(avgpool_node, nullptr);
+
+ loco::Node *previous_node = avgpool_node->value();
+ auto following_nodes = loco::succs(avgpool_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+
+ // attrs inside TFAvgPool2D
+ ASSERT_EQ(avgpool_node->data_layout(), "NHWC");
+ ASSERT_EQ(avgpool_node->padding(), "VALID");
+ ASSERT_EQ(avgpool_node->ksize(), std::vector<int64_t>({1, 2, 3, 1}));
+ ASSERT_EQ(avgpool_node->strides(), std::vector<int64_t>({1, 3, 2, 1}));
+ }
+}
diff --git a/compiler/moco-tf/src/Op/BiasAdd.cpp b/compiler/moco-tf/src/Op/BiasAdd.cpp
new file mode 100644
index 000000000..6862a522d
--- /dev/null
+++ b/compiler/moco-tf/src/Op/BiasAdd.cpp
@@ -0,0 +1,240 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BiasAdd.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFBiasAdd.h"
+
+#include <moco/tf/Names.h>
+
+#include <loco.h>
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+#include <vector>
+
+namespace
+{
+using namespace moco::tf;
+
+class ValueInputUpdate final : public GraphUpdate
+{
+public:
+ ValueInputUpdate(loco::BiasAdd<loco::Domain::Tensor> *bias_add, const TensorName &&input_name)
+ : _bias_add(bias_add), _input_name(input_name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::BiasAdd<loco::Domain::Tensor> *_bias_add;
+ const TensorName _input_name;
+};
+
+void ValueInputUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *input_node = node_table->node(_input_name);
+ _bias_add->value(input_node);
+}
+
+class BiasInputUpdate final : public GraphUpdate
+{
+public:
+ BiasInputUpdate(loco::BiasEncode *bias_enc, const TensorName &&input_name)
+ : _bias_enc(bias_enc), _input_name(input_name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::BiasEncode *_bias_enc;
+ const TensorName _input_name;
+};
+
+void BiasInputUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *input_node = node_table->node(_input_name);
+ _bias_enc->input(input_node);
+}
+
+class TFBiasAddGraphUpdate final : public GraphUpdate
+{
+public:
+ TFBiasAddGraphUpdate(moco::tf::TFBiasAdd *biasadd, std::vector<TensorName> &names)
+ : _biasadd(biasadd), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFBiasAdd *_biasadd;
+ std::vector<TensorName> _names;
+};
+
+void TFBiasAddGraphUpdate::input(const SymbolTable *node_table) const
+{
+ assert(_names.size() == 2);
+
+ auto value_node = node_table->node(_names[0]);
+ auto bias_node = node_table->node(_names[1]);
+ assert(value_node != nullptr);
+ assert(bias_node != nullptr);
+
+ _biasadd->value(value_node);
+ _biasadd->bias(bias_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for BiasAdd node
+ */
+class BiasAddGraphBuilder final : public BiasAddGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool BiasAddGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ // note: even though "data_format" is not entered when a model is written,
+ // TF seems to generate "data_format" field into a pb file
+ if (!plier::tf::has_attrs(node, {"T", "data_format"}))
+ return false;
+
+ // TODO add type check
+ // type of input and bias should be same (except using quantization)
+
+ // Note In case of TF.nn.bias_add,
+ // "value may have any number of dimensions." ...
+ // but "data_format: A string. 'NHWC' and 'NCHW' are supported."
+ // Not sure if value should be 4-D tensor. Let's skip this check for now.
+
+ return true;
+}
+
+void BiasAddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFBiasAdd>())
+ {
+ BiasAddGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ BiasAddGraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void BiasAddGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // tensorflow data_format: one of NHWC or NCHW.
+ auto data_layout = plier::tf::get_data_layout(node, "data_format");
+
+ // creating loco nodes
+ auto bias_enc = graph->nodes()->create<loco::BiasEncode>();
+
+ auto bias_add = graph->nodes()->create<loco::BiasAdd<loco::Domain::Tensor>>();
+ {
+ using plier::tf::DataLayout;
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ bias_add->axis(3);
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ bias_add->axis(1); // Channel
+ // Note: the following descrition of TF 1.13 at
+ // https://www.tensorflow.org/api_docs/python/tf/nn/bias_add seems wrong:
+ // "bias: A 1-D Tensor with size matching the last dimension of value."
+ // because providing the size of W (last dimension) to bias throws an error with TensorFlow
+ }
+ }
+
+ // link nodes
+ bias_add->bias(bias_enc);
+
+ // To set the input node of encode_node with biasAdd_name
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, bias_add);
+
+ // Record ifm inputs to featureEncode_node
+ auto value_update = stdex::make_unique<ValueInputUpdate>(bias_add, TensorName(node.input(0)));
+ auto bias_update = stdex::make_unique<BiasInputUpdate>(bias_enc, TensorName(node.input(1)));
+
+ updates->enroll(std::move(value_update));
+ updates->enroll(std::move(bias_update));
+}
+
+void BiasAddGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // tensorflow data_format: one of NHWC or NCHW.
+ auto data_layout = plier::tf::get_string_attr(node, "data_format");
+ auto tf_bias_add = graph->nodes()->create<moco::tf::TFBiasAdd>();
+
+ tf_bias_add->data_layout(data_layout);
+
+ // To set the input node of encode_node with biasAdd_name
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_bias_add);
+
+ std::vector<TensorName> input_names;
+ input_names.push_back(TensorName(node.input(0)));
+ input_names.push_back(TensorName(node.input(1)));
+
+ auto update = stdex::make_unique<TFBiasAddGraphUpdate>(tf_bias_add, input_names);
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(BiasAdd, BiasAddGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/BiasAdd.h b/compiler/moco-tf/src/Op/BiasAdd.h
new file mode 100644
index 000000000..890ca65b1
--- /dev/null
+++ b/compiler/moco-tf/src/Op/BiasAdd.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_BIAS_ADD_H__
+#define __OP_BIAS_ADD_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct BiasAddGraphBuilderBase : public GraphBuilder
+{
+ virtual ~BiasAddGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class BiasAddGraphBuilderImpl;
+
+template <>
+struct BiasAddGraphBuilderImpl<ImportTarget::Canonical> final : public BiasAddGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct BiasAddGraphBuilderImpl<ImportTarget::TensorFlow> final : public BiasAddGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_BIAS_ADD_H__
diff --git a/compiler/moco-tf/src/Op/BiasAdd.test.cpp b/compiler/moco-tf/src/Op/BiasAdd.test.cpp
new file mode 100644
index 000000000..823b66f42
--- /dev/null
+++ b/compiler/moco-tf/src/Op/BiasAdd.test.cpp
@@ -0,0 +1,301 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BiasAdd.h"
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFBiasAdd.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+
+// clang-format off
+const char *bias_add_01_pbtxtdata = STRING_CONTENT(
+
+node {
+ name: "val"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 5 }
+ dim { size: 5 }
+ dim { size: 3 }
+ }
+ float_val: 2.1
+ }
+ }
+ }
+}
+node {
+ name: "bias"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 3 }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "out"
+ op: "BiasAdd"
+ input: "val"
+ input: "bias"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NHWC" }
+ }
+}
+
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, bias_add_01)
+{
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("out", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(bias_add_01_pbtxtdata, graph_def));
+
+ // Test "BiasAddGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ using BiasAddGraphBuilder = BiasAddGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("BiasAdd", stdex::make_unique<BiasAddGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFBiasAdd
+ // - value() should not be nullptr
+ // - bias() should not be nullptr
+ // - data_layout should match
+
+ moco::tf::TFBiasAdd *bias_add =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFBiasAdd>(graph.get());
+
+ ASSERT_NE(bias_add, nullptr);
+
+ ASSERT_NE(bias_add->value(), nullptr);
+ ASSERT_NE(bias_add->bias(), nullptr);
+
+ ASSERT_TRUE(bias_add->data_layout() == "NHWC");
+ }
+ // Test "BiasAddGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ using BiasAddGraphBuilder = BiasAddGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("BiasAdd", stdex::make_unique<BiasAddGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist BiasAdd
+ // - value() should not be nullptr
+ // - bias() input should be BiasEncode
+ // - axis should match
+
+ // loco node : ------------+-- BiasAdd<Domain::Tensor> --
+ // BiasEncode -/
+
+ loco::BiasAdd<loco::Domain::Tensor> *bias_add =
+ moco::tf::test::find_first_node_bytype<loco::BiasAdd<loco::Domain::Tensor>>(graph.get());
+
+ ASSERT_NE(bias_add, nullptr);
+ ASSERT_NE(bias_add->value(), nullptr);
+
+ auto bias_enc = dynamic_cast<loco::BiasEncode *>(bias_add->bias());
+ ASSERT_NE(bias_enc, nullptr);
+
+ ASSERT_EQ(bias_add->axis(), 3); // NHWC
+ }
+}
+
+namespace
+{
+
+// clang-format off
+const char *bias_add_NCHW_pbtxtdata = STRING_CONTENT(
+
+node {
+ name: "val"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 3 }
+ dim { size: 299 }
+ dim { size: 299 }
+ }
+ float_val: 2.1
+ }
+ }
+ }
+}
+node {
+ name: "bias"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 3 }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "out"
+ op: "BiasAdd"
+ input: "val"
+ input: "bias"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NCHW" }
+ }
+}
+
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, bias_add_NCHW_axis)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("out", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(bias_add_NCHW_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // Test "BiasAddGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ using BiasAddGraphBuilder = BiasAddGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("BiasAdd", stdex::make_unique<BiasAddGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFBiasAdd
+ // - value() should not be nullptr
+ // - bias() should not be nullptr
+ // - data_layout should match
+
+ moco::tf::TFBiasAdd *bias_add =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFBiasAdd>(graph.get());
+
+ ASSERT_NE(bias_add, nullptr);
+
+ ASSERT_NE(bias_add->value(), nullptr);
+ ASSERT_NE(bias_add->bias(), nullptr);
+
+ ASSERT_TRUE(bias_add->data_layout() == "NCHW");
+ }
+ // Test "BiasAddGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ using BiasAddGraphBuilder = BiasAddGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("BiasAdd", stdex::make_unique<BiasAddGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist BiasAdd
+ // - value() should not be nullptr
+ // - bias() input should be BiasEncode
+ // - axis should match
+
+ // loco node : ------------+-- BiasAdd<Domain::Tensor> --
+ // BiasEncode -/
+
+ loco::BiasAdd<loco::Domain::Tensor> *bias_add =
+ moco::tf::test::find_first_node_bytype<loco::BiasAdd<loco::Domain::Tensor>>(graph.get());
+
+ ASSERT_NE(bias_add, nullptr);
+ ASSERT_NE(bias_add->value(), nullptr);
+
+ auto bias_enc = dynamic_cast<loco::BiasEncode *>(bias_add->bias());
+ ASSERT_NE(bias_enc, nullptr);
+
+ ASSERT_EQ(bias_add->axis(), 1); // NCHW
+ }
+}
diff --git a/compiler/moco-tf/src/Op/COpCall.cpp b/compiler/moco-tf/src/Op/COpCall.cpp
index 801196f0f..2bd3fcdfc 100644
--- a/compiler/moco-tf/src/Op/COpCall.cpp
+++ b/compiler/moco-tf/src/Op/COpCall.cpp
@@ -17,14 +17,17 @@
#include "COpCall.h"
#include "Convert.h"
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
#include <locoex/COpCall.h>
#include <locoex/COpAttrTypes.h>
-#include <moco/Names.h>
+#include <moco/tf/Names.h>
#include <moco/tf/Frontend.h>
#include <loco.h>
#include <stdex/Memory.h>
-#include <oops/UserExn.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
#include <vector>
#include <cassert>
@@ -33,22 +36,22 @@
namespace
{
-class COpCallGraphUpdate final : public moco::GraphUpdate
+class COpCallGraphUpdate final : public moco::tf::GraphUpdate
{
public:
- COpCallGraphUpdate(locoex::COpCall *node, const std::vector<moco::TensorName> &input_names)
+ COpCallGraphUpdate(locoex::COpCall *node, const std::vector<moco::tf::TensorName> &input_names)
: _node(node), _input_names(input_names)
{
}
- void input(const moco::SymbolTable *) const override;
+ void input(const moco::tf::SymbolTable *) const override;
private:
locoex::COpCall *_node;
- const std::vector<moco::TensorName> _input_names;
+ const std::vector<moco::tf::TensorName> _input_names;
};
-void COpCallGraphUpdate::input(const moco::SymbolTable *tensor_names) const
+void COpCallGraphUpdate::input(const moco::tf::SymbolTable *tensor_names) const
{
for (int n = 0; n < _input_names.size(); n++)
{
@@ -103,7 +106,7 @@ void COpCallGraphBuilder::build(const tensorflow::NodeDef &tf_node,
// TODO define more types
else
{
- throw oops::UserExn("Unsupported attribute type", tf_node.name());
+ throw std::runtime_error("not supported attribute type");
}
}
}
diff --git a/compiler/moco-tf/src/Op/COpCall.h b/compiler/moco-tf/src/Op/COpCall.h
index 0bb8a93c9..ea81d3a98 100644
--- a/compiler/moco-tf/src/Op/COpCall.h
+++ b/compiler/moco-tf/src/Op/COpCall.h
@@ -17,9 +17,12 @@
#ifndef __OP_COP_CALL_H__
#define __OP_COP_CALL_H__
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
#include <moco/tf/Frontend.h>
-#include <moco/Import/GraphBuilder.h>
+#include <tensorflow/core/framework/graph.pb.h>
namespace moco
{
diff --git a/compiler/moco-tf/src/Op/COpCall.test.cpp b/compiler/moco-tf/src/Op/COpCall.test.cpp
index f13118292..7b6f8d5c4 100644
--- a/compiler/moco-tf/src/Op/COpCall.test.cpp
+++ b/compiler/moco-tf/src/Op/COpCall.test.cpp
@@ -18,16 +18,14 @@
#include "TestHelper.h"
+#include "Importer.h"
#include "Canonicalizer.h"
-#include <moco/Importer.h>
-
#include <locoex/COpCall.h>
#include <locoex/COpAttrTypes.h>
#include <loco.h>
#include <plier/tf/TestHelper.h>
-#include <stdex/Memory.h>
#include <gtest/gtest.h>
@@ -77,10 +75,10 @@ node {
TEST(Call_Test, Call_01)
{
- moco::ModelSignature signature;
+ moco::tf::ModelSignature signature;
{
- signature.add_input(moco::TensorName("input1", 0));
- signature.add_output(moco::TensorName("my/customOp/000", 0));
+ signature.add_input(moco::tf::TensorName("input1", 0));
+ signature.add_output(moco::tf::TensorName("my/customOp/000", 0));
signature.add_customop("new_custom_op");
signature.dtype("my/customOp/000", loco::DataType::FLOAT32);
signature.shape("my/customOp/000", {1, 2});
@@ -90,10 +88,10 @@ TEST(Call_Test, Call_01)
EXPECT_TRUE(plier::tf::parse_graphdef(customop_01_pbtxtdata, graph_def));
// import
- moco::GraphBuilderRegistry registry{&moco::GraphBuilderRegistry::get()};
+ moco::tf::GraphBuilderRegistry registry{&moco::tf::GraphBuilderRegistry::get()};
registry.add("new_custom_op", stdex::make_unique<moco::tf::COpCallGraphBuilder>(&signature));
- moco::Importer importer(&registry);
+ moco::tf::Importer importer(&registry);
std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
// what to test:
diff --git a/compiler/moco-tf/src/Op/Concat.cpp b/compiler/moco-tf/src/Op/Concat.cpp
new file mode 100644
index 000000000..33788cdbf
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Concat.cpp
@@ -0,0 +1,276 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Concat.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConcatV2.h"
+
+#include "Annotations/ConcatData.h"
+
+#include <moco/tf/Names.h>
+
+#include <loco.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class ConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+ ConcatV2GraphUpdate(std::vector<loco::TensorConcat *> nodes, std::vector<TensorName> names)
+ : _nodes(nodes), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ std::vector<loco::TensorConcat *> _nodes;
+ std::vector<TensorName> _names;
+};
+
+class TFConcatV2GraphUpdate final : public GraphUpdate
+{
+public:
+ TFConcatV2GraphUpdate(moco::tf::TFConcatV2 *node, std::vector<TensorName> names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFConcatV2 *_node;
+ std::vector<TensorName> _names;
+};
+
+void ConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs >= 2);
+ assert(num_inputs == _nodes.size());
+
+ loco::Node *target;
+ // do "%0.lhs : %in[0].name" connection
+ target = tensor_names->node(_names[0]);
+ _nodes[0]->lhs(target);
+
+ for (int i = 1; i < num_inputs; ++i)
+ {
+ // do "%i.rhs : %in[i].name" connections
+ target = tensor_names->node(_names[i]);
+ _nodes[i]->rhs(target);
+ }
+}
+
+void TFConcatV2GraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ uint32_t num_values = _names.size() - 1; // exclude axis
+ assert(num_values >= 1);
+
+ for (uint32_t i = 0; i < num_values; ++i)
+ {
+ auto input_node = tensor_names->node(_names[i]);
+ assert(input_node != nullptr);
+ _node->values(i, input_node);
+ }
+ auto axis_node = tensor_names->node(_names[num_values]);
+ assert(axis_node != nullptr);
+ _node->axis(axis_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ConcatV2GraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ if (!plier::tf::has_attrs(node, {"T", "N", "Tidx"}))
+ return false;
+
+ // Concat node SHOULD have 3 or more inputs, that is 2 + axis
+ const int num_inputs = node.input_size() - 1;
+ assert(num_inputs >= 2);
+ assert(num_inputs == plier::tf::get_int_attr(node, "N"));
+ return (num_inputs >= 2) && (num_inputs == plier::tf::get_int_attr(node, "N"));
+}
+
+/**
+ * @brief GraphBuilder for Concat node of Tensor
+ */
+class ConcatV2GraphBuilder final : public ConcatV2GraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+void ConcatV2GraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConcatV2>())
+ {
+ ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ ConcatV2GraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void ConcatV2GraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ NodeDefTable *nodedef = context->nodedef();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Concat has 2 or more inputs and loco TensorConcat is fixed to 2 inputs
+ // for arbitrary N inputs (beginning from 0), TensorConcat will be created
+ // as follows;
+ // %0 = TensorConcat(%in[0], %in[1])
+ // %1 = %0 --> this is to match index of input name
+ // %2 = TensorConcat(%1, %in[2])
+ // ...
+ // %(N-1) = TensorConcat(%(N-2), %in[N-1]))
+ // %N = TensorConcat(%(N-1), %in[N]))
+ //
+ // Output of this sub graph will be set to %N with node.name()
+ //
+ // As we know that each input exist, one of input(lhs) can be linked while creating
+ // %2.lhs = %1
+ // %3.lhs = %2
+ // ...
+ // %(N-1).lhs = %(N-2)
+ // %N.lhs = %(N-1)
+
+ const int num_inputs = node.input_size() - 1;
+
+ std::vector<loco::TensorConcat *> concat_nodes;
+ std::vector<TensorName> input_names;
+
+ auto concat_node = graph->nodes()->create<loco::TensorConcat>();
+ loco::TensorConcat *last_concat = concat_node;
+
+ // Queue node input update
+ concat_nodes.push_back(concat_node); // used for LHS of connection -> %0
+ concat_nodes.push_back(concat_node); // used for RHS of connection -> %1
+ input_names.push_back(TensorName(node.input(0))); // for first concat (%0) LHS
+ input_names.push_back(TensorName(node.input(1))); // for first concat (%1) RHS
+
+ for (int ni = 2; ni < num_inputs; ++ni)
+ {
+ auto concat_node_next = graph->nodes()->create<loco::TensorConcat>();
+
+ concat_nodes.push_back(concat_node_next);
+ input_names.push_back(TensorName(node.input(ni)));
+
+ // connect LHS as we know the nodes
+ concat_node_next->lhs(last_concat);
+
+ // update last concat node
+ last_concat = concat_node_next;
+ }
+
+ // register string-name to the last node as output of concat(s)
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, last_concat);
+
+ // Find axis tensorflow::NodeDef and get the axis number
+ std::string axis_name = node.input(num_inputs);
+ const tensorflow::NodeDef *tfnode = nodedef->node(axis_name);
+ // assume data type is int32
+ assert(plier::tf::get_datatype_attr(*tfnode, "dtype") == tensorflow::DataType::DT_INT32);
+ const auto &tensor = plier::tf::get_tensor_attr(*tfnode, "value");
+ assert(tensor.int_val_size() == 1);
+ auto axis_value_read = tensor.int_val(0);
+
+ // set axis for all concat(s) as temporary data
+ // as the first and the second items are actually the same one, skip it.
+ std::vector<loco::TensorConcat *>::iterator iter = concat_nodes.begin();
+ for (++iter; iter != concat_nodes.end(); ++iter)
+ {
+ auto concat_node = *iter;
+ auto concat_data = stdex::make_unique<ConcatData>(axis_value_read);
+
+ concat_node->annot(std::move(concat_data));
+ }
+
+ // Input name queue is created like this in 'concat_nodes' and 'input_names'
+ // %0.lhs : %in[0].name
+ // %1.rhs : %in[1].name (as %0 == %1)
+ // %2.rhs : %in[2].name
+ // %3.rhs : %in[3].name
+ // ...
+ // %(N-2).rhs : %in[N-2].name
+ // %(N-1).rhs : %in[N-1].name
+ auto update = stdex::make_unique<ConcatV2GraphUpdate>(concat_nodes, input_names);
+ updates->enroll(std::move(update));
+}
+
+void ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ NodeDefTable *nodedef = context->nodedef();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ const int num_inputs = node.input_size() - 1;
+ std::vector<TensorName> input_names;
+ auto concat_node = graph->nodes()->create<TFConcatV2>(num_inputs);
+
+ for (int ni = 0; ni < num_inputs; ++ni)
+ {
+ input_names.push_back(TensorName(node.input(ni)));
+ }
+ // last one is the axis
+ input_names.push_back(TensorName(node.input(num_inputs)));
+
+ // register string-name to the last node as output of concat(s)
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, concat_node);
+
+ auto update = stdex::make_unique<TFConcatV2GraphUpdate>(concat_node, input_names);
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(ConcatV2, ConcatV2GraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Concat.h b/compiler/moco-tf/src/Op/Concat.h
new file mode 100644
index 000000000..6a5a857e3
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Concat.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_CONCAT_H__
+#define __OP_CONCAT_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct ConcatV2GraphBuilderBase : public GraphBuilder
+{
+ virtual ~ConcatV2GraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class ConcatV2GraphBuilderImpl;
+
+template <>
+struct ConcatV2GraphBuilderImpl<ImportTarget::Canonical> final : public ConcatV2GraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow> final : public ConcatV2GraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_CONCAT_H__
diff --git a/compiler/moco-tf/src/Op/Concat.test.cpp b/compiler/moco-tf/src/Op/Concat.test.cpp
new file mode 100644
index 000000000..e9ddd6bea
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Concat.test.cpp
@@ -0,0 +1,449 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Concat.h"
+
+#include "IR/TFConcatV2.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+
+// clang-format off
+const char *concat_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Input01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1
+ float_val: 2
+ float_val: 3
+ float_val: 4
+ float_val: 5
+ float_val: 6
+ }
+ }
+ }
+}
+node {
+ name: "Input02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 7
+ float_val: 8
+ float_val: 9
+ float_val: 10
+ float_val: 11
+ float_val: 12
+ }
+ }
+ }
+}
+node {
+ name: "Axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Concat"
+ op: "ConcatV2"
+ input: "Input01"
+ input: "Input02"
+ input: "Axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, concat_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ moco::tf::TensorName output("Concat", 0);
+ signature.add_output(output);
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(concat_01_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // TODO fix indent
+ // clang-format off
+
+ // what to test:
+ // - there should exist TensorConcat
+ // - lhs() should not be nullptr
+ // - rhs() should not be nullptr
+ // - axis() should match
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ loco::TensorConcat *concat_node =
+ moco::tf::test::find_first_node_bytype<loco::TensorConcat>(graph.get());
+
+ ASSERT_NE(concat_node, nullptr);
+ ASSERT_NE(concat_node->lhs(), nullptr);
+ ASSERT_NE(concat_node->rhs(), nullptr);
+ ASSERT_EQ(concat_node->axis(), 0);
+
+ // clang-format on
+ }
+
+ // Test "ConcatV2GraphBuilderImpl<ImportTarget::Tensorflow>"
+ {
+ // what to test:
+ // - there should exist TFConcatV2
+ // - there should be two values
+ // - values(idx) should not be nullptr
+ // - axis() should not be nullptr
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+ ASSERT_NE(concat_node, nullptr);
+ ASSERT_EQ(concat_node->num_values(), 2);
+ ASSERT_NE(concat_node->values(0), nullptr);
+ ASSERT_NE(concat_node->values(1), nullptr);
+ ASSERT_NE(concat_node->axis(), nullptr);
+ }
+}
+
+namespace
+{
+
+// clang-format off
+const char *concat_02_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Input01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1
+ float_val: 2
+ float_val: 3
+ float_val: 4
+ float_val: 5
+ float_val: 6
+ }
+ }
+ }
+}
+node {
+ name: "Input02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 7
+ float_val: 8
+ float_val: 9
+ float_val: 10
+ float_val: 11
+ float_val: 12
+ }
+ }
+ }
+}
+node {
+ name: "Input03"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 13
+ float_val: 14
+ float_val: 15
+ float_val: 16
+ float_val: 17
+ float_val: 18
+ }
+ }
+ }
+}
+node {
+ name: "Axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Concat"
+ op: "ConcatV2"
+ input: "Input01"
+ input: "Input02"
+ input: "Input03"
+ input: "Axis"
+ attr {
+ key: "N"
+ value {
+ i: 3
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, concat_02)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ moco::tf::TensorName output("Concat", 0);
+ signature.add_output(output);
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(concat_02_pbtxtdata, graph_def));
+
+ // Test "ConcatV2GraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // TODO fix indent
+ // clang-format off
+
+ // what to test: Concat has 3 inputs --> Importer creates 2 TensorConcat
+ // - there should exist two TensorConcat
+ // - lhs() of #1 should not be nullptr
+ // - rhs() of #1 should not be nullptr
+ // - lhs() of #2 should be #1
+ // - rhs() of #2 should not be nullptr
+ // - axis() should match
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ std::vector<loco::TensorConcat *> concat_nodes =
+ moco::tf::test::find_nodes_bytype<loco::TensorConcat>(graph.get());
+ ASSERT_EQ(concat_nodes.size(), 2);
+ loco::TensorConcat *concat_node0 = concat_nodes.at(0);
+ loco::TensorConcat *concat_node1 = concat_nodes.at(1);
+
+ ASSERT_NE(concat_node0, nullptr);
+ ASSERT_NE(concat_node1, nullptr);
+ ASSERT_NE(concat_node0->lhs(), nullptr);
+ ASSERT_NE(concat_node0->rhs(), nullptr);
+ ASSERT_NE(concat_node1->lhs(), nullptr);
+ ASSERT_NE(concat_node1->rhs(), nullptr);
+ ASSERT_TRUE(concat_node0->lhs() == concat_node1 || concat_node1->lhs() == concat_node0);
+ ASSERT_EQ(concat_node0->axis(), 0);
+ ASSERT_EQ(concat_node1->axis(), 0);
+
+ // clang-format on
+ }
+
+ // Test "ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ // what to test: TFConcatV2 has 3 inputs
+ // - there should exist TFConcatV2
+ // - values(idx) should not be nullptr
+ // - axis() should not be nullptr
+
+ using ConcatV2GraphBuilder = ConcatV2GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("ConcatV2", stdex::make_unique<ConcatV2GraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ auto concat_node = moco::tf::test::find_first_node_bytype<moco::tf::TFConcatV2>(graph.get());
+
+ ASSERT_NE(concat_node, nullptr);
+ ASSERT_EQ(concat_node->num_values(), 3);
+ ASSERT_NE(concat_node->values(0), nullptr);
+ ASSERT_NE(concat_node->values(1), nullptr);
+ ASSERT_NE(concat_node->values(2), nullptr);
+ ASSERT_NE(concat_node->axis(), nullptr);
+ }
+}
diff --git a/compiler/moco-tf/src/Op/Const.cpp b/compiler/moco-tf/src/Op/Const.cpp
new file mode 100644
index 000000000..f69645187
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Const.cpp
@@ -0,0 +1,359 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Const.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConst.h"
+
+#include <moco/tf/Names.h>
+#include <loco.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+#include <string>
+
+namespace
+{
+
+void read_value_int32(loco::ConstGen *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::S32>(num_elements);
+
+ int32_t input_elements = input_tensor.int_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(int32_t))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const int32_t *s32_ptr = reinterpret_cast<const int32_t *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = *(s32_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+void read_value_float32(loco::ConstGen *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+ int32_t input_elements = input_tensor.float_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(float))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const float *float_ptr = reinterpret_cast<const float *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = *(float_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+} // namespace
+
+namespace
+{
+
+void read_value_int32(moco::tf::TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::S32>(num_elements);
+
+ int32_t input_elements = input_tensor.int_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(int32_t))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const int32_t *s32_ptr = reinterpret_cast<const int32_t *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = *(s32_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+void read_value_float32(moco::tf::TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+ int32_t input_elements = input_tensor.float_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(float))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const float *float_ptr = reinterpret_cast<const float *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = *(float_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Const node
+ */
+class ConstGraphBuilder final : public ConstGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool ConstGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ return plier::tf::has_attrs(node, {"dtype", "value"});
+}
+
+void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConst>())
+ {
+ ConstGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ builder.build(node, context);
+ }
+ else
+ {
+ ConstGraphBuilderImpl<ImportTarget::Canonical> builder;
+ builder.build(node, context);
+ }
+}
+
+void ConstGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+
+ // Create a "ConstGen" node for Const
+ auto const_node = graph->nodes()->create<loco::ConstGen>();
+
+ // set dtype
+ auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
+ const_node->dtype(dtype);
+
+ // import shape and value
+ const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+ assert(input_shape.dim_size() <= 6);
+ const_node->rank(input_shape.dim_size());
+ int index = 0;
+ bool zero_sized_shape = false;
+ for (auto &d : input_dims)
+ {
+ if (d.size() > std::numeric_limits<int>::max())
+ throw std::runtime_error("Shape element overflows");
+ if (d.size() == 0)
+ zero_sized_shape = true;
+
+ if (d.size() >= 0)
+ const_node->dim(index++) = d.size();
+ else
+ throw std::runtime_error{"Error: Unknown dim size for " + node.name()};
+ }
+
+ int num_elements = 1;
+ if (zero_sized_shape)
+ {
+ const_node->rank(0);
+ num_elements = 0;
+ }
+ else
+ {
+ for (int d = 0; d < const_node->rank(); d++)
+ {
+ num_elements *= const_node->dim(d).value();
+ }
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::S32:
+ read_value_int32(const_node, num_elements, input_tensor);
+ break;
+
+ case loco::DataType::FLOAT32:
+ read_value_float32(const_node, num_elements, input_tensor);
+ break;
+
+ // TODO support other types
+
+ default:
+ throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+ }
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, const_node);
+}
+
+void ConstGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+
+ // Create a "TFConstant" node for Const
+ auto const_node = graph->nodes()->create<moco::tf::TFConst>();
+
+ // set dtype
+ auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
+ const_node->dtype(dtype);
+
+ // import shape and value
+ const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+ assert(input_shape.dim_size() <= 6);
+ const_node->rank(input_shape.dim_size());
+ int index = 0;
+ bool zero_sized_shape = false;
+ for (auto &d : input_dims)
+ {
+ if (d.size() > std::numeric_limits<int>::max())
+ throw std::runtime_error("Shape element overflows");
+ if (d.size() == 0)
+ zero_sized_shape = true;
+
+ if (d.size() >= 0)
+ const_node->dim(index++) = d.size();
+ else
+ throw std::runtime_error{"Error: Unknown dim size for " + node.name()};
+ }
+
+ int num_elements = 1;
+ if (zero_sized_shape)
+ {
+ const_node->rank(0);
+ num_elements = 0;
+ }
+ else
+ {
+ for (int d = 0; d < const_node->rank(); d++)
+ {
+ num_elements *= const_node->dim(d).value();
+ }
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::S32:
+ read_value_int32(const_node, num_elements, input_tensor);
+ break;
+
+ case loco::DataType::FLOAT32:
+ read_value_float32(const_node, num_elements, input_tensor);
+ break;
+
+ // TODO support other types
+
+ default:
+ throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+ }
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, const_node);
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Const, ConstGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Const.h b/compiler/moco-tf/src/Op/Const.h
new file mode 100644
index 000000000..4e727f06a
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Const.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_CONST_H__
+#define __OP_CONST_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct ConstGraphBuilderBase : public GraphBuilder
+{
+ virtual ~ConstGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class ConstGraphBuilderImpl;
+
+template <>
+struct ConstGraphBuilderImpl<ImportTarget::Canonical> final : public ConstGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct ConstGraphBuilderImpl<ImportTarget::TensorFlow> final : public ConstGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_CONST_H__
diff --git a/compiler/moco-tf/src/Op/Const.test.cpp b/compiler/moco-tf/src/Op/Const.test.cpp
new file mode 100644
index 000000000..20c6c0e77
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Const.test.cpp
@@ -0,0 +1,464 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Const.h"
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFConst.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+
+template <ImportTarget Target>
+std::unique_ptr<loco::Graph> import(const moco::tf::ModelSignature &sig, tensorflow::GraphDef &def)
+{
+ using ConstGraphBuilder = ConstGraphBuilderImpl<Target>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Const", stdex::make_unique<ConstGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ return importer.import(sig, def);
+}
+
+// Test case for "input_tensor.float_val_size() == num_elements"
+
+// clang-format off
+const char *const_float_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1.1
+ float_val: 2.2
+ float_val: 3.3
+ float_val: 4.4
+ float_val: 5.5
+ float_val: 6.6
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, const_float_01)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("const/float", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(const_float_01_pbtxtdata, graph_def));
+
+ // Test "tf.GraphDef -> loco.TF" importer
+ {
+ auto graph = import<ImportTarget::TensorFlow>(signature, graph_def);
+
+ moco::tf::TFConst *node0 =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
+ ASSERT_NE(node0, nullptr);
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+ }
+
+ // Test "tf.GraphDef -> loco.Canonical" importer
+ {
+ auto graph = import<ImportTarget::Canonical>(signature, graph_def);
+
+ loco::ConstGen *node0 = moco::tf::test::find_first_node_bytype<loco::ConstGen>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+ }
+}
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() == 1"
+
+// clang-format off
+const char *const_float_02_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, const_float_02)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("const/float", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(const_float_02_pbtxtdata, graph_def));
+
+ // Test "tf.GraphDef -> loco.TF" importer
+ {
+ auto graph = import<ImportTarget::TensorFlow>(signature, graph_def);
+
+ moco::tf::TFConst *node0 =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
+ ASSERT_NE(node0, nullptr);
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 1.1f);
+ }
+
+ // Test "tf.GraphDef -> loco.Canonical" importer
+ {
+ auto graph = import<ImportTarget::Canonical>(signature, graph_def);
+
+ loco::ConstGen *node0 = moco::tf::test::find_first_node_bytype<loco::ConstGen>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 1.1f);
+ }
+}
+
+namespace
+{
+// Test case for "input_tensor.tensor_content().size() == num_elements * sizeof(float)"
+// Generated with tfkit tool: "cat ./test.pbtxt | ./tfkit pack"
+
+// clang-format off
+const char *const_float_03_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ tensor_content: "\315\314\214?\315\314\014@33S@\315\314\214@\000\000\260@33\323@"
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, const_float_03)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("const/float", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(const_float_03_pbtxtdata, graph_def));
+
+ // Test "tf.GraphDef -> loco.TF" importer
+ {
+ auto graph = import<ImportTarget::TensorFlow>(signature, graph_def);
+
+ moco::tf::TFConst *node0 =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
+ ASSERT_NE(node0, nullptr);
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+ }
+
+ // Test "tf.GraphDef -> loco.Canonical" importer
+ {
+ auto graph = import<ImportTarget::Canonical>(signature, graph_def);
+
+ loco::ConstGen *node0 = moco::tf::test::find_first_node_bytype<loco::ConstGen>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+ }
+}
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() < num_elements"
+
+// clang-format off
+const char *const_float_04_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1.1
+ float_val: 2.2
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, const_float_04)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("const/float", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(const_float_04_pbtxtdata, graph_def));
+
+ // Test "tf.GraphDef -> loco.TF" importer
+ {
+ auto graph = import<ImportTarget::TensorFlow>(signature, graph_def);
+
+ moco::tf::TFConst *node0 =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 2.2f);
+ }
+
+ // Test "tf.GraphDef -> loco.Canonical" importer
+ {
+ auto graph = import<ImportTarget::Canonical>(signature, graph_def);
+
+ loco::ConstGen *node0 = moco::tf::test::find_first_node_bytype<loco::ConstGen>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 2.2f);
+ }
+}
+
+namespace
+{
+// Test case for "input_tensor.int_val_size() < num_elements"
+
+// clang-format off
+const char *const_int32_04_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/int"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ int_val: 1
+ int_val: 2
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, const_int32_04)
+{
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("const/int", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(const_int32_04_pbtxtdata, graph_def));
+
+// TODO Re-enable this
+#if 0
+ loco::Graph::OutputContext *outputs = graph->outputs();
+ ASSERT_EQ(outputs->size(), 1);
+ loco::GraphOutput *output = outputs->at(0);
+ loco::Push *push = output->node();
+
+ loco::Graph::NodeContext *nodes = graph->nodes();
+ ASSERT_EQ(nodes->size(), 2);
+#endif
+
+ // Test "tf.GraphDef -> loco.TF" importer
+ {
+ auto graph = import<ImportTarget::TensorFlow>(signature, graph_def);
+
+ moco::tf::TFConst *node0 =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConst>(graph.get());
+ ASSERT_NE(node0, nullptr);
+
+ ASSERT_EQ(node0->size<loco::DataType::S32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(0), 1);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(1), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(2), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(3), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(4), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(5), 2);
+ }
+
+ // Test "tf.GraphDef -> loco.Canonical" importer
+ {
+ auto graph = import<ImportTarget::Canonical>(signature, graph_def);
+
+ loco::ConstGen *node0 = moco::tf::test::find_first_node_bytype<loco::ConstGen>(graph.get());
+
+ ASSERT_EQ(node0->size<loco::DataType::S32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(0), 1);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(1), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(2), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(3), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(4), 2);
+ ASSERT_EQ(node0->at<loco::DataType::S32>(5), 2);
+ }
+}
diff --git a/compiler/moco-tf/src/Op/Conv2D.cpp b/compiler/moco-tf/src/Op/Conv2D.cpp
new file mode 100644
index 000000000..7e011a7e1
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Conv2D.cpp
@@ -0,0 +1,322 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Conv2D.h"
+
+#include "Convert.h"
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConv2D.h"
+
+#include "Annotations/PaddingData.h"
+#include "Annotations/PadData.h"
+
+#include <moco/tf/Names.h>
+
+#include <loco.h>
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+using namespace moco::tf;
+
+class IFMUpdate final : public GraphUpdate
+{
+public:
+ IFMUpdate(loco::FeatureEncode *ifm_enc, const TensorName &&ifm_name)
+ : _ifm_enc(ifm_enc), _ifm_name(ifm_name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::FeatureEncode *_ifm_enc;
+ const TensorName _ifm_name;
+};
+
+class KernelUpdate final : public GraphUpdate
+{
+public:
+ KernelUpdate(loco::FilterEncode *ker_enc, const TensorName &&ker_name)
+ : _ker_enc(ker_enc), _ker_name(ker_name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::FilterEncode *_ker_enc;
+ const TensorName _ker_name;
+};
+
+void IFMUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *ifm_node = node_table->node(_ifm_name);
+ _ifm_enc->input(ifm_node);
+}
+
+void KernelUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *ker_node = node_table->node(_ker_name);
+ _ker_enc->input(ker_node);
+}
+
+class TFConv2DGraphUpdate final : public GraphUpdate
+{
+public:
+ TFConv2DGraphUpdate(TFConv2D *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFConv2D *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFConv2DGraphUpdate::input(const SymbolTable *node_table) const
+{
+ assert(_names.size() == 2);
+
+ auto input_node = node_table->node(_names[0]);
+ auto filter_node = node_table->node(_names[1]);
+ assert(input_node != nullptr);
+ assert(filter_node != nullptr);
+
+ _node->input(input_node);
+ _node->filter(filter_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Conv2D node
+ */
+class Conv2DGraphBuilder final : public Conv2DGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool Conv2DGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ // note: even though "data_format" is not entered when a model is written,
+ // TF seems to generate "data_format" field into a pb file
+ return plier::tf::has_attrs(node, {"T", "data_format", "dilations", "padding", "strides"});
+}
+
+void Conv2DGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
+ {
+ Conv2DGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ builder.build(node, context);
+ }
+ else
+ {
+ Conv2DGraphBuilderImpl<ImportTarget::Canonical> builder;
+ builder.build(node, context);
+ }
+}
+
+void Conv2DGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ using plier::tf::DataLayout;
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ std::string conv2d_name = node.name();
+
+ // tensorflow data_format, e.g., NHWC, NCHW, etc.
+ auto data_layout = plier::tf::get_data_layout(node, "data_format");
+
+ auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
+ {
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ throw std::runtime_error("Not yet supported");
+
+ feature_enc->encoder(std::move(enc));
+ }
+
+ auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
+ {
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+ // In TensorFlow, conv2d filter is a 4-D tensor of following shape:
+ // [filter_height, filter_width, in_channels, out_channels] -> HWIO (HWCN)
+
+ enc->perm()->axis(loco::FilterAxis::Height) = 0;
+ enc->perm()->axis(loco::FilterAxis::Width) = 1;
+ enc->perm()->axis(loco::FilterAxis::Depth) = 2;
+ enc->perm()->axis(loco::FilterAxis::Count) = 3;
+
+ filter_enc->encoder(std::move(enc));
+ }
+
+ auto conv2d = graph->nodes()->create<loco::Conv2D>();
+ {
+ // let's convert attrs:
+ // TensorFlow attr : T, data_format, dilations, padding, strides
+ // to loco attr: not defined, TBD, TBD, TBD, stride
+
+ // tf strides -> loco stride
+ auto tf_strides = plier::tf::get_list_attr(node, "strides");
+ auto stride = conv2d->stride();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ stride->vertical(tf_strides.i(1));
+ stride->horizontal(tf_strides.i(2));
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ stride->vertical(tf_strides.i(2));
+ stride->horizontal(tf_strides.i(3));
+ }
+
+ // tf paddings -> PaddingData annotation
+ auto tf_padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
+ auto padding_data = stdex::make_unique<PaddingData>(tf_padding);
+ conv2d->annot(std::move(padding_data));
+ }
+
+ auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
+ {
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ throw std::runtime_error("Not supported data layout");
+
+ feature_dec->decoder(std::move(dec));
+ }
+
+ // link nodes
+ conv2d->ifm(feature_enc);
+ conv2d->ker(filter_enc);
+ feature_dec->input(conv2d);
+
+ // To set the input node of encode_node with conv2d_name
+ TensorName output_name(conv2d_name, 0);
+ tensor_names->enroll(output_name, feature_dec);
+
+ // Record ifm inputs to featureEncode_node
+ auto ifm_update = stdex::make_unique<IFMUpdate>(feature_enc, TensorName(node.input(0)));
+ auto ker_update = stdex::make_unique<KernelUpdate>(filter_enc, TensorName(node.input(1)));
+
+ updates->enroll(std::move(ifm_update));
+ updates->enroll(std::move(ker_update));
+}
+
+void Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ std::string conv2d_name = node.name();
+
+ auto conv2d = graph->nodes()->create<TFConv2D>();
+
+ // read attributes
+ auto data_layout = plier::tf::get_string_attr(node, "data_format");
+ if (!(data_layout == "NHWC" || data_layout == "NCHW"))
+ {
+ throw std::runtime_error("Not yet supported");
+ }
+ conv2d->data_layout(data_layout);
+
+ auto tf_strides = plier::tf::get_list_attr(node, "strides");
+ auto strides = plier::tf::as_int64_list(tf_strides);
+ conv2d->strides(strides);
+
+ auto padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
+ assert(padding == "VALID" || padding == "SAME");
+ conv2d->padding(padding);
+
+ // save the name for graph link updates
+ TensorName output_name(conv2d_name, 0);
+ tensor_names->enroll(output_name, conv2d);
+
+ std::vector<TensorName> input_names;
+ input_names.push_back(TensorName(node.input(0))); // input
+ input_names.push_back(TensorName(node.input(1))); // kernel
+
+ // Record ifm inputs to featureEncode_node
+ auto tfconv2d_update = stdex::make_unique<TFConv2DGraphUpdate>(conv2d, input_names);
+
+ updates->enroll(std::move(tfconv2d_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Conv2D, Conv2DGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Conv2D.h b/compiler/moco-tf/src/Op/Conv2D.h
new file mode 100644
index 000000000..e88b8e399
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Conv2D.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_CONV_2D_H__
+#define __OP_CONV_2D_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct Conv2DGraphBuilderBase : public GraphBuilder
+{
+ virtual ~Conv2DGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class Conv2DGraphBuilderImpl;
+
+template <>
+struct Conv2DGraphBuilderImpl<ImportTarget::Canonical> final : public Conv2DGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct Conv2DGraphBuilderImpl<ImportTarget::TensorFlow> final : public Conv2DGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_CONV_2D_H__
diff --git a/compiler/moco-tf/src/Op/Conv2D.test.cpp b/compiler/moco-tf/src/Op/Conv2D.test.cpp
new file mode 100644
index 000000000..faf2977fc
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Conv2D.test.cpp
@@ -0,0 +1,513 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Conv2D.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFConv2D.h"
+
+#include <loco.h>
+#include <loco/IR/TensorShape.h>
+#include <loco/IR/FeatureShape.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *conv2d_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "ifm"
+ op: "Const"
+ attr { key: "dtype" value { type: DT_FLOAT } }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 4 }
+ dim { size: 4 }
+ dim { size: 3 }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "ker"
+ op: "Const"
+ attr { key: "dtype" value { type: DT_FLOAT } }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 2 }
+ dim { size: 2 }
+ dim { size: 3 }
+ dim { size: 100 }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "conv2d"
+ op: "Conv2D"
+ input: "ifm"
+ input: "ker"
+ attr { key: "T" value { type: DT_FLOAT } }
+ attr { key: "data_format" value { s: "NHWC" } }
+ attr { key: "dilations" value { list { i: 1 i: 1 i: 1 i: 1 } } }
+ attr { key: "padding" value { s: "VALID" } }
+ attr { key: "strides" value { list { i: 1 i: 2 i: 3 i: 1 } } }
+ attr { key: "use_cudnn_on_gpu" value { b: false } }
+}
+);
+// clang-format on
+} // namespace
+
+namespace
+{
+
+void verify_Conv2D_01(loco::Graph *graph)
+{
+ // what to test:
+ // - Con2D node should exist
+ // - ifm() should be FeatureEncode
+ // - ker() should be FilterEncode
+ // - following node should be FeatureDecode
+ // - FeatureEncode encoder should encode Count-Height-Width-Depth order
+ // - FeatureDecode decoder should decode Count-Height-Width-Depth order
+
+ // test 1.
+ // loco node : ConstGen - FeatureEncode -- Conv2D - FeatureDecode - Push
+ // ConstGen - FilterEncode /
+
+ loco::Conv2D *conv2d = moco::tf::test::find_first_node_bytype<loco::Conv2D>(graph);
+ ASSERT_NE(conv2d, nullptr);
+
+ loco::FeatureEncode *ifm_enc = dynamic_cast<loco::FeatureEncode *>(conv2d->ifm());
+ loco::FilterEncode *ker_enc = dynamic_cast<loco::FilterEncode *>(conv2d->ker());
+ ASSERT_NE(ifm_enc, nullptr);
+ ASSERT_NE(ker_enc, nullptr);
+
+ auto following_nodes = loco::succs(conv2d);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+ loco::FeatureDecode *dec = dynamic_cast<loco::FeatureDecode *>(following_node);
+ ASSERT_NE(dec, nullptr);
+
+ // test 2.
+ // attrs inside Conv2D
+ {
+ // stride
+ ASSERT_EQ(conv2d->stride()->vertical(), 2);
+ ASSERT_EQ(conv2d->stride()->horizontal(), 3);
+
+ // TODO add padding test
+ }
+
+ // test 3.
+ // attrs inside FeatureEncoder
+ {
+ auto ifm_encoder = ifm_enc->encoder();
+ ASSERT_TRUE(ifm_encoder != nullptr);
+
+ loco::TensorShape tensor_shape;
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 1; // COUNT
+ tensor_shape.dim(1) = 720; // HEIGHT
+ tensor_shape.dim(2) = 1280; // WIDTH
+ tensor_shape.dim(3) = 3; // DEPTH
+
+ // Get the feature shape corresponding to a given image
+ auto feature_shape = ifm_encoder->shape(tensor_shape);
+
+ ASSERT_EQ(feature_shape.count().value(), 1);
+ ASSERT_EQ(feature_shape.height().value(), 720);
+ ASSERT_EQ(feature_shape.width().value(), 1280);
+ ASSERT_EQ(feature_shape.depth().value(), 3);
+ }
+
+ // test 4.
+ // attrs inside FilterEncoder
+ {
+ auto ker_encoder = ker_enc->encoder();
+ ASSERT_TRUE(ker_encoder != nullptr);
+
+ loco::TensorShape tensor_shape;
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 2; // H
+ tensor_shape.dim(1) = 4; // W
+ tensor_shape.dim(2) = 3; // I (C)
+ tensor_shape.dim(3) = 7; // O (N)
+
+ // Get the feature shape corresponding to a given image
+ auto ker_shape = ker_encoder->shape(tensor_shape);
+
+ ASSERT_EQ(ker_shape.height().value(), 2);
+ ASSERT_EQ(ker_shape.width().value(), 4);
+ ASSERT_EQ(ker_shape.depth().value(), 3);
+ ASSERT_EQ(ker_shape.count().value(), 7);
+ }
+
+ // test 5
+ // attrs inside FeatureDecoder
+ {
+ auto decoder = dec->decoder();
+ ASSERT_TRUE(decoder != nullptr);
+
+ loco::FeatureShape feature_shape;
+ feature_shape.count() = 1;
+ feature_shape.height() = 720;
+ feature_shape.width() = 1280;
+ feature_shape.depth() = 3;
+
+ // Get the tensor shape corresponding to a given image
+ auto tensor_shape = decoder->shape(feature_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0).value(), 1); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1).value(), 720); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2).value(), 1280); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3).value(), 3); // DEPTH
+ }
+}
+
+void verify_TFConv2D_01(loco::Graph *graph)
+{
+ // what to test:
+ // - Con2D node should exist
+ // - ifm() should not be nullptr
+ // - ker() should not be nullptr
+ // - attribute values should match
+
+ // loco node : ConstGen - TFConv2D - Push
+ // ConstGen /
+ moco::tf::TFConv2D *tfconv2d = moco::tf::test::find_first_node_bytype<moco::tf::TFConv2D>(graph);
+ ASSERT_NE(tfconv2d, nullptr);
+ ASSERT_NE(tfconv2d->input(), nullptr);
+ ASSERT_NE(tfconv2d->filter(), nullptr);
+
+ // attrs inside TFConv2D
+ ASSERT_EQ(tfconv2d->padding(), "VALID");
+ ASSERT_EQ(tfconv2d->data_layout(), "NHWC");
+ auto strides = tfconv2d->strides();
+ ASSERT_EQ(strides.size(), 4);
+ // TODO add verify dilation
+}
+
+} // namespace
+
+TEST(TensorFlowImport, Conv2D_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("conv2d", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(conv2d_01_pbtxtdata, graph_def));
+
+ // Test loco.TF Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ verify_TFConv2D_01(graph.get());
+ }
+
+ // Test loco.Canonical Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ verify_Conv2D_01(graph.get());
+ }
+}
+
+namespace
+{
+// clang-format off
+const char *conv2d_inception_pbtxtdata = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype" value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 1 }
+ dim { size: 299 }
+ dim { size: 299 }
+ dim { size: 3 }
+ }
+ }
+ }
+}
+node {
+ name: "InceptionV3/Conv2d_1a_3x3/weights/read/_3__cf__3"
+ op: "Const"
+ attr { key: "dtype" value { type: DT_FLOAT } }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 3 }
+ dim { size: 3 }
+ dim { size: 3 }
+ dim { size: 32 }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D"
+ op: "Conv2D"
+ input: "input:0"
+ input: "InceptionV3/Conv2d_1a_3x3/weights/read/_3__cf__3"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NHWC" }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list { i: 1 i: 1 i: 1 i: 1 }
+ }
+ }
+ attr {
+ key: "padding"
+ value { s: "VALID" }
+ }
+ attr {
+ key: "strides"
+ value {
+ list { i: 1 i: 2 i: 2 i: 1 }
+ }
+ }
+ attr {
+ key: "use_cudnn_on_gpu"
+ value { b: true }
+ }
+}
+);
+} // namespace
+
+namespace
+{
+
+void verify_Conv2D_inception_indexed_tensor_name(loco::Graph *graph)
+{
+ // what to test: name with ':0' should be treated correctly
+ // - Con2D node should exist
+ // - ifm() should be FeatureEncode
+ // - ker() should be FilterEncode
+ // - following node should be FeatureDecode
+ // - FeatureEncode encoder should encode Count-Height-Width-Depth order
+ // - FeatureDecode decoder should decode Count-Height-Width-Depth order
+
+ // test 1.
+ // loco node : Pull - FeatureEncode -- Conv2D - FeatureDecode - Push
+ // ConstGen - FilterEncode /
+
+ loco::Conv2D *conv2d =
+ moco::tf::test::find_first_node_bytype<loco::Conv2D>(graph);
+ ASSERT_NE(conv2d, nullptr);
+
+ loco::FeatureEncode *ifm_enc = dynamic_cast<loco::FeatureEncode *>(conv2d->ifm());
+ loco::FilterEncode *ker_enc = dynamic_cast<loco::FilterEncode *>(conv2d->ker());
+ ASSERT_NE(ifm_enc, nullptr);
+ ASSERT_NE(ker_enc, nullptr);
+
+ auto following_nodes = loco::succs(conv2d);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+ loco::FeatureDecode *dec = dynamic_cast<loco::FeatureDecode *>(following_node);
+ ASSERT_NE(dec, nullptr);
+
+ // TODO remove below tests as it's duplicate as in verify_Conv2D_01
+
+ // test 2.
+ // attrs inside Conv2D
+ {
+ // stride
+ ASSERT_EQ(conv2d->stride()->vertical(), 2);
+ ASSERT_EQ(conv2d->stride()->horizontal(), 2);
+
+ // TODO add padding test
+ }
+
+ // test 3.
+ // attrs inside FeatureEncoder
+ {
+ auto ifm_encoder = ifm_enc->encoder();
+ ASSERT_TRUE(ifm_encoder != nullptr);
+
+ loco::TensorShape tensor_shape;
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 1; // COUNT
+ tensor_shape.dim(1) = 299; // HEIGHT
+ tensor_shape.dim(2) = 299; // WIDTH
+ tensor_shape.dim(3) = 3; // DEPTH
+
+ // Get the feature shape corresponding to a given image
+ auto feature_shape = ifm_encoder->shape(tensor_shape);
+
+ ASSERT_EQ(feature_shape.count().value(), 1);
+ ASSERT_EQ(feature_shape.height().value(), 299);
+ ASSERT_EQ(feature_shape.width().value(), 299);
+ ASSERT_EQ(feature_shape.depth().value(), 3);
+ }
+
+ // test 4.
+ // attrs inside FilterEncoder
+ {
+ auto ker_encoder = ker_enc->encoder();
+ ASSERT_TRUE(ker_encoder != nullptr);
+
+ loco::TensorShape tensor_shape;
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 3; // H
+ tensor_shape.dim(1) = 3; // W
+ tensor_shape.dim(2) = 3; // I (C)
+ tensor_shape.dim(3) = 32; // O (N)
+
+ // Get the feature shape corresponding to a given image
+ auto ker_shape = ker_encoder->shape(tensor_shape);
+
+ ASSERT_EQ(ker_shape.height().value(), 3);
+ ASSERT_EQ(ker_shape.width().value(), 3);
+ ASSERT_EQ(ker_shape.depth().value(), 3);
+ ASSERT_EQ(ker_shape.count().value(), 32);
+ }
+
+ // test 5
+ // attrs inside FeatureDecoder
+ {
+ auto decoder = dec->decoder();
+ ASSERT_TRUE(decoder != nullptr);
+
+ loco::FeatureShape feature_shape;
+ feature_shape.count() = 1;
+ feature_shape.height() = 299;
+ feature_shape.width() = 299;
+ feature_shape.depth() = 3;
+
+ // Get the tensor shape corresponding to a given image
+ auto tensor_shape = decoder->shape(feature_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0).value(), 1); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1).value(), 299); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2).value(), 299); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3).value(), 3); // DEPTH
+ }
+}
+
+void verify_TFConv2D_inception_indexed_tensor_name(loco::Graph *graph)
+{
+ // what to test: name with ':0' should be treated correctly
+ // - Con2D node should exist
+ // - ifm() should not be nullptr
+ // - ker() should not be nullptr
+
+ // loco node : Pull - Conv2D - Push
+ // ConstGen /
+ moco::tf::TFConv2D *tfconv2d =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFConv2D>(graph);
+ ASSERT_NE(tfconv2d, nullptr);
+ ASSERT_NE(tfconv2d->input(), nullptr);
+ ASSERT_NE(tfconv2d->filter(), nullptr);
+}
+
+} // namespace
+
+TEST(TensorFlowImport, Conv2D_inception_indexed_tensor_name)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("input", 0));
+ signature.add_output(moco::tf::TensorName("InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(conv2d_inception_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // Test loco.TF Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ verify_TFConv2D_inception_indexed_tensor_name(graph.get());
+ }
+
+ // Test loco.Canonical Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ verify_Conv2D_inception_indexed_tensor_name(graph.get());
+ }
+}
diff --git a/compiler/moco-tf/src/Op/DepthwiseConv2dNative.cpp b/compiler/moco-tf/src/Op/DepthwiseConv2dNative.cpp
new file mode 100644
index 000000000..33f5fa4cd
--- /dev/null
+++ b/compiler/moco-tf/src/Op/DepthwiseConv2dNative.cpp
@@ -0,0 +1,155 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <plier/tf/Convert.h>
+#include "GraphBuilder.h"
+
+#include "IR/TFDepthwiseConv2dNative.h"
+
+#include "Annotations/PaddingData.h"
+#include "Annotations/PadData.h"
+
+#include <moco/tf/Names.h>
+
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+
+using namespace plier::tf;
+
+namespace
+{
+using namespace moco::tf;
+
+class TFDepthwiseConv2dNativeGraphUpdate final : public GraphUpdate
+{
+public:
+ TFDepthwiseConv2dNativeGraphUpdate(TFDepthwiseConv2dNative *node, std::vector<TensorName> names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFDepthwiseConv2dNative *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFDepthwiseConv2dNativeGraphUpdate::input(const SymbolTable *node_table) const
+{
+ assert(_names.size() == 2);
+
+ auto input_node = node_table->node(_names[0]);
+ auto filter_node = node_table->node(_names[1]);
+ assert(input_node != nullptr);
+ assert(filter_node != nullptr);
+
+ _node->input(input_node);
+ _node->filter(filter_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for DepthwiseConv2dNative node
+ */
+class DepthwiseConv2dNativeGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool DepthwiseConv2dNativeGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ auto data_layout = get_string_attr(node, "data_format");
+ if (!(data_layout == "NHWC" || data_layout == "NCHW"))
+ {
+ throw std::runtime_error("Not yet supported");
+ }
+
+ auto padding = moco::str_toupper(get_string_attr(node, "padding"));
+ assert(padding == "VALID" || padding == "SAME");
+
+ auto tf_strides = get_list_attr(node, "strides");
+ auto strides = as_int64_list(tf_strides);
+ assert(strides.size() == 4);
+ auto stride_n = strides.at(0);
+ auto stride_h = strides.at(1);
+ auto stride_w = strides.at(2);
+ auto stride_c = strides.at(3);
+ assert(stride_n == 1 && stride_c == 1);
+ assert(stride_h == stride_w);
+
+ // note: even though "data_format" and "dilations" are not entered when a model is written,
+ // TF seems to generate those field into a pb file.
+ return has_attrs(node, {"T", "data_format", "dilations", "padding", "strides"});
+}
+
+void DepthwiseConv2dNativeGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ auto depthwiseconv2d_native_node = graph->nodes()->create<TFDepthwiseConv2dNative>();
+
+ // read attributes
+ auto data_layout = get_string_attr(node, "data_format");
+ depthwiseconv2d_native_node->data_layout(data_layout);
+
+ auto tf_strides = get_list_attr(node, "strides");
+ auto strides = as_int64_list(tf_strides);
+ depthwiseconv2d_native_node->strides(strides);
+
+ auto padding = moco::str_toupper(get_string_attr(node, "padding"));
+ depthwiseconv2d_native_node->padding(padding);
+
+ // save the name for graph link updates
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, depthwiseconv2d_native_node);
+
+ std::vector<TensorName> input_names;
+ input_names.push_back(TensorName(node.input(0))); // input
+ input_names.push_back(TensorName(node.input(1))); // kernel
+
+ // Record ifm inputs to featureEncode_node
+ auto tfdepthwiseconv2dnative_update = stdex::make_unique<TFDepthwiseConv2dNativeGraphUpdate>(
+ depthwiseconv2d_native_node, input_names);
+
+ updates->enroll(std::move(tfdepthwiseconv2dnative_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(DepthwiseConv2dNative, DepthwiseConv2dNativeGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/DepthwiseConv2dNative.test.cpp b/compiler/moco-tf/src/Op/DepthwiseConv2dNative.test.cpp
new file mode 100644
index 000000000..64ae27da8
--- /dev/null
+++ b/compiler/moco-tf/src/Op/DepthwiseConv2dNative.test.cpp
@@ -0,0 +1,219 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFDepthwiseConv2dNative.h"
+
+#include <loco/IR/TensorShape.h>
+#include <loco/IR/FeatureShape.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *depthwise_conv2d_native_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 4
+ }
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "filter"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+}
+node {
+ name: "depthwise/Shape"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ int_val: 2
+ int_val: 2
+ int_val: 3
+ int_val: 2
+ }
+ }
+ }
+}
+node {
+ name: "depthwise/dilation_rate"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ int_val: 1
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "depthwise"
+ op: "DepthwiseConv2dNative"
+ input: "input"
+ input: "filter"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+}
+);
+// clang-format on
+} // namespace
+
+TEST(TensorFlowImport, Depthwise_conv2d_native)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("input", 0));
+ signature.add_output(moco::tf::TensorName("depthwise", 0));
+
+ tensorflow::GraphDef graph_def;
+
+ EXPECT_TRUE(plier::tf::parse_graphdef(depthwise_conv2d_native_01_pbtxtdata, graph_def));
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ moco::tf::TFDepthwiseConv2dNative *tfdepthwiseconv2dnative =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFDepthwiseConv2dNative>(graph.get());
+ ASSERT_NE(tfdepthwiseconv2dnative, nullptr);
+ ASSERT_NE(tfdepthwiseconv2dnative->input(), nullptr);
+ ASSERT_NE(tfdepthwiseconv2dnative->filter(), nullptr);
+
+ ASSERT_EQ(tfdepthwiseconv2dnative->padding(), "VALID");
+ ASSERT_EQ(tfdepthwiseconv2dnative->data_layout(), "NHWC");
+ ASSERT_EQ(tfdepthwiseconv2dnative->strides().size(), 4);
+}
diff --git a/compiler/moco-tf/src/Op/FusedBatchNorm.cpp b/compiler/moco-tf/src/Op/FusedBatchNorm.cpp
new file mode 100644
index 000000000..d22b690bf
--- /dev/null
+++ b/compiler/moco-tf/src/Op/FusedBatchNorm.cpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for FusedBatchNorm node
+ */
+class FusedBatchNormGraphUpdate final : public GraphUpdate
+{
+public:
+ FusedBatchNormGraphUpdate(TFFusedBatchNorm *node, std::vector<TensorName> names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFFusedBatchNorm *_node;
+ std::vector<TensorName> _names;
+};
+
+void FusedBatchNormGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 5);
+
+ _node->input(tensor_names->node(_names[0]));
+ _node->gamma(tensor_names->node(_names[1]));
+ _node->beta(tensor_names->node(_names[2]));
+ _node->mean(tensor_names->node(_names[3]));
+ _node->variance(tensor_names->node(_names[4]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for FusedBatchNorm node
+ */
+class FusedBatchNormGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool FusedBatchNormGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 5);
+
+ return plier::tf::has_attrs(node, {"epsilon"});
+}
+
+void FusedBatchNormGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ float epsilon = plier::tf::get_float_attr(node, "epsilon");
+
+ // creating TF dialect FusedBatchNorm node
+ auto tf_fbn = graph->nodes()->create<TFFusedBatchNorm>();
+ tf_fbn->epsilon(epsilon);
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_fbn);
+
+ std::vector<TensorName> fbn_input_names;
+ fbn_input_names.push_back(TensorName(node.input(0))); // input
+ fbn_input_names.push_back(TensorName(node.input(1))); // scale
+ fbn_input_names.push_back(TensorName(node.input(2))); // offset
+ fbn_input_names.push_back(TensorName(node.input(3))); // mean
+ fbn_input_names.push_back(TensorName(node.input(4))); // variance
+
+ auto tf_fbn_update = stdex::make_unique<FusedBatchNormGraphUpdate>(tf_fbn, fbn_input_names);
+ updates->enroll(std::move(tf_fbn_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(FusedBatchNorm, FusedBatchNormGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/FusedBatchNorm.test.cpp b/compiler/moco-tf/src/Op/FusedBatchNorm.test.cpp
new file mode 100644
index 000000000..d9c45bca0
--- /dev/null
+++ b/compiler/moco-tf/src/Op/FusedBatchNorm.test.cpp
@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *fbn_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 4 }
+ dim { size: 4 }
+ dim { size: 1 }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "gamma"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "beta"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/mean"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/variance"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01"
+ op: "FusedBatchNorm"
+ input: "input"
+ input: "gamma"
+ input: "beta"
+ input: "FBN_01/mean"
+ input: "FBN_01/variance"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "epsilon"
+ value {
+ f: 0.001
+ }
+ }
+ attr {
+ key: "is_training"
+ value {
+ b: false
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_fbn_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("FBN_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(fbn_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist a TFFusedBatchNorm
+ // - input() should not be nullptr
+ // - gamma() should not be nullptr
+ // - beta() should not be nullptr
+ // - mean() should not be nullptr
+ // - variance() should not be nullptr
+ // - epsilon() value should match
+
+ moco::tf::TFFusedBatchNorm *fbn_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFFusedBatchNorm>(graph.get());
+
+ ASSERT_NE(fbn_node->input(), nullptr);
+ ASSERT_NE(fbn_node->gamma(), nullptr);
+ ASSERT_NE(fbn_node->beta(), nullptr);
+ ASSERT_NE(fbn_node->mean(), nullptr);
+ ASSERT_NE(fbn_node->variance(), nullptr);
+ ASSERT_EQ(fbn_node->epsilon(), 0.001f);
+}
diff --git a/compiler/moco-tf/src/Op/Identity.cpp b/compiler/moco-tf/src/Op/Identity.cpp
new file mode 100644
index 000000000..03dac6d26
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Identity.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Identity.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFIdentity.h"
+
+#include <moco/tf/Names.h>
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class IdentityGraphUpdate final : public GraphUpdate
+{
+public:
+ IdentityGraphUpdate(loco::Forward *node, const std::vector<TensorName> &names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::Forward *_node;
+ const std::vector<TensorName> _names;
+};
+
+class TFIdentityGraphUpdate final : public GraphUpdate
+{
+public:
+ TFIdentityGraphUpdate(moco::tf::TFIdentity *node, const std::vector<TensorName> &names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFIdentity *_node;
+ const std::vector<TensorName> _names;
+};
+
+void IdentityGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ for (auto &name : _names)
+ {
+ loco::Node *target = tensor_names->node(name);
+ _node->input(target);
+ }
+}
+
+void TFIdentityGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ for (auto &name : _names)
+ {
+ loco::Node *target = tensor_names->node(name);
+ _node->input(target);
+ }
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Identity node
+ */
+class IdentityGraphBuilder final : public IdentityGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool IdentityGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ if (node.input_size() < 1) // from TensorFlow lite toco
+ return false;
+
+ return true;
+}
+
+void IdentityGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFIdentity>())
+ {
+ IdentityGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ IdentityGraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void IdentityGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a "Forward" node for Identity
+ auto forward_node = graph->nodes()->create<loco::Forward>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, forward_node);
+
+ // Queue node input update
+ // TODO: Check if we really need multiple input handlings
+ std::vector<TensorName> names;
+ for (int i = 0; i < node.input_size(); ++i)
+ {
+ names.emplace_back(TensorName(node.input(i)));
+ }
+ auto update = stdex::make_unique<IdentityGraphUpdate>(forward_node, names);
+ updates->enroll(std::move(update));
+}
+
+void IdentityGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a Identity node
+ auto identity_node = graph->nodes()->create<moco::tf::TFIdentity>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, identity_node);
+
+ // Queue node input update
+ // TODO: Check if we really need multiple input handlings
+ std::vector<TensorName> names;
+ for (int i = 0; i < node.input_size(); ++i)
+ {
+ names.emplace_back(TensorName(node.input(i)));
+ }
+ auto update = stdex::make_unique<TFIdentityGraphUpdate>(identity_node, names);
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Identity, IdentityGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Identity.h b/compiler/moco-tf/src/Op/Identity.h
new file mode 100644
index 000000000..55da0070e
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Identity.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_IDENTITY_H__
+#define __OP_IDENTITY_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct IdentityGraphBuilderBase : public GraphBuilder
+{
+ virtual ~IdentityGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class IdentityGraphBuilderImpl;
+
+template <>
+struct IdentityGraphBuilderImpl<ImportTarget::Canonical> final : public IdentityGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct IdentityGraphBuilderImpl<ImportTarget::TensorFlow> final : public IdentityGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_IDENTITY_H__
diff --git a/compiler/moco-tf/src/Op/MaxPool.cpp b/compiler/moco-tf/src/Op/MaxPool.cpp
new file mode 100644
index 000000000..079d91448
--- /dev/null
+++ b/compiler/moco-tf/src/Op/MaxPool.cpp
@@ -0,0 +1,297 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MaxPool.h"
+
+#include "Convert.h"
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFMaxPool.h"
+
+#include "Annotations/PaddingData.h"
+
+#include <moco/tf/Names.h>
+#include <loco.h>
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class MaxPoolGraphUpdate final : public GraphUpdate
+{
+public:
+ MaxPoolGraphUpdate(loco::FeatureEncode *node, const TensorName &&name)
+ : _encode_node(node), _input_name(name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::FeatureEncode *_encode_node;
+ const TensorName _input_name;
+};
+
+class TFMaxPoolGraphUpdate final : public GraphUpdate
+{
+public:
+ TFMaxPoolGraphUpdate(moco::tf::TFMaxPool *node, const TensorName &name)
+ : _maxpool_node(node), _value_name(name)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFMaxPool *_maxpool_node;
+ const TensorName _value_name;
+};
+
+void MaxPoolGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ loco::Node *input_node = tensor_names->node(_input_name);
+ _encode_node->input(input_node);
+}
+
+void TFMaxPoolGraphUpdate::input(const SymbolTable *node_table) const
+{
+ loco::Node *value_node = node_table->node(_value_name);
+ _maxpool_node->value(value_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for MaxPool node
+ */
+class MaxPoolGraphBuilder final : public MaxPoolGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool MaxPoolGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ // note: even though "data_format" is not entered when a model is written,
+ // TF seems to generate "data_format" field into a pb file
+ return plier::tf::has_attrs(node, {"T", "data_format", "ksize", "padding", "strides"});
+}
+
+void MaxPoolGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFMaxPool>())
+ {
+ MaxPoolGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ MaxPoolGraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void MaxPoolGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ using plier::tf::DataLayout;
+
+ assert(context != nullptr);
+ assert(node.input_size() == 1);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // create loco nodes
+ auto encode_node = graph->nodes()->create<loco::FeatureEncode>();
+ auto maxpool2d_node = graph->nodes()->create<loco::MaxPool2D>();
+ auto decode_node = graph->nodes()->create<loco::FeatureDecode>();
+
+ // name of loco nodes
+ ::std::string maxpool2d_name = node.name();
+
+ // tensorflow data_format, e.g., NHWC, NCHW, etc.
+ auto data_layout = plier::tf::get_data_layout(node, "data_format");
+
+ // FeatureEncode
+ {
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else
+ throw std::runtime_error("Not supported data layout");
+
+ encode_node->encoder(std::move(enc));
+ }
+
+ // MaxPool
+ {
+ // let's convert attrs:
+ // TensorFlow attr : T, data_format, ksize, padding, strides
+ // to loco attr: not defined, TBD, window, TBD, stride
+
+ // tf ksize -> loco window
+ auto tf_ksize = plier::tf::get_list_attr(node, "ksize");
+ auto window = maxpool2d_node->window();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ window->vertical(tf_ksize.i(1));
+ window->horizontal(tf_ksize.i(2));
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ window->vertical(tf_ksize.i(2));
+ window->horizontal(tf_ksize.i(3));
+ }
+
+ // tf strides -> loco stride
+ auto tf_strides = plier::tf::get_list_attr(node, "strides");
+ auto stride = maxpool2d_node->stride();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ stride->vertical(tf_strides.i(1));
+ stride->horizontal(tf_strides.i(2));
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ stride->vertical(tf_strides.i(2));
+ stride->horizontal(tf_strides.i(3));
+ }
+
+ // tf paddings -> PaddingData annotation
+ auto tf_padding = plier::tf::get_string_attr(node, "padding");
+ auto padding_data = stdex::make_unique<PaddingData>(tf_padding);
+ maxpool2d_node->annot(std::move(padding_data));
+ }
+
+ // FeatureDecode
+ {
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else
+ throw std::runtime_error("Not supported data layout");
+
+ decode_node->decoder(std::move(dec));
+ }
+
+ // link nodes
+ maxpool2d_node->ifm(encode_node);
+ decode_node->input(maxpool2d_node);
+
+ // To set the input node of encode_node with maxpool2d_name
+ TensorName output_name(maxpool2d_name, 0);
+ tensor_names->enroll(output_name, decode_node);
+
+ // Record ifm inputs to featureEncode_node
+ auto update = stdex::make_unique<MaxPoolGraphUpdate>(encode_node, TensorName(node.input(0)));
+
+ updates->enroll(std::move(update));
+}
+
+void MaxPoolGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ ::std::string node_name = node.name();
+
+ // tensorflow data_format: one of NHWC or NCHW.
+ auto data_layout = plier::tf::get_string_attr(node, "data_format");
+ auto maxPool_node = graph->nodes()->create<moco::tf::TFMaxPool>();
+ maxPool_node->data_layout(data_layout);
+
+ // padding
+ auto padding = moco::str_toupper(plier::tf::get_string_attr(node, "padding"));
+ maxPool_node->padding(padding);
+
+ // ksize
+ auto tf_ksize = plier::tf::get_list_attr(node, "ksize");
+ auto ksize = plier::tf::as_int64_list(tf_ksize);
+ if (ksize.size() != 4)
+ {
+ // TODO support ksize length for 1 and 2
+ throw std::runtime_error("MaxPool only supports ksize length 4");
+ }
+ maxPool_node->ksize(ksize);
+
+ // strides
+ auto tf_strides = plier::tf::get_list_attr(node, "strides");
+ auto strides = plier::tf::as_int64_list(tf_strides);
+ if (strides.size() != 4)
+ {
+ // TODO support strides length for 1 and 2
+ throw std::runtime_error("MaxPool only supports strides length 4");
+ }
+ maxPool_node->strides(strides);
+
+ // To set the input node of encode_node with node_name
+ TensorName output_name(node_name, 0);
+ tensor_names->enroll(output_name, maxPool_node);
+
+ // Record ifm inputs to featureEncode_node
+ auto update = stdex::make_unique<TFMaxPoolGraphUpdate>(maxPool_node, TensorName(node.input(0)));
+
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(MaxPool, MaxPoolGraphBuilder)
+
+// TODO Consider a case when TF MaxPool is for 3D.
+// MaxPool works for 2D and other Dimensions, such as 3D
+// So, in future, some other GraphBuilder decide if MaxPoolGraphBuilder is used or
+// other GraphBuilder is used for TF MaxPool
diff --git a/compiler/moco-tf/src/Op/MaxPool.h b/compiler/moco-tf/src/Op/MaxPool.h
new file mode 100644
index 000000000..e95f19e31
--- /dev/null
+++ b/compiler/moco-tf/src/Op/MaxPool.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_MAX_POOL_H__
+#define __OP_MAX_POOL_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct MaxPoolGraphBuilderBase : public GraphBuilder
+{
+ virtual ~MaxPoolGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class MaxPoolGraphBuilderImpl;
+
+template <>
+struct MaxPoolGraphBuilderImpl<ImportTarget::Canonical> final : public MaxPoolGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct MaxPoolGraphBuilderImpl<ImportTarget::TensorFlow> final : public MaxPoolGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_MAX_POOL_H__
diff --git a/compiler/moco-tf/src/Op/MaxPool.test.cpp b/compiler/moco-tf/src/Op/MaxPool.test.cpp
new file mode 100644
index 000000000..308520d46
--- /dev/null
+++ b/compiler/moco-tf/src/Op/MaxPool.test.cpp
@@ -0,0 +1,299 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MaxPool.h"
+
+#include "IR/TFMaxPool.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <loco/IR/TensorShape.h>
+#include <loco/IR/FeatureShape.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *maxpool_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+node {
+ name: "maxpool"
+ op: "MaxPool"
+ input: "const/float"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "ksize"
+ value {
+ list {
+ i: 1
+ i: 2
+ i: 3
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 3
+ i: 2
+ i: 1
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, MaxPool_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("maxpool", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(maxpool_01_pbtxtdata, graph_def));
+
+ // Test "MaxPoolGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // what to test:
+ // - there should exist MaxPool2D
+ // - ifm node should be FeatureEncode
+ // - following node should be FeatureDecode
+ // - stride values should match
+ // - window values should match
+
+ using MaxPoolGraphBuilder = MaxPoolGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("MaxPool", stdex::make_unique<MaxPoolGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ loco::MaxPool2D *maxpool2d_node =
+ moco::tf::test::find_first_node_bytype<loco::MaxPool2D>(graph.get());
+ ASSERT_NE(maxpool2d_node, nullptr);
+
+ loco::Node *previous_node = maxpool2d_node->ifm();
+ auto following_nodes = loco::succs(maxpool2d_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+
+ loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(previous_node);
+ loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(following_node);
+
+ ASSERT_NE(enc_node, nullptr);
+ ASSERT_NE(dec_node, nullptr);
+
+ // attrs inside MaxPool2D
+ auto maxpool2d = maxpool2d_node; // TODO remove this new variable
+
+ // stride
+ ASSERT_EQ(maxpool2d->stride()->vertical(), 3);
+ ASSERT_EQ(maxpool2d->stride()->horizontal(), 2);
+
+ // window
+ ASSERT_EQ(maxpool2d->window()->vertical(), 2);
+ ASSERT_EQ(maxpool2d->window()->horizontal(), 3);
+ }
+
+ // Test "MaxPoolGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ // what to test:
+ // - there should exist TFMaxPool
+ // - attributes value should match
+
+ using MaxPoolGraphBuilder = MaxPoolGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("MaxPool", stdex::make_unique<MaxPoolGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ moco::tf::TFMaxPool *maxpool_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFMaxPool>(graph.get());
+ ASSERT_NE(maxpool_node, nullptr);
+
+ loco::Node *previous_node = maxpool_node->value();
+ auto following_nodes = loco::succs(maxpool_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+
+ // attrs inside TFMaxPool
+ ASSERT_EQ(maxpool_node->data_layout(), "NHWC");
+ ASSERT_EQ(maxpool_node->padding(), "VALID");
+ ASSERT_EQ(maxpool_node->ksize(), std::vector<int64_t>({1, 2, 3, 1}));
+ ASSERT_EQ(maxpool_node->strides(), std::vector<int64_t>({1, 3, 2, 1}));
+ }
+}
+
+TEST(TensorFlowImport, MaxPool_02)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_output(moco::tf::TensorName("maxpool", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(maxpool_01_pbtxtdata, graph_def));
+
+ // Test "MaxPoolGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // TODO: fix indentation
+ // clang-format off
+
+ // what to test: Encoder and Decoder dimension order
+ // - there should exist MaxPool2D
+ // - ifm node should be FeatureEncode
+ // - following node should be FeatureDecode
+ // - FeatureEncode encoder should encode Count-Height-Width-Depth order
+ // - FeatureDecode decoder should decode Count-Height-Width-Depth order
+
+ using MaxPoolGraphBuilder = MaxPoolGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("MaxPool", stdex::make_unique<MaxPoolGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ loco::MaxPool2D *maxpool2d_node =
+ moco::tf::test::find_first_node_bytype<loco::MaxPool2D>(graph.get());
+ ASSERT_NE(maxpool2d_node, nullptr);
+
+ loco::Node *previous_node = maxpool2d_node->ifm();
+ auto following_nodes = loco::succs(maxpool2d_node);
+ ASSERT_EQ(following_nodes.size(), 1);
+ loco::Node *following_node = *following_nodes.begin();
+ ASSERT_NE(following_node, nullptr);
+
+ loco::FeatureEncode *enc_node = dynamic_cast<loco::FeatureEncode *>(previous_node);
+ loco::FeatureDecode *dec_node = dynamic_cast<loco::FeatureDecode *>(following_node);
+
+ ASSERT_NE(enc_node, nullptr);
+ ASSERT_NE(dec_node, nullptr);
+
+ // attrs inside FeatureEncoder
+ auto encoder = enc_node->encoder();
+ ASSERT_TRUE(encoder != nullptr);
+
+ loco::TensorShape tensor_shape;
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 1; // COUNT
+ tensor_shape.dim(1) = 720; // HEIGHT
+ tensor_shape.dim(2) = 1280; // WIDTH
+ tensor_shape.dim(3) = 3; // DEPTH
+
+ // Get the feature shape corresponding to a given image
+ auto feature_shape = encoder->shape(tensor_shape);
+
+ ASSERT_EQ(feature_shape.count().value(), 1);
+ ASSERT_EQ(feature_shape.depth().value(), 3);
+ ASSERT_EQ(feature_shape.height().value(), 720);
+ ASSERT_EQ(feature_shape.width().value(), 1280);
+
+ // attrs inside FeatureDecoder
+ auto decoder = dec_node->decoder();
+ ASSERT_TRUE(decoder != nullptr);
+
+ feature_shape.count() = 1;
+ feature_shape.depth() = 3;
+ feature_shape.height() = 720;
+ feature_shape.width() = 1280;
+
+ // Get the tensor shape corresponding to a given image
+ tensor_shape = decoder->shape(feature_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0).value(), 1); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1).value(), 720); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2).value(), 1280); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3).value(), 3); // DEPTH
+
+ // clang-format on
+ }
+
+ // Skip Test "AvgPoolGraphBuilderImpl<ImportTarget::TensorFlow>"
+ // There is no FeatureEncode nor FeatureDecode to test
+}
diff --git a/compiler/moco-tf/src/Op/Mul.cpp b/compiler/moco-tf/src/Op/Mul.cpp
new file mode 100644
index 000000000..5fa5b68aa
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Mul.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMul.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Mul node
+ */
+class TFMulGraphUpdate final : public GraphUpdate
+{
+public:
+ TFMulGraphUpdate(TFMul *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFMul *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFMulGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 2);
+
+ _node->x(tensor_names->node(_names[0]));
+ _node->y(tensor_names->node(_names[1]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Mul node
+ */
+class MulGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool MulGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ return true;
+}
+
+void MulGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Mul node
+ auto tf_mul = graph->nodes()->create<TFMul>();
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_mul);
+
+ std::vector<TensorName> add_input_names;
+ add_input_names.push_back(TensorName(node.input(0))); // x
+ add_input_names.push_back(TensorName(node.input(1))); // y
+
+ auto tf_mul_update = stdex::make_unique<TFMulGraphUpdate>(tf_mul, add_input_names);
+ updates->enroll(std::move(tf_mul_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Mul, MulGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Mul.test.cpp b/compiler/moco-tf/src/Op/Mul.test.cpp
new file mode 100644
index 000000000..7bc138656
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Mul.test.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFMul.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *mul_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input_01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "input_02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MUL_01"
+ op: "Mul"
+ input: "input_01"
+ input: "input_02"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_mul_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("MUL_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(mul_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFMul node should exist
+ // - both inputs x() and y() should not be null
+
+ auto mul_node = moco::tf::test::find_first_node_bytype<moco::tf::TFMul>(graph.get());
+
+ ASSERT_NE(mul_node, nullptr);
+ ASSERT_NE(mul_node->x(), nullptr);
+ ASSERT_NE(mul_node->y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Placeholder.cpp b/compiler/moco-tf/src/Op/Placeholder.cpp
new file mode 100644
index 000000000..e0b24d5df
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Placeholder.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <moco/tf/Names.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Placeholder node
+ */
+class PlaceholderGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool PlaceholderGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ return plier::tf::has_attrs(node, {"dtype", "shape"});
+}
+
+void PlaceholderGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+
+ loco::DataType dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
+ const auto &shape = plier::tf::get_shape_attr(node, "shape");
+ // TODO handle for unknown rank
+ assert(!shape.unknown_rank());
+ int64_t num_dims = shape.dim_size();
+
+ // TODO support other types
+ assert(dtype == loco::DataType::FLOAT32);
+
+ // Create a "pull" node as an input
+ auto pull_node = graph->nodes()->create<loco::Pull>();
+
+ pull_node->dtype(dtype);
+
+ // Setting shape info.
+ pull_node->rank(num_dims);
+ for (int64_t d = 0; d < num_dims; d++)
+ {
+ assert(shape.dim(d).size() < std::numeric_limits<uint32_t>::max());
+ int64_t dim_value = shape.dim(d).size();
+ if (dim_value >= 0ULL)
+ {
+ uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
+ pull_node->dim(d) = dim_value32;
+ }
+ else
+ {
+ pull_node->dim(d).unset();
+ // TODO Remove assert() and do implement
+ // NOTE Current implementation assumes dim is all know
+ assert(false);
+ }
+ }
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, pull_node);
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Placeholder, PlaceholderGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Placeholder.test.cpp b/compiler/moco-tf/src/Op/Placeholder.test.cpp
new file mode 100644
index 000000000..0fe32af37
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Placeholder.test.cpp
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <loco/IR/TensorShape.h>
+#include <loco/IR/FeatureShape.h>
+#include <nncc/core/ADT/tensor/Shape.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *known_batch_pbtxt = STRING_CONTENT(
+node {
+ name: "placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype" value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 1024 }
+ dim { size: 2 }
+ dim { size: 3 }
+ dim { size: 4 }
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Identity"
+ input: "placeholder"
+ attr {
+ key: "T" value { type: DT_FLOAT }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, placeholder_knwon_batch)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("output", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(known_batch_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // get loco::Pull
+ loco::Graph::NodeContext *loco_nodes = graph->nodes();
+ loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_nodes->at(0));
+
+ // Check dim
+ ASSERT_TRUE(pull_node->dim(0).known() && pull_node->dim(0).value() == 1024);
+ ASSERT_TRUE(pull_node->dim(1).known() && pull_node->dim(1).value() == 2);
+ ASSERT_TRUE(pull_node->dim(2).known() && pull_node->dim(2).value() == 3);
+ ASSERT_TRUE(pull_node->dim(3).known() && pull_node->dim(3).value() == 4);
+}
diff --git a/compiler/moco-tf/src/Op/RealDiv.cpp b/compiler/moco-tf/src/Op/RealDiv.cpp
new file mode 100644
index 000000000..4d96f7457
--- /dev/null
+++ b/compiler/moco-tf/src/Op/RealDiv.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRealDiv.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF RealDiv node
+ */
+class TFRealDivGraphUpdate final : public GraphUpdate
+{
+public:
+ TFRealDivGraphUpdate(TFRealDiv *node, std::vector<TensorName> names) : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFRealDiv *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFRealDivGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 2);
+
+ _node->x(tensor_names->node(_names[0]));
+ _node->y(tensor_names->node(_names[1]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for RealDiv node
+ */
+class RealDivGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool RealDivGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ return true;
+}
+
+void RealDivGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect RealDiv node
+ auto tf_div = graph->nodes()->create<TFRealDiv>();
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_div);
+
+ std::vector<TensorName> div_input_names;
+ div_input_names.push_back(TensorName(node.input(0))); // x
+ div_input_names.push_back(TensorName(node.input(1))); // y
+
+ auto tf_div_update = stdex::make_unique<TFRealDivGraphUpdate>(tf_div, div_input_names);
+ updates->enroll(std::move(tf_div_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(RealDiv, RealDivGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/RealDiv.test.cpp b/compiler/moco-tf/src/Op/RealDiv.test.cpp
new file mode 100644
index 000000000..40e55b276
--- /dev/null
+++ b/compiler/moco-tf/src/Op/RealDiv.test.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFRealDiv.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *div_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input_01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "input_02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "DIV_01"
+ op: "RealDiv"
+ input: "input_01"
+ input: "input_02"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_div_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("DIV_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(div_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFRealDiv node should exist
+ // - both inputs x() and y() should not be null
+
+ auto div_node = moco::tf::test::find_first_node_bytype<moco::tf::TFRealDiv>(graph.get());
+
+ ASSERT_NE(div_node, nullptr);
+ ASSERT_NE(div_node->x(), nullptr);
+ ASSERT_NE(div_node->y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Relu.cpp b/compiler/moco-tf/src/Op/Relu.cpp
new file mode 100644
index 000000000..0c3ceeec6
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFRelu.h"
+
+#include <moco/tf/Names.h>
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class ReLUGraphUpdate final : public GraphUpdate
+{
+public:
+ ReLUGraphUpdate(loco::ReLU *node, const TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::ReLU *_node;
+ const TensorName _name;
+};
+
+class TFReluGraphUpdate final : public GraphUpdate
+{
+public:
+ TFReluGraphUpdate(moco::tf::TFRelu *node, const TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFRelu *_node;
+ const TensorName _name;
+};
+
+void ReLUGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->input(target);
+}
+
+void TFReluGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->features(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Relu node
+ */
+class ReluGraphBuilder final : public ReluGraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool ReluGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ // ReLU node SHOULD have only one input
+ if (node.input_size() != 1)
+ return false;
+
+ return true;
+}
+
+void ReluGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFRelu>())
+ {
+ ReluGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ ReluGraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void ReluGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a "ReLU" node for Relu
+ auto relu_node = graph->nodes()->create<loco::ReLU>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, relu_node);
+
+ // Queue node input update
+ auto update = stdex::make_unique<ReLUGraphUpdate>(relu_node, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+void ReluGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a "TFRelu" node for Relu
+ auto relu_node = graph->nodes()->create<moco::tf::TFRelu>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, relu_node);
+
+ // Queue node input update
+ auto update = stdex::make_unique<TFReluGraphUpdate>(relu_node, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Relu, ReluGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Relu.h b/compiler/moco-tf/src/Op/Relu.h
new file mode 100644
index 000000000..7d75f8a03
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_RELU_H__
+#define __OP_RELU_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct ReluGraphBuilderBase : public GraphBuilder
+{
+ virtual ~ReluGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class ReluGraphBuilderImpl;
+
+template <> struct ReluGraphBuilderImpl<ImportTarget::Canonical> final : public ReluGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct ReluGraphBuilderImpl<ImportTarget::TensorFlow> final : public ReluGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_RELU_H__
diff --git a/compiler/moco-tf/src/Op/Relu.test.cpp b/compiler/moco-tf/src/Op/Relu.test.cpp
new file mode 100644
index 000000000..bdd1152c3
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu.test.cpp
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu.h"
+
+#include "IR/TFRelu.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+
+// clang-format off
+const char *relu_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ReLU"
+ op: "Relu"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, relu_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("ReLU", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(relu_01_pbtxtdata, graph_def));
+
+ // Test "ReluGraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // TODO: fix indentation
+ // clang-format off
+
+ using ReluGraphBuilder = ReluGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Relu", stdex::make_unique<ReluGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist ReLU
+ // - input node should not be nullptr
+
+ loco::ReLU *relu_node = moco::tf::test::find_first_node_bytype<loco::ReLU>(graph.get());
+
+ ASSERT_NE(relu_node, nullptr);
+ ASSERT_NE(relu_node->input(), nullptr);
+ // clang-format on
+ }
+
+ // Test "ReluGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ using ReluGraphBuilder = ReluGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Relu", stdex::make_unique<ReluGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFRelu
+ // - features node should not be nullptr
+
+ auto relu_node = moco::tf::test::find_first_node_bytype<moco::tf::TFRelu>(graph.get());
+
+ ASSERT_NE(relu_node, nullptr);
+ ASSERT_NE(relu_node->features(), nullptr);
+ }
+}
diff --git a/compiler/moco-tf/src/Op/Relu6.cpp b/compiler/moco-tf/src/Op/Relu6.cpp
new file mode 100644
index 000000000..8f697cc6f
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu6.cpp
@@ -0,0 +1,149 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu6.h"
+
+#include "GraphBuilder.h"
+#include "Knob.h"
+
+#include "IR/TFRelu6.h"
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+class ReLU6GraphUpdate final : public GraphUpdate
+{
+public:
+ ReLU6GraphUpdate(loco::ReLU6 *node, const TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ loco::ReLU6 *_node;
+ const TensorName _name;
+};
+
+class TFRelu6GraphUpdate final : public GraphUpdate
+{
+public:
+ TFRelu6GraphUpdate(moco::tf::TFRelu6 *node, const TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ moco::tf::TFRelu6 *_node;
+ const TensorName _name;
+};
+
+void ReLU6GraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->input(target);
+}
+
+void TFRelu6GraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->features(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+/**
+ * @brief GraphBuilder for Relu6 node
+ */
+class Relu6GraphBuilder final : public Relu6GraphBuilderBase
+{
+public:
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool Relu6GraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+{
+ // ReLU6 node SHOULD have only one input
+ if (node.input_size() != 1)
+ return false;
+ return true;
+}
+
+void Relu6GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFRelu6>())
+ {
+ Relu6GraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ return builder.build(node, context);
+ }
+ else
+ {
+ Relu6GraphBuilderImpl<ImportTarget::Canonical> builder;
+ return builder.build(node, context);
+ }
+}
+
+void Relu6GraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a "ReLU6" node for Relu6
+ auto relu6_node = graph->nodes()->create<loco::ReLU6>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, relu6_node);
+
+ // Queue node input update
+ auto update = stdex::make_unique<ReLU6GraphUpdate>(relu6_node, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+void Relu6GraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // Create a "TFRelu6" node for Relu
+ auto relu_node = graph->nodes()->create<moco::tf::TFRelu6>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, relu_node);
+
+ // Queue node input update
+ auto update = stdex::make_unique<TFRelu6GraphUpdate>(relu_node, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Relu6, Relu6GraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Relu6.h b/compiler/moco-tf/src/Op/Relu6.h
new file mode 100644
index 000000000..8bbadee1d
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu6.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_RELU6_H__
+#define __OP_RELU6_H__
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct Relu6GraphBuilderBase : public GraphBuilder
+{
+ virtual ~Relu6GraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class Relu6GraphBuilderImpl;
+
+template <>
+struct Relu6GraphBuilderImpl<ImportTarget::Canonical> final : public Relu6GraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct Relu6GraphBuilderImpl<ImportTarget::TensorFlow> final : public Relu6GraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_RELU6_H__
diff --git a/compiler/moco-tf/src/Op/Relu6.test.cpp b/compiler/moco-tf/src/Op/Relu6.test.cpp
new file mode 100644
index 000000000..4d6832353
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Relu6.test.cpp
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Relu6.h"
+
+#include "IR/TFRelu6.h"
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using namespace moco::tf;
+using namespace moco::tf::test;
+
+namespace
+{
+
+// clang-format off
+const char *relu6_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "ReLU6"
+ op: "Relu6"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, relu6_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("ReLU6", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(relu6_01_pbtxtdata, graph_def));
+
+ // Test "Relu6GraphBuilderImpl<ImportTarget::Canonical>"
+ {
+ // TODO: fix indentation
+ // clang-format off
+
+ using ReluGraphBuilder = Relu6GraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Relu6", stdex::make_unique<ReluGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist ReLU6
+ // - input node should not be nullptr
+
+ loco::ReLU6 *relu6_node = moco::tf::test::find_first_node_bytype<loco::ReLU6>(graph.get());
+
+ ASSERT_NE(relu6_node, nullptr);
+ ASSERT_NE(relu6_node->input(), nullptr);
+ // clang-format on
+ }
+
+ // Test "ReluGraphBuilderImpl<ImportTarget::TensorFlow>"
+ {
+ using ReluGraphBuilder = Relu6GraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Relu6", stdex::make_unique<ReluGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFRelu6
+ // - features node should not be null
+
+ auto relu_node = moco::tf::test::find_first_node_bytype<moco::tf::TFRelu6>(graph.get());
+
+ ASSERT_NE(relu_node, nullptr);
+ ASSERT_NE(relu_node->features(), nullptr);
+ }
+}
diff --git a/compiler/moco-tf/src/Op/Reshape.cpp b/compiler/moco-tf/src/Op/Reshape.cpp
new file mode 100644
index 000000000..08931f7e5
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Reshape.cpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include "IR/TFReshape.h"
+
+#include <moco/tf/Names.h>
+#include <plier/tf/Convert.h>
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+using namespace moco::tf;
+
+class ReshapeGraphUpdate final : public GraphUpdate
+{
+public:
+ ReshapeGraphUpdate(TFReshape *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFReshape *_node;
+ std::vector<TensorName> _names;
+};
+
+void ReshapeGraphUpdate::input(const SymbolTable *node_table) const
+{
+ assert(_names.size() == 2);
+
+ auto tensor_node = node_table->node(_names[0]);
+ auto shape_node = node_table->node(_names[1]);
+
+ assert(tensor_node != nullptr);
+ assert(shape_node != nullptr);
+
+ _node->tensor(tensor_node);
+ _node->shape(shape_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/// @brief GraphBuilder for Reshape node
+class ReshapeGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool ReshapeGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ // Tensorflow Reshape has 2 inputs: tensor & shape
+ if (node.input_size() != 2)
+ return false;
+
+ // TODO Assert Tshape value is DT_INT32?
+ return plier::tf::has_attrs(node, {"T", "Tshape"});
+}
+
+void ReshapeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // name of loco nodes
+ std::string reshape_name = node.name();
+
+ auto reshape = graph->nodes()->create<TFReshape>();
+
+ // save the name for graph link updates
+ TensorName output_name(reshape_name, 0);
+ tensor_names->enroll(output_name, reshape);
+
+ std::vector<TensorName> input_names;
+ input_names.push_back(TensorName(node.input(0))); // tensor
+ input_names.push_back(TensorName(node.input(1))); // shape
+
+ // Queue node input update
+ auto update = stdex::make_unique<ReshapeGraphUpdate>(reshape, input_names);
+
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Reshape, ReshapeGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Reshape.test.cpp b/compiler/moco-tf/src/Op/Reshape.test.cpp
new file mode 100644
index 000000000..66d4f0054
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Reshape.test.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFReshape.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+namespace
+{
+
+// clang-format off
+const char *reshape_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 4 }
+ }
+ }
+ }
+}
+node {
+ name: "shape"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_INT32 }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim { size: 2 }
+ }
+ int_val: 2
+ int_val: 2
+ }
+ }
+ }
+}
+node {
+ name: "reshape"
+ op: "Reshape"
+ input: "placeholder"
+ input: "shape"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "Tshape"
+ value { type: DT_INT32 }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, reshape_01)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("placeholder", 0));
+ signature.add_output(moco::tf::TensorName("reshape", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(reshape_01_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ moco::tf::TFReshape *reshape_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFReshape>(graph.get());
+
+ ASSERT_NE(reshape_node, nullptr);
+ ASSERT_NE(reshape_node->tensor(), nullptr);
+ ASSERT_NE(reshape_node->shape(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Rsqrt.cpp b/compiler/moco-tf/src/Op/Rsqrt.cpp
new file mode 100644
index 000000000..e3b7fcc98
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Rsqrt.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFRsqrt.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Rsqrt node
+ */
+class TFRsqrtGraphUpdate final : public GraphUpdate
+{
+public:
+ TFRsqrtGraphUpdate(TFRsqrt *node, TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFRsqrt *_node;
+ TensorName _name;
+};
+
+void TFRsqrtGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->x(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Rsqrt node
+ */
+class RsqrtGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool RsqrtGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return true;
+}
+
+void RsqrtGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Rsqrt node
+ auto tf_rsqrt = graph->nodes()->create<TFRsqrt>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_rsqrt);
+
+ // Queue node input update
+ auto tf_rsqrt_update =
+ stdex::make_unique<TFRsqrtGraphUpdate>(tf_rsqrt, TensorName(node.input(0)));
+ updates->enroll(std::move(tf_rsqrt_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Rsqrt, RsqrtGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Rsqrt.test.cpp b/compiler/moco-tf/src/Op/Rsqrt.test.cpp
new file mode 100644
index 000000000..0fd76d472
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Rsqrt.test.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFRsqrt.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *rsqrt_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+}
+node {
+ name: "RSQRT_01"
+ op: "Rsqrt"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_rsqrt_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("RSQRT_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(rsqrt_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFRsqrt node should exist
+ // - input x() should not be null
+
+ auto rsqrt_node = moco::tf::test::find_first_node_bytype<moco::tf::TFRsqrt>(graph.get());
+
+ ASSERT_NE(rsqrt_node, nullptr);
+ ASSERT_NE(rsqrt_node->x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Shape.cpp b/compiler/moco-tf/src/Op/Shape.cpp
new file mode 100644
index 000000000..9e2f00bb4
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Shape.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFShape.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for Shape node
+ */
+class ShapeGraphUpdate final : public GraphUpdate
+{
+public:
+ ShapeGraphUpdate(TFShape *node, const TensorName &&input_name)
+ : _node(node), _input_name(input_name)
+ {
+ // DO NOTHING
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFShape *_node;
+ const TensorName _input_name;
+};
+
+void ShapeGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *input_node = table->node(_input_name);
+ _node->input(input_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Shape node
+ */
+class ShapeGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool ShapeGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return plier::tf::has_attrs(node, {"T"});
+}
+
+void ShapeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // create TF dialect Shape node
+ auto tf_shape = graph->nodes()->create<TFShape>();
+
+ if (plier::tf::has_attrs(node, {"out_type"}))
+ {
+ auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "out_type"));
+ // TODO Support other dtype like S64
+ assert(dtype == loco::DataType::S32);
+
+ tf_shape->dtype(dtype);
+ }
+ else
+ {
+ // Set to S32, TF-documented default value for 'out_type'
+ tf_shape->dtype(loco::DataType::S32);
+ }
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_shape);
+
+ auto update = stdex::make_unique<ShapeGraphUpdate>(tf_shape, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Shape, ShapeGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Shape.test.cpp b/compiler/moco-tf/src/Op/Shape.test.cpp
new file mode 100644
index 000000000..6abefb071
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Shape.test.cpp
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFShape.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+namespace
+{
+
+// clang-format off
+const char *shape_000_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 1 }
+ dim { size: 2 }
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+ }
+ }
+}
+node {
+ name: "Shape"
+ op: "Shape"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "out_type"
+ value { type: DT_INT32 }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, shape_000)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("Shape", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(shape_000_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFShape
+ // - input node should not be null
+ // - dtype attribute is set same as out_type attribute of pbtxt
+
+ moco::tf::TFShape *shape_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFShape>(graph.get());
+
+ ASSERT_NE(shape_node, nullptr);
+ ASSERT_NE(shape_node->input(), nullptr);
+ ASSERT_EQ(shape_node->dtype(), loco::DataType::S32);
+}
diff --git a/compiler/moco-tf/src/Op/Softmax.cpp b/compiler/moco-tf/src/Op/Softmax.cpp
new file mode 100644
index 000000000..d813b9d3d
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Softmax.cpp
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSoftmax.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+using namespace moco::tf;
+
+/**
+* @brief GraphUpdate for Softmax node
+*/
+class SoftmaxGraphUpdate final : public GraphUpdate
+{
+public:
+ SoftmaxGraphUpdate(TFSoftmax *node, const TensorName &&input_name)
+ : _node(node), _input_name(input_name)
+ {
+ // DO NOTHING
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFSoftmax *_node;
+ const TensorName _input_name;
+};
+
+void SoftmaxGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *input_node = table->node(_input_name);
+ _node->logits(input_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+* @brief GraphBuilder for Softmax node
+*/
+class SoftmaxGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool SoftmaxGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return plier::tf::has_attrs(node, {"T"});
+}
+
+void SoftmaxGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Softmax node
+ auto tf_softmax = graph->nodes()->create<TFSoftmax>();
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_softmax);
+
+ auto update = stdex::make_unique<SoftmaxGraphUpdate>(tf_softmax, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Softmax, SoftmaxGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Softmax.test.cpp b/compiler/moco-tf/src/Op/Softmax.test.cpp
new file mode 100644
index 000000000..d4f9fc1c2
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Softmax.test.cpp
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFSoftmax.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+namespace
+{
+
+// clang-format off
+const char *softmax_2d_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "Softmax"
+ op: "Softmax"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, softmax_2d)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("Softmax", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(softmax_2d_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFSoftmax
+ // - logits node should not be null
+
+ moco::tf::TFSoftmax *softmax_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFSoftmax>(graph.get());
+
+ ASSERT_NE(softmax_node, nullptr);
+ ASSERT_NE(softmax_node->logits(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Sqrt.cpp b/compiler/moco-tf/src/Op/Sqrt.cpp
new file mode 100644
index 000000000..6b7dec61b
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Sqrt.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSqrt.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Sqrt node
+ */
+class TFSqrtGraphUpdate final : public GraphUpdate
+{
+public:
+ TFSqrtGraphUpdate(TFSqrt *node, TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFSqrt *_node;
+ TensorName _name;
+};
+
+void TFSqrtGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->x(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Sqrt node
+ */
+class SqrtGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool SqrtGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return true;
+}
+
+void SqrtGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Sqrt node
+ auto tf_sqrt = graph->nodes()->create<TFSqrt>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_sqrt);
+
+ // Queue node input update
+ auto tf_sqrt_update = stdex::make_unique<TFSqrtGraphUpdate>(tf_sqrt, TensorName(node.input(0)));
+ updates->enroll(std::move(tf_sqrt_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Sqrt, SqrtGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Sqrt.test.cpp b/compiler/moco-tf/src/Op/Sqrt.test.cpp
new file mode 100644
index 000000000..2c55c602e
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Sqrt.test.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFSqrt.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *sqrt_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+}
+node {
+ name: "SQRT_01"
+ op: "Sqrt"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_sqrt_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("SQRT_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(sqrt_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFSqrt node should exist
+ // - input x() should not be null
+
+ auto sqrt_node = moco::tf::test::find_first_node_bytype<moco::tf::TFSqrt>(graph.get());
+
+ ASSERT_NE(sqrt_node, nullptr);
+ ASSERT_NE(sqrt_node->x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/SquaredDifference.cpp b/compiler/moco-tf/src/Op/SquaredDifference.cpp
new file mode 100644
index 000000000..bbccad757
--- /dev/null
+++ b/compiler/moco-tf/src/Op/SquaredDifference.cpp
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSquaredDifference.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF SquaredDifference node
+ */
+class TFSquaredDifferenceGraphUpdate final : public GraphUpdate
+{
+public:
+ TFSquaredDifferenceGraphUpdate(TFSquaredDifference *node, std::vector<TensorName> names)
+ : _node(node), _names(names)
+ {
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFSquaredDifference *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFSquaredDifferenceGraphUpdate::input(const SymbolTable *table) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 2);
+
+ _node->x(table->node(_names[0]));
+ _node->y(table->node(_names[1]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for SquaredDifference node
+ */
+class SquaredDifferenceGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool SquaredDifferenceGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ return true;
+}
+
+void SquaredDifferenceGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect SquaredDifference node
+ auto tf_sqdiff = graph->nodes()->create<TFSquaredDifference>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_sqdiff);
+
+ std::vector<TensorName> add_input_names;
+ add_input_names.push_back(TensorName(node.input(0))); // x
+ add_input_names.push_back(TensorName(node.input(1))); // y
+
+ // Queue node input update
+ auto tf_sqrt_update =
+ stdex::make_unique<TFSquaredDifferenceGraphUpdate>(tf_sqdiff, add_input_names);
+ updates->enroll(std::move(tf_sqrt_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(SquaredDifference, SquaredDifferenceGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/SquaredDifference.test.cpp b/compiler/moco-tf/src/Op/SquaredDifference.test.cpp
new file mode 100644
index 000000000..1efe2ef48
--- /dev/null
+++ b/compiler/moco-tf/src/Op/SquaredDifference.test.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFSquaredDifference.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *sqdiff_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input_01"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "input_02"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "squared_difference"
+ op: "SquaredDifference"
+ input: "input_01"
+ input: "input_02"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_squdiff_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("squared_difference", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(sqdiff_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFSquaredDifference node should exist
+ // - input x() should not be null
+ // - input y() should not be null
+
+ auto sqdiff_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFSquaredDifference>(graph.get());
+
+ ASSERT_NE(sqdiff_node, nullptr);
+ ASSERT_NE(sqdiff_node->x(), nullptr);
+ ASSERT_NE(sqdiff_node->y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Squeeze.cpp b/compiler/moco-tf/src/Op/Squeeze.cpp
new file mode 100644
index 000000000..a7aca3790
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Squeeze.cpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSqueeze.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for Squeeze node
+ */
+class SqueezeGraphUpdate final : public GraphUpdate
+{
+public:
+ SqueezeGraphUpdate(TFSqueeze *node, const TensorName &&input_name)
+ : _node(node), _input_name(input_name)
+ {
+ // DO NOTHING
+ }
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFSqueeze *_node;
+ const TensorName _input_name;
+};
+
+void SqueezeGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *input_node = table->node(_input_name);
+ _node->input(input_node);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Squeeze node
+ */
+class SqueezeGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool SqueezeGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return plier::tf::has_attrs(node, {"T"});
+}
+
+void SqueezeGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ if (plier::tf::has_attrs(node, {"axis"}))
+ {
+ // TODO support 'axis' attribute
+ std::runtime_error("Import Squeeze: 'axis' attribute is not supported yet");
+ }
+
+ std::vector<int64_t> squeeze_dims;
+ if (plier::tf::has_attrs(node, {"squeeze_dims"}))
+ {
+ auto squeeze_dim_list = plier::tf::get_list_attr(node, {"squeeze_dims"});
+ // TODO assert squeeze_dims are mutually different?
+ squeeze_dims = plier::tf::as_int64_list(squeeze_dim_list);
+ }
+ // Note that it is possible that NodeDef does not have squeeze_dims attribute.
+ // In that case, TFSqueeze also has empty squeeze_dims,
+
+ // creating TF dialect Squeeze node
+ auto tf_squeeze = graph->nodes()->create<TFSqueeze>();
+ tf_squeeze->squeeze_dims(squeeze_dims);
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_squeeze);
+
+ auto update = stdex::make_unique<SqueezeGraphUpdate>(tf_squeeze, TensorName(node.input(0)));
+ updates->enroll(std::move(update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Squeeze, SqueezeGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Squeeze.test.cpp b/compiler/moco-tf/src/Op/Squeeze.test.cpp
new file mode 100644
index 000000000..179183b6c
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Squeeze.test.cpp
@@ -0,0 +1,162 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+#include "IR/TFSqueeze.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+namespace
+{
+
+// clang-format off
+const char *squeeze_all_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 2 }
+ dim { size: 1 }
+ dim { size: 3 }
+ dim { size: 1 }
+ }
+ }
+ }
+}
+node {
+ name: "Squeeze"
+ op: "Squeeze"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, squeeze_all)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("Squeeze", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(squeeze_all_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFSqueeze
+ // - input node should not be null
+ // - squeeze_dims attribute is set same as pbtxt
+
+ moco::tf::TFSqueeze *squeeze_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFSqueeze>(graph.get());
+
+ ASSERT_NE(squeeze_node, nullptr);
+ ASSERT_NE(squeeze_node->input(), nullptr);
+ ASSERT_EQ(squeeze_node->squeeze_dims().size(), 0);
+}
+
+namespace
+{
+
+// clang-format off
+const char *squeeze_some_pbtxtdata = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 2 }
+ dim { size: 1 }
+ dim { size: 3 }
+ dim { size: 1 }
+ }
+ }
+ }
+}
+node {
+ name: "Squeeze"
+ op: "Squeeze"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "squeeze_dims"
+ value {
+ list { i: 1 }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, squeeze_some)
+{
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+
+ signature.add_input(moco::tf::TensorName("Placeholder", 0));
+ signature.add_output(moco::tf::TensorName("Squeeze", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(squeeze_some_pbtxtdata, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - there should exist TFSqueeze
+ // - input node should not be null
+ // - squeeze_dims attribute is set same as pbtxt
+
+ moco::tf::TFSqueeze *squeeze_node =
+ moco::tf::test::find_first_node_bytype<moco::tf::TFSqueeze>(graph.get());
+
+ ASSERT_NE(squeeze_node, nullptr);
+ ASSERT_NE(squeeze_node->input(), nullptr);
+ ASSERT_EQ(squeeze_node->squeeze_dims().size(), 1);
+ ASSERT_EQ(squeeze_node->squeeze_dims().at(0), 1);
+}
+
+// TODO Add test case for negative squeeze dim
diff --git a/compiler/moco-tf/src/Op/StopGradient.cpp b/compiler/moco-tf/src/Op/StopGradient.cpp
new file mode 100644
index 000000000..dc28d6dfb
--- /dev/null
+++ b/compiler/moco-tf/src/Op/StopGradient.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFStopGradient.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <plier/tf/Convert.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF StopGradient node
+ */
+class TFStopGradientGraphUpdate final : public GraphUpdate
+{
+public:
+ TFStopGradientGraphUpdate(TFStopGradient *node, TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFStopGradient *_node;
+ TensorName _name;
+};
+
+void TFStopGradientGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->input(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for StopGradient node
+ */
+class StopGradientGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool StopGradientGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return plier::tf::has_attrs(node, {"T"});
+}
+
+void StopGradientGraphBuilder::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect StopGradient node
+ auto tf_stopgradient = graph->nodes()->create<TFStopGradient>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_stopgradient);
+
+ // Queue node input update
+ auto tf_stopgradient_update =
+ stdex::make_unique<TFStopGradientGraphUpdate>(tf_stopgradient, TensorName(node.input(0)));
+ updates->enroll(std::move(tf_stopgradient_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(StopGradient, StopGradientGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/StopGradient.test.cpp b/compiler/moco-tf/src/Op/StopGradient.test.cpp
new file mode 100644
index 000000000..dd92fb8f8
--- /dev/null
+++ b/compiler/moco-tf/src/Op/StopGradient.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFStopGradient.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *stopgradient_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+}
+node {
+ name: "StopGradient_01"
+ op: "StopGradient"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_stopgradient_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("StopGradient_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(stopgradient_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFStopGradient node should exist
+ // - input() should not be null
+
+ auto node = moco::tf::test::find_first_node_bytype<moco::tf::TFStopGradient>(graph.get());
+
+ ASSERT_NE(node, nullptr);
+ ASSERT_NE(node->input(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Sub.cpp b/compiler/moco-tf/src/Op/Sub.cpp
new file mode 100644
index 000000000..2629b5aa8
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Sub.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFSub.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Sub node
+ */
+class TFSubGraphUpdate final : public GraphUpdate
+{
+public:
+ TFSubGraphUpdate(TFSub *node, std::vector<TensorName> names) : _node(node), _names(names) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFSub *_node;
+ std::vector<TensorName> _names;
+};
+
+void TFSubGraphUpdate::input(const SymbolTable *tensor_names) const
+{
+ int num_inputs = _names.size();
+ assert(num_inputs == 2);
+
+ _node->x(tensor_names->node(_names[0]));
+ _node->y(tensor_names->node(_names[1]));
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Sub node
+ */
+class SubGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool SubGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 2);
+
+ return true;
+}
+
+void SubGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Sub node
+ auto tf_sub = graph->nodes()->create<TFSub>();
+
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_sub);
+
+ std::vector<TensorName> sub_input_names;
+ sub_input_names.push_back(TensorName(node.input(0))); // x
+ sub_input_names.push_back(TensorName(node.input(1))); // y
+
+ auto tf_sub_update = stdex::make_unique<TFSubGraphUpdate>(tf_sub, sub_input_names);
+ updates->enroll(std::move(tf_sub_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Sub, SubGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Sub.test.cpp b/compiler/moco-tf/src/Op/Sub.test.cpp
new file mode 100644
index 000000000..ad2ad55fc
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Sub.test.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFSub.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *sub_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input_01"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "input_02"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 4
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "SUB_01"
+ op: "Sub"
+ input: "input_01"
+ input: "input_02"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_sub_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("SUB_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(sub_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFSub node should exist
+ // - both inputs x() and y() should not be null
+
+ auto sub_node = moco::tf::test::find_first_node_bytype<moco::tf::TFSub>(graph.get());
+
+ ASSERT_NE(sub_node, nullptr);
+ ASSERT_NE(sub_node->x(), nullptr);
+ ASSERT_NE(sub_node->y(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Op/Tanh.cpp b/compiler/moco-tf/src/Op/Tanh.cpp
new file mode 100644
index 000000000..b465401d1
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Tanh.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFTanh.h"
+
+#include "GraphBuilder.h"
+#include "GraphBuilderContext.h"
+
+#include <loco.h>
+#include <stdex/Memory.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief GraphUpdate for TF Tanh node
+ */
+class TFTanhGraphUpdate final : public GraphUpdate
+{
+public:
+ TFTanhGraphUpdate(TFTanh *node, TensorName &&name) : _node(node), _name(name) {}
+
+ void input(const SymbolTable *) const override;
+
+private:
+ TFTanh *_node;
+ TensorName _name;
+};
+
+void TFTanhGraphUpdate::input(const SymbolTable *table) const
+{
+ loco::Node *target = table->node(_name);
+ _node->x(target);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief GraphBuilder for Tanh node
+ */
+class TanhGraphBuilder final : public GraphBuilder
+{
+public:
+ bool validate(const tensorflow::NodeDef &) const override;
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+};
+
+bool TanhGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ assert(node.input_size() == 1);
+
+ return true;
+}
+
+void TanhGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+ UpdateQueue *updates = context->updates();
+
+ // creating TF dialect Tanh node
+ auto tf_tanh = graph->nodes()->create<TFTanh>();
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, tf_tanh);
+
+ // Queue node input update
+ auto tf_tanh_update = stdex::make_unique<TFTanhGraphUpdate>(tf_tanh, TensorName(node.input(0)));
+ updates->enroll(std::move(tf_tanh_update));
+}
+
+} // namespace tf
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Tanh, TanhGraphBuilder)
diff --git a/compiler/moco-tf/src/Op/Tanh.test.cpp b/compiler/moco-tf/src/Op/Tanh.test.cpp
new file mode 100644
index 000000000..578ef2211
--- /dev/null
+++ b/compiler/moco-tf/src/Op/Tanh.test.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include "Importer.h"
+
+#include "IR/TFTanh.h"
+
+#include <loco.h>
+#include <plier/tf/TestHelper.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *tanh_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "Placeholder"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ }
+ }
+}
+node {
+ name: "output/tanh"
+ op: "Tanh"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowImport, tf_tanh_basic)
+{
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("output/tanh", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(tanh_basic_pbtxt, graph_def));
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
+ // what to test:
+ // - TFTanh node should exist
+ // - input x() should not be null
+
+ auto tanh_node = moco::tf::test::find_first_node_bytype<moco::tf::TFTanh>(graph.get());
+
+ ASSERT_NE(tanh_node, nullptr);
+ ASSERT_NE(tanh_node->x(), nullptr);
+}
diff --git a/compiler/moco-tf/src/Phase.cpp b/compiler/moco-tf/src/Phase.cpp
new file mode 100644
index 000000000..6764691c7
--- /dev/null
+++ b/compiler/moco-tf/src/Phase.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Phase.h"
+#include "LogHelper.h"
+
+#include <moco/Log.h>
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+void PhaseRunner<PhaseStrategy::Saturate>::run(const Phase &phase) const
+{
+ LOGGER(l);
+
+ INFO(l) << "==============================================================";
+ INFO(l) << "PhaseRunner<Saturate>";
+
+ INFO(l) << "Initial graph";
+ INFO(l) << fmt(_graph);
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &tr : phase)
+ {
+ bool chg_one = false;
+
+ INFO(l) << "--------------------------------------------------------------";
+ INFO(l) << "Before " << transform_name(tr.get());
+
+ if (tr->run(_graph))
+ {
+ changed = true;
+ chg_one = true;
+ }
+
+ INFO(l) << "After " << transform_name(tr.get()) << " (changed: " << to_char(chg_one) << ")";
+ INFO(l) << fmt(_graph);
+ }
+ }
+
+ INFO(l) << "PhaseRunner<Saturate> - done";
+}
+
+void PhaseRunner<PhaseStrategy::Restart>::run(const Phase &phase) const
+{
+ LOGGER(l);
+
+ INFO(l) << "==============================================================";
+ INFO(l) << "PhaseRunner<Restart>";
+
+ INFO(l) << "Initial graph";
+ INFO(l) << fmt(_graph);
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &tr : phase)
+ {
+ INFO(l) << "--------------------------------------------------------------";
+ INFO(l) << "Before " << transform_name(tr.get());
+
+ if (tr->run(_graph))
+ {
+ changed = true;
+ }
+
+ INFO(l) << "After " << transform_name(tr.get()) << " (changed: " << to_char(changed) << ")";
+ INFO(l) << fmt(_graph);
+
+ if (changed)
+ {
+ break;
+ }
+ }
+ }
+
+ INFO(l) << "PhaseRunner<Restart> - done";
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Phase.h b/compiler/moco-tf/src/Phase.h
new file mode 100644
index 000000000..cb1854b59
--- /dev/null
+++ b/compiler/moco-tf/src/Phase.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_PHASE_H__
+#define __MOCO_TF_PHASE_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+#include <vector>
+#include <memory>
+
+namespace moco
+{
+namespace tf
+{
+
+// Phase is a collection of Transform(s)
+using Phase = std::vector<std::unique_ptr<Transform>>;
+
+enum class PhaseStrategy
+{
+ // Run all the transforms until there is no transform that makes a change
+ Saturate,
+ // Same as Saturate but will restart from the first when there is a change
+ Restart,
+};
+
+template <PhaseStrategy S> class PhaseRunner;
+
+template <> class PhaseRunner<PhaseStrategy::Saturate>
+{
+public:
+ PhaseRunner(loco::Graph *graph) : _graph{graph}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ loco::Graph *_graph;
+};
+
+template <> class PhaseRunner<PhaseStrategy::Restart>
+{
+public:
+ PhaseRunner(loco::Graph *graph) : _graph{graph}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ loco::Graph *_graph;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_PHASE_H__
diff --git a/compiler/moco-tf/src/SimpleNodeTransform.h b/compiler/moco-tf/src/SimpleNodeTransform.h
deleted file mode 100644
index b69cbad6b..000000000
--- a/compiler/moco-tf/src/SimpleNodeTransform.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_SIMPLE_NODE_TRANSFORM_H__
-#define __MOCO_TF_SIMPLE_NODE_TRANSFORM_H__
-
-#include "Transform.h"
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Per-Node Transform
- */
-template <typename ConcreteNode> struct SimpleNodeTransform : public Transform
-{
- SimpleNodeTransform() = default;
-
- virtual ~SimpleNodeTransform() = default;
-
- // NOTE Users SHOULD implement this method
- virtual bool transform(ConcreteNode *node) const = 0;
-
- bool run(loco::Graph *graph) final
- {
- using loco::active_nodes;
- using loco::output_nodes;
-
- bool changed = false;
-
- for (auto node : active_nodes(output_nodes(graph)))
- {
- if (auto casted = dynamic_cast<ConcreteNode *>(node))
- {
- if (transform(casted))
- {
- changed = true;
- }
- }
- }
-
- return changed;
- }
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_SIMPLE_NODE_TRANSFORM_H__
diff --git a/compiler/moco-tf/src/SimpleNodeTransform.test.cpp b/compiler/moco-tf/src/SimpleNodeTransform.test.cpp
deleted file mode 100644
index 781a48781..000000000
--- a/compiler/moco-tf/src/SimpleNodeTransform.test.cpp
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "SimpleNodeTransform.h"
-
-#include <set>
-
-#include <gtest/gtest.h>
-
-TEST(SimpleNodeTransformTests, run)
-{
- class Transform final : public moco::tf::SimpleNodeTransform<loco::Push>
- {
- public:
- Transform(std::multiset<loco::Node *> *out) : _out{out}
- {
- // DO NOTHING
- }
-
- public:
- bool transform(loco::Push *node) const final
- {
- _out->insert(node);
- return false;
- }
-
- private:
- std::multiset<loco::Node *> *_out;
- };
-
- auto g = loco::make_graph();
- auto output_0 = g->outputs()->create();
- auto push = g->nodes()->create<loco::Push>();
- loco::link(output_0, push);
-
- std::multiset<loco::Node *> nodes;
- Transform transform{&nodes};
-
- transform.run(g.get());
-
- ASSERT_EQ(nodes.size(), 1);
- ASSERT_EQ(nodes.count(push), 1);
-}
diff --git a/compiler/moco-tf/src/TFEltwiseBinaryCanonicalzeHelper.h b/compiler/moco-tf/src/TFEltwiseBinaryCanonicalzeHelper.h
index df9aec144..c86f6b9a0 100644
--- a/compiler/moco-tf/src/TFEltwiseBinaryCanonicalzeHelper.h
+++ b/compiler/moco-tf/src/TFEltwiseBinaryCanonicalzeHelper.h
@@ -1,24 +1,8 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
#ifndef __TF_ELTWISE_BINARY_CANONICALIZE_HELPER_H__
#define __TF_ELTWISE_BINARY_CANONICALIZE_HELPER_H__
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "CanonicalEltwiseInputConnector.h"
#include "BroadcastHelper.h"
@@ -34,27 +18,22 @@ namespace
template <typename TFNodeT> struct EltwiseBinaryCanonicalizationRule;
-template <> struct EltwiseBinaryCanonicalizationRule<moco::TFAdd>
+template <> struct EltwiseBinaryCanonicalizationRule<moco::tf::TFAdd>
{
using CanonicalNode = loco::EltwiseAdd;
};
-template <> struct EltwiseBinaryCanonicalizationRule<moco::TFSub>
+template <> struct EltwiseBinaryCanonicalizationRule<moco::tf::TFSub>
{
using CanonicalNode = loco::EltwiseSub;
};
-template <> struct EltwiseBinaryCanonicalizationRule<moco::TFMaximum>
-{
- using CanonicalNode = loco::EltwiseMax;
-};
-
-template <> struct EltwiseBinaryCanonicalizationRule<moco::TFMul>
+template <> struct EltwiseBinaryCanonicalizationRule<moco::tf::TFMul>
{
using CanonicalNode = loco::EltwiseMul;
};
-template <> struct EltwiseBinaryCanonicalizationRule<moco::TFRealDiv>
+template <> struct EltwiseBinaryCanonicalizationRule<moco::tf::TFRealDiv>
{
using CanonicalNode = loco::EltwiseDiv;
};
@@ -67,34 +46,25 @@ template <typename TFNode> bool canonicalize_eltwise_binary_node(TFNode *node)
* This will replace T/F Eltwise Binary node with a corresponding Canonical Eltwise node
*
* BEFORE
- * A --- T/F Node --- C
- * /
- * B ----
+ * A --- T/F Node -- C
+ * B --/
*
* AFTER
- * A --- T/F Node ---
- * /
- * B ----
- *
- * A --- [FixedReshape] --- [TensorBroadcast] --- Canonical Node -- C
- * /
- * B --- [FixedReshape] --- [TensorBroadcast] ----
+ * +------ T/F Node --
+ * | /
+ * B -------
+ * | \
+ * A -+------ Canonical Node -- C
*
- * NOTE
- * - [...] means optional node. They may or may not be created during this procedure.
- * - T/F Node is disconnected from C after transformation.
+ * NOTE T/F Node is disconnected from C after transformation
*/
using CanonicalNodeT = typename EltwiseBinaryCanonicalizationRule<TFNode>::CanonicalNode;
+ // check condition: shape should be same
auto node_A = node->x();
auto node_B = node->y();
- if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
- return false;
- if (!loco::shape_known(node))
- return false;
-
auto out_shape = loco::shape_get(node).template as<loco::TensorShape>();
// Create a node
diff --git a/compiler/moco-tf/src/TFFormattedGraph.cpp b/compiler/moco-tf/src/TFFormattedGraph.cpp
index 2ea514a2b..08fad7a50 100644
--- a/compiler/moco-tf/src/TFFormattedGraph.cpp
+++ b/compiler/moco-tf/src/TFFormattedGraph.cpp
@@ -16,14 +16,15 @@
#include "TFFormattedGraph.h"
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
#include "LogHelper.h"
#include <pepper/str.h>
#include <locoex/Service/COpFormattedGraph.h>
-#include <oops/InternalExn.h>
#include <sstream>
@@ -34,12 +35,12 @@ std::string opname(uint32_t opnum)
{
static std::string prefix{"tf."};
- switch (static_cast<moco::TFOpcode>(opnum))
+ switch (static_cast<moco::tf::TFOpcode>(opnum))
{
#define TENSORFLOW_NODE(OPCODE, CLASS) \
- case moco::TFOpcode::OPCODE: \
+ case moco::tf::TFOpcode::OPCODE: \
return prefix + #OPCODE;
-#include <moco/IR/TFNodes.lst>
+#include "Dialect/TFNodes.lst"
#undef TENSORFLOW_NODE
default:
break;
@@ -48,7 +49,6 @@ std::string opname(uint32_t opnum)
return prefix + "Invalid";
}
-using namespace moco;
using namespace moco::tf;
/// TFNodeSummaryBuilder with default implementation
@@ -71,7 +71,7 @@ protected:
s.state(locop::NodeSummary::State::PartiallyKnown); \
return true; \
}
-#include <moco/IR/TFNodes.lst>
+#include "Dialect/TFNodes.lst"
#undef TENSORFLOW_NODE
protected:
@@ -99,24 +99,18 @@ private:
IMPLEMENT(TFConcatV2);
IMPLEMENT(TFConst);
IMPLEMENT(TFConv2D);
- IMPLEMENT(TFConv2DBackpropInput);
IMPLEMENT(TFDepthwiseConv2dNative);
IMPLEMENT(TFFusedBatchNorm);
- IMPLEMENT(TFMaximum);
IMPLEMENT(TFMaxPool);
IMPLEMENT(TFMean);
IMPLEMENT(TFMul);
- IMPLEMENT(TFPack);
IMPLEMENT(TFReshape);
IMPLEMENT(TFRsqrt);
IMPLEMENT(TFShape);
IMPLEMENT(TFSoftmax);
IMPLEMENT(TFSqueeze);
IMPLEMENT(TFStopGradient);
- IMPLEMENT(TFStridedSlice);
IMPLEMENT(TFTanh);
- // For virtual nodes
- IMPLEMENT(TFPush);
#undef IMPLEMENT
};
@@ -131,7 +125,7 @@ bool TFNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSummary
s.opname(opname(node->opnum())); \
return summary(dynamic_cast<const CLASS *>(node), s); \
}
-#include <moco/IR/TFNodes.lst>
+#include "Dialect/TFNodes.lst"
#undef TENSORFLOW_NODE
return false;
@@ -183,6 +177,10 @@ bool TFNodeSummaryBuilder::summary(const TFConst *node, locop::NodeSummary &s) c
{
std::ostringstream ss;
+ auto shapedata = node->annot<ShapeInferenceData>();
+ // TODO show real numbers like [1,2,3,4]
+ s.args().append("shape", shapedata ? "OK" : "?");
+
auto dtype = node->dtype();
switch (dtype)
{
@@ -193,7 +191,7 @@ bool TFNodeSummaryBuilder::summary(const TFConst *node, locop::NodeSummary &s) c
ss << node->size<loco::DataType::FLOAT32>();
break;
default:
- INTERNAL_EXN_V("Unsupported data type", node->name());
+ throw std::runtime_error("NYI for this DataType");
}
s.args().append("size", ss.str());
s.state(locop::NodeSummary::State::PartiallyKnown);
@@ -211,18 +209,6 @@ bool TFNodeSummaryBuilder::summary(const TFConv2D *node, locop::NodeSummary &s)
return true;
}
-bool TFNodeSummaryBuilder::summary(const TFConv2DBackpropInput *node, locop::NodeSummary &s) const
-{
- s.args().append("input_sizes", tbl()->lookup(node->input_sizes()));
- s.args().append("filter", tbl()->lookup(node->filter()));
- s.args().append("out_backprop", tbl()->lookup(node->out_backprop()));
- s.args().append("padding", node->padding());
- s.args().append("data_layout", node->data_layout());
- s.args().append("strides", pepper::str(node->strides()));
- s.state(locop::NodeSummary::State::PartiallyKnown);
- return true;
-}
-
bool TFNodeSummaryBuilder::summary(const TFDepthwiseConv2dNative *node, locop::NodeSummary &s) const
{
s.args().append("input", tbl()->lookup(node->input()));
@@ -236,9 +222,9 @@ bool TFNodeSummaryBuilder::summary(const TFDepthwiseConv2dNative *node, locop::N
bool TFNodeSummaryBuilder::summary(const TFFusedBatchNorm *node, locop::NodeSummary &s) const
{
- s.args().append("x", tbl()->lookup(node->x()));
- s.args().append("scale", tbl()->lookup(node->scale()));
- s.args().append("offset", tbl()->lookup(node->offset()));
+ s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("gamma", tbl()->lookup(node->gamma()));
+ s.args().append("beta", tbl()->lookup(node->beta()));
s.args().append("mean", tbl()->lookup(node->mean()));
s.args().append("variance", tbl()->lookup(node->variance()));
s.args().append("epsilon", pepper::str(node->epsilon()));
@@ -246,17 +232,9 @@ bool TFNodeSummaryBuilder::summary(const TFFusedBatchNorm *node, locop::NodeSumm
return true;
}
-bool TFNodeSummaryBuilder::summary(const TFMaximum *node, locop::NodeSummary &s) const
-{
- s.args().append("x", tbl()->lookup(node->x()));
- s.args().append("y", tbl()->lookup(node->y()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
bool TFNodeSummaryBuilder::summary(const TFMaxPool *node, locop::NodeSummary &s) const
{
- s.args().append("input", tbl()->lookup(node->input()));
+ s.args().append("value", tbl()->lookup(node->value()));
s.args().append("ksize", pepper::str(node->ksize()));
s.args().append("strides", pepper::str(node->strides()));
s.args().append("padding", node->padding());
@@ -283,16 +261,6 @@ bool TFNodeSummaryBuilder::summary(const TFMul *node, locop::NodeSummary &s) con
return true;
}
-bool TFNodeSummaryBuilder::summary(const TFPack *node, locop::NodeSummary &s) const
-{
- s.args().append("N", pepper::str(node->N()));
- s.args().append("axis", pepper::str(node->axis()));
- for (uint32_t n = 0; n < node->N(); ++n)
- s.args().append("values", tbl()->lookup(node->values(n)));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
bool TFNodeSummaryBuilder::summary(const TFReshape *node, locop::NodeSummary &s) const
{
s.args().append("tensor", tbl()->lookup(node->tensor()));
@@ -337,22 +305,6 @@ bool TFNodeSummaryBuilder::summary(const TFStopGradient *node, locop::NodeSummar
return true;
}
-bool TFNodeSummaryBuilder::summary(const TFStridedSlice *node, locop::NodeSummary &s) const
-{
- s.args().append("input", tbl()->lookup(node->input()));
- s.args().append("begin", tbl()->lookup(node->begin()));
- s.args().append("end", tbl()->lookup(node->end()));
- if (node->strides() != nullptr)
- s.args().append("strides", tbl()->lookup(node->strides()));
- s.args().append("begin_mask", pepper::str(node->begin_mask()));
- s.args().append("end_mask", pepper::str(node->end_mask()));
- s.args().append("ellipsis_mask", pepper::str(node->ellipsis_mask()));
- s.args().append("new_axis_mask", pepper::str(node->new_axis_mask()));
- s.args().append("shrink_axis_mask", pepper::str(node->shrink_axis_mask()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
bool TFNodeSummaryBuilder::summary(const TFTanh *node, locop::NodeSummary &s) const
{
s.args().append("x", tbl()->lookup(node->x()));
@@ -360,15 +312,6 @@ bool TFNodeSummaryBuilder::summary(const TFTanh *node, locop::NodeSummary &s) co
return true;
}
-// For virtual nodes
-bool TFNodeSummaryBuilder::summary(const TFPush *node, locop::NodeSummary &s) const
-{
- s.args().append("index", node->indexed() ? pepper::str(node->index()) : "?");
- s.args().append("from", tbl()->lookup(node->from()));
- s.state(locop::NodeSummary::State::Complete);
- return true;
-}
-
} // namespace
namespace moco
diff --git a/compiler/moco-tf/src/TFOptimizer.cpp b/compiler/moco-tf/src/TFOptimizer.cpp
index 2256b99b8..f6f76a718 100644
--- a/compiler/moco-tf/src/TFOptimizer.cpp
+++ b/compiler/moco-tf/src/TFOptimizer.cpp
@@ -36,37 +36,31 @@ void TFOptimizer::optimize(loco::Graph *g) const
/* TRANSFORM DECLARATION BEGIN */
if (moco::tf::get<moco::tf::Knob::ResolveFusedBatchNorm>())
{
- phase.emplace_back(stdex::make_unique<moco::ResolveFusedBatchNorm>());
+ phase.emplace_back(stdex::make_unique<moco::tf::ResolveFusedBatchNorm>());
}
if (moco::tf::get<moco::tf::Knob::FuseBinaryIntoPreceding>())
{
- phase.emplace_back(stdex::make_unique<moco::FuseBinaryIntoPreceding>());
+ phase.emplace_back(stdex::make_unique<moco::tf::FuseBinaryIntoPreceding>());
}
if (moco::tf::get<moco::tf::Knob::ResolveConstantShape>())
{
- phase.emplace_back(stdex::make_unique<moco::ResolveConstantShape>());
+ phase.emplace_back(stdex::make_unique<moco::tf::ResolveConstantShape>());
}
if (moco::tf::get<moco::tf::Knob::ResolveReshapeWildcardDim>())
{
- phase.emplace_back(stdex::make_unique<moco::ResolveReshapeWildcardDim>());
- }
- if (moco::tf::get<moco::tf::Knob::ResolveSquaredDifference>())
- {
- phase.emplace_back(stdex::make_unique<moco::ResolveSquaredDifference>());
+ phase.emplace_back(stdex::make_unique<moco::tf::ResolveReshapeWildcardDim>());
}
if (moco::tf::get<moco::tf::Knob::RemoveTFIdentityNode>())
{
- phase.emplace_back(stdex::make_unique<moco::RemoveTFIdentityNode>());
+ phase.emplace_back(stdex::make_unique<moco::tf::RemoveTFIdentityNodeTransform>());
}
if (moco::tf::get<moco::tf::Knob::RemoveDeadNode>())
{
phase.emplace_back(stdex::make_unique<logo::RemoveDeadNodePass>());
}
- if (moco::tf::get<moco::tf::Knob::SqueezeReduceNode>())
- {
- phase.emplace_back(stdex::make_unique<moco::SqueezeReduceNode>());
- }
- // Shape inference is needed for added nodes doing above transformations
+ // Fix shape and pad for added nodes doing above transformations
+ // TODO need to merge or remove the ones in importer
+ phase.emplace_back(stdex::make_unique<moco::tf::FixShapeTransform>());
phase.emplace_back(stdex::make_unique<moco::tf::ShapeInferencePass>());
phase.emplace_back(stdex::make_unique<moco::tf::TypeInferencePass>());
/* TRANSFORM DECLARATION END */
diff --git a/compiler/moco-tf/src/TFReduceCanonicalzeHelper.h b/compiler/moco-tf/src/TFReduceCanonicalzeHelper.h
deleted file mode 100644
index abd24cec8..000000000
--- a/compiler/moco-tf/src/TFReduceCanonicalzeHelper.h
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __TF_REDUCE_CANONICALIZE_HELPER_H__
-#define __TF_REDUCE_CANONICALIZE_HELPER_H__
-
-#include <moco/IR/TFDialect.h>
-#include <moco/IR/TFNodes.h>
-
-#include <loco/Service/ShapeInference.h>
-
-#include <moco/Log.h>
-
-namespace
-{
-
-template <typename TFNodeT> loco::ReduceFunc reduceFunc(void);
-
-template <> loco::ReduceFunc reduceFunc<moco::TFMean>(void) { return loco::ReduceFunc::Mean; }
-
-template <typename TFNode> bool canonicalize_reduce_node(TFNode *node)
-{
- LOGGER(l);
-
- INFO(l) << "TFNodeCanonicalize ReduceNode begin";
-
- auto graph = node->graph();
-
- /**
- * This will replace T/F Reduce node with a corresponding Canonical Reduce node
- *
- * BEFORE
- * reduction_indices -------- T/F Node -- C
- * input -------/
- *
- * AFTER
- * +------ T/F Node --
- * | /
- * reduction_indices -------
- * | \
- * input -+------ Canonical Node -- C
- *
- * NOTE
- * - T/F Node is disconnected from C after transformation
- */
-
- // TFSqueeze had to be inserted if keep_dims() was false
- assert(node->keep_dims());
-
- auto axes_node = node->reduction_indices();
- assert(axes_node != nullptr);
-
- auto node_tensor_shape = loco::shape_get(node).template as<loco::TensorShape>();
-
- // Canonicalization into TensorReduce is valid when reduction indices is constant
- // TODO Support general TensorReduce case
- std::vector<int32_t> axes_values;
- if (auto const_axes = dynamic_cast<moco::TFConst *>(axes_node))
- {
- // TODO Support S64 type
- assert(const_axes->dtype() == loco::DataType::S32);
-
- for (uint32_t i = 0; i < const_axes->size<loco::DataType::S32>(); ++i)
- {
- int32_t axis = const_axes->at<loco::DataType::S32>(i);
- if (axis < 0)
- axis += node_tensor_shape.rank();
- axes_values.push_back(axis);
- }
- }
- else if (auto const_axes = dynamic_cast<loco::ConstGen *>(axes_node))
- {
- // TODO Support S64 type
- assert(const_axes->dtype() == loco::DataType::S32);
-
- for (uint32_t i = 0; i < const_axes->size<loco::DataType::S32>(); ++i)
- {
- int32_t axis = const_axes->at<loco::DataType::S32>(i);
- if (axis < 0)
- axis += node_tensor_shape.rank();
- axes_values.push_back(axis);
- }
- }
- else
- return false;
-
- // Create loco node to replace
- auto reduce = graph->nodes()->template create<loco::TensorReduce>();
-
- // replace
- reduce->func(reduceFunc<TFNode>());
- reduce->input(node->input());
- for (uint32_t i = 0; i < axes_values.size(); ++i)
- reduce->axes()->insert(axes_values.at(i));
-
- replace(node).with(reduce);
-
- INFO(l) << "TFNodeCanonicalize ReduceNode done";
-
- return true;
-}
-
-} // namespace
-
-#endif // __TF_REDUCE_CANONICALIZE_HELPER_H__
diff --git a/compiler/moco-tf/src/TestHelper.cpp b/compiler/moco-tf/src/TestHelper.cpp
new file mode 100644
index 000000000..412556a68
--- /dev/null
+++ b/compiler/moco-tf/src/TestHelper.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <cstring>
+
+namespace moco
+{
+namespace tf
+{
+namespace test
+{
+
+void setup_output_node(loco::Graph *graph, loco::Node *last_node)
+{
+ // add push as output
+ auto push_node = graph->nodes()->create<loco::Push>();
+ push_node->from(last_node);
+
+ // set the graph output name and node object
+ auto graph_output = graph->outputs()->create();
+ graph_output->name("output");
+ graph_output->dtype(loco::DataType::FLOAT32);
+ loco::link(graph_output, push_node);
+}
+
+} // namespace test
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/TestHelper.h b/compiler/moco-tf/src/TestHelper.h
index dd32d4433..6978efaab 100644
--- a/compiler/moco-tf/src/TestHelper.h
+++ b/compiler/moco-tf/src/TestHelper.h
@@ -21,6 +21,8 @@
#include <tensorflow/core/framework/graph.pb.h>
+#include <istream>
+
#define STRING_CONTENT(content) #content
namespace moco
@@ -73,41 +75,4 @@ void setup_output_node(loco::Graph *graph, loco::Node *last_node);
} // namespace tf
} // namespace moco
-#include <moco/IR/TFNode.h>
-
-#include <moco/Import/GraphBuilder.h>
-
-#include <plier/tf/TestHelper.h>
-
-namespace moco
-{
-namespace tf
-{
-namespace test
-{
-
-class TFNodeBuildTester
-{
-public:
- TFNodeBuildTester();
-
-public:
- void inputs(const std::vector<std::string> &names);
- void output(const char *name);
- moco::TFNode *output(void);
-
- void run(tensorflow::NodeDef &node_def, moco::GraphBuilder &graph_builder);
-
-private:
- std::unique_ptr<moco::SymbolTable> _tensor_names;
- std::unique_ptr<loco::Graph> _graph;
-
- std::vector<moco::TFNode *> _inputs;
- const char *_output{nullptr};
-};
-
-} // namespace test
-} // namespace tf
-} // namespace moco
-
#endif // __TEST_HELPER_H__
diff --git a/compiler/moco-tf/src/TestHelper.test.cpp b/compiler/moco-tf/src/TestHelper.test.cpp
deleted file mode 100644
index 1e8c38e36..000000000
--- a/compiler/moco-tf/src/TestHelper.test.cpp
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TestHelper.h"
-
-#include <google/protobuf/io/coded_stream.h>
-#include <google/protobuf/io/zero_copy_stream_impl.h>
-#include <google/protobuf/text_format.h>
-
-#include <cstring>
-
-namespace moco
-{
-namespace tf
-{
-namespace test
-{
-
-void setup_output_node(loco::Graph *graph, loco::Node *last_node)
-{
- // add push as output
- auto push_node = graph->nodes()->create<loco::Push>();
- push_node->from(last_node);
-
- // set the graph output name and node object
- auto graph_output = graph->outputs()->create();
- graph_output->name("output");
- graph_output->dtype(loco::DataType::FLOAT32);
- loco::link(graph_output, push_node);
-}
-
-} // namespace test
-} // namespace tf
-} // namespace moco
-
-#include <moco/IR/Nodes/TFConst.h>
-
-#include <stdex/Memory.h>
-
-#include <gtest/gtest.h>
-
-namespace moco
-{
-namespace tf
-{
-namespace test
-{
-
-TFNodeBuildTester::TFNodeBuildTester()
-{
- _graph = loco::make_graph();
- _tensor_names = stdex::make_unique<moco::SymbolTable>();
-}
-
-void TFNodeBuildTester::inputs(const std::vector<std::string> &names)
-{
- for (auto name : names)
- {
- auto input = _graph->nodes()->create<moco::TFConst>();
- moco::TensorName name_01(name, 0);
- _tensor_names->enroll(name_01, input);
-
- _inputs.push_back(input);
- }
-}
-
-void TFNodeBuildTester::output(const char *name) { _output = name; }
-
-moco::TFNode *TFNodeBuildTester::output(void)
-{
- assert(_output != nullptr);
-
- moco::TensorName tname(_output, 0);
- return static_cast<moco::TFNode *>(_tensor_names->node(tname));
-}
-
-void TFNodeBuildTester::run(tensorflow::NodeDef &nodedef, moco::GraphBuilder &graphbuilder)
-{
- assert(_output != nullptr);
-
- auto node_defs = stdex::make_unique<moco::NodeDefTable>();
- auto updates = stdex::make_unique<moco::UpdateQueue>();
-
- moco::GraphBuilderContext gb_context(_graph.get(), node_defs.get(), _tensor_names.get(),
- updates.get());
-
- EXPECT_TRUE(graphbuilder.validate(nodedef));
- graphbuilder.build(nodedef, &gb_context);
-
- for (auto &update : updates->queue())
- {
- update->input(_tensor_names.get());
- }
-
- auto tfnode = output();
- ASSERT_NE(tfnode, nullptr);
-
- int idx = 0;
- ASSERT_EQ(tfnode->arity(), _inputs.size());
- for (auto input : _inputs)
- {
- ASSERT_EQ(tfnode->arg(idx++), input);
- }
-}
-
-} // namespace test
-} // namespace tf
-} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms.h b/compiler/moco-tf/src/Transforms.h
index f14b81675..653adbf3a 100644
--- a/compiler/moco-tf/src/Transforms.h
+++ b/compiler/moco-tf/src/Transforms.h
@@ -17,10 +17,16 @@
#ifndef __MOCO_TF_TRANSFORMS_H__
#define __MOCO_TF_TRANSFORMS_H__
+#include "Transforms/ClearAnnotTransform.h"
+#include "Transforms/FixShapeTransform.h"
+#include "Transforms/FuseBinaryIntoPreceding.h"
+#include "Transforms/RemoveTFIdentityNodeTransform.h"
+#include "Transforms/ResolveConstantShape.h"
+#include "Transforms/ResolveFusedBatchNorm.h"
+#include "Transforms/ResolveReshapeWildcardDim.h"
#include "Transforms/ShapeInferencePass.h"
#include "Transforms/TypeInferencePass.h"
#include <logo/Passes.h>
-#include <moco/Pass/Passes.h>
#endif // __MOCO_TF_TRANSFORMS_H__
diff --git a/compiler/moco-tf/src/Transforms/ClearAnnotTransform.cpp b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.cpp
new file mode 100644
index 000000000..37873cb04
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClearAnnotTransform.h"
+
+#include "Annotations/ConcatData.h"
+#include "Annotations/PadData.h"
+#include "Annotations/PaddingData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/StrideData.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+bool ClearAnnotTransform::run(loco::Graph *graph)
+{
+ for (auto node : loco::all_nodes(graph))
+ {
+// clang-format off
+#define MOCOANNOT(TYPE_NAME) \
+ { \
+ auto annot_data = node->annot<TYPE_NAME>(); \
+ if (annot_data != nullptr) \
+ { \
+ node->annot<TYPE_NAME>(nullptr); \
+ } \
+ }
+MOCOANNOT(ConcatData)
+MOCOANNOT(PadData)
+MOCOANNOT(PaddingData)
+MOCOANNOT(ShapeInferenceData)
+MOCOANNOT(StrideData)
+ // TODO add more annotation(s) to clear when defined
+#undef MOCOANNOT
+ // clang-format on
+ }
+
+ /** @note Current design requires to return boolean for changed but we don't
+ * need for this one-time-execution.
+ * It would be better to separate two transform types later.
+ */
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.h
index 569a71f82..ab56097e9 100644
--- a/compiler/moco-tf/src/Canonicalization/TFPushCanonicalizer.h
+++ b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.h
@@ -14,13 +14,10 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_PUSH_CANONICALIZER_H__
-#define __MOCO_TF_PUSH_CANONICALIZER_H__
+#ifndef __MOCO_TF_CLEAR_ANNOT_TRANSFORM_H__
+#define __MOCO_TF_CLEAR_ANNOT_TRANSFORM_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -30,18 +27,18 @@ namespace tf
{
/**
- * @brief Convert TFPush to Canonical Push
+ * @brief Clear(delete) annotation data no needed anymore
*/
-class TFPushCanonicalizer : public SimpleNodeTransform<moco::TFPush>
+class ClearAnnotTransform : public Transform
{
public:
- const char *name(void) const final { return "TFPushCanonicalizer"; }
+ const char *name(void) const final { return "ClearAnnotTransform"; }
public:
- bool transform(moco::TFPush *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_PUSH_CANONICALIZER_H__
+#endif // __MOCO_TF_CLEAR_ANNOT_TRANSFORM_H__
diff --git a/compiler/moco-tf/src/Transforms/ClearAnnotTransform.test.cpp b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.test.cpp
new file mode 100644
index 000000000..f2ea39be0
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ClearAnnotTransform.test.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ClearAnnotTransform.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(ClearAnnotTransform, ctor)
+{
+ moco::tf::ClearAnnotTransform catransform;
+ loco::Graph graph;
+
+ ASSERT_FALSE(catransform.run(&graph));
+}
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
new file mode 100644
index 000000000..93570dbbc
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
@@ -0,0 +1,1539 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FixShapeTransform.h"
+
+#include "LogHelper.h"
+
+#include "Annotations/ConcatData.h"
+#include "Annotations/PadData.h"
+#include "Annotations/ShapeInferenceData.h"
+#include "Annotations/StrideData.h"
+#include "Annotations/WindowData.h"
+#include "Dialect/TFNodes.h"
+
+#include <loco.h>
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+#include <plier/tf/Convert.h>
+#include <locoex/COpCall.h>
+
+#include <cassert>
+#include <sstream>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+/**
+ * @brief Return true if node has shape inference data for checking shape
+ * inference is done or not
+ */
+bool shape_inference_done(const loco::Node *node)
+{
+ auto shapedata = node->annot<ShapeInferenceData>();
+ return (shapedata != nullptr);
+}
+
+/**
+ * @brief Copy ShapeInferenceData values from src to dst
+ *
+ * @note T can be ShapeInferenceData or loco::Node based class having shape
+ * attributes like ConstGen, Pull and so on
+ */
+template <class T> void copy_shape_values(const T *src, ShapeInferenceData *dst)
+{
+ assert(src != nullptr);
+ assert(dst != nullptr);
+
+ uint32_t rank = src->rank();
+ dst->rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (src->dim(index).known())
+ dst->dim(index) = src->dim(index).value();
+ else
+ dst->dim(index).unset();
+ }
+}
+
+/**
+ * @brief Make copy of ShapeInferenceData from src
+ *
+ * @note T can be ShapeInferenceData or loco::Node based class having shape
+ * attributes like TFConst, COpCall and so on
+ */
+template <class T> std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const T *src)
+{
+ assert(src != nullptr);
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+
+ uint32_t rank = src->rank();
+ shape_data->rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (src->dim(index).known())
+ shape_data->dim(index) = src->dim(index).value();
+ else
+ shape_data->dim(index).unset();
+ }
+
+ return std::move(shape_data);
+}
+
+std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const loco::NodeShape &src)
+{
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+
+ switch (src.domain())
+ {
+ case loco::Domain::Tensor:
+ shape_data->tensor_shape(src.as<loco::TensorShape>());
+ break;
+
+ case loco::Domain::Feature:
+ shape_data->feature_shape(src.as<loco::FeatureShape>());
+ break;
+
+ case loco::Domain::Filter:
+ shape_data->filter_shape(src.as<loco::FilterShape>());
+ break;
+
+ case loco::Domain::DepthwiseFilter:
+ shape_data->depthwisefilter_shape(src.as<loco::DepthwiseFilterShape>());
+ break;
+
+ case loco::Domain::Bias:
+ shape_data->bias_shape(src.as<loco::BiasShape>());
+ break;
+
+ default:
+ throw std::runtime_error("Unsupported Domain in make_shape_inference_data");
+ }
+
+ return std::move(shape_data);
+}
+
+loco::NodeShape as_node_shape(const ShapeInferenceData *shapedata)
+{
+ switch (shapedata->domain())
+ {
+ case loco::Domain::Tensor:
+ return loco::NodeShape({shapedata->tensor_shape()});
+
+ case loco::Domain::Feature:
+ return loco::NodeShape({shapedata->feature_shape()});
+
+ case loco::Domain::Filter:
+ return loco::NodeShape({shapedata->filter_shape()});
+
+ case loco::Domain::DepthwiseFilter:
+ return loco::NodeShape({shapedata->depthwisefilter_shape()});
+
+ case loco::Domain::Bias:
+ return loco::NodeShape({shapedata->bias_shape()});
+ }
+
+ throw std::runtime_error("Unsupported Domain in as_node_shape");
+}
+
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ * auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+ TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::TensorShape to(uint32_t output_rank)
+ {
+ auto const &input_shape = _shape;
+ uint32_t const input_rank = input_shape.rank();
+
+ assert(input_rank <= output_rank && "Cannot shrink rank");
+ uint32_t const axis_shift = output_rank - input_rank;
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(output_rank);
+ for (uint32_t axis = 0; axis < output_rank; ++axis)
+ {
+ output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+ }
+
+ return output_shape;
+ }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+/**
+ * @breif Expand shape x and y to same rank by align right and filling with 1
+ */
+void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
+{
+ auto x_rank = x.rank();
+ auto y_rank = y.rank();
+
+ if (x_rank == y_rank)
+ return;
+
+ TensorShapeExpander x_exp(x);
+ TensorShapeExpander y_exp(y);
+
+ auto xy_rank = std::max(x_rank, y_rank);
+
+ x = x_rank > y_rank ? x : x_exp.to(xy_rank);
+ y = y_rank > x_rank ? y : y_exp.to(xy_rank);
+}
+
+/**
+ * @breif Returns shape of expanded dimension of input x and y having same rank
+ */
+loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ assert(x.rank() == y.rank());
+
+ auto rank = x.rank();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(rank);
+ for (auto axis = 0; axis < rank; ++axis)
+ {
+ assert(x.dim(axis).known() && y.dim(axis).known());
+
+ auto x_dim = x.dim(axis).value();
+ auto y_dim = y.dim(axis).value();
+
+ // each dimension of x and y should be same or one must be 1 if different
+ if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
+ throw std::runtime_error("Cannot produce expand_dimension of two shapes");
+
+ output_shape.dim(axis) = std::max(x_dim, y_dim);
+ }
+
+ return output_shape;
+}
+
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ auto x_match = x;
+ auto y_match = y;
+
+ expand_rank(x_match, y_match);
+
+ auto output_shape = expand_dimension(x_match, y_match);
+
+ return output_shape;
+}
+
+/**
+ * @brief Copy ShapeInferenceData from loco::Node pointer src to dst
+ */
+bool copy_shapedata(const loco::Node *src, loco::Node *dst)
+{
+ // if dst already has ShapeInferenceData, skip
+ if (shape_inference_done(dst))
+ return false;
+
+ // if src has loco::NodeShape, use it
+ if (loco::shape_known(src))
+ {
+ auto shape_data = make_shape_inference_data(loco::shape_get(src));
+ dst->annot(std::move(shape_data));
+
+ return true;
+ }
+
+ // if src doesn't have ShapeInferenceData, skip
+ if (!shape_inference_done(src))
+ return false;
+
+ auto src_shapedata = src->annot<ShapeInferenceData>();
+ auto shape_data = make_shape_inference_data(src_shapedata);
+ dst->annot(std::move(shape_data));
+
+ return true;
+}
+
+/**
+ * @note This will find broadcast shape from two inputs lhs and rhs using
+ * broadcast_shape() and return that shape to dst
+ */
+bool copy_shapedata(const loco::Node *lhs, const loco::Node *rhs, loco::Node *dst)
+{
+ // if dst already has ShapeInferenceData, skip
+ if (shape_inference_done(dst))
+ return false;
+
+ auto get_node_shape = [](const loco::Node *node, loco::NodeShape &out) {
+ if (loco::shape_known(node))
+ {
+ out = loco::shape_get(node);
+ }
+ else
+ {
+ if (!shape_inference_done(node))
+ return false;
+
+ out = as_node_shape(node->annot<ShapeInferenceData>());
+ }
+
+ return true;
+ };
+
+ loco::NodeShape lhs_shape;
+ loco::NodeShape rhs_shape;
+
+ if (loco::shape_known(lhs))
+ {
+ lhs_shape = loco::shape_get(lhs);
+ }
+ else
+ {
+ if (!shape_inference_done(lhs))
+ return false;
+
+ lhs_shape = as_node_shape(lhs->annot<ShapeInferenceData>());
+ }
+
+ if (loco::shape_known(rhs))
+ {
+ rhs_shape = loco::shape_get(rhs);
+ }
+ else
+ {
+ if (!shape_inference_done(rhs))
+ return false;
+
+ rhs_shape = as_node_shape(rhs->annot<ShapeInferenceData>());
+ }
+
+ if (lhs_shape.domain() != loco::Domain::Tensor || rhs_shape.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("copy_shapedata supports only for Tensor");
+ }
+
+ loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>();
+ loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>();
+ loco::TensorShape sum_tensorshape = broadcast_shape(lhs_tensorshape, rhs_tensorshape);
+
+ loco::NodeShape sum_shape({sum_tensorshape});
+ auto shape_data = make_shape_inference_data(sum_shape);
+ dst->annot(std::move(shape_data));
+
+ LOGGER(l);
+
+ INFO(l) << "copy_shapedata " << lhs_tensorshape << " or " << rhs_tensorshape << " -> "
+ << sum_tensorshape << std::endl;
+
+ return true;
+}
+
+/**
+ * @note While in shape inference, Node maybe Canonical, TF dialect or other dialects
+ * This will provide common loco::NodeShape as shape information
+ */
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
+{
+ if (loco::shape_known(node))
+ {
+ nodeshape = loco::shape_get(node);
+ return true;
+ }
+
+ if (!shape_inference_done(node))
+ return false;
+
+ auto shapedata = node->annot<ShapeInferenceData>();
+
+ switch (shapedata->domain())
+ {
+ case loco::Domain::Tensor:
+ nodeshape.set(shapedata->tensor_shape());
+ break;
+
+ case loco::Domain::Feature:
+ nodeshape.set(shapedata->feature_shape());
+ break;
+
+ case loco::Domain::Filter:
+ nodeshape.set(shapedata->filter_shape());
+ break;
+
+ case loco::Domain::DepthwiseFilter:
+ nodeshape.set(shapedata->depthwisefilter_shape());
+ break;
+
+ case loco::Domain::Bias:
+ nodeshape.set(shapedata->bias_shape());
+ break;
+
+ default:
+ throw std::runtime_error("Unsupported Domain in node_shape()");
+ }
+ return true;
+}
+
+loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
+ const TFDataLayout &data_layout)
+{
+ if (nodeshape.domain() == loco::Domain::Feature)
+ return nodeshape.as<loco::FeatureShape>();
+
+ loco::FeatureShape feature_shape;
+
+ // only convert from tensor to feature
+ if (nodeshape.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("as_feature_shape: domain is not tensor");
+ }
+
+ loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
+
+ if (tensor_shape.rank() != 4)
+ {
+ throw std::runtime_error("as_feature_shape: rank is not 4");
+ }
+
+ if (data_layout == "NHWC")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.height() = tensor_shape.dim(1);
+ feature_shape.width() = tensor_shape.dim(2);
+ feature_shape.depth() = tensor_shape.dim(3);
+ }
+ else if (data_layout == "NCHW")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.depth() = tensor_shape.dim(1);
+ feature_shape.height() = tensor_shape.dim(2);
+ feature_shape.width() = tensor_shape.dim(3);
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+ }
+
+ return feature_shape;
+}
+
+struct FixPadContext
+{
+ uint32_t input_height;
+ uint32_t input_width;
+ uint32_t output_height;
+ uint32_t output_width;
+ uint32_t stride_height;
+ uint32_t stride_width;
+ uint32_t effective_window_height;
+ uint32_t effective_window_width;
+};
+
+PadData calc_paddata(const FixPadContext &ctx)
+{
+ assert(ctx.output_height > 0);
+ assert(ctx.output_width > 0);
+
+ // calculate padding height, width
+ int64_t i_height = (int64_t)(ctx.output_height - 1) * (int64_t)ctx.stride_height +
+ (int64_t)ctx.effective_window_height - (int64_t)ctx.input_height;
+ int64_t i_width = (int64_t)(ctx.output_width - 1) * (int64_t)ctx.stride_width +
+ (int64_t)ctx.effective_window_width - (int64_t)ctx.input_width;
+ uint32_t pad_height = i_height >= 0 ? (uint32_t)i_height : 0U;
+ uint32_t pad_width = i_width >= 0 ? (uint32_t)i_width : 0U;
+
+ PadData pad_data;
+
+ pad_data.pad()->top(pad_height / 2);
+ pad_data.pad()->bottom(pad_height - pad_data.pad()->top());
+ pad_data.pad()->left(pad_width / 2);
+ pad_data.pad()->right(pad_width - pad_data.pad()->left());
+
+ return pad_data;
+}
+
+template <class T> void calc_annot_paddata(T *node, const FixPadContext &ctx)
+{
+ assert(node != nullptr);
+
+ PadData pd = calc_paddata(ctx);
+
+ // annotation of pad data
+ auto pad_data = stdex::make_unique<PadData>(pd);
+
+ node->annot(std::move(pad_data));
+
+ assert(node->template annot<PadData>() != nullptr);
+}
+
+template <class T> void update_stride_data(T *node)
+{
+ auto stride_data = stdex::make_unique<StrideData>();
+ auto strides = node->strides();
+ auto data_layout = plier::tf::as_data_layout(node->data_layout());
+ if (data_layout == plier::tf::DataLayout::NHWC)
+ {
+ stride_data->stride()->vertical(strides[1]);
+ stride_data->stride()->horizontal(strides[2]);
+ }
+ else if (data_layout == plier::tf::DataLayout::NCHW)
+ {
+ stride_data->stride()->vertical(strides[2]);
+ stride_data->stride()->horizontal(strides[3]);
+ }
+ node->annot(std::move(stride_data));
+}
+
+template <class T> void update_window_data(T *node)
+{
+ auto window_data = stdex::make_unique<WindowData>();
+ auto ksize = node->ksize();
+ auto data_layout = plier::tf::as_data_layout(node->data_layout());
+ if (data_layout == plier::tf::DataLayout::NHWC)
+ {
+ window_data->window()->vertical(ksize[1]);
+ window_data->window()->horizontal(ksize[2]);
+ }
+ else if (data_layout == plier::tf::DataLayout::NCHW)
+ {
+ window_data->window()->vertical(ksize[2]);
+ window_data->window()->horizontal(ksize[3]);
+ }
+ node->annot(std::move(window_data));
+}
+
+bool fix_shape(loco::Pull *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ // Pull itself has shape information, copy them
+ auto shape_data = make_shape_inference_data(node);
+ node->annot(std::move(shape_data));
+
+ return true;
+}
+
+bool fix_shape(loco::Push *node)
+{
+ // Output shape is same as the from
+ auto from = node->from();
+ return copy_shapedata(from, node);
+}
+
+bool fix_shape(moco::tf::TFAdd *node)
+{
+ auto x = node->x();
+ auto y = node->y();
+ loco::NodeShape x_shape;
+ loco::NodeShape y_shape;
+
+ if (!node_shape(x, x_shape))
+ return false;
+ if (!node_shape(y, y_shape))
+ return false;
+
+ // Output shape is same as the input
+ return copy_shapedata(x, y, node);
+}
+
+bool fix_shape(moco::tf::TFAvgPool *node)
+{
+ LOGGER(l);
+
+ if (shape_inference_done(node))
+ return false;
+
+ auto value = node->value();
+ loco::NodeShape value_shape;
+ if (!node_shape(value, value_shape))
+ {
+ // input node shape inference is not ready
+ return false;
+ }
+
+ auto padding = node->padding();
+ assert(padding == "VALID" || padding == "SAME");
+
+ update_stride_data(node);
+ update_window_data(node);
+
+ auto value_feature_shape = as_feature_shape(value_shape, node->data_layout());
+
+ auto stride_data = node->annot<StrideData>();
+ assert(stride_data != nullptr);
+ auto window_data = node->annot<WindowData>();
+ assert(window_data != nullptr);
+
+ uint32_t input_height = value_feature_shape.height().value();
+ uint32_t input_width = value_feature_shape.width().value();
+ uint32_t stride_height = stride_data->stride()->vertical();
+ uint32_t stride_width = stride_data->stride()->horizontal();
+ uint32_t window_height = window_data->window()->vertical();
+ uint32_t window_width = window_data->window()->horizontal();
+ uint32_t dilation_height = 1; // dilation is 1
+ uint32_t dilation_width = 1;
+ uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+ uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+ uint32_t output_height;
+ uint32_t output_width;
+
+ if (padding == "VALID")
+ {
+ output_height = (input_height + stride_height - effective_window_height) / stride_height;
+ output_width = (input_width + stride_width - effective_window_width) / stride_width;
+ }
+ else if (padding == "SAME")
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+
+ loco::FeatureShape ofm_feature_shape;
+ ofm_feature_shape.count() = value_feature_shape.count();
+ ofm_feature_shape.height() = output_height;
+ ofm_feature_shape.width() = output_width;
+ ofm_feature_shape.depth() = value_feature_shape.depth();
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ as_tensor_shape(*shape_data.get(), ofm_feature_shape, node->data_layout());
+ node->annot(std::move(shape_data));
+
+ FixPadContext ctx = {
+ input_height, input_width, output_height, output_width,
+ stride_height, stride_width, effective_window_height, effective_window_width};
+
+ calc_annot_paddata(node, ctx);
+
+ INFO(l) << "Fix TFAvgPool shape = ifm" << value_feature_shape << " --> ofm" << ofm_feature_shape;
+ INFO(l) << " pad = " << *node->annot<PadData>();
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFBiasAdd *node)
+{
+ auto value = node->value();
+ auto bias = node->bias();
+ loco::NodeShape value_shape;
+ loco::NodeShape bias_shape;
+ if (!node_shape(value, value_shape) || !node_shape(bias, bias_shape))
+ {
+ return false;
+ }
+
+ // Output shape is same as the value shape
+ return copy_shapedata(value, node);
+}
+
+template <class CONST_CLASS> bool valid_scala_value(CONST_CLASS *node)
+{
+ LOGGER(l);
+
+ loco::NodeShape nodeshape;
+ if (!node_shape(node, nodeshape))
+ {
+ return false;
+ }
+
+ if (node->dtype() != loco::DataType::S32)
+ {
+ INFO(l) << "valid_scala_value not S32";
+ return false;
+ }
+
+ auto tensor_shape = nodeshape.as<loco::TensorShape>();
+ if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
+ {
+ INFO(l) << "valid_scala_value rank not 0/1 : " << tensor_shape.rank();
+ return false;
+ }
+
+ return true;
+}
+
+template <class CONST_CLASS> int32_t scala_value(CONST_CLASS *node)
+{
+ loco::NodeShape nodeshape;
+ if (!node_shape(node, nodeshape))
+ {
+ return false;
+ }
+
+ assert(node->dtype() == loco::DataType::S32);
+
+ auto tensor_shape = nodeshape.as<loco::TensorShape>();
+ assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
+
+ return node->template at<loco::DataType::S32>(0);
+}
+
+bool fix_shape(moco::tf::TFConcatV2 *node)
+{
+ LOGGER(l);
+
+ if (shape_inference_done(node))
+ {
+ INFO(l) << "Fix shape TFConcatV2 already done";
+ return false;
+ }
+ // ConcatData should be null
+ assert(node->annot<ConcatData>() == nullptr);
+
+ // Check shape inference data are all ready
+ // Check shape rank are all same
+ auto value_a = node->values(0);
+ loco::NodeShape value_a_shape;
+ if (!node_shape(value_a, value_a_shape))
+ {
+ // shape inference is not ready for this value
+ INFO(l) << "Fix shape TFConcatV2 value 0 shape_data not ready";
+ return false;
+ }
+ assert(value_a_shape.domain() == loco::Domain::Tensor);
+ auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
+ uint32_t a_rank = value_a_tensor_shape.rank();
+
+ uint32_t num_values = node->num_values();
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ loco::NodeShape value_b_shape;
+ if (!node_shape(value_b, value_b_shape))
+ {
+ // shape inference is not ready for this value
+ INFO(l) << "Fix shape TFConcatV2 value " << ni << " shape_data not ready";
+ return false;
+ }
+ assert(value_b_shape.domain() == loco::Domain::Tensor);
+ auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
+ uint32_t b_rank = value_b_tensor_shape.rank();
+ assert(a_rank == b_rank);
+ }
+
+ // check for axis
+ auto axis_node = node->axis();
+ loco::NodeShape axis_shape;
+ if (!node_shape(axis_node, axis_shape))
+ {
+ // shape inference is not ready for axis_node
+ INFO(l) << "Fix shape TFConcatV2 axis shape_data not ready";
+ return false;
+ }
+
+ int32_t axis_value = 0;
+ bool axis_available = false;
+ {
+ // check for axis is TFConst
+ auto tfconst = dynamic_cast<moco::tf::TFConst *>(axis_node);
+ if (tfconst != nullptr)
+ {
+ if (valid_scala_value(tfconst))
+ {
+ axis_value = scala_value(tfconst);
+ axis_available = true;
+ }
+ }
+ }
+ {
+ // check for axis is ConstGen
+ auto constgen = dynamic_cast<loco::ConstGen *>(axis_node);
+ if (constgen != nullptr)
+ {
+ if (valid_scala_value(constgen))
+ {
+ axis_value = scala_value(constgen);
+ axis_available = true;
+ }
+ }
+ }
+ if (!axis_available)
+ {
+ // we cannot find a valid axis value
+ INFO(l) << "Fix shape TFConcatV2 axis_available false";
+ return false;
+ }
+
+ auto concat_data = stdex::make_unique<ConcatData>(axis_value);
+ node->annot(std::move(concat_data));
+
+ uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->rank(a_rank);
+
+ for (uint32_t index = 0; index < a_rank; ++index)
+ {
+ if (value_a_tensor_shape.dim(index).known())
+ {
+ uint32_t dim = value_a_tensor_shape.dim(index).value();
+ if (index == axis_absolute)
+ {
+ uint32_t dim_acc = dim;
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ loco::NodeShape value_b_shape;
+ node_shape(value_b, value_b_shape);
+ assert(value_b_shape.domain() == loco::Domain::Tensor);
+ auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
+ assert(value_b_tensor_shape.dim(index).known());
+ dim_acc += value_b_tensor_shape.dim(index).value();
+ }
+ dim = dim_acc;
+ }
+ shape_data->dim(index) = dim;
+ }
+ else
+ shape_data->dim(index).unset();
+ }
+ node->annot(std::move(shape_data));
+
+ INFO(l) << "Fix TFConcat shape = " << node->annot<ShapeInferenceData>();
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFConst *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ // TFConst itself has shape information, copy them
+ auto shape_data = make_shape_inference_data(node);
+ node->annot(std::move(shape_data));
+
+ {
+ LOGGER(l);
+ auto shapedata = node->annot<ShapeInferenceData>();
+ assert(shapedata != nullptr);
+ INFO(l) << "Fix TFConst shape = " << shapedata->tensor_shape();
+ }
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFConv2D *node)
+{
+ LOGGER(l);
+
+ if (shape_inference_done(node))
+ return false;
+
+ auto ifm = node->input();
+ loco::NodeShape ifm_shape;
+ if (!node_shape(ifm, ifm_shape))
+ {
+ // input node shape inference is not ready
+ return false;
+ }
+
+ auto ker = node->filter();
+ loco::NodeShape ker_shape;
+ if (!node_shape(ker, ker_shape))
+ {
+ return false;
+ }
+
+ auto padding = node->padding();
+ assert(padding == "VALID" || padding == "SAME");
+
+ update_stride_data(node);
+
+ auto stride_data = node->annot<StrideData>();
+ assert(stride_data != nullptr);
+ // TODO add and use 'stride_data->stride()' stream out
+ INFO(l) << "Fix TFConv2D strides = " << stride_data->stride()->vertical() << ", "
+ << stride_data->stride()->horizontal();
+
+ auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); // in NHWC
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
+ assert(ifm_tensor_shape.rank() == 4);
+ assert(ker_tensor_shape.rank() == 4);
+
+ uint32_t input_height = ifm_tensor_shape.dim(1).value();
+ uint32_t input_width = ifm_tensor_shape.dim(2).value();
+ uint32_t stride_height = stride_data->stride()->vertical();
+ uint32_t stride_width = stride_data->stride()->horizontal();
+ uint32_t ker_height = ker_tensor_shape.dim(0).value();
+ uint32_t ker_width = ker_tensor_shape.dim(1).value();
+ uint32_t dilation_height = 1; // TODO Consider dilation
+ uint32_t dilation_width = 1;
+ uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
+ uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
+ uint32_t output_height;
+ uint32_t output_width;
+
+ if (padding == "VALID")
+ {
+ output_height = (input_height + stride_height - effective_ker_height) / stride_height;
+ output_width = (input_width + stride_width - effective_ker_width) / stride_width;
+ }
+ else if (padding == "SAME")
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+ else
+ {
+ assert(false && "Unknown padding in fix_shape for TFConv2D");
+ }
+
+ loco::TensorShape ofm_tensor_shape;
+ ofm_tensor_shape.rank(4);
+ ofm_tensor_shape.dim(0) = ifm_tensor_shape.dim(0);
+ ofm_tensor_shape.dim(1) = output_height;
+ ofm_tensor_shape.dim(2) = output_width;
+ ofm_tensor_shape.dim(3) = ker_tensor_shape.dim(3);
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->tensor_shape(ofm_tensor_shape);
+ node->annot(std::move(shape_data));
+
+ FixPadContext ctx = {input_height, input_width, output_height, output_width,
+ stride_height, stride_width, effective_ker_height, effective_ker_width};
+
+ calc_annot_paddata(node, ctx);
+
+ INFO(l) << "Fix TFConv2D shape = ifm" << ifm_tensor_shape << " ker" << ker_tensor_shape
+ << " --> ofm" << ofm_tensor_shape;
+ INFO(l) << " pad = " << *node->annot<PadData>();
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFDepthwiseConv2dNative *node)
+{
+ LOGGER(l);
+
+ if (shape_inference_done(node))
+ return false;
+
+ auto ifm = node->input();
+ loco::NodeShape ifm_shape;
+ if (!node_shape(ifm, ifm_shape))
+ {
+ // input node shape inference is not ready
+ return false;
+ }
+
+ auto ker = node->filter();
+ loco::NodeShape ker_shape;
+ if (!node_shape(ker, ker_shape))
+ {
+ return false;
+ }
+
+ update_stride_data(node);
+
+ auto stride_data = node->annot<StrideData>();
+ assert(stride_data != nullptr);
+
+ INFO(l) << "FixShape TFDepthwiseConv2dNative strides = " << stride_data->stride()->vertical()
+ << ", " << stride_data->stride()->horizontal();
+
+ auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>(); // in NHWC
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
+ assert(ifm_tensor_shape.rank() == 4);
+ assert(ker_tensor_shape.rank() == 4);
+
+ uint32_t input_height = ifm_tensor_shape.dim(1).value();
+ uint32_t input_width = ifm_tensor_shape.dim(2).value();
+ uint32_t stride_height = stride_data->stride()->vertical();
+ uint32_t stride_width = stride_data->stride()->horizontal();
+ uint32_t ker_height = ker_tensor_shape.dim(0).value();
+ uint32_t ker_width = ker_tensor_shape.dim(1).value();
+ uint32_t dilation_height = 1; // TODO Consider dilation
+ uint32_t dilation_width = 1;
+ uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
+ uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
+ uint32_t output_height;
+ uint32_t output_width;
+
+ auto padding = node->padding();
+ assert(padding == "VALID" || padding == "SAME");
+
+ if (padding == "VALID")
+ {
+ output_height = (input_height + stride_height - effective_ker_height) / stride_height;
+ output_width = (input_width + stride_width - effective_ker_width) / stride_width;
+ }
+ else // padding == "SAME"
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+
+ loco::TensorShape ofm_tensor_shape;
+ ofm_tensor_shape.rank(4);
+ ofm_tensor_shape.dim(0) = ifm_tensor_shape.dim(0);
+ ofm_tensor_shape.dim(1) = output_height;
+ ofm_tensor_shape.dim(2) = output_width;
+ ofm_tensor_shape.dim(3) =
+ loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->tensor_shape(ofm_tensor_shape);
+ node->annot(std::move(shape_data));
+
+ FixPadContext ctx = {input_height, input_width, output_height, output_width,
+ stride_height, stride_width, effective_ker_height, effective_ker_width};
+
+ calc_annot_paddata(node, ctx);
+
+ INFO(l) << "Fix TFDepthwiseConv2dNative shape = ifm" << ifm_tensor_shape << " ker"
+ << ker_tensor_shape << " --> ofm" << ofm_tensor_shape;
+ INFO(l) << " pad = " << *node->annot<PadData>();
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFFusedBatchNorm *node)
+{
+ // Output shape is same as the input
+ auto input = node->input();
+ return copy_shapedata(input, node);
+}
+
+bool fix_shape(moco::tf::TFIdentity *node)
+{
+ // Output shape is same as the input
+ auto input = node->input();
+ return copy_shapedata(input, node);
+}
+
+bool fix_shape(moco::tf::TFMaxPool *node)
+{
+ LOGGER(l);
+
+ if (shape_inference_done(node))
+ return false;
+
+ auto value = node->value();
+ loco::NodeShape value_shape;
+ if (!node_shape(value, value_shape))
+ {
+ // input node shape inference is not ready
+ return false;
+ }
+
+ auto padding = node->padding();
+ assert(padding == "VALID" || padding == "SAME");
+
+ update_stride_data(node);
+ update_window_data(node);
+
+ auto stride_data = node->annot<StrideData>();
+ assert(stride_data != nullptr);
+ auto window_data = node->annot<WindowData>();
+ assert(window_data != nullptr);
+
+ auto value_feature_shape = as_feature_shape(value_shape, node->data_layout());
+
+ uint32_t input_height = value_feature_shape.height().value();
+ uint32_t input_width = value_feature_shape.width().value();
+ uint32_t stride_height = stride_data->stride()->vertical();
+ uint32_t stride_width = stride_data->stride()->horizontal();
+ uint32_t window_height = window_data->window()->vertical();
+ uint32_t window_width = window_data->window()->horizontal();
+ uint32_t dilation_height = 1; // dilation for MaxPool is 1
+ uint32_t dilation_width = 1;
+ uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+ uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+ uint32_t output_height;
+ uint32_t output_width;
+
+ if (padding == "VALID")
+ {
+ output_height = (input_height + stride_height - effective_window_height) / stride_height;
+ output_width = (input_width + stride_width - effective_window_width) / stride_width;
+ }
+ else if (padding == "SAME")
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+
+ loco::FeatureShape ofm_feature_shape;
+ ofm_feature_shape.count() = value_feature_shape.count();
+ ofm_feature_shape.height() = output_height;
+ ofm_feature_shape.width() = output_width;
+ ofm_feature_shape.depth() = value_feature_shape.depth();
+
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ as_tensor_shape(*shape_data.get(), ofm_feature_shape, node->data_layout());
+ node->annot(std::move(shape_data));
+
+ FixPadContext ctx = {
+ input_height, input_width, output_height, output_width,
+ stride_height, stride_width, effective_window_height, effective_window_width};
+
+ calc_annot_paddata(node, ctx);
+
+ INFO(l) << "Fix TFMaxPool shape = ifm" << value_feature_shape << " --> ofm" << ofm_feature_shape;
+ INFO(l) << " pad = " << *node->annot<PadData>();
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFMul *node)
+{
+ auto x = node->x();
+ auto y = node->y();
+ loco::NodeShape x_shape;
+ loco::NodeShape y_shape;
+
+ if (!node_shape(x, x_shape))
+ return false;
+ if (!node_shape(y, y_shape))
+ return false;
+
+ // Output shape is same as the input
+ return copy_shapedata(x, y, node);
+}
+
+bool fix_shape(moco::tf::TFMean *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ LOGGER(l);
+
+ auto input = node->input();
+ auto reduction_indices = node->reduction_indices();
+ loco::NodeShape input_shape;
+ loco::NodeShape reduction_indices_shape;
+
+ if (!node_shape(input, input_shape) || !node_shape(reduction_indices, reduction_indices_shape))
+ {
+ // Input and reduction_indices shape are required for TFMean shape inference
+ return false;
+ }
+
+ // Get constant values if reduction_indeces is const
+ std::vector<int32_t> reduction_values;
+ if (auto tfconst = dynamic_cast<moco::tf::TFConst *>(reduction_indices))
+ {
+ assert(tfconst->dtype() == loco::DataType::S32);
+ auto const_size = tfconst->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < const_size; ++i)
+ {
+ int32_t axis = tfconst->at<loco::DataType::S32>(i);
+ if (axis < 0)
+ axis += input_shape.as<loco::TensorShape>().rank();
+ reduction_values.push_back(axis);
+ }
+ }
+ else
+ {
+ // we cannot find a valid reduction indices value
+ INFO(l) << "Fix shape TFMean fail : reduction indeces are not constant or not valid";
+ return false;
+ }
+
+ loco::TensorShape shape_data;
+ loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ if (node->keep_dims())
+ {
+ shape_data.rank(input_tensor_shape.rank());
+ for (uint32_t i = 0; i < input_tensor_shape.rank(); ++i)
+ shape_data.dim(i) = input_tensor_shape.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ shape_data.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_tensor_shape.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ shape_data.rank(input_tensor_shape.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ shape_data.dim(j++) = i;
+ }
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(shape_data);
+ node->annot(std::move(shape_annot));
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFRealDiv *node)
+{
+ auto x = node->x();
+ auto y = node->y();
+ loco::NodeShape x_shape;
+ loco::NodeShape y_shape;
+
+ if (!node_shape(x, x_shape))
+ return false;
+ if (!node_shape(y, y_shape))
+ return false;
+
+ // Output shape is same as the input
+ return copy_shapedata(x, y, node);
+}
+
+bool fix_shape(moco::tf::TFRelu *node)
+{
+ // Output shape is same as the features
+ auto features = node->features();
+ return copy_shapedata(features, node);
+}
+
+bool fix_shape(moco::tf::TFRelu6 *node)
+{
+ // Output shape is same as the features
+ auto features = node->features();
+ return copy_shapedata(features, node);
+}
+
+bool fix_shape(moco::tf::TFReshape *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ // For now, we only consider Fixed Reshape, i.e. Reshape with determined
+ // 'shape' input. So here we only support case when 'shape' input of
+ // TFReshape is TFConst. If 'shape' input is not TFConst, another
+ // transform (e.g. constant folding) should be done beforehand to make
+ // it TFConst.
+ // TODO Support dynamic Reshape
+ // Note that 'shape()' here is 'shape' input, not node's shape information
+ auto const_shape_input = dynamic_cast<moco::tf::TFConst *>(node->shape());
+ if (!const_shape_input)
+ {
+ // 'shape' input of TFReshape is not TFConst, try next time when it becomes TFConst
+ return false;
+ }
+
+ // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1]
+ assert(const_shape_input->dtype() == loco::DataType::S32);
+ assert(const_shape_input->rank() == 1);
+
+ auto shape_rank = const_shape_input->dim(0).value();
+ assert(shape_rank > 0);
+
+ loco::TensorShape shape_data;
+ shape_data.rank(shape_rank);
+ for (uint32_t axis = 0; axis < shape_rank; ++axis)
+ {
+ auto shape_dim = const_shape_input->at<loco::DataType::S32>(axis);
+ if (shape_dim == -1)
+ {
+ // Reshape's new shape has wildcard dimension, i.e. dynamic reshape
+ return false;
+ }
+ assert(shape_dim >= 1);
+ shape_data.dim(axis) = shape_dim;
+ }
+
+ // TODO Compare 'tensor' input and validate coherency?
+ // Not sure this is appropriate stage for this task.
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(shape_data);
+ node->annot(std::move(shape_annot));
+
+ {
+ LOGGER(l);
+ auto shapedata = node->annot<ShapeInferenceData>();
+ assert(shapedata != nullptr);
+ INFO(l) << "Fix TFReshape shape = " << shapedata->tensor_shape();
+ }
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFRsqrt *node)
+{
+ // Output shape is same as the input x
+ auto x = node->x();
+ return copy_shapedata(x, node);
+}
+
+bool fix_shape(moco::tf::TFShape *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ auto input = node->input();
+ loco::NodeShape input_shape;
+ if (!node_shape(input, input_shape))
+ {
+ // Input shape is required for TFShape shape inference
+ return false;
+ }
+ loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ loco::TensorShape node_shape;
+
+ // Note that input shape becomes node(TFShape)'s value
+ node_shape.rank(1);
+ node_shape.dim(0) = input_tensor_shape.rank();
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(node_shape);
+ node->annot(std::move(shape_annot));
+
+ LOGGER(l);
+ INFO(l) << "Fix TFShape shape = " << node_shape;
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFSqrt *node)
+{
+ // Output shape is same as the input x
+ auto x = node->x();
+ return copy_shapedata(x, node);
+}
+
+bool fix_shape(moco::tf::TFSoftmax *node)
+{
+ // Output shape is same as the input x
+ auto logits = node->logits();
+ return copy_shapedata(logits, node);
+}
+
+bool fix_shape(moco::tf::TFSquaredDifference *node)
+{
+ auto x = node->x();
+ auto y = node->y();
+ return copy_shapedata(x, y, node);
+}
+
+bool fix_shape(moco::tf::TFSqueeze *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ auto input = node->input();
+ loco::NodeShape input_shape;
+ if (!node_shape(input, input_shape))
+ {
+ // Input shape is required for TFSqueeze shape inference
+ return false;
+ }
+
+ // TODO Not sure Squeeze only get input as Tensor
+ // Note that tensor_shape() has assertion in it
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ auto squeeze_dims_vec = node->squeeze_dims();
+ std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+ loco::TensorShape node_shape;
+ uint32_t node_rank = 0;
+
+ if (squeeze_dims.empty())
+ {
+ // Remove all dimensions whose value is 1
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (dim != 1)
+ {
+ assert(dim > 1);
+ node_shape.rank(++node_rank);
+ node_shape.dim(node_rank - 1) = dim;
+ }
+ }
+ }
+ else
+ {
+ uint32_t input_rank = input_tensor_shape.rank();
+
+ // Sanity check for 'squeeze_dims'
+ auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+ if (!(squeeze_dims.size() < input_rank))
+ return false;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (!(squeeze_dim >= -(int64_t)input_rank))
+ return false;
+ if (!(squeeze_dim < (int64_t)input_rank))
+ return false;
+ }
+ return true;
+ };
+
+ if (!is_valid_squeeze_dims())
+ {
+ throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
+ }
+
+ // Resolve negative squeeze dimension
+ std::set<int64_t> resolved_squeeze_dims;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (squeeze_dim < 0)
+ resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
+ else
+ resolved_squeeze_dims.insert(squeeze_dim);
+ }
+
+ // Remove squeeze dimensions only
+ for (uint32_t axis = 0; axis < input_rank; ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
+ {
+ // Not squeeze dim
+ node_shape.rank(++node_rank);
+ node_shape.dim(node_rank - 1) = dim;
+ }
+ else
+ {
+ // Is squeeze dim
+ assert(dim == 1);
+ // DO NOTHING
+ }
+ }
+ }
+
+ assert(node_shape.rank() > 0);
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(node_shape);
+ node->annot(std::move(shape_annot));
+
+ LOGGER(l);
+ INFO(l) << "Fix TFSqueeze shape = " << node_shape;
+
+ return true;
+}
+
+bool fix_shape(moco::tf::TFStopGradient *node)
+{
+ // Output shape is same as the input
+ auto input = node->input();
+ return copy_shapedata(input, node);
+}
+
+bool fix_shape(moco::tf::TFSub *node)
+{
+ auto x = node->x();
+ auto y = node->y();
+ loco::NodeShape x_shape;
+ loco::NodeShape y_shape;
+
+ if (!node_shape(x, x_shape))
+ return false;
+ if (!node_shape(y, y_shape))
+ return false;
+
+ // Output shape is same as the input
+ return copy_shapedata(x, y, node);
+}
+
+bool fix_shape(moco::tf::TFTanh *node)
+{
+ // Output shape is same as the input
+ auto x = node->x();
+ return copy_shapedata(x, node);
+}
+
+bool fix_shape(locoex::COpCall *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ auto shape_data = make_shape_inference_data(node);
+ node->annot(std::move(shape_data));
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool FixShapeTransform::run(loco::Graph *graph)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph)))
+ {
+// clang-format off
+// TODO remove this block after Pull, Push is not used in import
+#define CANONICAL_NODE(TYPE_NAME) \
+ if (as<loco::TYPE_NAME>(node)) \
+ { \
+ if (fix_shape(as<loco::TYPE_NAME>(node))) \
+ changed = true; \
+ } \
+ else
+CANONICAL_NODE(Pull)
+CANONICAL_NODE(Push)
+#undef CANONICAL_NODE
+
+#define TENSORFLOW_NODE(OPCODE,CLASS) \
+ if (as<moco::tf::CLASS>(node)) \
+ { \
+ if (fix_shape(as<moco::tf::CLASS>(node))) \
+ changed = true; \
+ } \
+ else
+#include "Dialect/TFNodes.lst"
+#undef TENSORFLOW_NODE
+ // clang-format on
+
+ if (as<locoex::COpCall>(node))
+ {
+ if (fix_shape(as<locoex::COpCall>(node)))
+ changed = true;
+ }
+ else
+ {
+ // Skip nodes that are not interested
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h b/compiler/moco-tf/src/Transforms/FixShapeTransform.h
index 469d7e3cd..d790d0ec7 100644
--- a/compiler/moco-tf/src/Canonicalization/MeanCanonicalizer.h
+++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.h
@@ -14,13 +14,10 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_MEAN_CANONICALIZER_H__
-#define __MOCO_TF_MEAN_CANONICALIZER_H__
+#ifndef __MOCO_TF_FIX_SHAPE_TRANSFORM_H__
+#define __MOCO_TF_FIX_SHAPE_TRANSFORM_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-
-#include <moco/IR/TFNodes.h>
#include <loco.h>
@@ -30,18 +27,18 @@ namespace tf
{
/**
- * @brief Canonicalize TF-dialect TFMean into canonical TensorReduce(Mean) node
- */
-class MeanCanonicalizer : public SimpleNodeTransform<moco::TFMean>
+ * @brief Fix unknown shape to concrete shape for all nodes in the graph
+*/
+class FixShapeTransform : public Transform
{
public:
- const char *name(void) const final { return "MeanCanonicalizer"; }
+ const char *name(void) const final { return "FixShapeTransform"; }
public:
- bool transform(moco::TFMean *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_MEAN_CANONICALIZER_H__
+#endif // __MOCO_TF_FIX_SHAPE_TRANSFORM_H__
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp
new file mode 100644
index 000000000..bb346865f
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.test.cpp
@@ -0,0 +1,227 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FixShapeTransform.h"
+
+#include "TestHelper.h"
+
+#include "Annotations/PaddingData.h"
+#include "Annotations/ShapeInferenceData.h"
+
+#include "Dialect/TFNodes.h"
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+TEST(FixShapeTransform, ctor)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ ASSERT_FALSE(fstransform.run(&graph));
+}
+
+namespace
+{
+
+moco::tf::TFAvgPool *avgpool_network_simple1331(loco::Graph *graph)
+{
+ auto avgpool_node = graph->nodes()->create<moco::tf::TFAvgPool>();
+
+ avgpool_node->data_layout("NHWC");
+ avgpool_node->ksize({1, 3, 3, 1});
+ avgpool_node->strides({1, 1, 1, 1});
+
+ // Dummy const node as ifm, just to fake FixShapeTransform for TFAvgPool.
+ // FixShapeTransform only cares about ShapeInferenceData of ifm()
+ auto const_node = graph->nodes()->create<moco::tf::TFConst>();
+ {
+ auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
+ loco::TensorShape tshape;
+ tshape.rank(4);
+ tshape.dim(0).set(1);
+ tshape.dim(1).set(3);
+ tshape.dim(2).set(3);
+ tshape.dim(3).set(1);
+ shapedata->tensor_shape(tshape);
+ const_node->annot(std::move(shapedata));
+ }
+ avgpool_node->value(const_node);
+
+ setup_output_node(graph, avgpool_node);
+
+ return avgpool_node;
+}
+
+} // namespace
+
+TEST(FixShapeTransform, avgpool_same)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ auto avgpool_node = avgpool_network_simple1331(&graph);
+ avgpool_node->padding("SAME");
+
+ moco::tf::FixShapeTransform transform;
+ transform.run(&graph);
+
+ auto shapedata = avgpool_node->annot<moco::tf::ShapeInferenceData>();
+ ASSERT_NE(shapedata, nullptr);
+ auto tshape = shapedata->tensor_shape();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), 1);
+ ASSERT_EQ(tshape.dim(1).value(), 3);
+ ASSERT_EQ(tshape.dim(2).value(), 3);
+ ASSERT_EQ(tshape.dim(3).value(), 1);
+}
+
+TEST(FixShapeTransform, avgpool_valid)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ auto avgpool_node = avgpool_network_simple1331(&graph);
+ avgpool_node->padding("VALID");
+
+ moco::tf::FixShapeTransform transform;
+ transform.run(&graph);
+
+ auto shapedata = avgpool_node->annot<moco::tf::ShapeInferenceData>();
+ ASSERT_NE(shapedata, nullptr);
+ auto tshape = shapedata->tensor_shape();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), 1);
+ ASSERT_EQ(tshape.dim(1).value(), 1);
+ ASSERT_EQ(tshape.dim(2).value(), 1);
+ ASSERT_EQ(tshape.dim(3).value(), 1);
+}
+
+namespace
+{
+
+void conv2d_test(const std::array<uint32_t, 4> ifm_shape, const std::array<uint32_t, 4> ker_shape,
+ const std::array<uint32_t, 2> stride_h_w, std::string padding,
+ const std::array<uint32_t, 4> expected_shape)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ auto conv2d_node = graph.nodes()->create<moco::tf::TFConv2D>();
+ conv2d_node->data_layout("NHWC");
+ conv2d_node->strides({1, stride_h_w[0], stride_h_w[1], 1});
+ conv2d_node->padding(padding);
+
+ auto ifm_node = graph.nodes()->create<moco::tf::TFConst>();
+ {
+ auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
+ loco::TensorShape tshape;
+ tshape.rank(4);
+ tshape.dim(0).set(ifm_shape[0]);
+ tshape.dim(1).set(ifm_shape[1]);
+ tshape.dim(2).set(ifm_shape[2]);
+ tshape.dim(3).set(ifm_shape[3]);
+ shapedata->tensor_shape(tshape);
+ ifm_node->annot(std::move(shapedata));
+ }
+
+ auto ker_node = graph.nodes()->create<loco::ConstGen>();
+ {
+ auto shapedata = stdex::make_unique<moco::tf::ShapeInferenceData>();
+ loco::TensorShape tshape;
+ tshape.rank(4);
+ tshape.dim(0).set(ker_shape[0]);
+ tshape.dim(1).set(ker_shape[1]);
+ tshape.dim(2).set(ker_shape[2]);
+ tshape.dim(3).set(ker_shape[3]);
+ shapedata->tensor_shape(tshape);
+ ker_node->annot(std::move(shapedata));
+ }
+
+ conv2d_node->input(ifm_node);
+ conv2d_node->filter(ker_node);
+
+ setup_output_node(&graph, conv2d_node);
+
+ moco::tf::FixShapeTransform transform;
+ transform.run(&graph);
+
+ auto shapedata = conv2d_node->annot<moco::tf::ShapeInferenceData>();
+ ASSERT_NE(shapedata, nullptr);
+ auto tshape = shapedata->tensor_shape();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), expected_shape[0]);
+ ASSERT_EQ(tshape.dim(1).value(), expected_shape[1]);
+ ASSERT_EQ(tshape.dim(2).value(), expected_shape[2]);
+ ASSERT_EQ(tshape.dim(3).value(), expected_shape[3]);
+}
+
+} // namespace
+
+/*
+ Testing "InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D" Conv2D node in Inception_v3:
+ The result shape of this test is generated with the code below:
+
+ ifm = tf.constant(value=1.1, shape=[1, 299, 299, 3])
+ ker = tf.constant(value=1.1, shape=[3, 3, 3, 32])
+
+ out = tf.nn.conv2d(ifm, ker, strides = [1, 2, 2, 1], padding= 'VALID')
+
+ with tf.Session() as sess:
+ res = sess.run(out)
+ print(res.shape)
+ */
+TEST(FixShapeTransform, conv2d_VALID)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ conv2d_test({1, 299, 299, 3}, // ifm
+ {3, 3, 3, 32}, // ker
+ {2, 2}, // strides
+ "VALID", // padding
+ {1, 149, 149, 32}); // expected shape after FixShape
+}
+
+/*
+ Testing "InceptionV3/InceptionV3/Conv2d_2b_3x3/Conv2D" Conv2D node in Inception_v3:
+ The result shape of this test is generated with the code below:
+
+ ifm = tf.constant(value=1.1, shape=[1, 147, 147, 32])
+ ker = tf.constant(value=1.1, shape=[3, 3, 32, 64])
+
+ out = tf.nn.conv2d(ifm, ker, strides = [1, 1, 1, 1], padding= 'SAME')
+
+ with tf.Session() as sess:
+ res = sess.run(out)
+ print(res.shape)
+ */
+TEST(FixShapeTransform, conv2d_SAME)
+{
+ moco::tf::FixShapeTransform fstransform;
+ loco::Graph graph;
+
+ conv2d_test({1, 147, 147, 32}, // ifm
+ {3, 3, 32, 64}, // ker
+ {1, 1}, // strides
+ "SAME", // padding
+ {1, 147, 147, 64}); // expected shape after FixShape
+}
diff --git a/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp
new file mode 100644
index 000000000..2edcae72e
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.cpp
@@ -0,0 +1,547 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "FuseBinaryIntoPreceding.h"
+
+#include "Annotations/ShapeInferenceData.h"
+#include "Dialect/TFDialect.h"
+
+#include "IR/TFAdd.h"
+#include "IR/TFBiasAdd.h"
+#include "IR/TFConst.h"
+#include "IR/TFConv2D.h"
+#include "IR/TFDepthwiseConv2dNative.h"
+#include "IR/TFMul.h"
+
+#include <loco.h>
+#include <moco/Log.h>
+
+#include <cassert>
+#include <memory>
+
+namespace
+{
+
+/**
+ * @brief Fusable operation type
+ */
+enum class FuseType
+{
+ Conv2D,
+ DepthwiseConv2D,
+ // TODO Support FullyConnected
+};
+
+// TODO rename this method when there is a better name
+bool is_only_one_valid(moco::tf::TFConst *xc, moco::tf::TFConst *yc)
+{
+ if (xc == nullptr && yc == nullptr)
+ return false;
+ if (xc != nullptr && yc != nullptr)
+ return false;
+
+ return true;
+}
+
+// TODO Put this in some common place
+void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst)
+{
+ assert(src != nullptr);
+ assert(dst != nullptr);
+
+ uint32_t rank = src->rank();
+ dst->rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (src->dim(index).known())
+ dst->dim(index) = src->dim(index);
+ else
+ dst->dim(index).unset();
+ }
+}
+
+/**
+ * @brief return true if shape is identical
+ */
+bool shape_match(const moco::tf::TFConst *c1, const moco::tf::TFConst *c2)
+{
+ assert(c1 != nullptr);
+ assert(c2 != nullptr);
+
+ uint32_t rank = c1->rank();
+ if (rank != c2->rank())
+ return false;
+
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (!c1->dim(index).known() || !c2->dim(index).known())
+ return false;
+
+ if (c1->dim(index).value() != c2->dim(index).value())
+ return false;
+ }
+ return true;
+}
+
+template <FuseType FT>
+moco::tf::TFConst *create_kernel_from_fuse_mulparam(loco::Graph *graph, moco::tf::TFConst *ker,
+ moco::tf::TFConst *mulparam);
+
+template <>
+moco::tf::TFConst *create_kernel_from_fuse_mulparam<FuseType::Conv2D>(loco::Graph *graph,
+ moco::tf::TFConst *ker,
+ moco::tf::TFConst *mulparam)
+{
+ auto ker_shape_inf = ker->annot<moco::tf::ShapeInferenceData>();
+ assert(ker_shape_inf);
+ auto ker_shape = ker_shape_inf->tensor_shape();
+
+ auto mulparam_shape_inf = mulparam->annot<moco::tf::ShapeInferenceData>();
+ assert(mulparam_shape_inf != nullptr);
+ auto mulparam_shape = mulparam_shape_inf->tensor_shape();
+
+ // create new ker_fused with same size of ker
+ auto ker_fused = graph->nodes()->create<moco::tf::TFConst>();
+
+ assert(ker_shape.rank() == 4);
+ assert(mulparam_shape.rank() == 1);
+ assert(ker_shape.dim(3).value() == mulparam_shape.dim(0).value());
+
+ ker_fused->dtype(loco::DataType::FLOAT32);
+ copy_shape(ker, ker_fused);
+ auto ker_num_elements = ker->size<loco::DataType::FLOAT32>();
+ ker_fused->size<loco::DataType::FLOAT32>(ker_num_elements);
+
+ // TensorFlow Conv2D Kernel has HWIO format
+ // Broadcast Mul vector to Kernel tensor by the Output
+ const uint32_t ker_height = ker_shape.dim(0).value();
+ const uint32_t ker_width = ker_shape.dim(1).value();
+ const uint32_t ker_input = ker_shape.dim(2).value();
+ const uint32_t ker_output = ker_shape.dim(3).value();
+
+ for (uint32_t ker_y = 0; ker_y < ker_height; ++ker_y)
+ {
+ for (uint32_t ker_x = 0; ker_x < ker_width; ++ker_x)
+ {
+ for (uint32_t in_ch = 0; in_ch < ker_input; ++in_ch)
+ {
+ uint32_t num_items = ((ker_y * ker_width + ker_x) * ker_input + in_ch) * ker_output;
+ for (uint32_t out_ch = 0; out_ch < ker_output; ++out_ch)
+ {
+ auto mulparam_v = mulparam->at<loco::DataType::FLOAT32>(out_ch);
+ auto ker_v = ker->at<loco::DataType::FLOAT32>(num_items + out_ch);
+ ker_fused->at<loco::DataType::FLOAT32>(num_items + out_ch) = ker_v * mulparam_v;
+ }
+ }
+ }
+ }
+
+ return ker_fused;
+}
+
+/**
+ * @brief Create a kernel from fuse mulparam<FuseType::DepthwiseConv2D> object
+ * @return Kernel of fused mulparam
+ */
+template <>
+moco::tf::TFConst *create_kernel_from_fuse_mulparam<FuseType::DepthwiseConv2D>(
+ loco::Graph *graph, moco::tf::TFConst *ker, moco::tf::TFConst *mulparam)
+{
+ auto ker_shape_inf = ker->annot<moco::tf::ShapeInferenceData>();
+ assert(ker_shape_inf);
+ auto ker_shape = ker_shape_inf->tensor_shape();
+
+ auto mulparam_shape_inf = mulparam->annot<moco::tf::ShapeInferenceData>();
+ assert(mulparam_shape_inf != nullptr);
+ auto mulparam_shape = mulparam_shape_inf->tensor_shape();
+
+ // create new ker_fused with same size of ker
+ auto ker_fused = graph->nodes()->create<moco::tf::TFConst>();
+
+ assert(ker_shape.rank() == 4);
+ assert(mulparam_shape.rank() == 1);
+ assert(ker_shape.dim(2).value() * ker_shape.dim(3).value() == mulparam_shape.dim(0).value());
+
+ ker_fused->dtype(loco::DataType::FLOAT32);
+ copy_shape(ker, ker_fused);
+ auto ker_num_elements = ker->size<loco::DataType::FLOAT32>();
+ ker_fused->size<loco::DataType::FLOAT32>(ker_num_elements);
+
+ // TensorFlow DepthwiseConv2DNative Kernel has HWIM format
+ // Broadcast Mul vector to Kernel tensor by the Output
+ const uint32_t ker_height = ker_shape.dim(0).value();
+ const uint32_t ker_width = ker_shape.dim(1).value();
+ const uint32_t ker_input = ker_shape.dim(2).value();
+ const uint32_t ker_multiplier = ker_shape.dim(3).value();
+
+ for (uint32_t ker_y = 0; ker_y < ker_height; ++ker_y)
+ {
+ for (uint32_t ker_x = 0; ker_x < ker_width; ++ker_x)
+ {
+ for (uint32_t in_ch = 0; in_ch < ker_input; ++in_ch)
+ {
+ uint32_t num_items = ((ker_y * ker_width + ker_x) * ker_input + in_ch) * ker_multiplier;
+ for (uint32_t ker_ch = 0; ker_ch < ker_multiplier; ++ker_ch)
+ {
+ auto mulparam_v = mulparam->at<loco::DataType::FLOAT32>(in_ch + ker_ch * ker_input);
+ auto ker_v = ker->at<loco::DataType::FLOAT32>(num_items + ker_ch);
+ ker_fused->at<loco::DataType::FLOAT32>(num_items + ker_ch) = ker_v * mulparam_v;
+ }
+ }
+ }
+ }
+
+ return ker_fused;
+}
+
+/**
+ * @brief Create a fused convolution opertion from kernel of fused mulparam
+ * @return Fused convolution operation
+ */
+template <FuseType FT, class T>
+T *fused_conv_node(loco::Graph *graph, moco::tf::TFConst *mulparam, T *conv_node)
+{
+ LOGGER(l);
+
+ // ker should be constant
+ auto ker = dynamic_cast<moco::tf::TFConst *>(conv_node->filter());
+ if (ker == nullptr)
+ {
+ // Wait until ker is becomes TFConst: there are cases when it's Identity.
+ INFO(l) << "Mul fuse_to_preceding: precedingOp ker is not TFConst";
+ return nullptr;
+ }
+ auto ifm = conv_node->input();
+ assert(ifm != nullptr);
+
+ // we need shape information, if not wait till it's ready
+ if (ker->annot<moco::tf::ShapeInferenceData>() == nullptr)
+ {
+ INFO(l) << "Mul fuse_to_preceding: precedingOp ker has no shape";
+ return nullptr;
+ }
+
+ auto mulparam_shape_inf = mulparam->annot<moco::tf::ShapeInferenceData>();
+ if (mulparam_shape_inf == nullptr)
+ {
+ INFO(l) << "Mul fuse_to_preceding: precedingOp mulparam has no shape";
+ return nullptr;
+ }
+ // if MulParam rank is not 1 we cannot fuse, just skip
+ auto mulparam_shape = mulparam_shape_inf->tensor_shape();
+ if (mulparam_shape.rank() != 1)
+ {
+ INFO(l) << "Mul fuse_to_preceding: Mul rank is not 1";
+ return nullptr;
+ }
+
+ auto ker_fused = create_kernel_from_fuse_mulparam<FT>(graph, ker, mulparam);
+ auto conv_fused = graph->nodes()->create<T>();
+
+ conv_fused->input(ifm);
+ conv_fused->filter(ker_fused);
+ conv_fused->padding(conv_node->padding());
+ conv_fused->data_layout(conv_node->data_layout());
+ conv_fused->strides(conv_node->strides());
+
+ return conv_fused;
+}
+
+/**
+ * @note This creates fused ker:2 from ker:1, 'mulparam' and
+ * new precedingOp:2 that uses ker:2 as the kernel.
+ * Then make C to use precedingOp:2 as new input.
+ *
+ * <Before>
+ * mulparam-\
+ * ker:1 --\ \
+ * ifm ----- precedingOp:1 ----------- Mul --- C
+ *
+ *
+ * <After>
+ * mulparam-\
+ * ker:1 --\ \
+ * - precedingOp:1 ----------- Mul ---
+ * /
+ * ifm ----- precedingOp:2 ------------------- C
+ * ker:2 ---/
+ *
+ *
+ * [Where]
+ * - precedingOp:1 can be one of TFConv2D, TFDepthwiseConv2dNative, FullyConnected
+ * - 'mulparam' and Mul will be disconnected from the Output.
+ * - ker:2 is added with fused values of ker:1 and mulparam
+ * - precedingOp:2 is added using ifm and ker:2 and other parameters
+ * same as precedingOp:1.
+ * - ker:1, precedingOp:1, 'mulparam' and Mul should be removed in
+ * RemoveDeadNodeTransform if not used.
+ */
+bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFMul *node)
+{
+ LOGGER(l);
+
+ auto xc = dynamic_cast<moco::tf::TFConst *>(node->x());
+ auto yc = dynamic_cast<moco::tf::TFConst *>(node->y());
+
+ // Note: if both are constants, it should be done by constant-folding
+ if (!(is_only_one_valid(xc, yc)))
+ return false;
+
+ moco::tf::TFConst *mulparam = nullptr;
+ moco::tf::TFNode *precedingOp = nullptr;
+
+ if (xc != nullptr)
+ {
+ mulparam = xc;
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
+ }
+ else // yc != nullptr
+ {
+ mulparam = yc;
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
+ }
+
+ assert(mulparam->dtype() == loco::DataType::FLOAT32);
+
+ // TODO support FullyConnected
+ moco::tf::TFNode *fused_node = nullptr;
+ if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
+ fused_node = fused_conv_node<FuseType::Conv2D, moco::tf::TFConv2D>(graph, mulparam, conv2d);
+ else if (auto dw_conv2d = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(precedingOp))
+ fused_node = fused_conv_node<FuseType::DepthwiseConv2D, moco::tf::TFDepthwiseConv2dNative>(
+ graph, mulparam, dw_conv2d);
+
+ // Not ready yet
+ if (fused_node == nullptr)
+ return false;
+
+ // Replace TFMul node with new precedingOp with fused kernel
+ // This will leave existing precedingOp as-is but can be removed if not used
+ // from other transformations
+ replace(node).with(fused_node);
+ // TODO check if need to disconnect
+ // node->x(nullptr);
+ // node->y(nullptr);
+ // fused_node->ifm(nullptr);
+ // fused_node->ker(nullptr);
+
+ return true;
+}
+
+/**
+ * @brief Create zero-filled BiasAdd opertion and insert after precedingOp
+ * The plan is to fuse 'addparam' to TFBiasAdd bias
+ * @return Zero-filled BiasAdd operation
+ */
+template <class T>
+moco::tf::TFBiasAdd *create_biasadd_node(loco::Graph *graph, moco::tf::TFConst *addparam,
+ T *precedingOp)
+{
+ auto dtype = addparam->dtype();
+ assert(dtype == loco::DataType::FLOAT32);
+
+ // Create TFConst(bias of TFBiasAdd) with same shape and dtype of 'addparam' but
+ // with values 0.0
+ auto biasadd_param = graph->nodes()->create<moco::tf::TFConst>();
+ biasadd_param->dtype(dtype);
+ copy_shape(addparam, biasadd_param);
+ auto biasadd_num_elements = addparam->size<loco::DataType::FLOAT32>();
+ biasadd_param->size<loco::DataType::FLOAT32>(biasadd_num_elements);
+ for (int32_t i = 0; i < biasadd_num_elements; i++)
+ {
+ biasadd_param->at<loco::DataType::FLOAT32>(i) = 0.0f;
+ }
+
+ // Create TFBiasAdd with same shape as TFAdd
+ auto data_layout = precedingOp->data_layout();
+ auto tf_biasadd = graph->nodes()->create<moco::tf::TFBiasAdd>();
+ tf_biasadd->data_layout(data_layout);
+
+ loco::replace(precedingOp).with(tf_biasadd);
+ tf_biasadd->value(precedingOp);
+ tf_biasadd->bias(biasadd_param);
+
+ return tf_biasadd;
+}
+
+/**
+ * @note TFAdd will be fused to TFBiasAdd
+ *
+ * <Before>
+ * If precedingOp is not TFBiasAdd, then insert TFConst:1 + TFBiasAdd that
+ * TFConst:1 has zero values.
+ *
+ * addparam --\
+ * \
+ * precedingOp ---------------------------- TFAdd ----- C
+ *
+ *
+ * <Intermediate>
+ * If it's TFBiasAdd and one of the input is TFConst type,
+ * then we can fuse 'addparam' to the input TFConst:2 value of TFBiasAdd, where
+ * TFConst:2 has added values from 'addparam'
+ *
+ * addparam --\
+ * TFConst:1 --------\ \
+ * precedingOp ------- TFBiasAdd ---------- TFAdd ----- C
+ *
+ *
+ * <After>
+ * addparam --\
+ * TFConst:2 --------\ \
+ * precedingOp ------- TFBiasAdd ---------- TFAdd -----
+ * \--------------------- C
+ *
+ *
+ * [Where]
+ * - precedingOp can be TFConv2D, TFDepthwiseConv2dNative, FullyConnected,
+ * TFBiasAdd.
+ * - Intermediate is to insert TFBiasAdd + TFConst:1
+ * - After is to fuse 'addparam' of TFAdd into TFConst:1 + TFBiasAdd
+ * that becomes TFConst:2 + TFBiasAdd
+ */
+bool fuse_to_preceding(loco::Graph *graph, moco::tf::TFAdd *node)
+{
+ LOGGER(l);
+
+ auto xc = dynamic_cast<moco::tf::TFConst *>(node->x());
+ auto yc = dynamic_cast<moco::tf::TFConst *>(node->y());
+
+ // Note: if both are constants, it should be done by constant-folding
+ if (!(is_only_one_valid(xc, yc)))
+ return false;
+
+ moco::tf::TFConst *addparam = nullptr;
+ moco::tf::TFNode *precedingOp = nullptr;
+
+ if (xc != nullptr)
+ {
+ addparam = xc;
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->y());
+ }
+ else // yc != nullptr
+ {
+ addparam = yc;
+ precedingOp = dynamic_cast<moco::tf::TFNode *>(node->x());
+ }
+
+ auto addparam_shape_inf = addparam->annot<moco::tf::ShapeInferenceData>();
+ if (addparam_shape_inf == nullptr)
+ {
+ INFO(l) << "Add fuse_to_preceding: addparam has no shape";
+ return false;
+ }
+ // if AddParam rank is not 0 or 1 we cannot fuse, just skip
+ auto addparam_shape = addparam_shape_inf->tensor_shape();
+ if (addparam_shape.rank() > 1)
+ {
+ INFO(l) << "Add fuse_to_preceding: Add rank is not 0 or 1";
+ return false;
+ }
+
+ // TODO do something when rank() is 0
+ if (addparam_shape.rank() == 0)
+ {
+ // Not supported yet
+ return false;
+ }
+ assert(addparam_shape.rank() != 0);
+
+ // TODO support FullyConnected
+ moco::tf::TFBiasAdd *biasadd = nullptr;
+ if (auto conv2d = dynamic_cast<moco::tf::TFConv2D *>(precedingOp))
+ biasadd = create_biasadd_node<moco::tf::TFConv2D>(graph, addparam, conv2d);
+ else if (auto dw_conv2d = dynamic_cast<moco::tf::TFDepthwiseConv2dNative *>(precedingOp))
+ biasadd = create_biasadd_node<moco::tf::TFDepthwiseConv2dNative>(graph, addparam, dw_conv2d);
+ else if (auto old_bias_add = dynamic_cast<moco::tf::TFBiasAdd *>(precedingOp))
+ biasadd = old_bias_add;
+
+ if (biasadd == nullptr)
+ {
+ // try next turn
+ return false;
+ }
+
+ // Let's fuse addparam into biasadd bias
+ auto biasadd_bias = dynamic_cast<moco::tf::TFConst *>(biasadd->bias());
+ assert(biasadd_bias != nullptr);
+ if (!shape_match(biasadd_bias, addparam))
+ {
+ INFO(l) << "TFBiasAdd bias and TFAdd input shape mismatch";
+ return false;
+ }
+ auto add_num_elements = addparam->size<loco::DataType::FLOAT32>();
+ assert(add_num_elements == biasadd_bias->size<loco::DataType::FLOAT32>());
+ for (int32_t i = 0; i < add_num_elements; i++)
+ {
+ biasadd_bias->at<loco::DataType::FLOAT32>(i) += addparam->at<loco::DataType::FLOAT32>(i);
+ }
+
+ replace(node).with(biasadd);
+ // TODO check if need to disconnect
+ // node->x(nullptr);
+ // node->y(nullptr);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool FuseBinaryIntoPreceding::run(loco::Graph *graph)
+{
+ bool changed = false;
+ auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
+
+ for (auto node : active_nodes)
+ {
+ if (node->dialect() == TFDialect::get())
+ {
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFMul *>(node);
+ if (tf_node != nullptr)
+ {
+ if (fuse_to_preceding(graph, tf_node))
+ changed = true;
+ }
+ }
+ {
+ // TODO support Div
+ }
+
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFAdd *>(node);
+ if (tf_node != nullptr)
+ {
+ if (fuse_to_preceding(graph, tf_node))
+ changed = true;
+ }
+ }
+ {
+ // TODO support Sub
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.h b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.h
new file mode 100644
index 000000000..33d4af14a
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/FuseBinaryIntoPreceding.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_FUSE_BINARY_INTO_PRECEDING_H__
+#define __MOCO_TF_FUSE_BINARY_INTO_PRECEDING_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Fuse TFAdd, TFMul to preceding TFConv2D or TFDepthWiseConv2D
+*/
+class FuseBinaryIntoPreceding : public Transform
+{
+public:
+ const char *name(void) const final { return "FuseBinaryIntoPreceding"; }
+
+public:
+ bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_FUSE_BINARY_INTO_PRECEDING_H__
diff --git a/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.cpp b/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.cpp
new file mode 100644
index 000000000..50f7ab92f
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "RemoveTFIdentityNodeTransform.h"
+
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNode.h"
+
+#include <set>
+
+namespace moco
+{
+namespace tf
+{
+
+bool RemoveTFIdentityNodeTransform::run(loco::Graph *g)
+{
+ struct Collector final : public moco::tf::TFNodeMutableVisitor<void>
+ {
+ void visit(moco::tf::TFIdentity *node) final
+ {
+ if (node->input() != nullptr)
+ {
+ candidates.insert(node);
+ }
+ }
+
+ void visit(moco::tf::TFNode *) final { return; }
+
+ std::set<moco::tf::TFIdentity *> candidates;
+ };
+
+ Collector collector;
+
+ for (auto node : loco::all_nodes(g))
+ {
+ if (node->dialect() == moco::tf::TFDialect::get())
+ {
+ auto tf_node = dynamic_cast<moco::tf::TFNode *>(node);
+ tf_node->accept(&collector);
+ }
+ }
+
+ for (auto node : collector.candidates)
+ {
+ replace(node).with(node->input());
+ node->input(nullptr);
+ }
+
+ return collector.candidates.size() > 0;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.h b/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.h
new file mode 100644
index 000000000..534061ee3
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/RemoveTFIdentityNodeTransform.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_REMOVE_TFIDENTITY_NODE_TRANSFORM_H__
+#define __MOCO_TF_REMOVE_TFIDENTITY_NODE_TRANSFORM_H__
+
+#include "Transform.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Use the input of "TFIdentity" node instead
+ *
+ * BEFORE:
+ * [X] -> [TFIdentity] -> [Y]
+ *
+ * AFTER:
+ * [X] -> [Y]
+ * [TFIdentity]
+ *
+ * NOTE This transform does not remove "TFIdentity" node
+ * This transform is identical to RemoveForwardNodeTransform
+ */
+struct RemoveTFIdentityNodeTransform final : public Transform
+{
+ const char *name(void) const final { return "RemoveTFIdentityNodeTransform"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_REMOVE_TFIDENTITY_NODE_TRANSFORM_H__
diff --git a/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp b/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp
new file mode 100644
index 000000000..017aa666f
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ResolveConstantShape.cpp
@@ -0,0 +1,126 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveConstantShape.h"
+
+#include "IR/TFShape.h"
+#include "IR/TFConst.h"
+#include "Annotations/ShapeInferenceData.h"
+
+#include <loco.h>
+
+#include <cassert>
+
+namespace
+{
+
+/**
+ * WHEN:
+ * - TFShape's input shape is determined
+ * DO:
+ * - Replace TFShape into TFConst
+ *
+ *
+ * <Before>
+ * in ---- TFShape ---- out(s)
+ *
+ * <After>
+ * in ---- TFShape
+ *
+ * TFConst ---- out(s)
+ */
+bool resolve_constant_shape(loco::Graph *graph, moco::tf::TFShape *shape_node)
+{
+ using moco::tf::ShapeInferenceData;
+
+ auto input_shape = shape_node->input()->annot<ShapeInferenceData>();
+
+ // Check condition
+ if (!input_shape)
+ {
+ // Cannot resolve without known input_shape
+ return false;
+ }
+ auto shape_rank = input_shape->rank();
+ for (uint32_t axis = 0; axis < shape_rank; ++axis)
+ {
+ if (!input_shape->dim(axis).known())
+ {
+ // Cannot resolve with unknown dimension
+ return false;
+ }
+ }
+
+ auto input_tensor_shape = input_shape->tensor_shape();
+
+ // Make TFConst to replace TFShape
+ auto const_node = graph->nodes()->create<moco::tf::TFConst>();
+
+ // set dtype
+ auto dtype = shape_node->dtype();
+ const_node->dtype(dtype);
+
+ // set shape
+ const_node->rank(1);
+ const_node->dim(0) = shape_rank;
+
+ // set data
+ if (dtype == loco::DataType::S32)
+ {
+ // TODO Better to make template for this when support new dtype
+ const_node->size<loco::DataType::S32>(shape_rank);
+ for (uint32_t axis = 0; axis < shape_rank; ++axis)
+ {
+ int32_t dim = (int32_t)input_tensor_shape.dim(axis).value();
+ assert(dim > 0);
+ const_node->at<loco::DataType::S32>(axis) = dim;
+ }
+ }
+ else
+ {
+ throw std::runtime_error("ResolveConstantShape: Not supported output data type");
+ }
+
+ // replace
+ loco::replace(shape_node).with(const_node);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveConstantShape::run(loco::Graph *graph)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph)))
+ {
+ if (auto shape_node = as<moco::tf::TFShape>(node))
+ {
+ if (resolve_constant_shape(graph, shape_node))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h b/compiler/moco-tf/src/Transforms/ResolveConstantShape.h
index 64bb6041a..069418b7b 100644
--- a/compiler/moco-tf/src/Canonicalization/PadCanonicalizer.h
+++ b/compiler/moco-tf/src/Transforms/ResolveConstantShape.h
@@ -14,13 +14,12 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_PAD_CANONICALIZER_H__
-#define __MOCO_TF_PAD_CANONICALIZER_H__
+#ifndef __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__
+#define __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-#include <moco/IR/TFNodes.h>
+#include <loco.h>
namespace moco
{
@@ -28,18 +27,18 @@ namespace tf
{
/**
- * @brief Convert TFPad to Canonical TensorConstantPad
+ * @brief Replace fully determined TFShape node into TFConst
*/
-class PadCanonicalizer final : public SimpleNodeTransform<moco::TFPad>
+class ResolveConstantShape : public Transform
{
public:
- const char *name(void) const final { return "PadCanonicalizer"; }
+ const char *name(void) const final { return "ResolveConstantShape"; }
public:
- bool transform(moco::TFPad *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_PAD_CANONICALIZER_H__
+#endif // __MOCO_TF_RESOLVE_CONSTANT_SHAPE_H__
diff --git a/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp
new file mode 100644
index 000000000..1eeb31f53
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.cpp
@@ -0,0 +1,259 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "IR/TFAdd.h"
+#include "IR/TFConst.h"
+#include "IR/TFMul.h"
+
+#include "IR/TFFusedBatchNorm.h"
+
+#include <loco.h>
+#include <moco/Log.h>
+
+#include <cassert>
+#include <cmath>
+#include <memory>
+
+namespace
+{
+
+bool is_same_shape(moco::tf::TFConst *lc, moco::tf::TFConst *rc)
+{
+ if (lc->rank() != rc->rank())
+ return false;
+
+ for (auto r = 0; r < lc->rank(); ++r)
+ {
+ if (lc->dim(r).value() != rc->dim(r).value())
+ return false;
+ }
+ return true;
+}
+
+void copy_shape(const moco::tf::TFConst *src, moco::tf::TFConst *dst)
+{
+ assert(src != nullptr);
+ assert(dst != nullptr);
+
+ uint32_t rank = src->rank();
+ dst->rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (src->dim(index).known())
+ dst->dim(index) = src->dim(index).value();
+ else
+ dst->dim(index).unset();
+ }
+}
+
+/**
+ * @note resolve_to_muladd() will transform TFFusedBatchNorm to TFMul, TFAdd and two ConstGen
+ *
+ * <arguments>
+ * %0:input
+ * %1:gamma : const
+ * %2:beta : const
+ * %3:mean : const
+ * %4:variance : const
+ * %5:epsilon : const
+ *
+ * <constant operations>
+ * fbn_epsilon_array = make_array(%5:epsilon)
+ * fbn_epsilon = %4:variance + fbn_epsilon_array
+ * fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+ *
+ * fbn_mean = %3:mean
+ * fbn_mul = fbn_rsqrt * %1:gamma
+ * fbn_offset = %2:beta
+ *
+ * fbn_mul_0_param = fbn_mul
+ * fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+ *
+ * <new replace nodes>
+ * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+ * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+ * %21:fbn_add_param = ConstGen(fbn_add_param)
+ * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+ */
+bool resolve_to_muladd(loco::Graph *graph, moco::tf::TFFusedBatchNorm *node)
+{
+ LOGGER(lfbn);
+
+ auto tffbn_input = node->input();
+ if (tffbn_input == nullptr)
+ {
+ // This node is already converted
+ return false;
+ }
+
+ auto tffbn_gamma = dynamic_cast<moco::tf::TFConst *>(node->gamma());
+ auto tffbn_beta = dynamic_cast<moco::tf::TFConst *>(node->beta());
+ auto tffbn_mean = dynamic_cast<moco::tf::TFConst *>(node->mean());
+ auto tffbn_variance = dynamic_cast<moco::tf::TFConst *>(node->variance());
+
+ // all should be const
+ if (tffbn_gamma == nullptr || tffbn_beta == nullptr || tffbn_mean == nullptr ||
+ tffbn_variance == nullptr)
+ {
+ INFO(lfbn) << "TFFBN resolve_to_muladd: One of constant input node is not a constant"
+ << std::endl;
+ return false;
+ }
+ assert(tffbn_gamma->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_beta->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_mean->dtype() == loco::DataType::FLOAT32);
+ assert(tffbn_variance->dtype() == loco::DataType::FLOAT32);
+
+ // check all const shape are the same
+ if (!is_same_shape(tffbn_gamma, tffbn_beta) || !is_same_shape(tffbn_gamma, tffbn_mean) ||
+ !is_same_shape(tffbn_gamma, tffbn_variance))
+ {
+ INFO(lfbn) << "TFFBN resolve_to_muladd: Shape of constant are not same" << std::endl;
+ return false;
+ }
+
+ auto tffbn_epsilon = node->epsilon();
+ INFO(lfbn) << "TFFBN tffbn_epsilon = " << tffbn_epsilon << std::endl;
+ auto const_num_elements = tffbn_gamma->size<loco::DataType::FLOAT32>();
+ INFO(lfbn) << "TFFBN const_num_elements = " << const_num_elements << std::endl;
+
+ // fbn_epsilon = %4:variance + fbn_epsilon_array
+ std::unique_ptr<float[]> fbn_epsilon{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ auto variance = tffbn_variance->at<loco::DataType::FLOAT32>(i);
+ fbn_epsilon.get()[i] = variance + tffbn_epsilon;
+ }
+
+ // fbn_rsqrt = 1.0 / math::sqrt(fbn_epsilon)
+ std::unique_ptr<float[]> fbn_rsqrt{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_rsqrt.get()[i] = 1.0 / sqrt(fbn_epsilon.get()[i]);
+ }
+
+ // fbn_mean = %3:mean : TODO remove this block and use %3:mean
+ std::unique_ptr<float[]> fbn_mean{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mean.get()[i] = tffbn_mean->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_mul = fbn_rsqrt * %1:gamma
+ std::unique_ptr<float[]> fbn_mul{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mul.get()[i] = fbn_rsqrt.get()[i] * tffbn_gamma->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_offset = %2:beta : TODO remove this block and use %2:beta
+ std::unique_ptr<float[]> fbn_offset{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_offset.get()[i] = tffbn_beta->at<loco::DataType::FLOAT32>(i);
+ }
+
+ // fbn_mul_0_param = fbn_mul : remove this and use fbn_mul
+ std::unique_ptr<float[]> fbn_mul_0_param{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_mul_0_param.get()[i] = fbn_mul.get()[i];
+ }
+
+ // fbn_add_param = fbn_offset - fbn_mean * fbn_mul
+ std::unique_ptr<float[]> fbn_add_param{new float[const_num_elements]};
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ fbn_add_param.get()[i] = fbn_offset.get()[i] - fbn_mean.get()[i] * fbn_mul.get()[i];
+ }
+
+ INFO(lfbn) << "TFFBN create ConstGen" << std::endl;
+
+ /*
+ * %11:fbn_mul_0_param = ConstGen(fbn_mul_0_param)
+ * %21:fbn_add_param = ConstGen(fbn_add_param)
+ */
+ auto const_fbn_mul_0_param = graph->nodes()->create<moco::tf::TFConst>();
+ const_fbn_mul_0_param->dtype(loco::DataType::FLOAT32);
+ copy_shape(tffbn_gamma, const_fbn_mul_0_param);
+ const_fbn_mul_0_param->size<loco::DataType::FLOAT32>(const_num_elements);
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ const_fbn_mul_0_param->at<loco::DataType::FLOAT32>(i) = fbn_mul_0_param.get()[i];
+ }
+ auto const_fbn_add_param = graph->nodes()->create<moco::tf::TFConst>();
+ const_fbn_add_param->dtype(loco::DataType::FLOAT32);
+ copy_shape(tffbn_gamma, const_fbn_add_param);
+ const_fbn_add_param->size<loco::DataType::FLOAT32>(const_num_elements);
+ for (int32_t i = 0; i < const_num_elements; i++)
+ {
+ const_fbn_add_param->at<loco::DataType::FLOAT32>(i) = fbn_add_param.get()[i];
+ }
+
+ INFO(lfbn) << "TFFBN create TFMul, TFAdd" << std::endl;
+ /*
+ * %12:fbn_mul_0 = TFMul(%0:input, %11:fbn_mul_0_param)
+ * %22:fbn = TFAdd(%12:fbn_mul_0,%21:fbn_add_param)
+ */
+ auto fbn_mul_0 = graph->nodes()->create<moco::tf::TFMul>();
+ fbn_mul_0->x(tffbn_input);
+ fbn_mul_0->y(const_fbn_mul_0_param);
+
+ auto fbn = graph->nodes()->create<moco::tf::TFAdd>();
+ fbn->x(fbn_mul_0);
+ fbn->y(const_fbn_add_param);
+
+ // replace old node with new fbn
+ replace(node).with(fbn);
+ // unlink from graph
+ node->input(nullptr);
+ node->gamma(nullptr);
+ node->beta(nullptr);
+ node->mean(nullptr);
+ node->variance(nullptr);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveFusedBatchNorm::run(loco::Graph *graph)
+{
+ for (auto node : loco::active_nodes(loco::output_nodes(graph)))
+ {
+ if (as<moco::tf::TFFusedBatchNorm>(node))
+ {
+ if (resolve_to_muladd(graph, as<moco::tf::TFFusedBatchNorm>(node)))
+ {
+ // tree has been changed. let's return so that we don't need to
+ // considier about following node is correct or not.
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.h b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.h
new file mode 100644
index 000000000..9243951f5
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
+#define __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Trasform TFFusedBatchNorm into TFAdd + TFRsqrt + TFMul + TFBatchNorm
+*/
+class ResolveFusedBatchNorm : public Transform
+{
+public:
+ const char *name(void) const final { return "ResolveFusedBatchNorm"; }
+
+public:
+ bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_RESOLVE_FUSEDBATCHNORM_H__
diff --git a/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp
new file mode 100644
index 000000000..de4e1051d
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ResolveFusedBatchNorm.test.cpp
@@ -0,0 +1,232 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveFusedBatchNorm.h"
+
+#include "LogHelper.h"
+#include "TestHelper.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "Importer.h"
+
+#include <loco.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+#include <plier/tf/TestHelper.h>
+
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+const char *fbn_basic_pbtxt = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 1 }
+ dim { size: 4 }
+ dim { size: 4 }
+ dim { size: 1 }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "gamma"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "beta"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/mean"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01/variance"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "FBN_01"
+ op: "FusedBatchNorm"
+ input: "input"
+ input: "gamma"
+ input: "beta"
+ input: "FBN_01/mean"
+ input: "FBN_01/variance"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "epsilon"
+ value {
+ f: 0.001
+ }
+ }
+ attr {
+ key: "is_training"
+ value {
+ b: false
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+TEST(ResolveFusedBatchNorm, fbn_resolve_basic)
+{
+ LOGGER(l);
+
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("FBN_01", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(plier::tf::parse_graphdef(fbn_basic_pbtxt, graph_def));
+ auto graph = importer.import(signature, graph_def);
+
+ INFO(l) << "Before ResolveFusedBatchNorm";
+ INFO(l) << moco::tf::fmt(graph);
+
+ moco::tf::ResolveFusedBatchNorm transform;
+ bool changed = transform.run(graph.get());
+
+ INFO(l) << "After ResolveFusedBatchNorm " << to_char(changed);
+ INFO(l) << moco::tf::fmt(graph);
+
+ // Output value test will be done with mocotest-tf
+ // Network structure of transformation is not important and may be changed
+ // in the future so it will not be checked here.
+
+ SUCCEED();
+}
diff --git a/compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.cpp b/compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.cpp
new file mode 100644
index 000000000..80242521a
--- /dev/null
+++ b/compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.cpp
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ResolveReshapeWildcardDim.h"
+
+#include "IR/TFReshape.h"
+#include "IR/TFConst.h"
+#include "Annotations/ShapeInferenceData.h"
+
+#include <loco.h>
+
+#include <cassert>
+#include <limits>
+
+namespace
+{
+
+/**
+ * @return true when 'node' has one and only one wildcard dimension
+ * @return false when 'node' has no wildcard dimension, i.e. fixed reshape case
+ *
+ * @note Assertions in this function are sanity check for 'node', Reshape's
+ * Const shape input
+ */
+bool has_one_wildcard_dim(const moco::tf::TFConst *node)
+{
+ assert(node->dtype() == loco::DataType::S32);
+ assert(node->rank() == 1);
+
+ auto len = node->dim(0).value();
+ assert(len > 0);
+
+ // Must have one and only wildcard dimension(-1)
+ uint32_t count_wildcard_dim = 0;
+ for (uint32_t i = 0; i < len; ++i)
+ {
+ auto dim = node->at<loco::DataType::S32>(i);
+ if (dim == -1)
+ count_wildcard_dim++;
+ else
+ assert(dim >= 1);
+ }
+
+ assert(count_wildcard_dim <= 1 &&
+ "Invalid Reshape: there should be none or only one wildcard dimension");
+ return count_wildcard_dim;
+}
+
+uint32_t volume(const loco::TensorShape &shape)
+{
+ uint32_t ret = 1;
+ auto rank = shape.rank();
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ ret *= shape.dim(axis).value();
+ }
+ return ret;
+}
+
+void deduce_and_fix_wildcard_dim(moco::tf::TFConst *node,
+ const moco::tf::ShapeInferenceData *shape_data)
+{
+ assert(has_one_wildcard_dim(node));
+
+ assert(shape_data->domain() == loco::Domain::Tensor);
+ auto shape = shape_data->tensor_shape();
+
+ auto len = node->dim(0).value();
+ uint32_t wildcard_index = std::numeric_limits<uint32_t>::max();
+ uint32_t product_of_non_wildcard_dims = 1;
+
+ // Deduce
+ for (uint32_t i = 0; i < len; ++i)
+ {
+ auto dim = node->at<loco::DataType::S32>(i);
+ if (dim == -1)
+ {
+ wildcard_index = i;
+ }
+ else
+ {
+ product_of_non_wildcard_dims *= dim;
+ }
+ }
+ assert(wildcard_index != std::numeric_limits<uint32_t>::max());
+
+ // Fix
+ assert(volume(shape) % product_of_non_wildcard_dims == 0);
+ node->at<loco::DataType::S32>(wildcard_index) = volume(shape) / product_of_non_wildcard_dims;
+}
+
+/**
+ * WHEN:
+ * - TFReshape's shape input is TFConst
+ * - The TFConst is valid shape input for dynamic reshape, i.e. it has one and
+ * only wildcard dimension(-1)
+ * - TFReshape's tensor input has complete shape inference data
+ * DO:
+ * - Deduce what the wildcard dimension is and fix it
+ */
+bool resolve_wildcard_dim(moco::tf::TFReshape *reshape)
+{
+ // Check conditions (WHEN)
+ auto const_shape_input = dynamic_cast<moco::tf::TFConst *>(reshape->shape());
+ if (!const_shape_input)
+ return false;
+
+ if (!has_one_wildcard_dim(const_shape_input))
+ return false;
+
+ auto tensor_input_shape = reshape->tensor()->annot<moco::tf::ShapeInferenceData>();
+ if (!tensor_input_shape)
+ return false;
+
+ // Deduce (DO)
+ deduce_and_fix_wildcard_dim(const_shape_input, tensor_input_shape);
+
+ return true;
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ResolveReshapeWildcardDim::run(loco::Graph *graph)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph)))
+ {
+ if (auto reshape = as<moco::tf::TFReshape>(node))
+ {
+ if (resolve_wildcard_dim(reshape))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h b/compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.h
index 66eafe6af..c165c9027 100644
--- a/compiler/moco-tf/src/Canonicalization/PlaceholderCanonicalizer.h
+++ b/compiler/moco-tf/src/Transforms/ResolveReshapeWildcardDim.h
@@ -14,13 +14,12 @@
* limitations under the License.
*/
-#ifndef __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
-#define __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
+#ifndef __MOCO_TF_RESOLVE_RESHAPE_WILDCARD_DIM_H__
+#define __MOCO_TF_RESOLVE_RESHAPE_WILDCARD_DIM_H__
#include "Transform.h"
-#include "SimpleNodeTransform.h"
-#include <moco/IR/Nodes/TFPlaceholder.h>
+#include <loco.h>
namespace moco
{
@@ -28,20 +27,19 @@ namespace tf
{
/**
- * @brief Convert TFPlaceholder to Canonical Pull
- *
- * @note GraphInputIndex is copied to Pull
+ * @brief Determine wildcard dimension (denoted as -1) of Reshape's shape input
+ * if possible
*/
-class PlaceholderCanonicalizer : public SimpleNodeTransform<::moco::TFPlaceholder>
+class ResolveReshapeWildcardDim : public Transform
{
public:
- const char *name(void) const final { return "PlaceholderCanonicalizer"; }
+ const char *name(void) const final { return "ResolveReshapeWildcardDim"; }
public:
- bool transform(moco::TFPlaceholder *) const final;
+ bool run(loco::Graph *graph) override;
};
} // namespace tf
} // namespace moco
-#endif // __MOCO_TF_PLACEHOLDER_CANONICALIZER_H__
+#endif // __MOCO_TF_RESOLVE_RESHAPE_WILDCARD_DIM_H__
diff --git a/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp b/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp
index 64ba9dfb1..e47b06964 100644
--- a/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp
+++ b/compiler/moco-tf/src/Transforms/ShapeInferencePass.cpp
@@ -16,9 +16,8 @@
#include "ShapeInferencePass.h"
-#include <moco/IR/TFDialect.h>
-
-#include <moco/Service/TFShapeInferenceRule.h>
+#include "Dialect/TFShapeInferenceRule.h"
+#include "Dialect/TFDialect.h"
#include <loco.h>
@@ -40,7 +39,7 @@ namespace tf
bool ShapeInferencePass::run(loco::Graph *graph)
{
loco::CanonicalShapeInferenceRule canonical_rule;
- moco::TFShapeInferenceRule tf_rule;
+ TFShapeInferenceRule tf_rule;
locoex::COpShapeInferenceRule cop_rule; // rule for custop op
loco::MultiDialectShapeInferenceRule rules;
diff --git a/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp b/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp
index db6cf7521..efdacc5a0 100644
--- a/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp
+++ b/compiler/moco-tf/src/Transforms/TypeInferencePass.cpp
@@ -16,9 +16,8 @@
#include "TypeInferencePass.h"
-#include <moco/IR/TFDialect.h>
-
-#include <moco/Service/TFTypeInferenceRule.h>
+#include "Dialect/TFTypeInferenceRule.h"
+#include "Dialect/TFDialect.h"
#include <loco.h>
@@ -36,7 +35,7 @@ namespace tf
bool TypeInferencePass::run(loco::Graph *graph)
{
loco::CanonicalTypeInferenceRule canonical_rule;
- moco::TFTypeInferenceRule tf_rule; // rule for TF dialect
+ TFTypeInferenceRule tf_rule; // rule for TF dialect
locoex::COpTypeInferenceRule cop_rule; // rule for custop op
loco::MultiDialectTypeInferenceRule rules;