|
|
|
@ -55,16 +55,13 @@ void hl_max_sequence_forward(real* input,
|
|
|
|
|
|
|
|
|
|
dim3 threads(256, 1);
|
|
|
|
|
dim3 grid(numSequences, 1);
|
|
|
|
|
KeMaxSequenceForward<<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(input, sequence, output, index, numSequences, dim);
|
|
|
|
|
KeMaxSequenceForward<<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
input, sequence, output, index, numSequences, dim);
|
|
|
|
|
CHECK_SYNC("hl_max_sequence_forward failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__global__ void KeMaxSequenceBackward(real *outputGrad,
|
|
|
|
|
int *index,
|
|
|
|
|
real* inputGrad,
|
|
|
|
|
int numSequences,
|
|
|
|
|
int dim) {
|
|
|
|
|
__global__ void KeMaxSequenceBackward(
|
|
|
|
|
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
|
|
|
|
|
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
|
|
|
int colIdx = idx % dim;
|
|
|
|
|
if (idx < numSequences * dim) {
|
|
|
|
@ -73,11 +70,8 @@ __global__ void KeMaxSequenceBackward(real *outputGrad,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void hl_max_sequence_backward(real* outputGrad,
|
|
|
|
|
int *index,
|
|
|
|
|
real* inputGrad,
|
|
|
|
|
int numSequences,
|
|
|
|
|
int dim) {
|
|
|
|
|
void hl_max_sequence_backward(
|
|
|
|
|
real* outputGrad, int* index, real* inputGrad, int numSequences, int dim) {
|
|
|
|
|
CHECK_NOTNULL(outputGrad);
|
|
|
|
|
CHECK_NOTNULL(index);
|
|
|
|
|
CHECK_NOTNULL(inputGrad);
|
|
|
|
@ -85,8 +79,8 @@ void hl_max_sequence_backward(real* outputGrad,
|
|
|
|
|
unsigned int blocks = (numSequences * dim + 128 - 1) / 128;
|
|
|
|
|
dim3 threads(128, 1);
|
|
|
|
|
dim3 grid(blocks, 1);
|
|
|
|
|
KeMaxSequenceBackward<<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(outputGrad, index, inputGrad, numSequences, dim);
|
|
|
|
|
KeMaxSequenceBackward<<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
outputGrad, index, inputGrad, numSequences, dim);
|
|
|
|
|
CHECK_SYNC("hl_max_sequence_backward failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -118,9 +112,12 @@ __global__ void KeMatrixAddRows(real* output,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int blockDimX, int blockDimY, int gridDimX, bool seq2batch, bool isAdd>
|
|
|
|
|
__global__
|
|
|
|
|
void KeSequence2Batch(real *batch,
|
|
|
|
|
template <int blockDimX,
|
|
|
|
|
int blockDimY,
|
|
|
|
|
int gridDimX,
|
|
|
|
|
bool seq2batch,
|
|
|
|
|
bool isAdd>
|
|
|
|
|
__global__ void KeSequence2Batch(real* batch,
|
|
|
|
|
real* sequence,
|
|
|
|
|
const int* batchIndex,
|
|
|
|
|
int seqWidth,
|
|
|
|
@ -164,11 +161,11 @@ void hl_sequence2batch_copy(real *batch,
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grid(8, 1);
|
|
|
|
|
if (seq2batch) {
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 1, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
} else {
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 0, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
}
|
|
|
|
|
CHECK_SYNC("hl_sequence2batch_copy failed");
|
|
|
|
|
}
|
|
|
|
@ -186,18 +183,17 @@ void hl_sequence2batch_add(real *batch,
|
|
|
|
|
dim3 threads(128, 8);
|
|
|
|
|
dim3 grid(8, 1);
|
|
|
|
|
if (seq2batch) {
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 1, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
} else {
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
KeSequence2Batch<128, 8, 8, 0, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, batchIndex, seqWidth, batchCount);
|
|
|
|
|
}
|
|
|
|
|
CHECK_SYNC("hl_sequence2batch_add failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <bool normByTimes, bool seq2batch>
|
|
|
|
|
__global__
|
|
|
|
|
void KeSequence2BatchPadding(real* batch,
|
|
|
|
|
__global__ void KeSequence2BatchPadding(real* batch,
|
|
|
|
|
real* sequence,
|
|
|
|
|
const int* sequenceStartPositions,
|
|
|
|
|
const size_t sequenceWidth,
|
|
|
|
@ -277,36 +273,48 @@ void hl_sequence2batch_copy_padding(real* batch,
|
|
|
|
|
/* sequence -> batch */
|
|
|
|
|
if (normByTimes) {
|
|
|
|
|
KeSequence2BatchPadding<1, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, sequenceStartPositions,
|
|
|
|
|
sequenceWidth, maxSequenceLength, numSequences);
|
|
|
|
|
batch,
|
|
|
|
|
sequence,
|
|
|
|
|
sequenceStartPositions,
|
|
|
|
|
sequenceWidth,
|
|
|
|
|
maxSequenceLength,
|
|
|
|
|
numSequences);
|
|
|
|
|
} else {
|
|
|
|
|
KeSequence2BatchPadding<0, 1><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, sequenceStartPositions,
|
|
|
|
|
sequenceWidth, maxSequenceLength, numSequences);
|
|
|
|
|
batch,
|
|
|
|
|
sequence,
|
|
|
|
|
sequenceStartPositions,
|
|
|
|
|
sequenceWidth,
|
|
|
|
|
maxSequenceLength,
|
|
|
|
|
numSequences);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
/* batch -> sequence */
|
|
|
|
|
if (normByTimes) {
|
|
|
|
|
KeSequence2BatchPadding<1, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, sequenceStartPositions,
|
|
|
|
|
sequenceWidth, maxSequenceLength, numSequences);
|
|
|
|
|
batch,
|
|
|
|
|
sequence,
|
|
|
|
|
sequenceStartPositions,
|
|
|
|
|
sequenceWidth,
|
|
|
|
|
maxSequenceLength,
|
|
|
|
|
numSequences);
|
|
|
|
|
} else {
|
|
|
|
|
KeSequence2BatchPadding<0, 0><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
batch, sequence, sequenceStartPositions,
|
|
|
|
|
sequenceWidth, maxSequenceLength, numSequences);
|
|
|
|
|
batch,
|
|
|
|
|
sequence,
|
|
|
|
|
sequenceStartPositions,
|
|
|
|
|
sequenceWidth,
|
|
|
|
|
maxSequenceLength,
|
|
|
|
|
numSequences);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CHECK_SYNC("hl_sequence2batch_copy_padding failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__device__ inline float my_rsqrt(float x) {
|
|
|
|
|
return rsqrtf(x);
|
|
|
|
|
}
|
|
|
|
|
__device__ inline float my_rsqrt(float x) { return rsqrtf(x); }
|
|
|
|
|
|
|
|
|
|
__device__ inline double my_rsqrt(double x) {
|
|
|
|
|
return rsqrt(x);
|
|
|
|
|
}
|
|
|
|
|
__device__ inline double my_rsqrt(double x) { return rsqrt(x); }
|
|
|
|
|
|
|
|
|
|
__global__ void KeSequenceAvgForward(real* dst,
|
|
|
|
|
real* src,
|
|
|
|
@ -327,8 +335,8 @@ __global__ void KeSequenceAvgForward(real* dst,
|
|
|
|
|
for (int i = start; i < end; i++) {
|
|
|
|
|
sum += src[i * width + col];
|
|
|
|
|
}
|
|
|
|
|
sum = mode == 1 ? sum :
|
|
|
|
|
(mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
|
|
|
|
|
sum = mode == 1 ? sum : (mode == 0 ? sum / seqLength
|
|
|
|
|
: sum * my_rsqrt((real)seqLength));
|
|
|
|
|
dst[gid] += sum;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -349,8 +357,8 @@ void hl_sequence_avg_forward(real* dst,
|
|
|
|
|
CHECK(mode == 0 || mode == 1 || mode == 2)
|
|
|
|
|
<< "mode error in hl_sequence_avg_forward!";
|
|
|
|
|
|
|
|
|
|
KeSequenceAvgForward<<< grid, block, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(dst, src, starts, height, width, mode);
|
|
|
|
|
KeSequenceAvgForward<<<grid, block, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
dst, src, starts, height, width, mode);
|
|
|
|
|
CHECK_SYNC("hl_sequence_avg_forward failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -370,8 +378,8 @@ __global__ void KeSequenceAvgBackward(real* dst,
|
|
|
|
|
int seqLength = end - start;
|
|
|
|
|
if (seqLength == 0) return;
|
|
|
|
|
real grad = src[gid];
|
|
|
|
|
grad = mode == 1 ? grad :
|
|
|
|
|
(mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength));
|
|
|
|
|
grad = mode == 1 ? grad : (mode == 0 ? grad / seqLength
|
|
|
|
|
: grad * my_rsqrt((real)seqLength));
|
|
|
|
|
for (int i = start; i < end; i++) {
|
|
|
|
|
dst[i * width + col] += grad;
|
|
|
|
|
}
|
|
|
|
@ -394,7 +402,7 @@ void hl_sequence_avg_backward(real* dst,
|
|
|
|
|
CHECK(mode == 0 || mode == 1 || mode == 2)
|
|
|
|
|
<< "mode error in hl_sequence_avg_backward!";
|
|
|
|
|
|
|
|
|
|
KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>>
|
|
|
|
|
(dst, src, starts, height, width, mode);
|
|
|
|
|
KeSequenceAvgBackward<<<grid, block, 0, STREAM_DEFAULT>>>(
|
|
|
|
|
dst, src, starts, height, width, mode);
|
|
|
|
|
CHECK_SYNC("hl_sequence_avg_backward failed");
|
|
|
|
|
}
|
|
|
|
|