Make lstm_op follow google code style.

release/0.11.0
qingqing01 7 years ago
parent d89061c39c
commit e5b51c4d10

@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
lstm_value.prevStateValue = nullptr;
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) {
@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>();
lstm_value.prev_state_value = ordered_c0.data<T>();
}
// Use the local variable as here.
@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast<T>(1.0));
}
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
lstm_value.gate_value = gate_t.data<T>();
lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
lstm_value.prev_state_value = lstm_value.state_value;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LstmMetaValue<T> lstm_value;
if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
lstm_value.check_ig = bias_data + 4 * frame_size;
lstm_value.check_fg = lstm_value.check_ig + frame_size;
lstm_value.check_og = lstm_value.check_fg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
}
math::LstmMetaGrad<T> lstm_grad;
@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
lstm_grad.check_ig_grad = bias_g_data + 4 * frame_size;
lstm_grad.check_fg_grad = lstm_grad.check_ig_grad + frame_size;
lstm_grad.check_og_grad = lstm_grad.check_fg_grad + frame_size;
} else {
lstm_grad.checkIgGrad = nullptr;
lstm_grad.checkFgGrad = nullptr;
lstm_grad.checkOgGrad = nullptr;
lstm_grad.check_ig_grad = nullptr;
lstm_grad.check_fg_grad = nullptr;
lstm_grad.check_og_grad = nullptr;
}
math::LoDTensor2BatchFunctor<Place, T> to_batch;
@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>();
lstm_value.stateValue = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>();
lstm_value.gate_value = gate.data<T>();
lstm_value.state_value = cell.data<T>();
lstm_value.state_active_value = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
lstm_grad.state_grad = cell_g.data<T>();
lstm_grad.gate_grad = gate_g.data<T>();
lstm_grad.output_grad = out_g.data<T>();
if (n > 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
lstm_value.prev_state_value = cell_pre.data<T>();
lstm_grad.prev_state_grad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
lstm_value.prev_state_value = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prev_state_grad = c0_g ? ordered_c0_g.data<T>() : nullptr;
}
int cur_batch_size = bend - bstart;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -27,19 +27,19 @@ namespace forward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
valueIn = activation(valueIn, active_node);
valueIg = activation(valueIg + prevState * checkI, active_gate);
valueFg = activation(valueFg + prevState * checkF, active_gate);
state = valueIn * valueIg + prevState * valueFg;
valueOg = activation(valueOg + state * checkO, active_gate);
stateAtv = activation(state, active_state);
output = valueOg * stateAtv;
value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate);
state = value_in * value_ig + prev_state * value_fg;
value_og = activation(value_og + state * checkO, active_gate);
state_atv = activation(state, active_state);
output = value_og * state_atv;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
__m256 &value_fg, __m256 &value_og,
__m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
valueIn = activation(valueIn, active_node);
valueIg = activation(
_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate);
valueFg = activation(
_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate);
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg));
valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)),
active_gate);
stateAtv = activation(state, active_state);
output = _mm256_mul_ps(valueOg, stateAtv);
value_in = activation(value_in, active_node);
value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
active_gate);
value_fg =
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
_mm256_mul_ps(prev_state, value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
active_gate);
state_atv = activation(state, active_state);
output = _mm256_mul_ps(value_og, state_atv);
}
#endif
#endif
@ -78,25 +81,26 @@ namespace backward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
T &prev_state, T &prev_state_grad, T &state,
T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
gradOg = activation(outputGrad * stateAtv, valueOg, active_gate);
stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) +
gradOg * checkO;
gradIn = activation(stateGrad * valueIg, valueIn, active_node);
gradIg = activation(stateGrad * valueIn, valueIg, active_gate);
gradFg = activation(stateGrad * prevState, valueFg, active_gate);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO;
grad_in = activation(state_grad * value_ig, value_in, active_node);
grad_ig = activation(state_grad * value_in, value_ig, active_gate);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
prev_state_grad =
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
checkIGrad = grad_ig * prev_state;
checkFGrad = grad_fg * prev_state;
checkOGrad = grad_og * state;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(
__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg,
__m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg,
__m256 &prevState, __m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
__m256 &checkOGrad, activation_mode_t active_node,
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node,
activation_mode_t active_gate, activation_mode_t active_state) {
gradOg =
activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate);
stateGrad = _mm256_add_ps(
activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state),
stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn =
activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node);
gradIg =
activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate);
gradFg =
activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF));
prevStateGrad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
checkIGrad = _mm256_mul_ps(gradIg, prevState);
checkFGrad = _mm256_mul_ps(gradFg, prevState);
checkOGrad = _mm256_mul_ps(gradOg, state);
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
state_atv, active_state),
state_grad);
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
grad_in =
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
grad_ig =
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
active_gate);
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
_mm256_mul_ps(grad_fg, checkF));
prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state);
}
#endif
#endif

@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frame_size * 4;
value.stateValue += frame_size;
value.stateActiveValue += frame_size;
value.outputValue += frame_size;
if (value.prevStateValue) {
value.prevStateValue += frame_size;
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
grad.gateGrad += frame_size * 4;
grad.stateGrad += frame_size;
grad.stateActiveGrad += frame_size;
grad.outputGrad += frame_size;
if (grad.prevStateGrad) {
grad.prevStateGrad += frame_size;
grad.gate_grad += frame_size * 4;
grad.state_grad += frame_size;
grad.state_active_grad += frame_size;
grad.output_grad += frame_size;
if (grad.prev_state_grad) {
grad.prev_state_grad += frame_size;
}
}
}

@ -31,26 +31,26 @@ typedef enum {
template <class T>
struct LstmMetaValue {
T *gateValue;
T *prevStateValue;
T *stateValue;
T *stateActiveValue;
T *outputValue;
T *checkIg;
T *checkFg;
T *checkOg;
T *gate_value;
T *prev_state_value;
T *state_value;
T *state_active_value;
T *output_value;
T *check_ig;
T *check_fg;
T *check_og;
};
template <class T>
struct LstmMetaGrad {
T *gateGrad;
T *prevStateGrad;
T *stateGrad;
T *stateActiveGrad;
T *outputGrad;
T *checkIgGrad;
T *checkFgGrad;
T *checkOgGrad;
T *gate_grad;
T *prev_state_grad;
T *state_grad;
T *state_active_grad;
T *output_grad;
T *check_ig_grad;
T *check_fg_grad;
T *check_og_grad;
};
inline activation_mode_t ActiveType(const std::string &type) {

Loading…
Cancel
Save