|
|
|
@ -505,8 +505,9 @@ class DistributionTest(unittest.TestCase):
|
|
|
|
|
feed={'logits': logits_np},
|
|
|
|
|
fetch_list=[entropy_np, kl_np])
|
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
|
output_entropy_np, gt_entropy_np, rtol=tolerance)
|
|
|
|
|
np.testing.assert_allclose(output_kl_np, gt_kl_np, rtol=tolerance)
|
|
|
|
|
output_entropy_np, gt_entropy_np, rtol=tolerance, atol=tolerance)
|
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
|
output_kl_np, gt_kl_np, rtol=tolerance, atol=tolerance)
|
|
|
|
|
|
|
|
|
|
def test_multivariateNormalDiag_distribution(self,
|
|
|
|
|
batch_size=2,
|
|
|
|
@ -568,8 +569,9 @@ class DistributionTest(unittest.TestCase):
|
|
|
|
|
},
|
|
|
|
|
fetch_list=[entropy_np, kl_np])
|
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
|
output_entropy_np, gt_entropy_np, rtol=tolerance)
|
|
|
|
|
np.testing.assert_allclose(output_kl_np, gt_kl_np, rtol=tolerance)
|
|
|
|
|
output_entropy_np, gt_entropy_np, rtol=tolerance, atol=tolerance)
|
|
|
|
|
np.testing.assert_allclose(
|
|
|
|
|
output_kl_np, gt_kl_np, rtol=tolerance, atol=tolerance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|