|
|
|
@ -228,6 +228,26 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ExpandDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void Apply(GradOpPtr<T> op) const override {
|
|
|
|
|
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
if (this->HasInput("expand_times_tensor")) {
|
|
|
|
|
op->SetInput("expand_times_tensor", this->Input("expand_times_tensor"));
|
|
|
|
|
}
|
|
|
|
|
if (this->HasInput("ExpandTimes")) {
|
|
|
|
|
op->SetInput("ExpandTimes", this->Input("ExpandTimes"));
|
|
|
|
|
}
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|
op->SetType("expand");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandGradNoNeedBufVarsInferer, "X");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
@ -238,6 +258,8 @@ REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
|
|
|
|
|
ops::ExpandGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ExpandGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp,
|
|
|
|
|
ops::ExpandDoubleGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ExpandDoubleGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::ExpandGradNoNeedBufVarsInferer);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|