|
|
|
@ -51,6 +51,20 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ReduceSumDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void Apply(GradOpPtr<T> op) const override {
|
|
|
|
|
op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|
op->SetType("reduce_sum");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceSumGradNoNeedBufferVarInferer, "X");
|
|
|
|
|
class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
@ -77,6 +91,8 @@ REGISTER_OPERATOR(reduce_sum, ops::ReduceOp, ReduceSumOpMaker,
|
|
|
|
|
ops::ReduceSumOpGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ReduceSumOpGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
|
|
|
|
|
ops::ReduceSumDoubleOpGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::ReduceSumDoubleOpGradMaker<paddle::imperative::OpBase>,
|
|
|
|
|
ops::ReduceSumGradNoNeedBufferVarInferer);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|