Update the error message for the margin_ranking_loss

Update the error message for the margin_ranking_loss
test_feature_precision_test_c
wawltor 5 years ago committed by GitHub
parent 94b05850d2
commit ecfb89e133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -177,6 +177,16 @@ class MarginRakingLossError(unittest.TestCase):
self.assertRaises(ValueError, test_margin_value_error)
def test_functional_margin_value_error():
x = paddle.static.data(name="x", shape=[10, 10], dtype="float64")
y = paddle.static.data(name="y", shape=[10, 10], dtype="float64")
label = paddle.static.data(
name="label", shape=[10, 10], dtype="float64")
result = paddle.nn.functional.margin_ranking_loss(
x, y, label, margin=0.1, reduction="reduction_mean")
self.assertRaises(ValueError, test_functional_margin_value_error)
if __name__ == "__main__":
unittest.main()

@ -338,6 +338,10 @@ def margin_ranking_loss(input,
loss = paddle.nn.functional.margin_ranking_loss(input, other, label)
print(loss.numpy()) # [0.75]
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in MarginRankingLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
if fluid.framework.in_dygraph_mode():
out = core.ops.elementwise_sub(other, input)
out = core.ops.elementwise_mul(out, label)

Loading…
Cancel
Save