|
|
|
@ -92,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
const detail::ActivationType active_node,
|
|
|
|
|
const detail::ActivationType active_gate) {
|
|
|
|
|
const detail::ActivationType active_gate,
|
|
|
|
|
bool origin_mode) {
|
|
|
|
|
auto stream = context.stream();
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
@ -112,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node);
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
|
|
|
|
|
} else {
|
|
|
|
|
detail::KeGruBackwardStateGrad<
|
|
|
|
|
detail::backward::gru_stateGrad<T>,
|
|
|
|
|
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node);
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
|
|
|
|
|