summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/locop/src/FormattedGraph.cpp44
1 files changed, 42 insertions, 2 deletions
diff --git a/compiler/locop/src/FormattedGraph.cpp b/compiler/locop/src/FormattedGraph.cpp
index 04ca7c4..a6896b9 100644
--- a/compiler/locop/src/FormattedGraph.cpp
+++ b/compiler/locop/src/FormattedGraph.cpp
@@ -17,6 +17,7 @@
#include "locop/FormattedGraph.h"
#include "locop/FormattedTensorShape.h"
+#include <loco/Service/TypeInference.h>
#include <loco/Service/ShapeInference.h>
#include <pp/Format.h>
@@ -33,6 +34,45 @@ using locop::SymbolTable;
namespace
{
+std::string str(const loco::DataType &dtype)
+{
+ switch (dtype)
+ {
+ case loco::DataType::Unknown:
+ return "Unknown";
+
+ case loco::DataType::U8:
+ return "U8";
+ case loco::DataType::U16:
+ return "U16";
+ case loco::DataType::U32:
+ return "U32";
+ case loco::DataType::U64:
+ return "U64";
+
+ case loco::DataType::S8:
+ return "S8";
+ case loco::DataType::S16:
+ return "S16";
+ case loco::DataType::S32:
+ return "S32";
+ case loco::DataType::S64:
+ return "S64";
+
+ case loco::DataType::FLOAT16:
+ return "FLOAT16";
+ case loco::DataType::FLOAT32:
+ return "FLOAT32";
+ case loco::DataType::FLOAT64:
+ return "FLOAT64";
+
+ default:
+ break;
+ };
+
+ throw std::invalid_argument{"dtype"};
+}
+
std::string str(const loco::Domain &domain)
{
// TODO Generate!
@@ -372,8 +412,8 @@ void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const
os << "<";
os << str(node_shape);
os << ", ";
- os << "?";
- // TODO Show DataType
+ // Show DataType
+ os << (loco::dtype_known(node) ? str(loco::dtype_get(node)) : std::string{"?"});
os << ">";
}