|
|
|
@ -144,7 +144,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
|
|
|
|
|
gru_grad.prev_out_grad = hidden_prev_grad_t.data<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
gru_value.output_value = nullptr;
|
|
|
|
|
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
|
|
|
|
|
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
|
|
|
|
|
active_gate, origin_mode);
|
|
|
|
|