|
|
|
@ -27,19 +27,19 @@ namespace forward {
|
|
|
|
|
template <class T>
|
|
|
|
|
class lstm {
|
|
|
|
|
public:
|
|
|
|
|
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &prevState, T &state, T &stateAtv, T &output,
|
|
|
|
|
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
|
|
|
|
|
T &prev_state, T &state, T &state_atv, T &output,
|
|
|
|
|
T &checkI, T &checkF, T &checkO,
|
|
|
|
|
activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate,
|
|
|
|
|
activation_mode_t active_state) {
|
|
|
|
|
valueIn = activation(valueIn, active_node);
|
|
|
|
|
valueIg = activation(valueIg + prevState * checkI, active_gate);
|
|
|
|
|
valueFg = activation(valueFg + prevState * checkF, active_gate);
|
|
|
|
|
state = valueIn * valueIg + prevState * valueFg;
|
|
|
|
|
valueOg = activation(valueOg + state * checkO, active_gate);
|
|
|
|
|
stateAtv = activation(state, active_state);
|
|
|
|
|
output = valueOg * stateAtv;
|
|
|
|
|
value_in = activation(value_in, active_node);
|
|
|
|
|
value_ig = activation(value_ig + prev_state * checkI, active_gate);
|
|
|
|
|
value_fg = activation(value_fg + prev_state * checkF, active_gate);
|
|
|
|
|
state = value_in * value_ig + prev_state * value_fg;
|
|
|
|
|
value_og = activation(value_og + state * checkO, active_gate);
|
|
|
|
|
state_atv = activation(state, active_state);
|
|
|
|
|
output = value_og * state_atv;
|
|
|
|
|
}
|
|
|
|
|
#ifndef __NVCC__
|
|
|
|
|
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
|
|
|
@ -48,24 +48,27 @@ class lstm {
|
|
|
|
|
// Only float support AVX optimization
|
|
|
|
|
static const bool avx = std::is_same<T, float>::value;
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
|
|
|
|
|
__m256 &valueOg, __m256 &prevState, __m256 &state,
|
|
|
|
|
__m256 &stateAtv, __m256 &output, __m256 &checkI,
|
|
|
|
|
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
|
|
|
|
|
__m256 &value_fg, __m256 &value_og,
|
|
|
|
|
__m256 &prev_state, __m256 &state,
|
|
|
|
|
__m256 &state_atv, __m256 &output, __m256 &checkI,
|
|
|
|
|
__m256 &checkF, __m256 &checkO,
|
|
|
|
|
activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate,
|
|
|
|
|
activation_mode_t active_state) {
|
|
|
|
|
valueIn = activation(valueIn, active_node);
|
|
|
|
|
valueIg = activation(
|
|
|
|
|
_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate);
|
|
|
|
|
valueFg = activation(
|
|
|
|
|
_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate);
|
|
|
|
|
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
|
|
|
|
|
_mm256_mul_ps(prevState, valueFg));
|
|
|
|
|
valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)),
|
|
|
|
|
active_gate);
|
|
|
|
|
stateAtv = activation(state, active_state);
|
|
|
|
|
output = _mm256_mul_ps(valueOg, stateAtv);
|
|
|
|
|
value_in = activation(value_in, active_node);
|
|
|
|
|
value_ig =
|
|
|
|
|
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
|
|
|
|
|
active_gate);
|
|
|
|
|
value_fg =
|
|
|
|
|
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
|
|
|
|
|
active_gate);
|
|
|
|
|
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
|
|
|
|
|
_mm256_mul_ps(prev_state, value_fg));
|
|
|
|
|
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
|
|
|
|
|
active_gate);
|
|
|
|
|
state_atv = activation(state, active_state);
|
|
|
|
|
output = _mm256_mul_ps(value_og, state_atv);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
@ -78,25 +81,26 @@ namespace backward {
|
|
|
|
|
template <class T>
|
|
|
|
|
class lstm {
|
|
|
|
|
public:
|
|
|
|
|
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
|
|
|
|
|
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
|
|
|
|
|
T &prevState, T &prevStateGrad, T &state,
|
|
|
|
|
T &stateGrad, T &stateAtv, T &outputGrad,
|
|
|
|
|
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
|
|
|
|
|
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
|
|
|
|
|
T &prev_state, T &prev_state_grad, T &state,
|
|
|
|
|
T &state_grad, T &state_atv, T &output_grad,
|
|
|
|
|
T &checkI, T &checkF, T &checkO, T &checkIGrad,
|
|
|
|
|
T &checkFGrad, T &checkOGrad,
|
|
|
|
|
activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate,
|
|
|
|
|
activation_mode_t active_state) {
|
|
|
|
|
gradOg = activation(outputGrad * stateAtv, valueOg, active_gate);
|
|
|
|
|
stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) +
|
|
|
|
|
gradOg * checkO;
|
|
|
|
|
gradIn = activation(stateGrad * valueIg, valueIn, active_node);
|
|
|
|
|
gradIg = activation(stateGrad * valueIn, valueIg, active_gate);
|
|
|
|
|
gradFg = activation(stateGrad * prevState, valueFg, active_gate);
|
|
|
|
|
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
|
|
|
|
|
checkIGrad = gradIg * prevState;
|
|
|
|
|
checkFGrad = gradFg * prevState;
|
|
|
|
|
checkOGrad = gradOg * state;
|
|
|
|
|
grad_og = activation(output_grad * state_atv, value_og, active_gate);
|
|
|
|
|
state_grad += activation(output_grad * value_og, state_atv, active_state) +
|
|
|
|
|
grad_og * checkO;
|
|
|
|
|
grad_in = activation(state_grad * value_ig, value_in, active_node);
|
|
|
|
|
grad_ig = activation(state_grad * value_in, value_ig, active_gate);
|
|
|
|
|
grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
|
|
|
|
|
prev_state_grad =
|
|
|
|
|
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
|
|
|
|
|
checkIGrad = grad_ig * prev_state;
|
|
|
|
|
checkFGrad = grad_fg * prev_state;
|
|
|
|
|
checkOGrad = grad_og * state;
|
|
|
|
|
}
|
|
|
|
|
#ifndef __NVCC__
|
|
|
|
|
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
|
|
|
|
@ -105,32 +109,32 @@ class lstm {
|
|
|
|
|
// Only float support AVX optimization
|
|
|
|
|
static const bool avx = std::is_same<T, float>::value;
|
|
|
|
|
HOSTDEVICE void operator()(
|
|
|
|
|
__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg,
|
|
|
|
|
__m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg,
|
|
|
|
|
__m256 &prevState, __m256 &prevStateGrad, __m256 &state,
|
|
|
|
|
__m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI,
|
|
|
|
|
__m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
|
|
|
|
|
__m256 &checkOGrad, activation_mode_t active_node,
|
|
|
|
|
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
|
|
|
|
|
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
|
|
|
|
|
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
|
|
|
|
|
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
|
|
|
|
|
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
|
|
|
|
|
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node,
|
|
|
|
|
activation_mode_t active_gate, activation_mode_t active_state) {
|
|
|
|
|
gradOg =
|
|
|
|
|
activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate);
|
|
|
|
|
stateGrad = _mm256_add_ps(
|
|
|
|
|
activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state),
|
|
|
|
|
stateGrad);
|
|
|
|
|
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
|
|
|
|
|
gradIn =
|
|
|
|
|
activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node);
|
|
|
|
|
gradIg =
|
|
|
|
|
activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate);
|
|
|
|
|
gradFg =
|
|
|
|
|
activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate);
|
|
|
|
|
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
|
|
|
|
|
_mm256_mul_ps(gradFg, checkF));
|
|
|
|
|
prevStateGrad =
|
|
|
|
|
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
|
|
|
|
|
checkIGrad = _mm256_mul_ps(gradIg, prevState);
|
|
|
|
|
checkFGrad = _mm256_mul_ps(gradFg, prevState);
|
|
|
|
|
checkOGrad = _mm256_mul_ps(gradOg, state);
|
|
|
|
|
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
|
|
|
|
|
active_gate);
|
|
|
|
|
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
|
|
|
|
|
state_atv, active_state),
|
|
|
|
|
state_grad);
|
|
|
|
|
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
|
|
|
|
|
grad_in =
|
|
|
|
|
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
|
|
|
|
|
grad_ig =
|
|
|
|
|
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
|
|
|
|
|
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
|
|
|
|
|
active_gate);
|
|
|
|
|
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
|
|
|
|
|
_mm256_mul_ps(grad_fg, checkF));
|
|
|
|
|
prev_state_grad =
|
|
|
|
|
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
|
|
|
|
|
checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
|
|
|
|
|
checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
|
|
|
|
|
checkOGrad = _mm256_mul_ps(grad_og, state);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|