diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 66656b559e..6ae6c68844 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer): return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - args = {"input_x": input_x_dtype, "weight": weight_dtype} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name) return input_x_dtype