|
|
|
@ -43,7 +43,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
|
|
|
|
|
frame_size, batch_size, active_gate, true,
|
|
|
|
|
&context);
|
|
|
|
|
nullptr);
|
|
|
|
|
|
|
|
|
|
if (value.prev_out_value) {
|
|
|
|
|
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
|
|
|
|
@ -54,7 +54,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
|
|
|
|
|
frame_size, batch_size, active_node,
|
|
|
|
|
origin_mode, &context);
|
|
|
|
|
origin_mode, true, nullptr);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|