|
|
|
@ -467,7 +467,10 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type):
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"min": min_type}, valid_types, self.name)
|
|
|
|
@ -521,7 +524,10 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
|
|
|
|
return dout_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"dout": dout_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
@ -616,7 +622,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type):
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"min": min_type}, valid_types, self.name)
|
|
|
|
@ -670,7 +679,10 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
|
|
|
|
|
return dout_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"dout": dout_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
|