fix rowwise_add_grad_op

revert-3824-remove_grad_op_type
qiaolongfei 8 years ago
parent cef27dab47
commit 82b820e97b

@ -63,7 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1")
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}

Loading…
Cancel
Save