|
|
|
@ -54,17 +54,12 @@ class CastOpKernel : public framework::OpKernel<InT> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* in = context.Input<framework::Tensor>("X");
|
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
#if !defined(_MSC_VER)
|
|
|
|
|
|
|
|
|
|
framework::VisitDataType(
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(
|
|
|
|
|
context.Attr<int>("out_dtype")),
|
|
|
|
|
CastOpFunctor<DeviceContext, InT>(
|
|
|
|
|
in, out, context.template device_context<DeviceContext>()));
|
|
|
|
|
#else
|
|
|
|
|
auto type = static_cast<framework::proto::VarType::Type>(
|
|
|
|
|
context.Attr<int>("out_dtype"));
|
|
|
|
|
trans
|
|
|
|
|
#endif // msvc
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|