Use fixed activation in the lstm kernel, since there is some bug in the activation function pointer. It will be fixed later.

fix-typo
dangqingqing 8 years ago
parent bd680f157f
commit b50c33fd00

@ -82,6 +82,13 @@ class LSTMOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell"); ctx->ShareLoD("Input", "Cell");
} }
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
}; };
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
@ -239,6 +246,13 @@ class LSTMGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(b_g_name)) if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
} }
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
}; };
} // namespace operators } // namespace operators

@ -26,10 +26,7 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize, int frameSize) {
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
} }
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), rOut, rCheckI, rCheckF, rCheckO);
act(active_state));
valueIn[i] = rValueIn; valueIn[i] = rValueIn;
valueIg[i] = rValueIg; valueIg[i] = rValueIg;
@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frameSize) {
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
} }
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state)); rCheckOGrad);
gradIn[i] = rGradIn; gradIn[i] = rGradIn;
gradIg[i] = rGradIg; gradIg[i] = rGradIg;
@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, naive_lstm_forward_one_sequence<T>(op, value, frameSize);
active_gate, active_state);
} }
} }
@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize);
active_gate, active_state);
} }
} }

@ -32,9 +32,7 @@ namespace detail {
*/ */
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node, int batchSize) {
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
rPrevState = value.prevStateValue[frameIdx]; rPrevState = value.prevStateValue[frameIdx];
} }
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), rOut, rCheckI, rCheckF, rCheckO);
act(active_state));
value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg; value.gateValue[frameIdx + frameSize] = rValueIg;
@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node, int batchSize) {
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[frameIdx]; rPrevState = value.prevStateValue[frameIdx];
} }
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
act(active_node), act(active_gate), act(active_state));
grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg; grad.gateGrad[frameIdx + frameSize] = rGradIg;
@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frameSize, batchSize);
active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate, op, value, frameSize, batchSize);
active_state);
} }
} }
@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frameSize, batchSize);
active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate, op, value, grad, frameSize, batchSize);
active_state);
} }
} }

@ -24,15 +24,29 @@ namespace detail {
namespace forward { namespace forward {
template <typename T>
DEVICE inline T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE inline T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output, T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO) {
typename hppl::ForwardActType<T>::type actInput, #if 0
typename hppl::ForwardActType<T>::type actGate, // TODO(qingqing) support to activation speficed by users
typename hppl::ForwardActType<T>::type actState) {
valueIn = actInput(valueIn); valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI); valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF); valueFg = actGate(valueFg + prevState * checkF);
@ -40,6 +54,15 @@ class lstm {
valueOg = actGate(valueOg + state * checkO); valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state); stateAtv = actState(state);
output = valueOg * stateAtv; output = valueOg * stateAtv;
#else
valueIn = tanh<T>(valueIn);
valueIg = sigmoid<T>(valueIg + prevState * checkI);
valueFg = sigmoid<T>(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = sigmoid<T>(valueOg + state * checkO);
stateAtv = tanh<T>(state);
output = valueOg * stateAtv;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
@ -72,6 +95,16 @@ class lstm {
namespace backward { namespace backward {
template <typename T>
DEVICE inline T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE inline T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
@ -80,10 +113,9 @@ class lstm {
T &prevState, T &prevStateGrad, T &state, T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad, T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad, T &checkFGrad, T &checkOGrad) {
typename hppl::BackwardActType<T>::type actInput, #if 0
typename hppl::BackwardActType<T>::type actGate, // TODO(qingqing) support to activation speficed by users
typename hppl::BackwardActType<T>::type actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg); gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn); gradIn = actInput(stateGrad * valueIg, valueIn);
@ -93,6 +125,17 @@ class lstm {
checkIGrad = gradIg * prevState; checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState; checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state; checkOGrad = gradOg * state;
#else
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default

@ -110,7 +110,7 @@ def lstm(
class TestLstmOp(OpTest): class TestLstmOp(OpTest):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
@ -164,12 +164,13 @@ class TestLstmOp(OpTest):
# TODO(qingqing) remove folowing two lines after the check_grad is refined. # TODO(qingqing) remove folowing two lines after the check_grad is refined.
self.outputs['BatchGate'] = None self.outputs['BatchGate'] = None
self.outputs['BatchCellPreAct'] = None self.outputs['BatchCellPreAct'] = None
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02)
class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
@ -182,7 +183,7 @@ class TestLstmOpHasNoInitial(TestLstmOp):
class TestLstmOpRerverse(TestLstmOp): class TestLstmOpRerverse(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 6]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'

Loading…
Cancel
Save