refine gru op unit test

createGenDocLib
tensor-tang 7 years ago
parent 6f78fd7d1e
commit 5424a698df

@ -19,22 +19,19 @@ import numpy as np
import math import math
import functools import functools
from op_test import OpTest from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu from test_lstm_op import ACTIVATION
class TestGRUOp(OpTest): def gru(
lod = [[2, 4, 3]] input, # T x 3D
batch_size = sum(lod[0]) lod, # 1 x N
frame_size = 5 h0, # N x D
activate = { weight, # D x 3D
'identity': identity, bias, # 1 x 3D
'sigmoid': sigmoid, is_reverse,
'tanh': tanh, act_state,
'relu': relu act_gate):
} def _seq_to_batch(lod, is_reverse):
@staticmethod
def seq_to_batch(lod, is_reverse):
idx_in_seq_list = [] idx_in_seq_list = []
seq_lens = lod[0] seq_lens = lod[0]
seq_starts = [0] seq_starts = [0]
@ -56,44 +53,38 @@ class TestGRUOp(OpTest):
idx_in_seq_list.append(idx_in_seq) idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list, sorted_seqs return idx_in_seq_list, sorted_seqs
def gru_step(self, x, h_p, w, b): def _step(x, h_p, w, b, act_state, act_gate):
batch_size = x.shape[0] T = x.shape[0]
frame_size = w.shape[0] D = w.shape[0]
g = x + np.tile(b, (batch_size, 1)) g = x + np.tile(b, (T, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( w_u_r = w.flatten()[:D * D * 2].reshape((D, D * 2))
(frame_size, frame_size * 2)) u_r = act_gate(np.dot(h_p, w_u_r) + g[:, :D * 2])
u_r = self.activate[self.attrs['gate_activation']](np.dot( u = u_r[:, :D]
h_p, w_u_r) + g[:, :frame_size * 2]) r = u_r[:, D:D * 2]
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape( w_c = w.flatten()[D * D * 2:].reshape((D, D))
(frame_size, frame_size)) c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p h = u * c + (1 - u) * h_p
return g, r_h_p, h return g, r_h_p, h
def gru(self): T = sum(lod[0])
input, lod = self.inputs['Input'] N = len(lod[0])
w = self.inputs['Weight'] D = weight.shape[0]
b = self.inputs['Bias'] if 'Bias' in self.inputs else np.zeros( batch_gate = np.zeros((T, 3 * D), dtype='float64')
(1, self.frame_size * 3)) batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
batch_gate = self.outputs['BatchGate'] batch_hidden = np.zeros((T, D), dtype='float64')
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev'] hidden = np.zeros((T, D), dtype='float64')
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden'] idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
idx_in_seq_list = self.idx_in_seq_list h_p = h0[sorted_seqs]
h_p = self.inputs['H0'][ max_seq_len = len(idx_in_seq_list)
self.sorted_seqs] if 'H0' in self.inputs else np.zeros( assert len(idx_in_seq_list[0]) == N
(len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0 end_idx = 0
for batch_idx in range(num_batch): for batch_idx in range(max_seq_len):
x = input[idx_in_seq_list[batch_idx]] x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b) g, r_h_p, h = _step(x, h_p, weight, bias, act_state, act_gate)
if batch_idx < (num_batch - 1): if batch_idx < (max_seq_len - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])] h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx]) end_idx = start_idx + len(idx_in_seq_list[batch_idx])
@ -101,76 +92,86 @@ class TestGRUOp(OpTest):
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden return batch_gate, batch_reset_hidden_prev, batch_hidden, hidden
def set_data(self):
lod = self.lod
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
class TestGRUOp(OpTest):
def set_confs(self): def set_confs(self):
self.is_reverse = False pass
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
def setUp(self): def setUp(self):
self.op_type = "gru" self.op_type = "gru"
self.lod = [[2, 4, 3]]
self.D = 5
self.is_reverse = False
self.with_h0 = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.set_confs() self.set_confs()
self.set_data()
self.gru() T = sum(self.lod[0])
N = len(self.lod[0])
input = np.random.rand(T, 3 * self.D).astype('float64')
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0
self.outputs = {
'Hidden': (hidden, self.lod),
'BatchGate': batch_gate,
'BatchResetHiddenPrev': batch_reset_hidden_prev,
'BatchHidden': batch_hidden,
}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(atol=1e-8)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp): class TestGRUOpNoInitial(TestGRUOp):
def set_data(self): def set_confs(self):
super(TestGRUOpNoInitial, self).set_data() self.with_h0 = False
self.inputs.pop('H0')
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoBias(TestGRUOp):
def set_confs(self):
self.with_bias = False
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp): class TestGRUOpReverse(TestGRUOp):
def set_confs(self): def set_confs(self):
self.is_reverse = True self.is_reverse = True
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save