"backward check todo"

revert-3824-remove_grad_op_type
dongzhihong 8 years ago
parent 789d6ed9b7
commit b7ee1e7d9c

@ -42,18 +42,18 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0);
auto bGrad = context.Output<Tensor>(1);
auto* XGrad = context.Output<Tensor>(0);
auto* bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) =
OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) =
OutGrad.cumsum(1); // colwise add
}
};

@ -15,5 +15,7 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
#TODO(dzh): rowwise_grad check
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save