|
|
|
@ -20,14 +20,44 @@ import math
|
|
|
|
|
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
import random
|
|
|
|
|
random.seed(2)
|
|
|
|
|
np.set_printoptions(threshold=np.inf)
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
SIGMOID_THRESHOLD_MIN = -40.0
|
|
|
|
|
SIGMOID_THRESHOLD_MAX = 13.0
|
|
|
|
|
EXP_MAX_INPUT = 40.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RandomWeight:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def updata_weight(self, hidden_size, input_size, dtype):
|
|
|
|
|
std = 1.0 / math.sqrt(hidden_size)
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
self.weight_ih = np.random.uniform(
|
|
|
|
|
low=-std, high=std, size=(4 * self.hidden_size,
|
|
|
|
|
self.input_size)).astype(dtype)
|
|
|
|
|
self.weight_hh = np.random.uniform(
|
|
|
|
|
low=-std, high=std, size=(4 * self.hidden_size,
|
|
|
|
|
self.hidden_size)).astype(dtype)
|
|
|
|
|
self.bias_ih = np.random.uniform(
|
|
|
|
|
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
|
|
|
|
|
self.bias_hh = np.random.uniform(
|
|
|
|
|
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weight = RandomWeight()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerMixin(object):
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
|
return self.forward(*args, **kwargs)
|
|
|
|
@ -51,16 +81,13 @@ class LSTMCell(LayerMixin):
|
|
|
|
|
self.bias = bias
|
|
|
|
|
self.dtype = np.float64
|
|
|
|
|
self.parameters = dict()
|
|
|
|
|
std = 1.0 / math.sqrt(hidden_size)
|
|
|
|
|
self.weight_ih = np.ones(
|
|
|
|
|
(4 * hidden_size, input_size), dtype=self.dtype)
|
|
|
|
|
self.weight_hh = np.ones((4 * hidden_size,
|
|
|
|
|
hidden_size)).astype(self.dtype)
|
|
|
|
|
self.weight_ih = weight.weight_ih
|
|
|
|
|
self.weight_hh = weight.weight_hh
|
|
|
|
|
self.parameters['weight_ih'] = self.weight_ih
|
|
|
|
|
self.parameters['weight_hh'] = self.weight_hh
|
|
|
|
|
if bias:
|
|
|
|
|
self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype)
|
|
|
|
|
self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype)
|
|
|
|
|
self.bias_ih = weight.bias_ih
|
|
|
|
|
self.bias_hh = weight.bias_hh
|
|
|
|
|
self.parameters['bias_ih'] = self.bias_ih
|
|
|
|
|
self.parameters['bias_hh'] = self.bias_hh
|
|
|
|
|
else:
|
|
|
|
@ -353,24 +380,26 @@ class LSTM(RNNMixin):
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
|
"core is not compiled with CUDA")
|
|
|
|
|
class TestCUDNNLstmOp(OpTest):
|
|
|
|
|
#TODO(GaoWei8): Need to satisfy the result through the new interface
|
|
|
|
|
def get_weight_names(self):
|
|
|
|
|
weight_names = []
|
|
|
|
|
for i in range(2 * self.num_layers):
|
|
|
|
|
weight_names.append('weight{}'.format(i))
|
|
|
|
|
for i in range(2 * self.num_layers):
|
|
|
|
|
weight_names.append('bias{}'.format(i))
|
|
|
|
|
return weight_names
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "cudnn_lstm"
|
|
|
|
|
self.dtype = np.float64
|
|
|
|
|
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
|
|
|
|
|
self.num_layers = 1
|
|
|
|
|
self.set_attrs()
|
|
|
|
|
|
|
|
|
|
seq_length = 12
|
|
|
|
|
batch_size = 5
|
|
|
|
|
input_size = 21
|
|
|
|
|
hidden_size = 21
|
|
|
|
|
|
|
|
|
|
input_weight_size = (hidden_size * hidden_size) * 4
|
|
|
|
|
hidden_weight_size = (hidden_size * hidden_size) * 4
|
|
|
|
|
weight_size = input_weight_size + hidden_weight_size
|
|
|
|
|
weight_size += hidden_size * 8
|
|
|
|
|
weight_size *= self.num_layers
|
|
|
|
|
|
|
|
|
|
input = np.random.uniform(
|
|
|
|
|
low=-0.1, high=0.1,
|
|
|
|
|
size=(seq_length, batch_size, input_size)).astype(self.dtype)
|
|
|
|
@ -379,17 +408,39 @@ class TestCUDNNLstmOp(OpTest):
|
|
|
|
|
input[9][3:][:] = 0
|
|
|
|
|
input[8][4:][:] = 0
|
|
|
|
|
|
|
|
|
|
weight.updata_weight(hidden_size, input_size, self.dtype)
|
|
|
|
|
rnn1 = LSTM(
|
|
|
|
|
input_size,
|
|
|
|
|
hidden_size,
|
|
|
|
|
self.num_layers,
|
|
|
|
|
num_layers=self.num_layers,
|
|
|
|
|
time_major=True,
|
|
|
|
|
direction="forward")
|
|
|
|
|
|
|
|
|
|
output, (last_hidden, last_cell) = rnn1(
|
|
|
|
|
input, sequence_length=self.sequence_length)
|
|
|
|
|
|
|
|
|
|
flat_w = np.ones((weight_size)).astype(self.dtype)
|
|
|
|
|
flat_w = []
|
|
|
|
|
num = 0
|
|
|
|
|
for i in range(self.num_layers):
|
|
|
|
|
if i == 0:
|
|
|
|
|
weight_ih = weight.weight_ih
|
|
|
|
|
else:
|
|
|
|
|
weight_ih = weight.weight_hh
|
|
|
|
|
flat_w.append(("weight" + str(num), weight_ih))
|
|
|
|
|
num += 1
|
|
|
|
|
for i in range(self.num_layers):
|
|
|
|
|
weight_hh = weight.weight_hh
|
|
|
|
|
flat_w.append(("weight" + str(num), weight_hh))
|
|
|
|
|
num += 1
|
|
|
|
|
num = 0
|
|
|
|
|
for i in range(self.num_layers):
|
|
|
|
|
bias_ih = weight.bias_ih
|
|
|
|
|
flat_w.append(("bias" + str(num), bias_ih))
|
|
|
|
|
num += 1
|
|
|
|
|
for i in range(self.num_layers):
|
|
|
|
|
bias_hh = weight.bias_hh
|
|
|
|
|
flat_w.append(("bias" + str(num), bias_hh))
|
|
|
|
|
num += 1
|
|
|
|
|
init_h = np.zeros((self.num_layers, batch_size,
|
|
|
|
|
hidden_size)).astype(self.dtype)
|
|
|
|
|
init_c = np.zeros((self.num_layers, batch_size,
|
|
|
|
@ -398,7 +449,7 @@ class TestCUDNNLstmOp(OpTest):
|
|
|
|
|
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'Input': input,
|
|
|
|
|
'W': flat_w,
|
|
|
|
|
'WeightList': flat_w,
|
|
|
|
|
'InitH': init_h,
|
|
|
|
|
'InitC': init_c,
|
|
|
|
|
'SequenceLength': self.sequence_length
|
|
|
|
@ -408,7 +459,7 @@ class TestCUDNNLstmOp(OpTest):
|
|
|
|
|
'is_bidirec': False,
|
|
|
|
|
'input_size': input_size,
|
|
|
|
|
'hidden_size': hidden_size,
|
|
|
|
|
'num_layers': 1,
|
|
|
|
|
'num_layers': self.num_layers,
|
|
|
|
|
}
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Out': output,
|
|
|
|
@ -428,16 +479,42 @@ class TestCUDNNLstmOp(OpTest):
|
|
|
|
|
|
|
|
|
|
def test_grad_with_place(self):
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
self.check_grad_with_place(place,
|
|
|
|
|
set(['Input', 'W', 'InitH', 'InitC']),
|
|
|
|
|
['Out', 'LastH', 'LastC'])
|
|
|
|
|
var_name_list = self.get_weight_names()
|
|
|
|
|
for var_name in var_name_list:
|
|
|
|
|
self.check_grad_with_place(
|
|
|
|
|
place,
|
|
|
|
|
set(['Input', var_name, 'InitH', 'InitC']),
|
|
|
|
|
['Out', 'LastH', 'LastC'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
|
"core is not compiled with CUDA")
|
|
|
|
|
class TestCUDNNLstmOp2(TestCUDNNLstmOp):
|
|
|
|
|
def set_attrs(self):
|
|
|
|
|
self.num_layers = 2
|
|
|
|
|
class TestCUDNNlstmAPI(unittest.TestCase):
|
|
|
|
|
def test_lstm(self):
|
|
|
|
|
seq_len = 20
|
|
|
|
|
batch_size = 5
|
|
|
|
|
hidden_size = 20
|
|
|
|
|
dropout_prob = 0.0
|
|
|
|
|
num_layers = 1
|
|
|
|
|
input = fluid.data(
|
|
|
|
|
name='input',
|
|
|
|
|
shape=[seq_len, batch_size, hidden_size],
|
|
|
|
|
dtype='float64')
|
|
|
|
|
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
|
|
|
|
|
'float64', 0.0)
|
|
|
|
|
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
|
|
|
|
|
'float64', 0.0)
|
|
|
|
|
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
|
|
|
|
|
hidden_size, num_layers,
|
|
|
|
|
dropout_prob, False)
|
|
|
|
|
exe = fluid.Executor(fluid.CUDAPlace(0))
|
|
|
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
|
input_i = np.random.uniform(
|
|
|
|
|
low=-0.1, high=0.1, size=(seq_len, batch_size,
|
|
|
|
|
hidden_size)).astype("float64")
|
|
|
|
|
out = exe.run(fluid.default_main_program(),
|
|
|
|
|
feed={'input': input_i},
|
|
|
|
|
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
@ -448,7 +525,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
|
|
|
|
|
batch_size = 5
|
|
|
|
|
hidden_size = 20
|
|
|
|
|
dropout_prob = 0.0
|
|
|
|
|
num_layers = 1
|
|
|
|
|
num_layers = 2
|
|
|
|
|
input = fluid.data(
|
|
|
|
|
name='input',
|
|
|
|
|
shape=[seq_len, batch_size, hidden_size],
|
|
|
|
|