!801 fix checking bug of prelu

Merge pull request !801 from fary86/fix_checking_bug_of_prelu
pull/801/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit eeb8e4d4d3

@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer):
return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype):
args = {"input_x": input_x_dtype, "weight": weight_dtype}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name)
return input_x_dtype

Loading…
Cancel
Save