|
|
|
@ -1076,20 +1076,17 @@ class TestDygraphTransformer(unittest.TestCase):
|
|
|
|
|
4]] = out[k]
|
|
|
|
|
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_avg_cost_value, dy_avg_cost._numpy()))
|
|
|
|
|
np.array_equal(static_avg_cost_value, dy_avg_cost._numpy()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_sum_cost_value, dy_sum_cost._numpy()))
|
|
|
|
|
np.array_equal(static_sum_cost_value, dy_sum_cost._numpy()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(
|
|
|
|
|
static_predict_value, dy_predict._numpy(), atol=1e-5))
|
|
|
|
|
np.array_equal(static_predict_value, dy_predict._numpy()))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(static_token_num_value, dy_token_num._numpy()))
|
|
|
|
|
np.array_equal(static_token_num_value, dy_token_num._numpy()))
|
|
|
|
|
for key, value in six.iteritems(static_param_init):
|
|
|
|
|
self.assertTrue(np.allclose(value, dy_param_init[key]))
|
|
|
|
|
self.assertTrue(np.array_equal(value, dy_param_init[key]))
|
|
|
|
|
for key, value in six.iteritems(static_param_updated):
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
np.allclose(
|
|
|
|
|
value, dy_param_updated[key], atol=1e-4))
|
|
|
|
|
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|