|
|
|
@ -31,19 +31,25 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
dim3 grid;
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
if (context.GetComputeCapability() >= 70) {
|
|
|
|
|
constexpr int tiled_size = 16;
|
|
|
|
|
auto ComputeTiledSize = [](int frame_size) {
|
|
|
|
|
if (frame_size >= 16)
|
|
|
|
|
return 16;
|
|
|
|
|
else if (frame_size < 16)
|
|
|
|
|
return 8;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto tiled_size = ComputeTiledSize(frame_size);
|
|
|
|
|
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
|
|
|
|
|
threads = dim3(tiled_size, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
detail::KeFastCollectiveGruGate<
|
|
|
|
|
T, tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
|
|
|
|
|
detail::KeFastCollectiveGruGate<T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.gate_value, value.prev_out_value, value.gate_weight,
|
|
|
|
|
value.reset_output_value, frame_size, active_gate);
|
|
|
|
|
|
|
|
|
|
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
detail::KeFastCollectiveGruOut<
|
|
|
|
|
T, tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::KeFastCollectiveGruOut<T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.state_weight, value.prev_out_value, value.output_value,
|
|
|
|
|
value.gate_value, value.reset_output_value, frame_size, active_node,
|
|
|
|
|
origin_mode);
|
|
|
|
|