|
|
|
@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/math/detail/hl_activation_functions.h"
|
|
|
|
|
#include "paddle/platform/hostdevice.h"
|
|
|
|
|
|
|
|
|
|
#ifdef __CUDA_ARCH__
|
|
|
|
|
#define INLINE __device__ inline
|
|
|
|
|
#else
|
|
|
|
|
#define INLINE inline
|
|
|
|
|
#endif
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -30,12 +27,12 @@ namespace forward {
|
|
|
|
|
template <class T>
|
|
|
|
|
class lstm {
|
|
|
|
|
public:
|
|
|
|
|
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &prevState, T &state, T &stateAtv, T &output,
|
|
|
|
|
T &checkI, T &checkF, T &checkO,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actInput,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actGate,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actState) {
|
|
|
|
|
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &prevState, T &state, T &stateAtv, T &output,
|
|
|
|
|
T &checkI, T &checkF, T &checkO,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actInput,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actGate,
|
|
|
|
|
typename hppl::ForwardActType<T>::type actState) {
|
|
|
|
|
valueIn = actInput(valueIn);
|
|
|
|
|
valueIg = actGate(valueIg + prevState * checkI);
|
|
|
|
|
valueFg = actGate(valueFg + prevState * checkF);
|
|
|
|
@ -45,17 +42,19 @@ class lstm {
|
|
|
|
|
output = valueOg * stateAtv;
|
|
|
|
|
}
|
|
|
|
|
#ifndef __NVCC__
|
|
|
|
|
#ifndef __AVX__
|
|
|
|
|
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
|
|
|
|
static const bool avx = false;
|
|
|
|
|
#else
|
|
|
|
|
static const bool avx = true;
|
|
|
|
|
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
|
|
|
|
__m256 &valueOg, __m256 &prevState, __m256 &state,
|
|
|
|
|
__m256 &stateAtv, __m256 &output, __m256 &checkI,
|
|
|
|
|
__m256 &checkF, __m256 &checkO,
|
|
|
|
|
hppl::Active<__m256>::forward actInput,
|
|
|
|
|
hppl::Active<__m256>::forward actGate,
|
|
|
|
|
hppl::Active<__m256>::forward actState) {
|
|
|
|
|
// Only float support AVX optimization
|
|
|
|
|
static const bool avx = std::is_same<T, float>::value;
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
|
|
|
|
__m256 &valueOg, __m256 &prevState, __m256 &state,
|
|
|
|
|
__m256 &stateAtv, __m256 &output, __m256 &checkI,
|
|
|
|
|
__m256 &checkF, __m256 &checkO,
|
|
|
|
|
hppl::Active<__m256>::forward actInput,
|
|
|
|
|
hppl::Active<__m256>::forward actGate,
|
|
|
|
|
hppl::Active<__m256>::forward actState) {
|
|
|
|
|
valueIn = actInput(valueIn);
|
|
|
|
|
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
|
|
|
|
|
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
|
|
|
|
@ -76,14 +75,15 @@ namespace backward {
|
|
|
|
|
template <class T>
|
|
|
|
|
class lstm {
|
|
|
|
|
public:
|
|
|
|
|
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
|
|
|
|
|
T &prevState, T &prevStateGrad, T &state, T &stateGrad,
|
|
|
|
|
T &stateAtv, T &outputGrad, T &checkI, T &checkF,
|
|
|
|
|
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actInput,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actGate,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actState) {
|
|
|
|
|
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
|
|
|
|
|
T &prevState, T &prevStateGrad, T &state,
|
|
|
|
|
T &stateGrad, T &stateAtv, T &outputGrad,
|
|
|
|
|
T &checkI, T &checkF, T &checkO, T &checkIGrad,
|
|
|
|
|
T &checkFGrad, T &checkOGrad,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actInput,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actGate,
|
|
|
|
|
typename hppl::BackwardActType<T>::type actState) {
|
|
|
|
|
gradOg = actGate(outputGrad * stateAtv, valueOg);
|
|
|
|
|
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
|
|
|
|
|
gradIn = actInput(stateGrad * valueIg, valueIn);
|
|
|
|
@ -95,21 +95,22 @@ class lstm {
|
|
|
|
|
checkOGrad = gradOg * state;
|
|
|
|
|
}
|
|
|
|
|
#ifndef __NVCC__
|
|
|
|
|
#ifndef __AVX__
|
|
|
|
|
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
|
|
|
|
static const bool avx = false;
|
|
|
|
|
#else
|
|
|
|
|
static const bool avx = true;
|
|
|
|
|
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
|
|
|
|
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
|
|
|
|
|
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
|
|
|
|
|
__m256 &prevStateGrad, __m256 &state,
|
|
|
|
|
__m256 &stateGrad, __m256 &stateAtv,
|
|
|
|
|
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
|
|
|
|
|
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
|
|
|
|
|
__m256 &checkOGrad,
|
|
|
|
|
hppl::Active<__m256>::backward actInput,
|
|
|
|
|
hppl::Active<__m256>::backward actGate,
|
|
|
|
|
hppl::Active<__m256>::backward actState) {
|
|
|
|
|
// Only float support AVX optimization
|
|
|
|
|
static const bool avx = std::is_same<T, float>::value;
|
|
|
|
|
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
|
|
|
|
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
|
|
|
|
|
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
|
|
|
|
|
__m256 &prevStateGrad, __m256 &state,
|
|
|
|
|
__m256 &stateGrad, __m256 &stateAtv,
|
|
|
|
|
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
|
|
|
|
|
__m256 &checkO, __m256 &checkIGrad,
|
|
|
|
|
__m256 &checkFGrad, __m256 &checkOGrad,
|
|
|
|
|
hppl::Active<__m256>::backward actInput,
|
|
|
|
|
hppl::Active<__m256>::backward actGate,
|
|
|
|
|
hppl::Active<__m256>::backward actState) {
|
|
|
|
|
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
|
|
|
|
|
stateGrad = _mm256_add_ps(
|
|
|
|
|
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
|
|
|
|
|