|
|
|
@ -42,18 +42,18 @@ template <typename Place, typename T>
|
|
|
|
|
class RowwiseAddGradKernel : public OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const ExecutionContext& context) const override {
|
|
|
|
|
auto XGrad = context.Output<Tensor>(0);
|
|
|
|
|
auto bGrad = context.Output<Tensor>(1);
|
|
|
|
|
auto* XGrad = context.Output<Tensor>(0);
|
|
|
|
|
auto* bGrad = context.Output<Tensor>(1);
|
|
|
|
|
XGrad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
bGrad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
// I, O, OG => [X, b], [Out], [OutGrad]
|
|
|
|
|
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
|
|
|
|
|
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
|
|
|
|
|
EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) =
|
|
|
|
|
OutGrad;
|
|
|
|
|
|
|
|
|
|
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
|
|
|
|
|
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
|
|
|
|
|
EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) =
|
|
|
|
|
OutGrad.cumsum(1); // colwise add
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|