|
|
|
@ -114,26 +114,20 @@ def lstm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLstmOp(OpTest):
|
|
|
|
|
def set_data(self):
|
|
|
|
|
# self.lod = [[0, 2, 6, 9]]
|
|
|
|
|
# self.D = 64
|
|
|
|
|
# self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
|
|
|
|
|
|
|
|
|
self.lod = [[0, 1]]
|
|
|
|
|
self.D = 4
|
|
|
|
|
self.sort_idx = [0]
|
|
|
|
|
|
|
|
|
|
# self.act_gate = 'identity'
|
|
|
|
|
# self.act_cell = 'identity'
|
|
|
|
|
# self.act_cand = 'identity'
|
|
|
|
|
def set_argument(self):
|
|
|
|
|
self.lod = [[0, 2, 6, 9]]
|
|
|
|
|
self.D = 16
|
|
|
|
|
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
|
|
|
|
|
|
|
|
|
self.act_gate = 'sigmoid'
|
|
|
|
|
self.act_cell = 'tanh'
|
|
|
|
|
self.act_cand = 'tanh'
|
|
|
|
|
|
|
|
|
|
self.has_initial_state = True
|
|
|
|
|
self.is_reverse = False
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.set_data()
|
|
|
|
|
self.set_argument()
|
|
|
|
|
self.op_type = 'lstm'
|
|
|
|
|
|
|
|
|
|
T = self.lod[0][-1]
|
|
|
|
@ -155,17 +149,14 @@ class TestLstmOp(OpTest):
|
|
|
|
|
for i, j in enumerate(self.sort_idx):
|
|
|
|
|
g_sort[i, :] = g[j, :]
|
|
|
|
|
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'Input': (x, self.lod),
|
|
|
|
|
'H0': h0,
|
|
|
|
|
'C0': c0,
|
|
|
|
|
'Weight': w,
|
|
|
|
|
'Bias': b
|
|
|
|
|
}
|
|
|
|
|
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
|
|
|
|
|
self.inputs['H0'] = h0
|
|
|
|
|
self.inputs['C0'] = c0
|
|
|
|
|
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Hidden': (h, self.lod),
|
|
|
|
|
'Cell': (c, self.lod),
|
|
|
|
|
#'BatchGate': g_sort,
|
|
|
|
|
'BatchGate': g_sort,
|
|
|
|
|
}
|
|
|
|
|
self.attrs = {
|
|
|
|
|
'usePeepholes': True,
|
|
|
|
@ -175,26 +166,43 @@ class TestLstmOp(OpTest):
|
|
|
|
|
'candidateActivation': self.act_cand
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def not_test_check_output(self):
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
#TODO(qingqing) add more unit testing case
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
# TODO(qingqing) remove folowing two lines after the check_grad is refined.
|
|
|
|
|
self.outputs['BatchGate'] = None
|
|
|
|
|
self.outputs['BatchCellPreAct'] = None
|
|
|
|
|
self.check_grad(['Input', 'Weight'], ['Hidden', 'Cell'])
|
|
|
|
|
#['Input', 'Weight', 'Bias'], ['Hidden', 'Cell'])
|
|
|
|
|
|
|
|
|
|
#class TestLstmOpRerverse(TestLstmOp):
|
|
|
|
|
# def set_data(self):
|
|
|
|
|
# self.lod = [[0, 2, 6, 9]]
|
|
|
|
|
# self.D = 64
|
|
|
|
|
# self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
|
|
|
|
#
|
|
|
|
|
# self.act_gate = 'sigmoid'
|
|
|
|
|
# self.act_cell = 'tanh'
|
|
|
|
|
# self.act_cand = 'tanh'
|
|
|
|
|
#
|
|
|
|
|
# self.is_reverse = True
|
|
|
|
|
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLstmOpHasNoInitial(TestLstmOp):
|
|
|
|
|
def set_argument(self):
|
|
|
|
|
self.lod = [[0, 2, 6, 9]]
|
|
|
|
|
self.D = 64
|
|
|
|
|
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
|
|
|
|
|
|
|
|
|
self.act_gate = 'sigmoid'
|
|
|
|
|
self.act_cell = 'tanh'
|
|
|
|
|
self.act_cand = 'tanh'
|
|
|
|
|
|
|
|
|
|
self.has_initial_state = False
|
|
|
|
|
self.is_reverse = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLstmOpRerverse(TestLstmOp):
|
|
|
|
|
def set_argument(self):
|
|
|
|
|
self.lod = [[0, 2, 6, 9]]
|
|
|
|
|
self.D = 64
|
|
|
|
|
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
|
|
|
|
|
|
|
|
|
|
self.act_gate = 'sigmoid'
|
|
|
|
|
self.act_cell = 'tanh'
|
|
|
|
|
self.act_cand = 'tanh'
|
|
|
|
|
|
|
|
|
|
self.has_initial_state = True
|
|
|
|
|
self.is_reverse = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|