|
|
|
@ -50,7 +50,7 @@ def _fake_quant_perchannel_tbe():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@fusion_manager.register("fake_quant_perchannel")
|
|
|
|
|
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
|
|
|
|
|
def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric,
|
|
|
|
|
kernel_name="fake_quant_perchannel"):
|
|
|
|
|
"""FakeQuantPerChannel"""
|
|
|
|
|
x_shape = te.lang.cce.util.shape_to_list(x.shape)
|
|
|
|
@ -59,6 +59,9 @@ def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max,
|
|
|
|
|
quant_max = tvm.const(quant_max, x.dtype)
|
|
|
|
|
quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype)
|
|
|
|
|
quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype)
|
|
|
|
|
if symmetric:
|
|
|
|
|
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
|
|
|
|
|
min_val = te.lang.cce.vmuls(max_val, -1.)
|
|
|
|
|
|
|
|
|
|
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
|
|
|
|
|
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
|
|
|
@ -119,12 +122,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|
|
|
|
util.check_dtype_rule(min_dtype, check_list)
|
|
|
|
|
util.check_dtype_rule(max_dtype, check_list)
|
|
|
|
|
|
|
|
|
|
if symmetric:
|
|
|
|
|
quant_min = 0 - 2 ** (num_bits - 1)
|
|
|
|
|
quant_max = 2 ** (num_bits - 1) - 1
|
|
|
|
|
else:
|
|
|
|
|
quant_min = 0
|
|
|
|
|
quant_max = 2 ** num_bits - 1
|
|
|
|
|
quant_min = 0
|
|
|
|
|
quant_max = 2 ** num_bits - 1
|
|
|
|
|
if narrow_range:
|
|
|
|
|
quant_min = quant_min + 1
|
|
|
|
|
|
|
|
|
@ -136,7 +135,7 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|
|
|
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
|
|
|
|
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
|
|
|
|
res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
|
|
|
|
|
quant_min, quant_max, kernel_name)
|
|
|
|
|
quant_min, quant_max, symmetric, kernel_name)
|
|
|
|
|
|
|
|
|
|
with tvm.target.cce():
|
|
|
|
|
sch = generic.auto_schedule(res)
|
|
|
|
|