fix teacher_student_sigmoid_loss dtype check, test=develop (#24586)

v1.8
Bai Yifan 5 years ago committed by GitHub
parent 7fa9f16c17
commit c417f991c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1536,9 +1536,11 @@ def teacher_student_sigmoid_loss(input,
cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)
"""
check_variable_and_dtype(input, "input", ['float32', 'float64'],
check_variable_and_dtype(input, "input",
['float32', 'float64', 'int32', 'int64'],
'teacher_student_sigmoid_loss')
check_variable_and_dtype(label, "label", ['float32', 'float64'],
check_variable_and_dtype(label, "label",
['float32', 'float64', 'int32', 'int64'],
'teacher_student_sigmoid_loss')
helper = LayerHelper('teacher_student_sigmoid_loss', **locals())

Loading…
Cancel
Save