|
|
|
@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase):
|
|
|
|
|
pred_loss = paddle.nn.functional.kl_div(input, label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestKLDivLossTypePromotion(unittest.TestCase):
|
|
|
|
|
def test_kl_div_promotion(self):
|
|
|
|
|
|
|
|
|
|
with paddle.fluid.dygraph.guard():
|
|
|
|
|
x1 = paddle.rand([5, 20], dtype='float32')
|
|
|
|
|
target1 = paddle.rand([5, 20], dtype='float64')
|
|
|
|
|
|
|
|
|
|
kldiv_criterion = paddle.nn.KLDivLoss()
|
|
|
|
|
pred_loss1 = kldiv_criterion(x1, target1)
|
|
|
|
|
|
|
|
|
|
x2 = paddle.rand([5, 20], dtype='float64')
|
|
|
|
|
target2 = paddle.rand([5, 20], dtype='float32')
|
|
|
|
|
pred_loss2 = paddle.nn.functional.kl_div(x2, target2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|