|
|
|
@ -48,52 +48,17 @@ struct CastOpFunctor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename InT, typename OutT>
|
|
|
|
|
static void CastFunction(const framework::ExecutionContext& context) {
|
|
|
|
|
auto* in = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto in_t = framework::EigenVector<InT>::Flatten(*in);
|
|
|
|
|
out->mutable_data<OutT>(context.GetPlace());
|
|
|
|
|
auto out_t = framework::EigenVector<OutT>::Flatten(*out);
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
out_t.device(place) = in_t.template cast<OutT>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename InT>
|
|
|
|
|
class CastOpKernel : public framework::OpKernel<InT> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto out_type = static_cast<framework::proto::VarType::Type>(
|
|
|
|
|
context.Attr<int>("out_dtype"));
|
|
|
|
|
|
|
|
|
|
if (out_type == paddle::framework::proto::VarType::FP64) {
|
|
|
|
|
CastFunction<DeviceContext, InT, double>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::FP32) {
|
|
|
|
|
CastFunction<DeviceContext, InT, float>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::FP16) {
|
|
|
|
|
CastFunction<DeviceContext, InT, paddle::platform::float16>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::INT64) {
|
|
|
|
|
CastFunction<DeviceContext, InT, int64_t>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::INT32) {
|
|
|
|
|
CastFunction<DeviceContext, InT, int>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::UINT8) {
|
|
|
|
|
CastFunction<DeviceContext, InT, uint8_t>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::BOOL) {
|
|
|
|
|
CastFunction<DeviceContext, InT, bool>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::COMPLEX64) {
|
|
|
|
|
CastFunction<DeviceContext, InT, paddle::platform::complex64>(context);
|
|
|
|
|
} else if (out_type == paddle::framework::proto::VarType::COMPLEX128) {
|
|
|
|
|
CastFunction<DeviceContext, InT, paddle::platform::complex128>(context);
|
|
|
|
|
} else {
|
|
|
|
|
// NOTE(chenweihang): if else branch do nothing, the output var will
|
|
|
|
|
// be non-initialized in dygraph, which will throw error if the
|
|
|
|
|
// non-initialized var is used as the next op's input
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"Now does not support casting Tensor to `%s` data type.",
|
|
|
|
|
framework::DataTypeToString(out_type)));
|
|
|
|
|
}
|
|
|
|
|
auto* in = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(
|
|
|
|
|
context.Attr<int>("out_dtype")),
|
|
|
|
|
CastOpFunctor<DeviceContext, InT>(
|
|
|
|
|
in, out, context.template device_context<DeviceContext>()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|