|
|
|
@ -405,6 +405,11 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int cur_batch_size = bend - bstart;
|
|
|
|
|
// lstm_value.output_value not used in bp, set to null
|
|
|
|
|
// lstm_grad.state_active_grad not used in bp, set to null
|
|
|
|
|
lstm_value.output_value = nullptr;
|
|
|
|
|
lstm_grad.state_active_grad = nullptr;
|
|
|
|
|
|
|
|
|
|
math::LstmUnitGradFunctor<DeviceContext, T>::compute(
|
|
|
|
|
device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size,
|
|
|
|
|
gate_act, cell_act, cand_act);
|
|
|
|
|