add float16 check for gpu fakequant op

pull/8081/head
yuchaojie 4 years ago
parent 5c4940cdcc
commit 033e73ef12

@ -467,7 +467,10 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
@ -521,7 +524,10 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
@ -616,7 +622,10 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
@ -670,7 +679,10 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
return dout_shape
def infer_dtype(self, dout_type, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
if context.get_context('device_target') == "GPU":
valid_types = (mstype.float32,)
else:
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"dout": dout_type}, valid_types, self.name)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)

Loading…
Cancel
Save