|
|
|
@ -13,39 +13,13 @@
|
|
|
|
|
#limitations under the License.
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
|
|
|
|
|
SIGMOID_THRESHOLD_MIN = -40.0
|
|
|
|
|
SIGMOID_THRESHOLD_MAX = 13.0
|
|
|
|
|
EXP_MAX_INPUT = 40.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def identity(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sigmoid(x):
|
|
|
|
|
y = np.copy(x)
|
|
|
|
|
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
|
|
|
|
|
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
|
|
|
|
|
return 1. / (1. + np.exp(-y))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tanh(x):
|
|
|
|
|
y = -2. * x
|
|
|
|
|
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
|
|
|
|
|
return (2. / (1. + np.exp(y))) - 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def relu(x):
|
|
|
|
|
return np.maximum(x, 0)
|
|
|
|
|
|
|
|
|
|
import test_lstm_op as LstmTest
|
|
|
|
|
|
|
|
|
|
ACTIVATION = {
|
|
|
|
|
'identity': identity,
|
|
|
|
|
'sigmoid': sigmoid,
|
|
|
|
|
'tanh': tanh,
|
|
|
|
|
'relu': relu
|
|
|
|
|
'identity': LstmTest.identity,
|
|
|
|
|
'sigmoid': LstmTest.sigmoid,
|
|
|
|
|
'tanh': LstmTest.tanh,
|
|
|
|
|
'relu': LstmTest.relu
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -55,7 +29,7 @@ def lstmp(
|
|
|
|
|
lod, # 1 x N
|
|
|
|
|
h0=None, # N x D
|
|
|
|
|
c0=None, # N x D
|
|
|
|
|
w_r=None, # P x 5D
|
|
|
|
|
w_r=None, # P x 4D
|
|
|
|
|
w_rh=None, # D x P
|
|
|
|
|
w_b=None, # 1 x 4D
|
|
|
|
|
w_c=None, # 1 x 3D
|
|
|
|
@ -130,26 +104,16 @@ def lstmp(
|
|
|
|
|
return projection, cell
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLstmpOp(OpTest):
|
|
|
|
|
class TestLstmpOp(LstmTest.TestLstmOp):
|
|
|
|
|
def reset_argument(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.lod = [[0, 2, 5, 7]]
|
|
|
|
|
# hidden size
|
|
|
|
|
self.D = 16
|
|
|
|
|
self.set_argument()
|
|
|
|
|
# projection size
|
|
|
|
|
self.P = 10
|
|
|
|
|
|
|
|
|
|
self.act_gate = 'sigmoid'
|
|
|
|
|
self.act_cell = 'tanh'
|
|
|
|
|
self.act_cand = 'tanh'
|
|
|
|
|
self.act_proj = self.act_cell
|
|
|
|
|
|
|
|
|
|
self.has_initial_state = False
|
|
|
|
|
self.is_reverse = False
|
|
|
|
|
self.use_peepholes = True
|
|
|
|
|
|
|
|
|
|
self.reset_argument()
|
|
|
|
|
self.op_type = 'lstmp'
|
|
|
|
|
|
|
|
|
|