|
|
|
@ -31,7 +31,8 @@ def gru(
|
|
|
|
|
is_reverse,
|
|
|
|
|
act_state,
|
|
|
|
|
act_gate,
|
|
|
|
|
dtype='float32'):
|
|
|
|
|
dtype='float32',
|
|
|
|
|
origin_mode=False):
|
|
|
|
|
def _seq_to_batch(lod, is_reverse):
|
|
|
|
|
idx_in_seq_list = []
|
|
|
|
|
seq_lens = lod[0]
|
|
|
|
@ -66,7 +67,10 @@ def gru(
|
|
|
|
|
w_c = w.flatten()[D * D * 2:].reshape((D, D))
|
|
|
|
|
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
|
|
|
|
|
g = np.hstack((u_r, c))
|
|
|
|
|
h = u * c + (1 - u) * h_p
|
|
|
|
|
if origin_mode:
|
|
|
|
|
h = (1 - u) * c + u * h_p
|
|
|
|
|
else:
|
|
|
|
|
h = u * c + (1 - u) * h_p
|
|
|
|
|
return g, r_h_p, h
|
|
|
|
|
|
|
|
|
|
T = sum(lod[0])
|
|
|
|
@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
|
|
|
|
|
self.act_state = 'tanh'
|
|
|
|
|
self.act_gate = 'sigmoid'
|
|
|
|
|
self.dtype = 'float64'
|
|
|
|
|
self.origin_mode = False
|
|
|
|
|
self.set_confs()
|
|
|
|
|
|
|
|
|
|
T = sum(self.lod[0])
|
|
|
|
@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
|
|
|
|
|
|
|
|
|
|
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
|
|
|
|
|
input, self.lod, h0, weight, bias, self.is_reverse,
|
|
|
|
|
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
|
|
|
|
|
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
|
|
|
|
|
self.origin_mode)
|
|
|
|
|
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
|
|
|
|
|
|
|
|
|
|
if self.with_bias:
|
|
|
|
@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
|
|
|
|
|
self.attrs = {
|
|
|
|
|
'activation': self.act_state,
|
|
|
|
|
'gate_activation': self.act_gate,
|
|
|
|
|
'is_reverse': self.is_reverse
|
|
|
|
|
'is_reverse': self.is_reverse,
|
|
|
|
|
'origin_mode': self.origin_mode
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
|
|
|
|
|
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRUOriginMode(TestGRUOp):
|
|
|
|
|
def set_confs(self):
|
|
|
|
|
self.origin_mode = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRUOp2(TestGRUOp):
|
|
|
|
|
def set_confs(self):
|
|
|
|
|
self.D = 19
|
|
|
|
|
self.dtype = 'float32'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRUOp2OriginMode(TestGRUOp):
|
|
|
|
|
def set_confs(self):
|
|
|
|
|
self.D = 19
|
|
|
|
|
self.dtype = 'float32'
|
|
|
|
|
self.origin_mode = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRUOpNoInitial(TestGRUOp):
|
|
|
|
|
def set_confs(self):
|
|
|
|
|
self.with_h0 = False
|
|
|
|
@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
|
|
|
|
|
self.is_reverse = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRUOpReverseOriginMode(TestGRUOp):
|
|
|
|
|
def set_confs(self):
|
|
|
|
|
self.is_reverse = True
|
|
|
|
|
self.origin_mode = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|