|
|
@ -2066,8 +2066,9 @@ class PReLU(PrimitiveWithInfer):
|
|
|
|
return input_x_shape
|
|
|
|
return input_x_shape
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x_dtype, weight_dtype):
|
|
|
|
def infer_dtype(self, input_x_dtype, weight_dtype):
|
|
|
|
args = {"input_x": input_x_dtype, "weight": weight_dtype}
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
|
|
|
|
validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name)
|
|
|
|
return input_x_dtype
|
|
|
|
return input_x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|