add type promotion (#27756)

my_2.0rc
LielinJiang 4 years ago committed by GitHub
parent 9089841b6e
commit b9c7c66ea5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase):
pred_loss = paddle.nn.functional.kl_div(input, label)
class TestKLDivLossTypePromotion(unittest.TestCase):
def test_kl_div_promotion(self):
with paddle.fluid.dygraph.guard():
x1 = paddle.rand([5, 20], dtype='float32')
target1 = paddle.rand([5, 20], dtype='float64')
kldiv_criterion = paddle.nn.KLDivLoss()
pred_loss1 = kldiv_criterion(x1, target1)
x2 = paddle.rand([5, 20], dtype='float64')
target2 = paddle.rand([5, 20], dtype='float32')
pred_loss2 = paddle.nn.functional.kl_div(x2, target2)
if __name__ == "__main__":
unittest.main()

@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None):
# shape=[5, 20]
"""
# ugly type promotion
if fluid.data_feeder.convert_dtype(
input.dtype) == 'float32' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float64':
input = fluid.layers.cast(input, 'float64')
elif fluid.data_feeder.convert_dtype(
input.dtype) == 'float64' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float32':
label = fluid.layers.cast(label, 'float64')
if paddle.in_dynamic_mode():
out = core.ops.kldiv_loss(input, label, 'reduction', reduction)
return out

Loading…
Cancel
Save