|
|
|
@ -30,25 +30,31 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
constexpr int tiled_size = 16;
|
|
|
|
|
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
|
|
|
|
|
threads = dim3(tiled_size, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
|
|
|
|
|
detail::KeFastCollectiveGruGate<T,
|
|
|
|
|
tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.gate_value, value.prev_out_value, value.gate_weight,
|
|
|
|
|
value.reset_output_value, frame_size, active_gate);
|
|
|
|
|
|
|
|
|
|
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
detail::KeFastCollectiveGruOut<T,
|
|
|
|
|
tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.state_weight, value.prev_out_value, value.output_value,
|
|
|
|
|
value.gate_value, value.reset_output_value, frame_size, active_node,
|
|
|
|
|
origin_mode);
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
if (context.GetComputeCapability() >= 70) {
|
|
|
|
|
constexpr int tiled_size = 16;
|
|
|
|
|
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
|
|
|
|
|
threads = dim3(tiled_size, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
detail::KeFastCollectiveGruGate<
|
|
|
|
|
T, tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.gate_value, value.prev_out_value, value.gate_weight,
|
|
|
|
|
value.reset_output_value, frame_size, active_gate);
|
|
|
|
|
|
|
|
|
|
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
detail::KeFastCollectiveGruOut<
|
|
|
|
|
T, tiled_size><<<grid, threads, 0, stream>>>(
|
|
|
|
|
value.state_weight, value.prev_out_value, value.output_value,
|
|
|
|
|
value.gate_value, value.reset_output_value, frame_size, active_node,
|
|
|
|
|
origin_mode);
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
|
|
|
|
|
int frame_blocks = (frame_size + 1024 - 1) / 1024;
|
|
|
|
|
threads = dim3(frame_per_block, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
threads = dim3(32, 32);
|
|
|
|
|
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
|
|
|
|
|