|
|
|
@ -517,7 +517,7 @@ class DecoderSubLayer(Layer):
|
|
|
|
|
y = self._preprocess_layer(None, input, "n", 0.1)
|
|
|
|
|
slf_attn_output = self._multihead_attention_layer(y, None, None,
|
|
|
|
|
slf_attn_bias)
|
|
|
|
|
return slf_attn_output
|
|
|
|
|
return slf_attn_output, y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
@ -536,7 +536,7 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
dy_param_init = dict()
|
|
|
|
|
dy_param_updated = dict()
|
|
|
|
|
for i in range(batch_num):
|
|
|
|
|
loss = transformer(to_variable(x1), to_variable(x2))
|
|
|
|
|
loss, y = transformer(to_variable(x1), to_variable(x2))
|
|
|
|
|
loss = fluid.layers.reduce_sum(loss)
|
|
|
|
|
print('dy los', loss.shape)
|
|
|
|
|
if i == 0:
|
|
|
|
@ -545,6 +545,7 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
loss._backward()
|
|
|
|
|
optimizer.minimize(loss)
|
|
|
|
|
dy_key_value = y._gradient()
|
|
|
|
|
transformer.clear_gradients()
|
|
|
|
|
if i == batch_num - 1:
|
|
|
|
|
for param in transformer.parameters():
|
|
|
|
@ -563,7 +564,7 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
data1 = fluid.layers.data(name='X', shape=[4, 512], dtype='float32')
|
|
|
|
|
data2 = fluid.layers.data(
|
|
|
|
|
name='Y', shape=[8, 4, 4], dtype='float32')
|
|
|
|
|
loss = transformer(data1, data2)
|
|
|
|
|
loss, y = transformer(data1, data2)
|
|
|
|
|
loss = fluid.layers.reduce_sum(loss)
|
|
|
|
|
print('loss hspae', loss.shape)
|
|
|
|
|
|
|
|
|
@ -580,24 +581,33 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
for i in range(len(static_param_name_list)):
|
|
|
|
|
static_param_init[static_param_name_list[i]] = out[i]
|
|
|
|
|
|
|
|
|
|
print(fluid.default_main_program())
|
|
|
|
|
for i in range(batch_num):
|
|
|
|
|
feed_dict = {"X": x1, "Y": x2}
|
|
|
|
|
fetch_list = []
|
|
|
|
|
fetch_list = [
|
|
|
|
|
"transformer/DecoderSubLayer_0/PrePostProcessLayer_0/LayerNorm_0.tmp_2@GRAD"
|
|
|
|
|
]
|
|
|
|
|
fetch_list.extend(static_param_name_list)
|
|
|
|
|
|
|
|
|
|
out = exe.run(fluid.default_main_program(),
|
|
|
|
|
feed=feed_dict,
|
|
|
|
|
fetch_list=fetch_list)
|
|
|
|
|
if i == batch_num - 1:
|
|
|
|
|
for k in range(0, len(out)):
|
|
|
|
|
static_key_value = out[0]
|
|
|
|
|
for k in range(1, len(out)):
|
|
|
|
|
static_param_updated[static_param_name_list[k -
|
|
|
|
|
0]] = out[k]
|
|
|
|
|
1]] = out[k]
|
|
|
|
|
|
|
|
|
|
for key, value in six.iteritems(static_param_init):
|
|
|
|
|
self.assertTrue(np.array_equal(value, dy_param_init[key]))
|
|
|
|
|
for key, value in six.iteritems(static_param_updated):
|
|
|
|
|
if not (value == dy_param_updated[key]).all():
|
|
|
|
|
print(key)
|
|
|
|
|
if not np.array_equal(dy_key_value, static_key_value):
|
|
|
|
|
print("xxx", dy_key_value, static_key_value)
|
|
|
|
|
print("yyy")
|
|
|
|
|
print(dy_key_value - static_key_value)
|
|
|
|
|
print(np.where(dy_key_value - static_key_value))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|