Fix checking bug of PReLU

pull/801/head
fary86 5 years ago
parent 64abbeaa89
commit 7714a7fa87

@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer):
return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype):
args = {"input_x": input_x_dtype, "weight": weight_dtype}
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name)
return input_x_dtype

Loading…
Cancel
Save