!1784 fix bug in fake quant grad

Merge pull request !1784 from SanjayChan/fakequant_bug_fix
pull/1784/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 90eedfb351

@ -30,7 +30,9 @@ FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
quant_max_(0),
quant_size_(0),
quant_delay_(0),
global_step_(0) {}
global_step_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
@ -59,8 +61,19 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0.";
}
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (narrow_range_) {
quant_min_++;
}
if (quant_size_ == 0) {
quant_size_ = 1;

@ -54,6 +54,8 @@ class FakeQuantGradGpuKernel : public GpuKernel {
int quant_size_;
int quant_delay_;
int global_step_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore

@ -35,6 +35,8 @@ fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
.partial_flag(True) \
.attr("num_bits", "optional", "int", "all") \
.attr("quant_delay", "optional", "int", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \
.input(2, "min", None, "required", None) \
@ -104,8 +106,9 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return res
@util.check_input_type(dict, dict, dict, dict, dict, int, int, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay,
@util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str)
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
num_bits, quant_delay, symmetric, narrow_range,
kernel_name="fake_quant_with_min_max_grad"):
"""FakeQuantWithMinMaxGrad"""
input_shape = x.get("shape")
@ -136,8 +139,15 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
quant_min = 0
quant_max = 2 ** num_bits - 1
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)

@ -23,10 +23,10 @@ from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("fake_quant_with_min_max_update5d.so") \
.binfile_name("fake_quant_with_min_max_update.so") \
.compute_cost(10) \
.kernel_name("fake_quant_with_min_max_update") \
.partial_flag(True) \
@ -47,9 +47,9 @@ fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.get_op_info()
@op_info_register(fake_quant_update5d_op_info)
def _fake_quant_update5d_tbe():
"""_FakeQuantWithMinMaxUpdate5D TBE register"""
@op_info_register(fake_quant_update_op_info)
def _fake_quant_update_tbe():
"""_FakeQuantWithMinMaxUpdate TBE register"""
return

@ -116,15 +116,17 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
>>> _max = Tensor(np.array([2]), mindspore.float32)
>>> result = fake_min_max_grad(dout, input_x, _min, _max)
"""
support_quant_bit = [4, 8]
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0):
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
@ -172,7 +174,7 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 8]
support_quant_bit = [4, 7, 8]
channel_axis = 0
@prim_attr_register
@ -219,16 +221,18 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
>>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
>>> result = fqmmpc_grad(dout, input_x, _min, _max)
"""
support_quant_bit = [4, 8]
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, quant_delay=0):
def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False):
"""init FakeQuantWithMinMaxPerChannel Fill"""
if num_bits not in self.support_quant_bit:
raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.")
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name)
self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name)
self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):

Loading…
Cancel
Save