"add rowwise add backward op"

revert-3824-remove_grad_op_type
dongzhihong 8 years ago
parent 7e60706b51
commit 264b644718

@ -46,6 +46,17 @@ for i in xrange(X.shape[0]):
)DOC"); )DOC");
} }
}; };
class RowWiseAddGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
"RowWiseAddGrad inputs is I, O, OG, size must be 4");
PADDLE_ENFORCE(ctx.OutputSize() == 2,
"RowWiseAddGrad output is IG, size must be 2");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
@ -53,3 +64,7 @@ for i in xrange(X.shape[0]):
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add, REGISTER_OP_CPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::CPUPlace, float>); ops::RowWiseAddKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp);
REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
ops::RowWiseAddGradKernel<ops::CPUPlace, float>);

@ -38,5 +38,24 @@ public:
} }
}; };
template <typename Place, typename T>
class RowWiseAddGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0);
auto bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad;
// const int dimension = bGrad.dimension(0);
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
OutGrad.cumsum(1); // colwise add
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

Loading…
Cancel
Save