|
|
|
@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CastOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
|
|
|
|
|
// CastOp kernel's device type is decided by input tensor place
|
|
|
|
|
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
using CPU = paddle::platform::CPUDeviceContext;
|
|
|
|
|
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
|
|
|
|
|
ops::CastOpProtoMaker);
|
|
|
|
|
REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker,
|
|
|
|
|
ops::CastOpInferShape, ops::CastOpProtoMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
|
|
|
|
|
ops::CastOpKernel<CPU, double>,
|
|
|
|
|
ops::CastOpKernel<CPU, int>,
|
|
|
|
|