|
|
|
@ -467,6 +467,9 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type):
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
@ -521,6 +524,9 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
|
|
|
|
|
return dout_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"dout": dout_type}, valid_types, self.name)
|
|
|
|
@ -616,6 +622,9 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
|
|
return x_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x_type, min_type, max_type):
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
@ -670,6 +679,9 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
|
|
|
|
|
return dout_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, dout_type, x_type, min_type, max_type):
|
|
|
|
|
if context.get_context('device_target') == "GPU":
|
|
|
|
|
valid_types = (mstype.float32,)
|
|
|
|
|
else:
|
|
|
|
|
valid_types = (mstype.float16, mstype.float32)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"dout": dout_type}, valid_types, self.name)
|
|
|
|
|