Pass grad checking for projection weight

emailweixu-patch-1
Yibing Liu 7 years ago
parent 552c901204
commit 7a5b8ffacb

@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. Long-Short Term Memory with Recurrent Projection (LSTMP) Operator.
LATMP is stand LSTM appended by a recurrent projection layer to reduce the LSTMP is stand LSTM appended by a recurrent projection layer to reduce the
number of parameters, espeacially when the output size is relative large. number of parameters, espeacially when the output size is relative large.
The formula is as follows: The formula is as follows:
@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
h_t = o_t \odot act_h(c_t) h_t = o_t \odot act_h(c_t)
r_t = act_h'(W_{rh}h_t) r_t = act_{h'}(W_{rh}h_t)
$$ $$
where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix

@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
proj_g_dev); proj_g_dev);
} }
/* hidden state backwarad */
Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight, math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
true, static_cast<T>(1.0), &out_g, true, static_cast<T>(1.0), &out_g,
static_cast<T>(0.0)); static_cast<T>(0.0));
/* projection weight backward*/
if (proj_weight_g) {
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
Tensor gate = batch_gate->Slice(bstart, bend); Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend);
@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), &pre_proj_g, static_cast<T>(1.0), &pre_proj_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
if (weight_g) { if (weight_g) {
/* backward weight */ /* weight backward*/
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g, math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
false, static_cast<T>(1.0), weight_g, false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
if (proj_weight_g) {
/* backward proj weigh */
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev); proj0_g_dev);
} }
// Tensor proj0_g = proj_g.Slice(bstart, bend);
if (h0_g) { if (h0_g) {
math::matmul<DeviceContext, T>( math::matmul<DeviceContext, T>(
device_ctx, proj0_g, false, *proj_weight, true, device_ctx, proj0_g, false, *proj_weight, true,

@ -207,8 +207,8 @@ class TestLstmOp(OpTest):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3) max_relative_error=1e-2)
class TestLstmOpHasInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'],
max_relative_error=5e-3) ['Projection'],
max_relative_error=1e-2)
def test_check_grad_ingore_bias(self): def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight'], ['Projection'], ['Input', 'ProjWeight', 'Weight'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Bias')) no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self): def test_check_grad_ingore_weight(self):
@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Bias'], ['Projection'], ['Input', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Weight')) no_grad_set=set('Weight'))
def test_check_grad_ingore_proj_weight(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('ProjWeight'))
def test_check_grad_ingore_input(self): def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Weight', 'Bias'], ['Projection'], ['Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Input')) no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self): def test_check_grad_ingore_h0(self):
@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('H0')) no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self): def test_check_grad_ingore_c0(self):
@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('C0')) no_grad_set=set('C0'))

Loading…
Cancel
Save