|
|
@ -230,6 +230,26 @@ class ExpandV2GradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
|
|
|
class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
void Apply(GradOpPtr<T> op) const override {
|
|
|
|
|
|
|
|
op->SetType("expand_v2");
|
|
|
|
|
|
|
|
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
|
|
|
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
|
|
|
if (this->HasInput("expand_shapes_tensor")) {
|
|
|
|
|
|
|
|
op->SetInput("expand_shapes_tensor", this->Input("expand_shapes_tensor"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (this->HasInput("Shape")) {
|
|
|
|
|
|
|
|
op->SetInput("Shape", this->Input("Shape"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandV2GradNoNeedBufVarsInferer, "X");
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
@ -240,6 +260,8 @@ REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker,
|
|
|
|
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
ops::ExpandV2GradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
ops::ExpandV2GradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
|
|
|
|
REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp,
|
|
|
|
|
|
|
|
ops::ExpandV2DoubleGradOpMaker<paddle::framework::OpDesc>,
|
|
|
|
|
|
|
|
ops::ExpandV2DoubleGradOpMaker<paddle::imperative::OpBase>,
|
|
|
|
ops::ExpandV2GradNoNeedBufVarsInferer);
|
|
|
|
ops::ExpandV2GradNoNeedBufVarsInferer);
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
expand_v2, ops::ExpandV2Kernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|