From 7714a7fa878ba45c1e6b840ac8797a9e9c08c4b2 Mon Sep 17 00:00:00 2001 From: fary86 Date: Tue, 28 Apr 2020 16:35:32 +0800 Subject: [PATCH] Fix checking bug of PReLU --- mindspore/ops/operations/nn_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 66656b559e..6ae6c68844 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer): return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - args = {"input_x": input_x_dtype, "weight": weight_dtype} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name) return input_x_dtype