|
|
@ -154,6 +154,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (h0_grad) {
|
|
|
|
if (h0_grad) {
|
|
|
|
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
|
|
|
|
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
|
|
|
|
|
|
|
|
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|
bool is_reverse = context.Attr<bool>("is_reverse");
|
|
|
|