From 5424a698df953eb68afd03a69d42842bf1deda61 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 17 Aug 2018 17:53:34 +0800 Subject: [PATCH] refine gru op unit test --- .../fluid/tests/unittests/test_gru_op.py | 207 +++++++++--------- 1 file changed, 104 insertions(+), 103 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 001fd7efb1..9f6f03f9cf 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -19,22 +19,19 @@ import numpy as np import math import functools from op_test import OpTest -from test_lstm_op import identity, sigmoid, tanh, relu - - -class TestGRUOp(OpTest): - lod = [[2, 4, 3]] - batch_size = sum(lod[0]) - frame_size = 5 - activate = { - 'identity': identity, - 'sigmoid': sigmoid, - 'tanh': tanh, - 'relu': relu - } - - @staticmethod - def seq_to_batch(lod, is_reverse): +from test_lstm_op import ACTIVATION + + +def gru( + input, # T x 3D + lod, # 1 x N + h0, # N x D + weight, # D x 3D + bias, # 1 x 3D + is_reverse, + act_state, + act_gate): + def _seq_to_batch(lod, is_reverse): idx_in_seq_list = [] seq_lens = lod[0] seq_starts = [0] @@ -56,121 +53,125 @@ class TestGRUOp(OpTest): idx_in_seq_list.append(idx_in_seq) return idx_in_seq_list, sorted_seqs - def gru_step(self, x, h_p, w, b): - batch_size = x.shape[0] - frame_size = w.shape[0] - g = x + np.tile(b, (batch_size, 1)) - w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( - (frame_size, frame_size * 2)) - u_r = self.activate[self.attrs['gate_activation']](np.dot( - h_p, w_u_r) + g[:, :frame_size * 2]) - u = u_r[:, :frame_size] - r = u_r[:, frame_size:frame_size * 2] + def _step(x, h_p, w, b, act_state, act_gate): + T = x.shape[0] + D = w.shape[0] + g = x + np.tile(b, (T, 1)) + w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2)) + u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2]) + u = u_r[:, :D] + r = u_r[:, D:D * 2] r_h_p = r * h_p - w_c = w.flatten()[frame_size * frame_size * 2:].reshape( - (frame_size, frame_size)) - c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + - g[:, frame_size * 2:]) + w_c = w.flatten()[D * D * 2:].reshape((D, D)) + c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:]) g = np.hstack((u_r, c)) h = u * c + (1 - u) * h_p return g, r_h_p, h - def gru(self): - input, lod = self.inputs['Input'] - w = self.inputs['Weight'] - b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros( - (1, self.frame_size * 3)) - batch_gate = self.outputs['BatchGate'] - batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev'] - batch_hidden = self.outputs['BatchHidden'] - hidden = self.outputs['Hidden'] - idx_in_seq_list = self.idx_in_seq_list - h_p = self.inputs['H0'][ - self.sorted_seqs] if 'H0' in self.inputs else np.zeros( - (len(idx_in_seq_list[0]), self.frame_size)) - num_batch = len(idx_in_seq_list) - end_idx = 0 - for batch_idx in range(num_batch): - x = input[idx_in_seq_list[batch_idx]] - g, r_h_p, h = self.gru_step(x, h_p, w, b) - if batch_idx < (num_batch - 1): - h_p = h[:len(idx_in_seq_list[batch_idx + 1])] - start_idx = end_idx - end_idx = start_idx + len(idx_in_seq_list[batch_idx]) - batch_gate[start_idx:end_idx] = g - batch_reset_hidden_prev[start_idx:end_idx] = r_h_p - batch_hidden[start_idx:end_idx] = h - hidden[idx_in_seq_list[batch_idx]] = h - return batch_gate, batch_reset_hidden_prev, hidden - - def set_data(self): - lod = self.lod - self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch( - lod, self.is_reverse) - batch_size = self.batch_size - frame_size = self.frame_size - input = np.random.rand(batch_size, frame_size * 3).astype('float64') - h0 = np.random.rand(len(self.idx_in_seq_list[0]), - frame_size).astype('float64') - weight = np.random.rand(frame_size, frame_size * 3).astype('float64') - bias = np.random.rand(1, frame_size * 3).astype('float64') - - self.inputs = { - 'Input': (input, lod), - 'H0': h0, - 'Weight': weight, - 'Bias': bias - } + T = sum(lod[0]) + N = len(lod[0]) + D = weight.shape[0] + batch_gate = np.zeros((T, 3 * D), dtype='float64') + batch_reset_hidden_prev = np.zeros((T, D), dtype='float64') + batch_hidden = np.zeros((T, D), dtype='float64') + hidden = np.zeros((T, D), dtype='float64') + + idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse) + h_p = h0[sorted_seqs] + max_seq_len = len(idx_in_seq_list) + assert len(idx_in_seq_list[0]) == N + end_idx = 0 + for batch_idx in range(max_seq_len): + x = input[idx_in_seq_list[batch_idx]] + g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate) + if batch_idx < (max_seq_len - 1): + h_p = h[:len(idx_in_seq_list[batch_idx + 1])] + start_idx = end_idx + end_idx = start_idx + len(idx_in_seq_list[batch_idx]) + batch_gate[start_idx:end_idx] = g + batch_reset_hidden_prev[start_idx:end_idx] = r_h_p + batch_hidden[start_idx:end_idx] = h + hidden[idx_in_seq_list[batch_idx]] = h + return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden - self.outputs = { - 'BatchGate': np.zeros( - (batch_size, frame_size * 3), dtype='float64'), - 'BatchResetHiddenPrev': np.zeros( - (batch_size, frame_size), dtype='float64'), - 'BatchHidden': np.zeros( - (batch_size, frame_size), dtype='float64'), - 'Hidden': np.zeros( - (batch_size, frame_size), dtype='float64') - } +class TestGRUOp(OpTest): def set_confs(self): - self.is_reverse = False - self.attrs = { - 'activation': 'tanh', - 'gate_activation': 'sigmoid', - 'is_reverse': self.is_reverse - } + pass def setUp(self): self.op_type = "gru" + self.lod = [[2, 4, 3]] + self.D = 5 + self.is_reverse = False + self.with_h0 = True + self.with_bias = True + self.act_state = 'tanh' + self.act_gate = 'sigmoid' self.set_confs() - self.set_data() - self.gru() + + T = sum(self.lod[0]) + N = len(self.lod[0]) + + input = np.random.rand(T, 3 * self.D).astype('float64') + weight = np.random.rand(self.D, 3 * self.D).astype('float64') + bias = np.random.rand( + 1, 3 * self.D).astype('float64') if self.with_bias else np.zeros( + (1, 3 * self.D), dtype='float64') + h0 = np.random.rand( + N, self.D).astype('float64') if self.with_h0 else np.zeros( + (N, self.D), dtype='float64') + + batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( + input, self.lod, h0, weight, bias, self.is_reverse, + ACTIVATION[self.act_state], ACTIVATION[self.act_gate]) + self.inputs = {'Input': (input, self.lod), 'Weight': weight} + + if self.with_bias: + self.inputs['Bias'] = bias + + if self.with_h0: + self.inputs['H0'] = h0 + + self.outputs = { + 'Hidden': (hidden, self.lod), + 'BatchGate': batch_gate, + 'BatchResetHiddenPrev': batch_reset_hidden_prev, + 'BatchHidden': batch_hidden, + } + + self.attrs = { + 'activation': self.act_state, + 'gate_activation': self.act_gate, + 'is_reverse': self.is_reverse + } def test_check_output(self): - self.check_output() + self.check_output(atol=1e-8) def test_check_grad(self): self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) class TestGRUOpNoInitial(TestGRUOp): - def set_data(self): - super(TestGRUOpNoInitial, self).set_data() - self.inputs.pop('H0') + def set_confs(self): + self.with_h0 = False def test_check_grad(self): self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) +class TestGRUOpNoBias(TestGRUOp): + def set_confs(self): + self.with_bias = False + + def test_check_grad(self): + self.check_grad(['Input', 'H0', 'Weight'], ['Hidden']) + + class TestGRUOpReverse(TestGRUOp): def set_confs(self): self.is_reverse = True - self.attrs = { - 'activation': 'tanh', - 'gate_activation': 'sigmoid', - 'is_reverse': self.is_reverse - } if __name__ == "__main__":