|
|
|
@ -116,15 +116,17 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
|
|
|
|
|
>>> _max = Tensor(np.array([2]), mindspore.float32)
|
|
|
|
|
>>> result = fake_min_max_grad(dout, input_x, _min, _max)
|
|
|
|
|
"""
|
|
|
|
|
support_quant_bit = [4, 8]
|
|
|
|
|
support_quant_bit = [4, 7, 8]
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, num_bits=8, quant_delay=0):
|
|
|
|
|
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
|
|
|
|
|
if num_bits not in self.support_quant_bit:
|
|
|
|
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
|
|
|
|
|
|
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
|
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
|
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
|
|
|
|
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
|
|
|
|
|
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
|
|
|
@ -172,7 +174,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
|
|
|
|
|
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
|
|
|
|
|
>>> result = fake_quant(input_x, _min, _max)
|
|
|
|
|
"""
|
|
|
|
|
support_quant_bit = [4, 8]
|
|
|
|
|
support_quant_bit = [4, 7, 8]
|
|
|
|
|
channel_axis = 0
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -219,16 +221,18 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
|
|
|
|
|
>>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
|
|
|
|
|
>>> result = fqmmpc_grad(dout, input_x, _min, _max)
|
|
|
|
|
"""
|
|
|
|
|
support_quant_bit = [4, 8]
|
|
|
|
|
support_quant_bit = [4, 7, 8]
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, num_bits=8, quant_delay=0):
|
|
|
|
|
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
|
|
|
|
|
"""init FakeQuantWithMinMaxPerChannel Fill"""
|
|
|
|
|
if num_bits not in self.support_quant_bit:
|
|
|
|
|
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
|
|
|
|
|
|
|
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
|
|
|
|
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
|
|
|
|
|
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
|
|
|
|
|
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
|
|
|
|
|
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
|
|
|
|
|
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
|
|
|
|
|