#include "caffe2/core/context_gpu.h" #include "caffe2/operators/cast_op.h" #include "caffe2/utils/conversions.h" namespace caffe2 { template __global__ void CastKernel(const int N, const SrcType* X, DstType* Y) { CUDA_1D_KERNEL_LOOP(i, N) { // Y[i] = static_cast(X[i]); Y[i] = convert::To(X[i]); } } template <> template bool CastOp::DoRunWithType() { auto& input = Input(0); auto* output = Output(0); output->ResizeLike(input); const auto* data = input.template data(); auto* out = output->template mutable_data(); DCHECK(input.size() < INT_MAX); int N = input.size(); if (N == 0) { // skip the rest of the computation if input is empty return true; } CastKernel <<>>(N, data, out); return true; } template <> template bool CastOp::DoRunWithDstType() { return DispatchHelper< TensorTypes< float, int32_t, bool, uint8_t, int8_t, uint16_t, int16_t, int64_t, double>, DstType>::call(this, Input(0)); } // specific version that allows for casting to fp16 template <> template <> bool CastOp::DoRunWithDstType() { return DispatchHelper< TensorTypes< float, float16, int32_t, bool, uint8_t, int8_t, uint16_t, int16_t, int64_t, double>, float /* DstType */>::call(this, Input(0)); } // specific version for casting _from_ fp16 template <> template <> bool CastOp::DoRunWithDstType() { return DispatchHelper< TensorTypes< float, float16>, float16 /* DstType */>::call(this, Input(0)); } template <> void CastOp::SetBody(TensorProto_DataType to) { switch (to) { case TensorProto_DataType_FLOAT: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_INT32: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_BYTE: LOG(FATAL) << "BYTE is deprecated"; break; case TensorProto_DataType_STRING: CAFFE_THROW("Casting to and from strings is not supported yet"); // break; case TensorProto_DataType_BOOL: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_UINT8: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_INT8: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_UINT16: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_INT16: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_INT64: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_FLOAT16: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_DOUBLE: body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_UNDEFINED: CAFFE_THROW("Cast op must have 'to' argument of type DataType"); // break; default: CAFFE_THROW("Unexpected 'to' argument value: ", to); } } REGISTER_CUDA_OPERATOR(Cast, CastOp); } // namespace caffe2