diff options
Diffstat (limited to 'runtime/onert/core/include/ir/train')
14 files changed, 766 insertions, 0 deletions
diff --git a/runtime/onert/core/include/ir/train/ITrainableOperation.h b/runtime/onert/core/include/ir/train/ITrainableOperation.h new file mode 100644 index 000000000..590bed45d --- /dev/null +++ b/runtime/onert/core/include/ir/train/ITrainableOperation.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ +#define __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ + +#include "ir/IOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct TrainableOperationVisitor; + +// NOTE Virtual inheritance is introduced because trainable operations inherit +// `ITrainableOperation` and `Operation` which inherit `IOperation`. +class ITrainableOperation : virtual public IOperation +{ +public: + virtual ~ITrainableOperation() = default; + +public: + virtual std::unique_ptr<ITrainableOperation> clone() const = 0; + virtual void accept(OperationVisitor &v) const override = 0; + virtual void accept(TrainableOperationVisitor &v) const = 0; + // TODO Add virtual methods related to training +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_ITRAINABLE_OPERATION_H__ diff --git a/runtime/onert/core/include/ir/train/Operations.Include.h b/runtime/onert/core/include/ir/train/Operations.Include.h new file mode 100644 index 000000000..56e752f94 --- /dev/null +++ b/runtime/onert/core/include/ir/train/Operations.Include.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ +#define __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ + +#include "ir/train/operation/Conv2D.h" +#include "ir/train/operation/ElementwiseActivation.h" +#include "ir/train/operation/FullyConnected.h" +#include "ir/train/operation/Loss.h" +#include "ir/train/operation/Permute.h" +#include "ir/train/operation/Pool2D.h" +#include "ir/train/operation/Reshape.h" +#include "ir/train/operation/Softmax.h" + +#endif // __ONERT_IR_TRAIN_OPERATIONS_OPERATION_INCLUDE_H__ diff --git a/runtime/onert/core/include/ir/train/Operations.lst b/runtime/onert/core/include/ir/train/Operations.lst new file mode 100644 index 000000000..14dc38819 --- /dev/null +++ b/runtime/onert/core/include/ir/train/Operations.lst @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OP +#error Define OP before including this file +#endif + +OP(Conv2D) +OP(ElementwiseActivation) +OP(FullyConnected) +OP(Loss) +OP(Permute) +OP(Pool2D) +OP(Reshape) +OP(Softmax) diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h new file mode 100644 index 000000000..90c49e212 --- /dev/null +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ +#define __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ + +#include <functional> +#include <unordered_map> + +#include "ir/Graph.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +class TrainableGraph : public IGraph +{ +public: + /** + * @brief Construct a new Trainable Graph object + * + * @param graph + */ + explicit TrainableGraph(); + explicit TrainableGraph(const TrainableGraph &tgraph); + explicit TrainableGraph(const Graph &graph); + ~TrainableGraph() = default; + + // TrainableGraph Building +public: + OperandIndex addOperand(const Shape &shape, const TypeInfo &type); + /** + * @brief Add an operand to the graph with the given index and object + * + * If the given index is available, it succeeds. And @c operand is moved which invalidates the + * caller's pointer. If the given index is already taken, it fails. And @c operand will not be + * moved so the caller's pointer will be still valid. + * + * @param[in] index Index to be added + * @param[in] operand Operand to be added + * @return OperandIndex @c index if successful, UNDEFINED otherwise + */ + OperandIndex addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand); + /** + * @brief Add a new trainable operation to the graph + * + * If the given @c operation has at least one invalid operand index, it fails. And @c operation + * will not be moved so the caller's pointer will be still valid. + * + * @param operation Operation to be added + * @return OperationIndex @c index if successful, UNDEFINED otherwise + */ + OperationIndex addOperation(std::unique_ptr<ITrainableOperation> &&operation); + /** + * @brief Replace a trainable operation which the graph already has + * + * If the given @c index is available, it succeeds. And @c operation is moved which invalidates + * the caller's pointer. If the given @c operation has at least one invalid operand index, it + * fails. And @c operation will not be moved so the caller's pointer will be still valid. + * + * No information in the graph is changed except for replacing an operation. + * + * @param operation Operation to be added + * @return OperationIndex @c index if successful, UNDEFINED otherwise + */ + OperationIndex replaceOperation(OperationIndex index, + std::unique_ptr<ITrainableOperation> &&operation); + + /** + * @brief Add a derivative to the graph with the given index and object + * + * If the given index is available, it succeeds. And @c derivative is moved which invalidates the + * caller's pointer. If the given index is already taken, it fails. And @c derivative will not be + * moved so the caller's pointer will be still valid. + * + * @param[in] index Index to be added + * @param[in] derivative Derivative operand to be added + * @return OperandIndex @c index if successful, UNDEFINED otherwise + */ + OperandIndex addDerivative(OperandIndex index, std::unique_ptr<Operand> &&derivative); + +public: + void changeShape(const OperandIndex &ind, const ir::Shape &new_shape) override; + void changeDerivativeShape(const OperandIndex &ind, const ir::Shape &new_shape); + void addInput(const OperandIndex &ind, const std::string &name = ""); + void addOutput(const OperandIndex &ind, const std::string &name = ""); + void addLoss(const OperandIndex &loss_ind, const IOIndex &pred_io_ind); + void verify() const; + void removeOperand(const OperandIndex &ind); + void setLayout(Layout layout); + void setInputs(OperandIndexSequence inputs, + std::unordered_map<std::string, IOIndex> name_to_input); + void setOutputs(OperandIndexSequence outputs, + std::unordered_map<std::string, IOIndex> name_to_output); + + // Accessors +public: + const OperandIndexSequence &getInputs() const override { return _graph.getInputs(); } + const OperandIndexSequence &getOutputs() const override { return _graph.getOutputs(); } + IOIndex getInputIndex(const std::string &name) const override; + IOIndex getOutputIndex(const std::string &name) const override; + const Operands &operands() const override { return _graph.operands(); } + Operands &operands() { return _graph.operands(); } // TODO Remove this non-const accessor + const Operations &operations() const override { return _graph.operations(); } + const Operands &derivatives() const { return _derivatives; } + OperandIndex getLossIndex(const IOIndex &pred_io_ind) const; + Layout layout() const { return _graph.layout(); } + const Graph &graph() const { return _graph; } + +public: + const ITrainableOperation &operation(OperationIndex index) const; + +public: + std::vector<ir::OperationIndex> topolSortOperations() const; + // TODO Support topological sort for backwarding + +private: + Graph _graph; + Operands _derivatives; + + std::unordered_map<IOIndex, OperandIndex> _losses; +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_TRAINABLE_GRAPH_H__ diff --git a/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h new file mode 100644 index 000000000..fc58c351d --- /dev/null +++ b/runtime/onert/core/include/ir/train/TrainableOperationVisitor.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ +#define __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ + +#include "ir/train/Operations.Include.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct TrainableOperationVisitor +{ + virtual ~TrainableOperationVisitor() = default; + +#define OP(InternalName) \ + virtual void visit(const operation::InternalName &) {} +#include "ir/train/Operations.lst" +#undef OP +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_TRAINABLE_OPERATION_VISITOR_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Conv2D.h b/runtime/onert/core/include/ir/train/operation/Conv2D.h new file mode 100644 index 000000000..b8968926a --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Conv2D.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ +#define __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ + +#include "ir/operation/Conv2D.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Conv2D : public ir::operation::Conv2D, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Conv2D; + +public: + Conv2D(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_CONV2D_H__ diff --git a/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h new file mode 100644 index 000000000..97ab54d17 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/ElementwiseActivation.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ +#define __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ + +#include "ir/operation/ElementwiseActivation.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class ElementwiseActivation : public ir::operation::ElementwiseActivation, + public ITrainableOperation +{ +private: + using OperationType = ir::operation::ElementwiseActivation; + +public: + ElementwiseActivation(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_ELEMENTWISE_ACTIVATION_H__ diff --git a/runtime/onert/core/include/ir/train/operation/FullyConnected.h b/runtime/onert/core/include/ir/train/operation/FullyConnected.h new file mode 100644 index 000000000..bede58d69 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/FullyConnected.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ +#define __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ + +#include "ir/operation/FullyConnected.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class FullyConnected : public ir::operation::FullyConnected, public ITrainableOperation +{ +private: + using OperationType = ir::operation::FullyConnected; + +public: + FullyConnected(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_FULLYCONNECTED_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Loss.h b/runtime/onert/core/include/ir/train/operation/Loss.h new file mode 100644 index 000000000..c7cc4213a --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Loss.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_LOSS_H__ +#define __ONERT_IR_TRAIN_OPERATION_LOSS_H__ + +#include "ir/operation/Loss.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Loss : public ir::operation::Loss, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Loss; + +public: + Loss(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_LOSS_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Permute.h b/runtime/onert/core/include/ir/train/operation/Permute.h new file mode 100644 index 000000000..e652b136d --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Permute.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ +#define __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ + +#include "ir/operation/Permute.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Permute : public ir::operation::Permute, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Permute; + +public: + Permute(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_PERMUTE_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Pool2D.h b/runtime/onert/core/include/ir/train/operation/Pool2D.h new file mode 100644 index 000000000..024997074 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Pool2D.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ +#define __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ + +#include "ir/operation/Pool2D.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Pool2D : public ir::operation::Pool2D, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Pool2D; + +public: + Pool2D(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_POOL2D_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Reshape.h b/runtime/onert/core/include/ir/train/operation/Reshape.h new file mode 100644 index 000000000..1efd62cfe --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Reshape.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ +#define __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ + +#include "ir/operation/Reshape.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Reshape : public ir::operation::Reshape, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Reshape; + +public: + Reshape(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_RESHAPE_H__ diff --git a/runtime/onert/core/include/ir/train/operation/Softmax.h b/runtime/onert/core/include/ir/train/operation/Softmax.h new file mode 100644 index 000000000..b12e6abc1 --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/Softmax.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ +#define __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ + +#include "ir/operation/Softmax.h" +#include "ir/train/ITrainableOperation.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +class Softmax : public ir::operation::Softmax, public ITrainableOperation +{ +private: + using OperationType = ir::operation::Softmax; + +public: + Softmax(const OperationType &operation); + +public: + std::unique_ptr<ITrainableOperation> clone() const override; + void accept(OperationVisitor &v) const override; + void accept(TrainableOperationVisitor &v) const override; +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_SOFTMAX_H__ diff --git a/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h new file mode 100644 index 000000000..7cda0ec0c --- /dev/null +++ b/runtime/onert/core/include/ir/train/operation/UntrainableOperation.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ +#define __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ + +#include "ir/train/ITrainableOperation.h" + +#include "ir/OperationVisitor.h" +#include "ir/train/TrainableOperationVisitor.h" + +#include <type_traits> + +namespace onert +{ +namespace ir +{ +namespace train +{ +namespace operation +{ + +// `UntrainableOperation` wraps operations that are not yet supported for training. +// This class can be removed if all operations are supported for training. +template <typename OperationType, + typename = std::enable_if_t<std::is_base_of<Operation, OperationType>::value>> +class UntrainableOperation : public OperationType, public ITrainableOperation +{ +public: + UntrainableOperation(const OperationType &operation) : OperationType{operation} {} + virtual ~UntrainableOperation() = default; + +public: + std::unique_ptr<ITrainableOperation> clone() const override + { + return std::make_unique<UntrainableOperation<OperationType>>(*this); + } + void accept(OperationVisitor &v) const override { v.visit(*this); } + void accept(TrainableOperationVisitor &) const override + { + throw std::runtime_error(OperationType::name() + "operation is not trainable yet"); + } +}; + +} // namespace operation +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_OPERATION_UNTRAINABLE_OPERATION_H__ |