diff options
Diffstat (limited to 'caffe2/utils/math/half_utils.h')
-rw-r--r-- | caffe2/utils/math/half_utils.h | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/caffe2/utils/math/half_utils.h b/caffe2/utils/math/half_utils.h new file mode 100644 index 0000000000..ac841d165a --- /dev/null +++ b/caffe2/utils/math/half_utils.h @@ -0,0 +1,49 @@ +#ifndef CAFFE2_UTILS_MATH_HALF_UTILS_H_ +#define CAFFE2_UTILS_MATH_HALF_UTILS_H_ + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" +#include "caffe2/utils/conversions.h" +#include "caffe2/utils/math/utils.h" + +namespace caffe2 { +namespace math { +namespace utils { + +struct HalfAddFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) + convert::To<at::Half, float>(b)); + } +}; + +struct HalfSubFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) - convert::To<at::Half, float>(b)); + } +}; + +struct HalfMulFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) * convert::To<at::Half, float>(b)); + } +}; + +struct HalfDivFunctor { + MATH_UTILS_DECL at::Half operator()(const at::Half a, const at::Half b) + const { + return convert::To<float, at::Half>( + convert::To<at::Half, float>(a) / convert::To<at::Half, float>(b)); + } +}; + +} // namespace utils +} // namespace math +} // namespace caffe2 + +#endif // CAFFE2_UTILS_MATH_HALF_UTILS_H_ |