parent
9106a4bba1
commit
3cace73701
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_ACTIVATION_FUNCTIONS_H_
|
||||
#define HL_ACTIVATION_FUNCTIONS_H_
|
||||
|
||||
#include "hl_functions.h"
|
||||
|
||||
/**
|
||||
* Active functions: sigmoid, relu, tanh and linear.
|
||||
*/
|
||||
#define HPPL_ACTIVE_FUNCTION \
|
||||
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
|
||||
|
||||
namespace hppl {
|
||||
|
||||
/**
|
||||
* Hppl supports sigmoid, relu, tanh, linear active functions
|
||||
* for neural networks' forward and backward activation.
|
||||
*/
|
||||
template <class T>
|
||||
class Active {
|
||||
public:
|
||||
typedef T (*forward)(T);
|
||||
typedef T (*backward)(T, T);
|
||||
};
|
||||
|
||||
#ifdef __NVCC__
|
||||
namespace gpu {
|
||||
static __device__ Active<float>::forward forward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static __device__ Active<float>::backward backward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static __device__ Active<double>::forward forward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static __device__ Active<double>::backward backward[] = HPPL_ACTIVE_FUNCTION;
|
||||
} // namespace gpu
|
||||
#else
|
||||
namespace cpu {
|
||||
static Active<float>::forward forward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static Active<float>::backward backward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static Active<double>::forward forward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static Active<double>::backward backward[] = HPPL_ACTIVE_FUNCTION;
|
||||
} // namespace cpu
|
||||
|
||||
#ifdef __AVX__
|
||||
namespace avx {
|
||||
static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION;
|
||||
static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION;
|
||||
} // namespace avx
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_ACTIVATION_FUNCTIONS_H_
|
@ -0,0 +1,68 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <immintrin.h>
|
||||
#include "hl_functions.h"
|
||||
|
||||
namespace hppl {
|
||||
|
||||
extern __m256 exp(__m256 a);
|
||||
|
||||
__m256 relu(const __m256 a) {
|
||||
__m256 tmp = _mm256_set1_ps(0.0f);
|
||||
return _mm256_max_ps(a, tmp);
|
||||
}
|
||||
|
||||
__m256 sigmoid(const __m256 a) {
|
||||
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
|
||||
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
|
||||
__m256 tmp = _mm256_max_ps(a, min);
|
||||
tmp = _mm256_min_ps(tmp, max);
|
||||
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
|
||||
tmp = exp(tmp);
|
||||
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
|
||||
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
__m256 tanh(const __m256 a) {
|
||||
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
|
||||
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
|
||||
tmp = _mm256_min_ps(tmp, max);
|
||||
tmp = exp(tmp);
|
||||
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
|
||||
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
|
||||
_mm256_set1_ps(1.0f));
|
||||
}
|
||||
|
||||
__m256 linear(const __m256 a) { return a; }
|
||||
|
||||
__m256 relu(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(
|
||||
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
|
||||
_mm256_set1_ps(1.0f)));
|
||||
}
|
||||
|
||||
__m256 sigmoid(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(_mm256_mul_ps(a, b),
|
||||
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
|
||||
}
|
||||
|
||||
__m256 tanh(const __m256 a, const __m256 b) {
|
||||
return _mm256_mul_ps(
|
||||
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
|
||||
}
|
||||
|
||||
__m256 linear(const __m256 a, const __m256 b) { return a; }
|
||||
} // namespace hppl
|
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_AVX_FUNCTIONS_H_
|
||||
#define HL_AVX_FUNCTIONS_H_
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
namespace hppl {
|
||||
__m256 relu(const __m256 a);
|
||||
__m256 sigmoid(const __m256 a);
|
||||
__m256 tanh(const __m256 a);
|
||||
__m256 linear(const __m256 a);
|
||||
|
||||
__m256 relu(const __m256 a, const __m256 b);
|
||||
__m256 sigmoid(const __m256 a, const __m256 b);
|
||||
__m256 tanh(const __m256 a, const __m256 b);
|
||||
__m256 linear(const __m256 a, const __m256 b);
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_AVX_FUNCTIONS_H_
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <math.h>
|
||||
#include "/paddle/operators/math/detail/hl_functions.h"
|
||||
|
||||
namespace hppl {
|
||||
|
||||
real relu(const real a) { return a > 0.0f ? a : 0.0f; }
|
||||
|
||||
real sigmoid(const real a) {
|
||||
const real min = SIGMOID_THRESHOLD_MIN;
|
||||
const real max = SIGMOID_THRESHOLD_MAX;
|
||||
real tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return 1.0 / (1.0 + exp(-tmp));
|
||||
}
|
||||
|
||||
real tanh(const real a) {
|
||||
real tmp = -2.0 * a;
|
||||
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
|
||||
return (2.0 / (1.0 + exp(tmp))) - 1.0;
|
||||
}
|
||||
|
||||
real linear(const real a) { return a; }
|
||||
|
||||
real relu(const real a, const real b) { return a * (b > 0.0f ? 1.0f : 0.0f); }
|
||||
|
||||
real sigmoid(const real a, const real b) { return a * b * (1 - b); }
|
||||
|
||||
real tanh(const real a, const real b) { return a * (1.0f - b * b); }
|
||||
|
||||
real linear(const real a, const real b) { return a; }
|
||||
} // namespace hppl
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_FUNCTIONS_H_
|
||||
#define HL_FUNCTIONS_H_
|
||||
|
||||
/**
|
||||
* sigmoid threshold maximum
|
||||
*/
|
||||
#define SIGMOID_THRESHOLD_MIN -40.0
|
||||
|
||||
/**
|
||||
* sigmoid threshold minimum
|
||||
*/
|
||||
#define SIGMOID_THRESHOLD_MAX 13.0
|
||||
|
||||
#ifndef __NVCC__
|
||||
namespace hppl {
|
||||
/*
|
||||
* forward activation
|
||||
*/
|
||||
template <typename T>
|
||||
T relu(const T a);
|
||||
template <typename T>
|
||||
T sigmoid(const T a);
|
||||
template <typename T>
|
||||
T tanh(const T a);
|
||||
template <typename T>
|
||||
T linear(const T a);
|
||||
|
||||
/*
|
||||
* backward activation
|
||||
*/
|
||||
template <typename T>
|
||||
T relu(const T a, const T b);
|
||||
template <typename T>
|
||||
T sigmoid(const T a, const T b);
|
||||
template <typename T>
|
||||
T tanh(const T a, const T b);
|
||||
template <typename T>
|
||||
T linear(const T a, const T b);
|
||||
} // namespace hppl
|
||||
|
||||
#ifdef __AVX__
|
||||
#include "hl_avx_functions.h"
|
||||
#endif
|
||||
|
||||
#else
|
||||
#include "hl_gpu_functions.h"
|
||||
#endif
|
||||
|
||||
#endif // HL_FUNCTIONS_H_
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifndef HL_GPU_FUNCTIONS_CUH_
|
||||
#define HL_GPU_FUNCTIONS_CUH_
|
||||
|
||||
#include "hl_base.h"
|
||||
|
||||
namespace hppl {
|
||||
|
||||
template <typename T>
|
||||
__device__ static T relu(const T a) {
|
||||
return a > 0.0f ? a : 0.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static float sigmoid(const float a) {
|
||||
const float min = SIGMOID_THRESHOLD_MIN;
|
||||
const float max = SIGMOID_THRESHOLD_MAX;
|
||||
float tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return __fdividef(1.0f, 1.0f + __expf(-tmp));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static double sigmoid(const double a) {
|
||||
const double min = SIGMOID_THRESHOLD_MIN;
|
||||
const double max = SIGMOID_THRESHOLD_MAX;
|
||||
double tmp = (a < min) ? min : ((a > max) ? max : a);
|
||||
return 1.0 / (1.0 + exp(-tmp));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static float tanh(const float a) {
|
||||
return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static double tanh(const double a) {
|
||||
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T linear(const T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T relu(const T a, const T b) {
|
||||
return a * (b > 0.0f ? 1.0f : 0.0f);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T sigmoid(const T a, const T b) {
|
||||
return a * b * (1 - b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T tanh(const T a, const T b) {
|
||||
return a * (1.0f - b * b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ static T linear(const T a, const T b) {
|
||||
return a;
|
||||
}
|
||||
|
||||
} // namespace hppl
|
||||
|
||||
#endif // HL_GPU_FUNCTIONS_CUH_
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,244 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/operators/math/detail/lstm_kernel.h"
|
||||
#include "paddle/operators/math/lstm_compute.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* threads(framePerBlock, batchPerBlock)
|
||||
* grid(frameBlocks, batchBlocks)
|
||||
*/
|
||||
template <class T, class Op, bool isBatch>
|
||||
__global__ void KeLstmForward(Op op, lstm_value value, int frameSize,
|
||||
int batchSize, activation_mode_t active_node,
|
||||
activation_mode_t active_gate,
|
||||
activation_mode_t active_state) {
|
||||
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (frameIdx >= frameSize) return;
|
||||
|
||||
int batchIdx = 0;
|
||||
if (isBatch) {
|
||||
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
if (batchIdx >= batchSize) return;
|
||||
value.gateValue += batchIdx * frameSize * 4;
|
||||
value.outputValue += batchIdx * frameSize;
|
||||
value.stateValue += batchIdx * frameSize;
|
||||
value.stateActiveValue += batchIdx * frameSize;
|
||||
}
|
||||
|
||||
T rState;
|
||||
T rPrevState = 0;
|
||||
T rStateAtv;
|
||||
T rOut;
|
||||
T rValueIn;
|
||||
T rValueIg;
|
||||
T rValueFg;
|
||||
T rValueOg;
|
||||
T rCheckI = value.checkIg[frameIdx];
|
||||
T rCheckF = value.checkFg[frameIdx];
|
||||
T rCheckO = value.checkOg[frameIdx];
|
||||
|
||||
rValueIn = value.gateValue[frameIdx];
|
||||
rValueIg = value.gateValue[frameIdx + frameSize];
|
||||
rValueFg = value.gateValue[frameIdx + frameSize * 2];
|
||||
rValueOg = value.gateValue[frameIdx + frameSize * 3];
|
||||
|
||||
if (value.prevStateValue) {
|
||||
if (isBatch) value.prevStateValue += batchIdx * frameSize;
|
||||
rPrevState = value.prevStateValue[frameIdx];
|
||||
}
|
||||
|
||||
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
|
||||
rOut, rCheckI, rCheckF, rCheckO, hppl::gpu::forward[active_node],
|
||||
hppl::gpu::forward[active_gate], hppl::gpu::forward[active_state]);
|
||||
|
||||
value.gateValue[frameIdx] = rValueIn;
|
||||
value.gateValue[frameIdx + frameSize] = rValueIg;
|
||||
value.gateValue[frameIdx + frameSize * 2] = rValueFg;
|
||||
value.gateValue[frameIdx + frameSize * 3] = rValueOg;
|
||||
|
||||
value.stateValue[frameIdx] = rState;
|
||||
value.stateActiveValue[frameIdx] = rStateAtv;
|
||||
value.outputValue[frameIdx] = rOut;
|
||||
}
|
||||
|
||||
/*
|
||||
* threads(framePerBlock, batchPerBlock)
|
||||
* grid(frameBlocks, batchBlocks)
|
||||
*/
|
||||
template <class T, class Op, bool isBatch>
|
||||
__global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad,
|
||||
int frameSize, int batchSize,
|
||||
activation_mode_t active_node,
|
||||
activation_mode_t active_gate,
|
||||
activation_mode_t active_state) {
|
||||
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (frameIdx >= frameSize) return;
|
||||
|
||||
int batchIdx = 0;
|
||||
if (isBatch) {
|
||||
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
if (batchIdx >= batchSize) return;
|
||||
value.gateValue += batchIdx * frameSize * 4;
|
||||
value.stateValue += batchIdx * frameSize;
|
||||
value.stateActiveValue += batchIdx * frameSize;
|
||||
grad.gateGrad += batchIdx * frameSize * 4;
|
||||
grad.stateGrad += batchIdx * frameSize;
|
||||
grad.outputGrad += batchIdx * frameSize;
|
||||
}
|
||||
|
||||
T rValueIn;
|
||||
T rValueIg;
|
||||
T rValueFg;
|
||||
T rValueOg;
|
||||
T rGradIn;
|
||||
T rGradIg;
|
||||
T rGradFg;
|
||||
T rGradOg;
|
||||
T rPrevState = 0;
|
||||
T rPrevStateGrad;
|
||||
T rState;
|
||||
T rStateGrad;
|
||||
T rStateAtv;
|
||||
T rOutputGrad;
|
||||
T rCheckI = value.checkIg[frameIdx];
|
||||
T rCheckF = value.checkFg[frameIdx];
|
||||
T rCheckO = value.checkOg[frameIdx];
|
||||
T rCheckIGrad;
|
||||
T rCheckFGrad;
|
||||
T rCheckOGrad;
|
||||
|
||||
rValueIn = value.gateValue[frameIdx];
|
||||
rValueIg = value.gateValue[frameIdx + frameSize];
|
||||
rValueFg = value.gateValue[frameIdx + frameSize * 2];
|
||||
rValueOg = value.gateValue[frameIdx + frameSize * 3];
|
||||
rState = value.stateValue[frameIdx];
|
||||
rStateAtv = value.stateActiveValue[frameIdx];
|
||||
rOutputGrad = grad.outputGrad[frameIdx];
|
||||
rStateGrad = grad.stateGrad[frameIdx];
|
||||
|
||||
if (value.prevStateValue) {
|
||||
if (isBatch) value.prevStateValue += batchIdx * frameSize;
|
||||
rPrevState = value.prevStateValue[frameIdx];
|
||||
}
|
||||
|
||||
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
|
||||
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
|
||||
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
|
||||
hppl::gpu::backward[active_node], hppl::gpu::backward[active_gate],
|
||||
hppl::gpu::backward[active_state]);
|
||||
|
||||
grad.gateGrad[frameIdx] = rGradIn;
|
||||
grad.gateGrad[frameIdx + frameSize] = rGradIg;
|
||||
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
|
||||
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
|
||||
grad.stateGrad[frameIdx] = rStateGrad;
|
||||
if (grad.prevStateGrad) {
|
||||
if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
|
||||
grad.prevStateGrad[frameIdx] = rPrevStateGrad;
|
||||
}
|
||||
|
||||
if (isBatch) {
|
||||
if (value.prevStateValue) {
|
||||
if (grad.checkIgGrad)
|
||||
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
|
||||
rCheckIGrad);
|
||||
if (grad.checkFgGrad)
|
||||
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
|
||||
rCheckFGrad);
|
||||
}
|
||||
if (grad.checkOgGrad)
|
||||
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
|
||||
} else {
|
||||
if (value.prevStateValue) {
|
||||
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
|
||||
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
|
||||
}
|
||||
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class Op>
|
||||
void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize,
|
||||
activation_mode_t active_node,
|
||||
activation_mode_t active_gate,
|
||||
activation_mode_t active_state) {
|
||||
dim3 threads;
|
||||
dim3 grid;
|
||||
if (batchSize == 1) {
|
||||
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
||||
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
||||
threads = dim3(framePerBlock, 1);
|
||||
grid = dim3(frameBlocks, 1);
|
||||
} else {
|
||||
/* framePerBlock = 32 batchPerBlock = 32 */
|
||||
threads = dim3(32, 32);
|
||||
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
||||
}
|
||||
|
||||
if (batchSize == 1) {
|
||||
KeLstmForward<T, Op,
|
||||
/* isBatch= */ false><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
||||
op, value, frameSize, batchSize, active_node, active_gate,
|
||||
active_state);
|
||||
} else {
|
||||
KeLstmForward<T, Op,
|
||||
/* isBatch= */ true><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
||||
op, value, frameSize, batchSize, active_node, active_gate,
|
||||
active_state);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class Op>
|
||||
void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize,
|
||||
int batchSize, activation_mode_t active_node,
|
||||
activation_mode_t active_gate,
|
||||
activation_mode_t active_state) {
|
||||
dim3 threads;
|
||||
dim3 grid;
|
||||
if (batchSize == 1) {
|
||||
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
|
||||
int frameBlocks = (frameSize + 1024 - 1) / 1024;
|
||||
threads = dim3(framePerBlock, 1);
|
||||
grid = dim3(frameBlocks, 1);
|
||||
} else {
|
||||
/* framePerBlock = 32 batchPerBlock = 32 */
|
||||
threads = dim3(32, 32);
|
||||
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
|
||||
}
|
||||
|
||||
if (batchSize == 1) {
|
||||
KeLstmBackward<T, Op,
|
||||
/* isBatch= */ false><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
||||
op, value, grad, frameSize, batchSize, active_node, active_gate,
|
||||
active_state);
|
||||
} else {
|
||||
KeLstmBackward<T, Op,
|
||||
/* isBatch= */ true><<<grid, threads, 0, STREAM_DEFAULT>>>(
|
||||
op, value, grad, frameSize, batchSize, active_node, active_gate,
|
||||
active_state);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,138 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "hl_activation_functions.h"
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define INLINE __device__ inline
|
||||
#else
|
||||
#define INLINE inline
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
namespace detail {
|
||||
|
||||
namespace forward {
|
||||
|
||||
template <class T>
|
||||
class lstm {
|
||||
public:
|
||||
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
||||
T &prevState, T &state, T &stateAtv, T &output,
|
||||
T &checkI, T &checkF, T &checkO,
|
||||
Active<T>::forward actInput,
|
||||
Active<T>::forward actGate,
|
||||
Active<T>::forward actState) {
|
||||
valueIn = actInput(valueIn);
|
||||
valueIg = actGate(valueIg + prevState * checkI);
|
||||
valueFg = actGate(valueFg + prevState * checkF);
|
||||
state = valueIn * valueIg + prevState * valueFg;
|
||||
valueOg = actGate(valueOg + state * checkO);
|
||||
stateAtv = actState(state);
|
||||
output = valueOg * stateAtv;
|
||||
}
|
||||
#ifndef __NVCC__
|
||||
#ifndef __AVX__
|
||||
static const bool avx = false;
|
||||
#else
|
||||
static const bool avx = true;
|
||||
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
||||
__m256 &valueOg, __m256 &prevState, __m256 &state,
|
||||
__m256 &stateAtv, __m256 &output, __m256 &checkI,
|
||||
__m256 &checkF, __m256 &checkO,
|
||||
Active<__m256>::forward actInput,
|
||||
Active<__m256>::forward actGate,
|
||||
Active<__m256>::forward actState) {
|
||||
valueIn = actInput(valueIn);
|
||||
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
|
||||
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
|
||||
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
|
||||
_mm256_mul_ps(prevState, valueFg));
|
||||
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)));
|
||||
stateAtv = actState(state);
|
||||
output = _mm256_mul_ps(valueOg, stateAtv);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace forward
|
||||
|
||||
namespace backward {
|
||||
|
||||
template <class T>
|
||||
class lstm {
|
||||
public:
|
||||
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
||||
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
|
||||
T &prevState, T &prevStateGrad, T &state, T &stateGrad,
|
||||
T &stateAtv, T &outputGrad, T &checkI, T &checkF,
|
||||
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
|
||||
Active<T>::backward actInput,
|
||||
Active<T>::backward actGate,
|
||||
Active<T>::backward actState) {
|
||||
gradOg = actGate(outputGrad * stateAtv, valueOg);
|
||||
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
|
||||
gradIn = actInput(stateGrad * valueIg, valueIn);
|
||||
gradIg = actGate(stateGrad * valueIn, valueIg);
|
||||
gradFg = actGate(stateGrad * prevState, valueFg);
|
||||
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
|
||||
checkIGrad = gradIg * prevState;
|
||||
checkFGrad = gradFg * prevState;
|
||||
checkOGrad = gradOg * state;
|
||||
}
|
||||
#ifndef __NVCC__
|
||||
#ifndef __AVX__
|
||||
static const bool avx = false;
|
||||
#else
|
||||
static const bool avx = true;
|
||||
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
||||
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
|
||||
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
|
||||
__m256 &prevStateGrad, __m256 &state,
|
||||
__m256 &stateGrad, __m256 &stateAtv,
|
||||
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
|
||||
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
|
||||
__m256 &checkOGrad, Active<__m256>::backward actInput,
|
||||
Active<__m256>::backward actGate,
|
||||
Active<__m256>::backward actState) {
|
||||
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
|
||||
stateGrad = _mm256_add_ps(
|
||||
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
|
||||
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
|
||||
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn);
|
||||
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg);
|
||||
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg);
|
||||
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
|
||||
_mm256_mul_ps(gradFg, checkF));
|
||||
prevStateGrad =
|
||||
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
|
||||
checkIGrad = _mm256_mul_ps(gradIg, prevState);
|
||||
checkFGrad = _mm256_mul_ps(gradFg, prevState);
|
||||
checkOGrad = _mm256_mul_ps(gradOg, state);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace backward
|
||||
|
||||
} // namespace detail
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#endif /* HL_LSTM_OPS_CUH_ */
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "LstmCompute.h"
|
||||
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
|
||||
#include "paddle/operators/math/detail/lstm_kernel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitFunctor<platform::CPUPlace, T> {
|
||||
static void compute(lstm_value value, int frame_size, int batch_size,
|
||||
std::string gate_act, std::string cell_act,
|
||||
std::string cand_act) {
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frameSize,
|
||||
ActiveType(cand_act), ActiveType(gate_act),
|
||||
ActiveType(cell_act));
|
||||
value.gateValue += frameSize * 4;
|
||||
value.stateValue += frameSize;
|
||||
value.stateActiveValue += frameSize;
|
||||
value.outputValue += frameSize;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frameSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
|
||||
static void compute(lstm_value value, lstm_grad grad, int frame_size,
|
||||
int batch_size, std::string gate_act,
|
||||
std::string cell_act, std::string cand_act) {
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
|
||||
frameSize, ActiveType(cand_act),
|
||||
ActiveType(gate_act), ActiveType(cell_act));
|
||||
|
||||
value.gateValue += frameSize * 4;
|
||||
value.stateValue += frameSize;
|
||||
value.stateActiveValue += frameSize;
|
||||
value.outputValue += frameSize;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frameSize;
|
||||
}
|
||||
|
||||
grad.gateGrad += frameSize * 4;
|
||||
grad.stateGrad += frameSize;
|
||||
grad.stateActiveGrad += frameSize;
|
||||
grad.outputGrad += frameSize;
|
||||
if (grad.prevStateGrad) {
|
||||
grad.prevStateGrad += frameSize;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,73 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "LstmCompute.h"
|
||||
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
|
||||
#include "paddle/operators/math/detail/lstm_kernel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitFunctor<platform::GPUPlace, T> {
|
||||
static void compute(lstm_value value, int frame_size, int batch_size,
|
||||
std::string gate_act, std::string cell_act,
|
||||
std::string cand_act) {
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
detail::gpu_lstm_forward(detail::forward::lstm<T>(), value, frameSize,
|
||||
ActiveType(cand_act), ActiveType(gate_act),
|
||||
ActiveType(cell_act));
|
||||
value.gateValue += frameSize * 4;
|
||||
value.stateValue += frameSize;
|
||||
value.stateActiveValue += frameSize;
|
||||
value.outputValue += frameSize;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frameSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
|
||||
static void compute(lstm_value value, lstm_grad grad, int frame_size,
|
||||
int batch_size, std::string gate_act,
|
||||
std::string cell_act, std::string cand_act) {
|
||||
for (int b = 0; b < batchSize; b++) {
|
||||
detail::gpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
|
||||
frameSize, ActiveType(cand_act),
|
||||
ActiveType(gate_act), ActiveType(cell_act));
|
||||
|
||||
value.gateValue += frameSize * 4;
|
||||
value.stateValue += frameSize;
|
||||
value.stateActiveValue += frameSize;
|
||||
value.outputValue += frameSize;
|
||||
if (value.prevStateValue) {
|
||||
value.prevStateValue += frameSize;
|
||||
}
|
||||
|
||||
grad.gateGrad += frameSize * 4;
|
||||
grad.stateGrad += frameSize;
|
||||
grad.stateActiveGrad += frameSize;
|
||||
grad.outputGrad += frameSize;
|
||||
if (grad.prevStateGrad) {
|
||||
grad.prevStateGrad += frameSize;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,87 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
typedef enum {
|
||||
HL_ACTIVATION_SIGMOID = 0,
|
||||
HL_ACTIVATION_RELU = 1,
|
||||
HL_ACTIVATION_TANH = 2,
|
||||
HL_ACTIVATION_LINEAR = 3,
|
||||
HL_ACTIVATION_END
|
||||
} activation_mode_t;
|
||||
|
||||
template <T>
|
||||
struct lstm_value {
|
||||
real *gateValue;
|
||||
real *prevStateValue;
|
||||
real *stateValue;
|
||||
real *stateActiveValue;
|
||||
real *outputValue;
|
||||
real *checkIg;
|
||||
real *checkFg;
|
||||
real *checkOg;
|
||||
};
|
||||
|
||||
template <T>
|
||||
struct lstm_grad {
|
||||
real *gateGrad;
|
||||
real *prevStateGrad;
|
||||
real *stateGrad;
|
||||
real *stateActiveGrad;
|
||||
real *outputGrad;
|
||||
real *checkIgGrad;
|
||||
real *checkFgGrad;
|
||||
real *checkOgGrad;
|
||||
};
|
||||
|
||||
activation_mode_t ActiveType(const std::string &type) {
|
||||
if (type == "sigmoid") {
|
||||
return HL_ACTIVATION_SIGMOID;
|
||||
} else if (type == "relu") {
|
||||
return HL_ACTIVATION_RELU;
|
||||
} else if (type == "tanh") {
|
||||
return HL_ACTIVATION_TANH;
|
||||
} else if (type == "linear" || type == "") {
|
||||
return HL_ACTIVATION_LINEAR;
|
||||
} else {
|
||||
PADDLE_THROW("Do not support activation type.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LstmUnitFunctor {
|
||||
public:
|
||||
static void compute(lstm_value value, int frame_size, int batch_size,
|
||||
std::string gate_act, std::string cell_act,
|
||||
std::string cand_act);
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LstmUnitGradFunctor {
|
||||
public:
|
||||
static void compute(lstm_value value, lstm_grad grad, int frame_size,
|
||||
int batch_size, std::string gate_act,
|
||||
std::string cell_act, std::string cand_act);
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue