Enable backward computation in lstmp_op

emailweixu-patch-1
Yibing Liu 7 years ago
parent f2c4bb679b
commit 552c901204

@ -39,21 +39,12 @@ class LSTMPOp : public framework::OperatorWithKernel {
"Output(BatchGate) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP should not be null.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
auto proj_dims = ctx->GetInputDim("ProjWeight");
@ -75,6 +66,18 @@ class LSTMPOp : public framework::OperatorWithKernel {
"should be %d.",
frame_size);
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]});
}
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
@ -98,6 +101,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->SetOutputDim("BatchHidden", out_dims);
ctx->ShareLoD("Input", "Projection");
ctx->ShareLoD("Input", "Cell");
}
@ -169,6 +173,15 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
AddOutput("BatchHidden",
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
AddOutput("OrderedP0",
"(Tensor) the projection of the initial hidden state "
"H0. This is a tensor with shape (N x P), where N is the "
"batch size and P is the hidden size.")
.AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
@ -177,6 +190,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed LSTMP.")
.SetDefault(false);
AddAttr<bool>("share_cell_act",
"(bool, defalut: True) "
"whether to share activation with cell output. "
"If false, the projection would be linear, else "
"through an activation same with the cell output.")
.SetDefault(true);
AddAttr<std::string>(
"gate_activation",
"(string, default: sigmoid)"
@ -213,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
h_t = o_t \odot act_h(c_t)
r_t = W_{rh}h_t
r_t = act_h'(W_{rh}h_t)
$$
where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
@ -229,7 +248,8 @@ layer.
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
are the cell input and cell output activation functions and `tanh` is usually
used for them.
used for them. If `share_cell_act` setted to `False`, $act_h'$ will be linear
else will be same with $act_h$.
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator.
@ -246,12 +266,14 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(Hidden) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Projection) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null.");
@ -268,6 +290,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
SetOutGradDim("Input");
SetOutGradDim("Weight");
SetOutGradDim("ProjWeight");
SetOutGradDim("Bias");
SetOutGradDim("H0");
SetOutGradDim("C0");

File diff suppressed because it is too large Load Diff

@ -62,7 +62,8 @@ def lstmp(
is_reverse=False,
act_gate=None,
act_cell=None,
act_cand=None):
act_cand=None,
share_cell_act=True):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand):
g = np.dot(r_pre, w_r) # 1 x 4D
g = g + x
@ -85,6 +86,8 @@ def lstmp(
h = g_o * act_cell(c)
# projection
r = np.dot(h, w_rh)
if share_cell_act:
r = act_cell(r)
return r, c
def _reverse(x, lod):
@ -107,6 +110,8 @@ def lstmp(
seq_len = offset[i + 1] - offset[i]
x = input[offset[i]:offset[i + 1], :]
r_pre = np.dot(h0[i], w_rh) # 1 x P
if share_cell_act:
r_pre = act_cell(r_pre)
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
@ -138,6 +143,7 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
@ -167,7 +173,7 @@ class TestLstmOp(OpTest):
w_rh = np.random.normal(size=(self.D, self.P)).astype('float64')
r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
ACTVATION[self.act_cand], self.share_cell_act)
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh}
@ -192,28 +198,30 @@ class TestLstmOp(OpTest):
def test_check_output(self):
self.check_output(atol=1e-8)
"""
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
"""
['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=5e-3)
"""
class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.P = 5
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = True
self.is_reverse = True
self.use_peepholes = True
@ -221,63 +229,74 @@ class TestLstmOpHasInitial(TestLstmOp):
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'],
max_relative_error=5e-3)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
['Input', 'Weight'], ['Projection'],
max_relative_error=5e-3,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
['Input', 'Bias'], ['Projection'],
max_relative_error=5e-3,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
['Weight', 'Bias'], ['Projection'],
max_relative_error=5e-3,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
max_relative_error=5e-4,
['Input', 'Weight', 'Bias', 'C0'], ['Projection'],
max_relative_error=5e-3,
no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
max_relative_error=5e-4,
['Input', 'Weight', 'Bias', 'H0'], ['Projection'],
max_relative_error=5e-3,
no_grad_set=set('C0'))
"""
class TestLstmOpRerverse(TestLstmOp):
@ -290,6 +309,7 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = True
@ -305,6 +325,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = False

Loading…
Cancel
Save