|
|
|
@ -163,8 +163,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
@ -195,13 +196,16 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
|
|
|
|
|
ops::ExpandGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp);
|
|
|
|
|
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp,
|
|
|
|
|
ops::ExpandGradNoNeedBufVarsInferer);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|