!7183 fix symmetric bug in FakeQuantPerChannel op

Merge pull request !7183 from yuchaojie/quant2
pull/7183/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4beb4de5fe

@ -50,7 +50,7 @@ def _fake_quant_perchannel_tbe():
@fusion_manager.register("fake_quant_perchannel")
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric,
kernel_name="fake_quant_perchannel"):
"""FakeQuantPerChannel"""
x_shape = te.lang.cce.util.shape_to_list(x.shape)
@ -59,6 +59,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
quant_max = tvm.const(quant_max, x.dtype)
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
if symmetric:
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
min_val = te.lang.cce.vmuls(max_val, -1.)
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
@ -119,12 +122,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
@ -136,7 +135,7 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name)
quant_min, quant_max, symmetric, kernel_name)
with tvm.target.cce():
sch = generic.auto_schedule(res)

Loading…
Cancel
Save