|
|
|
@ -32,9 +32,7 @@ namespace detail {
|
|
|
|
|
*/
|
|
|
|
|
template <class T, class Op, bool isBatch>
|
|
|
|
|
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
|
|
|
|
|
int batchSize, activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate,
|
|
|
|
|
activation_mode_t active_state) {
|
|
|
|
|
int batchSize) {
|
|
|
|
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (frameIdx >= frameSize) return;
|
|
|
|
|
|
|
|
|
@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
|
|
|
|
|
rPrevState = value.prevStateValue[frameIdx];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hppl::gpu::ForwardAct<T> act;
|
|
|
|
|
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
|
|
|
|
|
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
|
|
|
|
|
act(active_state));
|
|
|
|
|
rOut, rCheckI, rCheckF, rCheckO);
|
|
|
|
|
|
|
|
|
|
value.gateValue[frameIdx] = rValueIn;
|
|
|
|
|
value.gateValue[frameIdx + frameSize] = rValueIg;
|
|
|
|
@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
|
|
|
|
|
template <class T, class Op, bool isBatch>
|
|
|
|
|
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
|
|
|
|
|
LstmMetaGrad<T> grad, int frameSize,
|
|
|
|
|
int batchSize, activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate,
|
|
|
|
|
activation_mode_t active_state) {
|
|
|
|
|
int batchSize) {
|
|
|
|
|
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
|
if (frameIdx >= frameSize) return;
|
|
|
|
|
|
|
|
|
@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
|
|
|
|
|
rPrevState = value.prevStateValue[frameIdx];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hppl::gpu::BackwardAct<T> act;
|
|
|
|
|
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
|
|
|
|
|
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
|
|
|
|
|
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
|
|
|
|
|
act(active_node), act(active_gate), act(active_state));
|
|
|
|
|
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
|
|
|
|
|
|
|
|
|
|
grad.gateGrad[frameIdx] = rGradIn;
|
|
|
|
|
grad.gateGrad[frameIdx + frameSize] = rGradIg;
|
|
|
|
@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
KeLstmForward<T, Op,
|
|
|
|
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
op, value, frameSize, batchSize, active_node, active_gate,
|
|
|
|
|
active_state);
|
|
|
|
|
op, value, frameSize, batchSize);
|
|
|
|
|
} else {
|
|
|
|
|
KeLstmForward<T, Op,
|
|
|
|
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
op, value, frameSize, batchSize, active_node, active_gate,
|
|
|
|
|
active_state);
|
|
|
|
|
op, value, frameSize, batchSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
|
|
|
|
|
if (batchSize == 1) {
|
|
|
|
|
KeLstmBackward<T, Op,
|
|
|
|
|
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
|
|
|
|
|
op, value, grad, frameSize, batchSize, active_node, active_gate,
|
|
|
|
|
active_state);
|
|
|
|
|
op, value, grad, frameSize, batchSize);
|
|
|
|
|
} else {
|
|
|
|
|
KeLstmBackward<T, Op,
|
|
|
|
|
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
|
|
|
|
|
op, value, grad, frameSize, batchSize, active_node, active_gate,
|
|
|
|
|
active_state);
|
|
|
|
|
op, value, grad, frameSize, batchSize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|