Add conditional compile for gru opt (#17368)

* improve gru unit performance.
refine code

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>

* Add conditional compile for gru opt

Not enable gru opt if compute ability < 700

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>

* refine code.

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>
resnext-opt
zhaoyuchen2018 6 years ago committed by GitHub
parent 6a53fa95e7
commit b02f2aff04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,25 +30,31 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<T,
tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node,
origin_mode);
return;
if (context.GetComputeCapability() >= 70) {
constexpr int tiled_size = 16;
int frame_blocks = (frame_size * 2 + tiled_size - 1) / tiled_size;
threads = dim3(tiled_size, 1);
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruGate<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.gate_value, value.prev_out_value, value.gate_weight,
value.reset_output_value, frame_size, active_gate);
frame_blocks = (frame_size + tiled_size - 1) / tiled_size;
grid = dim3(frame_blocks, 1);
detail::KeFastCollectiveGruOut<
T, tiled_size><<<grid, threads, 0, stream>>>(
value.state_weight, value.prev_out_value, value.output_value,
value.gate_value, value.reset_output_value, frame_size, active_node,
origin_mode);
return;
} else {
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
int frame_blocks = (frame_size + 1024 - 1) / 1024;
threads = dim3(frame_per_block, 1);
grid = dim3(frame_blocks, 1);
}
} else {
threads = dim3(32, 32);
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);

Loading…
Cancel
Save