summaryrefslogtreecommitdiff
path: root/runtime/onert/core/include/ir/train
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/include/ir/train')
-rw-r--r--runtime/onert/core/include/ir/train/ITrainableOperation.h49
-rw-r--r--runtime/onert/core/include/ir/train/Operations.Include.h29
-rw-r--r--runtime/onert/core/include/ir/train/Operations.lst28
-rw-r--r--runtime/onert/core/include/ir/train/TrainableGraph.h145
-rw-r--r--runtime/onert/core/include/ir/train/TrainableOperationVisitor.h43
-rw-r--r--runtime/onert/core/include/ir/train/operation/Conv2D.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h52
-rw-r--r--runtime/onert/core/include/ir/train/operation/FullyConnected.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/Loss.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/Permute.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/Pool2D.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/Reshape.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/Softmax.h51
-rw-r--r--runtime/onert/core/include/ir/train/operation/UntrainableOperation.h63
14 files changed, 766 insertions, 0 deletions
diff --git a/runtime/onert/core/include/ir/train/ITrainableOperation.h b/runtime/onert/core/include/ir/train/ITrainableOperation.h
new file mode 100644
index 000000000..590bed45d
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/ITrainableOperation.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__
+#define __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__
+
+#include "ir/IOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+struct TrainableOperationVisitor;
+
+// NOTE Virtual inheritance is introduced because trainable operations inherit
+// `ITrainableOperation` and `Operation` which inherit `IOperation`.
+class ITrainableOperation : virtual public IOperation
+{
+public:
+ virtual ~ITrainableOperation() = default;
+
+public:
+ virtual std::unique_ptr<ITrainableOperation> clone() const = 0;
+ virtual void accept(OperationVisitor &v) const override = 0;
+ virtual void accept(TrainableOperationVisitor &v) const = 0;
+ // TODO Add virtual methods related to training
+};
+
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__
diff --git a/runtime/onert/core/include/ir/train/Operations.Include.h b/runtime/onert/core/include/ir/train/Operations.Include.h
new file mode 100644
index 000000000..56e752f94
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/Operations.Include.h
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__
+#define __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__
+
+#include "ir/train/operation/Conv2D.h"
+#include "ir/train/operation/ElementwiseActivation.h"
+#include "ir/train/operation/FullyConnected.h"
+#include "ir/train/operation/Loss.h"
+#include "ir/train/operation/Permute.h"
+#include "ir/train/operation/Pool2D.h"
+#include "ir/train/operation/Reshape.h"
+#include "ir/train/operation/Softmax.h"
+
+#endif // __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__
diff --git a/runtime/onert/core/include/ir/train/Operations.lst b/runtime/onert/core/include/ir/train/Operations.lst
new file mode 100644
index 000000000..14dc38819
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/Operations.lst
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef OP
+#error Define OP before including this file
+#endif
+
+OP(Conv2D)
+OP(ElementwiseActivation)
+OP(FullyConnected)
+OP(Loss)
+OP(Permute)
+OP(Pool2D)
+OP(Reshape)
+OP(Softmax)
diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h
new file mode 100644
index 000000000..90c49e212
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/TrainableGraph.h
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
+#define __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
+
+#include <functional>
+#include <unordered_map>
+
+#include "ir/Graph.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+class TrainableGraph : public IGraph
+{
+public:
+ /**
+ * @brief Construct a new Trainable Graph object
+ *
+ * @param graph
+ */
+ explicit TrainableGraph();
+ explicit TrainableGraph(const TrainableGraph &tgraph);
+ explicit TrainableGraph(const Graph &graph);
+ ~TrainableGraph() = default;
+
+ // TrainableGraph Building
+public:
+ OperandIndex addOperand(const Shape &shape, const TypeInfo &type);
+ /**
+ * @brief Add an operand to the graph with the given index and object
+ *
+ * If the given index is available, it succeeds. And @c operand is moved which invalidates the
+ * caller's pointer. If the given index is already taken, it fails. And @c operand will not be
+ * moved so the caller's pointer will be still valid.
+ *
+ * @param[in] index Index to be added
+ * @param[in] operand Operand to be added
+ * @return OperandIndex @c index if successful, UNDEFINED otherwise
+ */
+ OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand);
+ /**
+ * @brief Add a new trainable operation to the graph
+ *
+ * If the given @c operation has at least one invalid operand index, it fails. And @c operation
+ * will not be moved so the caller's pointer will be still valid.
+ *
+ * @param operation Operation to be added
+ * @return OperationIndex @c index if successful, UNDEFINED otherwise
+ */
+ OperationIndex addOperation(std::unique_ptr<ITrainableOperation> &&operation);
+ /**
+ * @brief Replace a trainable operation which the graph already has
+ *
+ * If the given @c index is available, it succeeds. And @c operation is moved which invalidates
+ * the caller's pointer. If the given @c operation has at least one invalid operand index, it
+ * fails. And @c operation will not be moved so the caller's pointer will be still valid.
+ *
+ * No information in the graph is changed except for replacing an operation.
+ *
+ * @param operation Operation to be added
+ * @return OperationIndex @c index if successful, UNDEFINED otherwise
+ */
+ OperationIndex replaceOperation(OperationIndex index,
+ std::unique_ptr<ITrainableOperation> &&operation);
+
+ /**
+ * @brief Add a derivative to the graph with the given index and object
+ *
+ * If the given index is available, it succeeds. And @c derivative is moved which invalidates the
+ * caller's pointer. If the given index is already taken, it fails. And @c derivative will not be
+ * moved so the caller's pointer will be still valid.
+ *
+ * @param[in] index Index to be added
+ * @param[in] derivative Derivative operand to be added
+ * @return OperandIndex @c index if successful, UNDEFINED otherwise
+ */
+ OperandIndex addDerivative(OperandIndex index, std::unique_ptr<Operand> &&derivative);
+
+public:
+ void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override;
+ void changeDerivativeShape(const OperandIndex &ind, const ir::Shape &new_shape);
+ void addInput(const OperandIndex &ind, const std::string &name = "");
+ void addOutput(const OperandIndex &ind, const std::string &name = "");
+ void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind);
+ void verify() const;
+ void removeOperand(const OperandIndex &ind);
+ void setLayout(Layout layout);
+ void setInputs(OperandIndexSequence inputs,
+ std::unordered_map<std::string, IOIndex> name_to_input);
+ void setOutputs(OperandIndexSequence outputs,
+ std::unordered_map<std::string, IOIndex> name_to_output);
+
+ // Accessors
+public:
+ const OperandIndexSequence &getInputs() const override { return _graph.getInputs(); }
+ const OperandIndexSequence &getOutputs() const override { return _graph.getOutputs(); }
+ IOIndex getInputIndex(const std::string &name) const override;
+ IOIndex getOutputIndex(const std::string &name) const override;
+ const Operands &operands() const override { return _graph.operands(); }
+ Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor
+ const Operations &operations() const override { return _graph.operations(); }
+ const Operands &derivatives() const { return _derivatives; }
+ OperandIndex getLossIndex(const IOIndex &pred_io_ind) const;
+ Layout layout() const { return _graph.layout(); }
+ const Graph &graph() const { return _graph; }
+
+public:
+ const ITrainableOperation &operation(OperationIndex index) const;
+
+public:
+ std::vector<ir::OperationIndex> topolSortOperations() const;
+ // TODO Support topological sort for backwarding
+
+private:
+ Graph _graph;
+ Operands _derivatives;
+
+ std::unordered_map<IOIndex, OperandIndex> _losses;
+};
+
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__
diff --git a/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h
new file mode 100644
index 000000000..fc58c351d
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__
+#define __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__
+
+#include "ir/train/Operations.Include.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+struct TrainableOperationVisitor
+{
+ virtual ~TrainableOperationVisitor() = default;
+
+#define OP(InternalName) \
+ virtual void visit(const operation::InternalName &) {}
+#include "ir/train/Operations.lst"
+#undef OP
+};
+
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Conv2D.h b/runtime/onert/core/include/ir/train/operation/Conv2D.h
new file mode 100644
index 000000000..b8968926a
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Conv2D.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_CONV2D_H__
+#define __ONERT_IR_TRAIN_OPERATION_CONV2D_H__
+
+#include "ir/operation/Conv2D.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Conv2D : public ir::operation::Conv2D, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Conv2D;
+
+public:
+ Conv2D(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_CONV2D_H__
diff --git a/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h
new file mode 100644
index 000000000..97ab54d17
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__
+#define __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__
+
+#include "ir/operation/ElementwiseActivation.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class ElementwiseActivation : public ir::operation::ElementwiseActivation,
+ public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::ElementwiseActivation;
+
+public:
+ ElementwiseActivation(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__
diff --git a/runtime/onert/core/include/ir/train/operation/FullyConnected.h b/runtime/onert/core/include/ir/train/operation/FullyConnected.h
new file mode 100644
index 000000000..bede58d69
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/FullyConnected.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__
+#define __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__
+
+#include "ir/operation/FullyConnected.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class FullyConnected : public ir::operation::FullyConnected, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::FullyConnected;
+
+public:
+ FullyConnected(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Loss.h b/runtime/onert/core/include/ir/train/operation/Loss.h
new file mode 100644
index 000000000..c7cc4213a
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Loss.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_LOSS_H__
+#define __ONERT_IR_TRAIN_OPERATION_LOSS_H__
+
+#include "ir/operation/Loss.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Loss : public ir::operation::Loss, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Loss;
+
+public:
+ Loss(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_LOSS_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Permute.h b/runtime/onert/core/include/ir/train/operation/Permute.h
new file mode 100644
index 000000000..e652b136d
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Permute.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__
+#define __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__
+
+#include "ir/operation/Permute.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Permute : public ir::operation::Permute, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Permute;
+
+public:
+ Permute(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Pool2D.h b/runtime/onert/core/include/ir/train/operation/Pool2D.h
new file mode 100644
index 000000000..024997074
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Pool2D.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_POOL2D_H__
+#define __ONERT_IR_TRAIN_OPERATION_POOL2D_H__
+
+#include "ir/operation/Pool2D.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Pool2D : public ir::operation::Pool2D, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Pool2D;
+
+public:
+ Pool2D(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_POOL2D_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Reshape.h b/runtime/onert/core/include/ir/train/operation/Reshape.h
new file mode 100644
index 000000000..1efd62cfe
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Reshape.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__
+#define __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__
+
+#include "ir/operation/Reshape.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Reshape : public ir::operation::Reshape, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Reshape;
+
+public:
+ Reshape(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__
diff --git a/runtime/onert/core/include/ir/train/operation/Softmax.h b/runtime/onert/core/include/ir/train/operation/Softmax.h
new file mode 100644
index 000000000..b12e6abc1
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/Softmax.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__
+#define __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__
+
+#include "ir/operation/Softmax.h"
+#include "ir/train/ITrainableOperation.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+class Softmax : public ir::operation::Softmax, public ITrainableOperation
+{
+private:
+ using OperationType = ir::operation::Softmax;
+
+public:
+ Softmax(const OperationType &operation);
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override;
+ void accept(OperationVisitor &v) const override;
+ void accept(TrainableOperationVisitor &v) const override;
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__
diff --git a/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h
new file mode 100644
index 000000000..7cda0ec0c
--- /dev/null
+++ b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__
+#define __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__
+
+#include "ir/train/ITrainableOperation.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include <type_traits>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+// `UntrainableOperation` wraps operations that are not yet supported for training.
+// This class can be removed if all operations are supported for training.
+template <typename OperationType,
+ typename = std::enable_if_t<std::is_base_of<Operation, OperationType>::value>>
+class UntrainableOperation : public OperationType, public ITrainableOperation
+{
+public:
+ UntrainableOperation(const OperationType &operation) : OperationType{operation} {}
+ virtual ~UntrainableOperation() = default;
+
+public:
+ std::unique_ptr<ITrainableOperation> clone() const override
+ {
+ return std::make_unique<UntrainableOperation<OperationType>>(*this);
+ }
+ void accept(OperationVisitor &v) const override { v.visit(*this); }
+ void accept(TrainableOperationVisitor &) const override
+ {
+ throw std::runtime_error(OperationType::name() + "operation is not trainable yet");
+ }
+};
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__