|
|
|
@ -21,66 +21,66 @@ namespace math {
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct GRUUnitFunctor<platform::GPUPlace, T> {
|
|
|
|
|
static void compute(const platform::DeviceContext &context,
|
|
|
|
|
hl_gru_value<T> value, int frameSize, int batchSize,
|
|
|
|
|
hl_gru_value<T> value, int frame_size, int batch_size,
|
|
|
|
|
activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate) {
|
|
|
|
|
auto stream =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
|
|
|
|
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
|
|
|
|
threads = dim3(framePerBlock, 1);
|
|
|
|
|
grid = dim3(frameBlocks, 1);
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
|
|
|
|
|
int frame_blocks = (frame_size + 1024 - 1) / 1024;
|
|
|
|
|
threads = dim3(frame_per_block, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
} else {
|
|
|
|
|
threads = dim3(32, 32);
|
|
|
|
|
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
|
|
|
|
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (value.prevOutValue) {
|
|
|
|
|
if (value.prev_out_value) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, false, false, batchSize, frameSize * 2, frameSize, 1,
|
|
|
|
|
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
|
|
|
|
|
value.gateValue, frameSize * 3);
|
|
|
|
|
context, false, false, batch_size, frame_size * 2, frame_size, 1,
|
|
|
|
|
value.prev_out_value, frame_size, value.gate_weight, frame_size * 2,
|
|
|
|
|
1, value.gate_value, frame_size * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
|
|
|
|
|
/* isBatch= */ false,
|
|
|
|
|
/* is_batch= */ false,
|
|
|
|
|
T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::forward::gru_resetOutput<T>(), value.gateValue,
|
|
|
|
|
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
|
|
|
|
|
active_gate);
|
|
|
|
|
detail::forward::gru_resetOutput<T>(), value.gate_value,
|
|
|
|
|
value.reset_output_value, value.prev_out_value, frame_size,
|
|
|
|
|
batch_size, active_gate);
|
|
|
|
|
} else {
|
|
|
|
|
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
|
|
|
|
|
/* isBatch= */ true,
|
|
|
|
|
/* is_batch= */ true,
|
|
|
|
|
T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::forward::gru_resetOutput<T>(), value.gateValue,
|
|
|
|
|
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
|
|
|
|
|
active_gate);
|
|
|
|
|
detail::forward::gru_resetOutput<T>(), value.gate_value,
|
|
|
|
|
value.reset_output_value, value.prev_out_value, frame_size,
|
|
|
|
|
batch_size, active_gate);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (value.prevOutValue) {
|
|
|
|
|
if (value.prev_out_value) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, false, false, batchSize, frameSize, frameSize, 1,
|
|
|
|
|
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
|
|
|
|
|
value.gateValue + frameSize * 2, frameSize * 3);
|
|
|
|
|
context, false, false, batch_size, frame_size, frame_size, 1,
|
|
|
|
|
value.reset_output_value, frame_size, value.state_weight, frame_size,
|
|
|
|
|
1, value.gate_value + frame_size * 2, frame_size * 3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
|
|
|
|
|
/* isBatch= */ false,
|
|
|
|
|
/* is_batch= */ false,
|
|
|
|
|
T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::forward::gru_finalOutput<T>(), value.gateValue,
|
|
|
|
|
value.prevOutValue, value.outputValue, frameSize, batchSize,
|
|
|
|
|
detail::forward::gru_finalOutput<T>(), value.gate_value,
|
|
|
|
|
value.prev_out_value, value.output_value, frame_size, batch_size,
|
|
|
|
|
active_node);
|
|
|
|
|
} else {
|
|
|
|
|
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
|
|
|
|
|
/* isBatch= */ true,
|
|
|
|
|
/* is_batch= */ true,
|
|
|
|
|
T><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::forward::gru_finalOutput<T>(), value.gateValue,
|
|
|
|
|
value.prevOutValue, value.outputValue, frameSize, batchSize,
|
|
|
|
|
detail::forward::gru_finalOutput<T>(), value.gate_value,
|
|
|
|
|
value.prev_out_value, value.output_value, frame_size, batch_size,
|
|
|
|
|
active_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct GRUUnitGradFunctor<platform::GPUPlace, T> {
|
|
|
|
|
static void compute(const platform::DeviceContext &context,
|
|
|
|
|
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
|
|
|
|
|
int batchSize, activation_mode_t active_node,
|
|
|
|
|
hl_gru_value<T> value, hl_gru_grad<T> grad,
|
|
|
|
|
int frame_size, int batch_size,
|
|
|
|
|
activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate) {
|
|
|
|
|
auto stream =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
|
|
|
|
|
dim3 threads;
|
|
|
|
|
dim3 grid;
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
|
|
|
|
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
|
|
|
|
threads = dim3(framePerBlock, 1);
|
|
|
|
|
grid = dim3(frameBlocks, 1);
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
|
|
|
|
|
int frame_blocks = (frame_size + 1024 - 1) / 1024;
|
|
|
|
|
threads = dim3(frame_per_block, 1);
|
|
|
|
|
grid = dim3(frame_blocks, 1);
|
|
|
|
|
} else {
|
|
|
|
|
threads = dim3(32, 32);
|
|
|
|
|
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
|
|
|
|
grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
detail::KeGruBackwardStateGrad<
|
|
|
|
|
detail::backward::gru_stateGrad<T>,
|
|
|
|
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
|
|
|
|
|
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
|
|
|
|
|
batchSize, active_node);
|
|
|
|
|
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node);
|
|
|
|
|
} else {
|
|
|
|
|
detail::KeGruBackwardStateGrad<
|
|
|
|
|
detail::backward::gru_stateGrad<T>,
|
|
|
|
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
|
|
|
|
|
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
|
|
|
|
|
batchSize, active_node);
|
|
|
|
|
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_stateGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.output_grad, frame_size, batch_size, active_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (value.prevOutValue && grad.prevOutGrad) {
|
|
|
|
|
if (value.prev_out_value && grad.prev_out_grad) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, false, true, batchSize, frameSize, frameSize, 1,
|
|
|
|
|
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
|
|
|
|
|
frameSize, 0, grad.resetOutputGrad, frameSize);
|
|
|
|
|
context, false, true, batch_size, frame_size, frame_size, 1,
|
|
|
|
|
grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight,
|
|
|
|
|
frame_size, 0, grad.reset_output_grad, frame_size);
|
|
|
|
|
|
|
|
|
|
if (grad.stateWeightGrad) {
|
|
|
|
|
if (grad.state_weight_grad) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, true, false, frameSize, frameSize, batchSize, 1,
|
|
|
|
|
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
|
|
|
|
|
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
|
|
|
|
|
context, true, false, frame_size, frame_size, batch_size, 1,
|
|
|
|
|
value.reset_output_value, frame_size,
|
|
|
|
|
grad.gate_grad + frame_size * 2, frame_size * 3, 1,
|
|
|
|
|
grad.state_weight_grad, frame_size);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
if (batch_size == 1) {
|
|
|
|
|
detail::KeGruBackwardResetGrad<
|
|
|
|
|
detail::backward::gru_resetGrad<T>,
|
|
|
|
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
|
|
|
|
|
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
|
|
|
|
|
batchSize, active_gate);
|
|
|
|
|
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_resetGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.reset_output_grad, frame_size, batch_size, active_gate);
|
|
|
|
|
} else {
|
|
|
|
|
detail::KeGruBackwardResetGrad<
|
|
|
|
|
detail::backward::gru_resetGrad<T>,
|
|
|
|
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
|
|
|
|
|
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
|
|
|
|
|
batchSize, active_gate);
|
|
|
|
|
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
detail::backward::gru_resetGrad<T>(), value.gate_value,
|
|
|
|
|
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
|
|
|
|
|
grad.reset_output_grad, frame_size, batch_size, active_gate);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (grad.prevOutGrad && value.prevOutValue) {
|
|
|
|
|
if (grad.prev_out_grad && value.prev_out_value) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, false, true, batchSize, frameSize, frameSize * 2, 1,
|
|
|
|
|
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
|
|
|
|
|
grad.prevOutGrad, frameSize);
|
|
|
|
|
context, false, true, batch_size, frame_size, frame_size * 2, 1,
|
|
|
|
|
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1,
|
|
|
|
|
grad.prev_out_grad, frame_size);
|
|
|
|
|
|
|
|
|
|
if (grad.gateWeightGrad) {
|
|
|
|
|
if (grad.gate_weight_grad) {
|
|
|
|
|
math::gemm<platform::GPUPlace, T>(
|
|
|
|
|
context, true, false, frameSize, frameSize * 2, batchSize, 1,
|
|
|
|
|
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
|
|
|
|
|
grad.gateWeightGrad, frameSize * 2);
|
|
|
|
|
context, true, false, frame_size, frame_size * 2, batch_size, 1,
|
|
|
|
|
value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1,
|
|
|
|
|
grad.gate_weight_grad, frame_size * 2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|