Fix cpplint errors in lstm kernel (#10394)

scopeFix
Siddharth Goyal 8 years ago committed by Abhinav Arora
parent bd66eed50a
commit b65282168c

@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;

@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_state); active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig; value.gate_value[frame_idx + frame_size] = r_value_ig;
@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
active_state); active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;

@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
@ -27,19 +27,19 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &prev_state, T &state, T &state_atv, T &output, T *prev_state, T *state, T *state_atv, T *output,
T &checkI, T &checkF, T &checkO, T *checkI, T *checkF, T *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate); *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
state = value_in * value_ig + prev_state * value_fg; *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
value_og = activation(value_og + state * checkO, active_gate); *value_og = activation(*value_og + (*state) * (*checkO), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = value_og * state_atv; *output = (*value_og) * (*state_atv);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@ -48,27 +48,27 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig,
__m256 &value_fg, __m256 &value_og, __m256 *value_fg, __m256 *value_og,
__m256 &prev_state, __m256 &state, __m256 *prev_state, __m256 *state,
__m256 &state_atv, __m256 &output, __m256 &checkI, __m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 &checkF, __m256 &checkO, __m256 *checkF, __m256 *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = *value_ig = activation(
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), _mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)),
active_gate); active_gate);
value_fg = *value_fg = activation(
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), _mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)),
active_gate); active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(prev_state, value_fg)); _mm256_mul_ps(*prev_state, *value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), *value_og = activation(
active_gate); _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = _mm256_mul_ps(value_og, state_atv); *output = _mm256_mul_ps(*value_og, *state_atv);
} }
#endif #endif
#endif #endif
@ -81,26 +81,29 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, T *grad_in, T *grad_ig, T *grad_fg, T *grad_og,
T &prev_state, T &prev_state_grad, T &state, T *prev_state, T *prev_state_grad, T *state,
T &state_grad, T &state_atv, T &output_grad, T *state_grad, T *state_atv, T *output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T *checkI, T *checkF, T *checkO, T *checkIGrad,
T &checkFGrad, T &checkOGrad, T *checkFGrad, T *checkOGrad,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate); *grad_og =
state_grad += activation(output_grad * value_og, state_atv, active_state) + activation((*output_grad) * (*state_atv), *value_og, active_gate);
grad_og * checkO; *state_grad +=
grad_in = activation(state_grad * value_ig, value_in, active_node); activation((*output_grad) * (*value_og), *state_atv, active_state) +
grad_ig = activation(state_grad * value_in, value_ig, active_gate); (*grad_og) * (*checkO);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate); *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
prev_state_grad = *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; *grad_fg =
checkIGrad = grad_ig * prev_state; activation((*state_grad) * (*prev_state), *value_fg, active_gate);
checkFGrad = grad_fg * prev_state; *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
checkOGrad = grad_og * state; (*state_grad) * (*value_fg);
*checkIGrad = (*grad_ig) * (*prev_state);
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@ -109,32 +112,33 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()( HOSTDEVICE void operator()(
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, __m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, __m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node, __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) { ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate); active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), *state_grad =
state_atv, active_state), _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
state_grad); *state_atv, active_state),
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad); *state_grad);
grad_in = *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node); *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
grad_ig = active_node);
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate); *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg, active_gate);
active_gate); *grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg,
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), active_gate);
_mm256_mul_ps(grad_fg, checkF)); *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
prev_state_grad = _mm256_mul_ps(*grad_fg, *checkF));
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); *prev_state_grad =
checkIGrad = _mm256_mul_ps(grad_ig, prev_state); _mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state); *checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state); *checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state);
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
} }
#endif #endif
#endif #endif

Loading…
Cancel
Save