|
|
|
@ -1241,9 +1241,9 @@ class ActsULQ(PrimitiveWithInfer):
|
|
|
|
|
def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
|
|
|
|
|
"""infer dtype of primitive"""
|
|
|
|
|
valid_types = [mstype.float32, mstype.float16]
|
|
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"clamp_min": clamp_min_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"clamp_max": clamp_max_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)
|
|
|
|
|
|
|
|
|
|
return x_dtype, mstype.bool_, mstype.bool_, x_dtype
|
|
|
|
|
|
|
|
|
@ -1267,7 +1267,7 @@ class ActsULQInputGrad(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
|
|
|
|
|
valid_types = [mstype.float32, mstype.float16]
|
|
|
|
|
validator.check_tensor_type_same({"y_grad": y_grad_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
|
|
|
|
|
return y_grad_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1386,9 +1386,9 @@ class WtsARQ(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
|
|
|
|
|
valid_types = [mstype.float32, mstype.float16]
|
|
|
|
|
validator.check_tensor_type_same({"w": w_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
|
|
|
|
|
return w_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|