parent
b94949ea99
commit
d35c41e737
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,149 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""_BatchNormFold op"""
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
from te import tvm
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
|
||||
batch_norm_op_info = TBERegOp("BatchNormFoldD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnorm_fold.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold") \
|
||||
.partial_flag(True) \
|
||||
.attr("momentum", "optional", "float", "all") \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.attr("freeze_bn", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "x_sum", False, "required", "all") \
|
||||
.input(2, "x_square_sum", False, "required", "all") \
|
||||
.input(3, "mean", False, "required", "all") \
|
||||
.input(4, "variance", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "batch_mean", False, "required", "all") \
|
||||
.output(2, "batch_std", False, "required", "all") \
|
||||
.output(3, "running_mean", False, "required", "all") \
|
||||
.output(4, "running_std", False, "required", "all") \
|
||||
.output(5, "mean_updated", False, "required", "all") \
|
||||
.output(6, "variance_updated", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batch_norm_op_info)
|
||||
def _batchnorm_fold_tbe():
|
||||
"""_BatchNormFold TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict,
|
||||
dict, dict, dict, dict, dict, dict, dict,
|
||||
float, float, bool, int, str, str)
|
||||
def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
||||
y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated,
|
||||
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
|
||||
kernel_name="batchnorm_fold"):
|
||||
"""batchnorm_fold TBE op"""
|
||||
momentum = 1.0 - momentum
|
||||
util.check_kernel_name(kernel_name)
|
||||
data_format = data_format.upper()
|
||||
if data_format != "NCHW":
|
||||
raise RuntimeError("The data_format only support NCHW")
|
||||
|
||||
shape_x = x.get("shape")
|
||||
shape_mean = mean.get("shape")
|
||||
shape_variance = variance.get("shape")
|
||||
dtype_x = x.get("dtype")
|
||||
dtype_mean = mean.get("dtype")
|
||||
dtype_variance = variance.get("dtype")
|
||||
for shape in (shape_x, shape_mean, shape_variance):
|
||||
util.check_shape_rule(shape)
|
||||
util.check_tensor_shape_size(shape)
|
||||
check_tuple = ("float16", "float32")
|
||||
for dtype in (dtype_x, dtype_mean, dtype_variance):
|
||||
util.check_dtype_rule(dtype.lower(), check_tuple)
|
||||
|
||||
format_data = x.get("format").upper()
|
||||
if format_data not in ("NCHW", "NC1HWC0"):
|
||||
raise RuntimeError("Format of input only support 4D and 5HD")
|
||||
|
||||
if format_data == "NC1HWC0":
|
||||
if len(shape_x) != 5:
|
||||
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||
"when input format is NC1HWC0")
|
||||
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||
elif format_data == "NCHW":
|
||||
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||
if shape_x[1] != shape_mean[0]:
|
||||
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||
"be equal to the second axis of shape_x")
|
||||
shape_mean = (1, shape_x[1],)
|
||||
for _ in range(2, len(shape_x)):
|
||||
shape_mean = shape_mean + (1,)
|
||||
|
||||
x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower())
|
||||
x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower())
|
||||
x_square_sum = tvm.placeholder(shape_mean, name="x_square_sum", dtype=dtype_x.lower())
|
||||
mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
|
||||
variance = tvm.placeholder(shape_mean, name="variance", dtype=dtype_variance.lower())
|
||||
|
||||
shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
|
||||
# compute the mean of x
|
||||
batch_mean = te.lang.cce.vmuls(x_sum, num_rec)
|
||||
|
||||
# compute the variance of x
|
||||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
mean_mul = te.lang.cce.vmuls(batch_mean, factor)
|
||||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_variance, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
y = te.lang.cce.vadds(x_input, 0.0)
|
||||
running_mean = te.lang.cce.vadds(mean, 0.0)
|
||||
running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
|
||||
res = [y, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated]
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
config = {"name": kernel_name,
|
||||
"tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res}
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,110 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""_BatchNormFold2 op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnorm_fold2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "beta", None, "required", None) \
|
||||
.input(2, "gamma", None, "required", None) \
|
||||
.input(3, "batch_std", None, "required", None) \
|
||||
.input(4, "batch_mean", None, "required", None) \
|
||||
.input(5, "running_std", None, "required", None) \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batchnorm_fold2_op_info)
|
||||
def _batchnorm_fold2_tbe():
|
||||
"""_BatchNormFold2 TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("batchnorm_fold2")
|
||||
def batchnorm_fold2_compute(x, beta, gamma, batch_std, batch_mean, running_std, kernel_name="batchnorm_fold2"):
|
||||
"""_BatchNormFold2 compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||
factor = te.lang.cce.vdiv(running_std, batch_std)
|
||||
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||
res = te.lang.cce.vmul(x, factor_b)
|
||||
bias = te.lang.cce.vdiv(batch_mean, batch_std)
|
||||
bias = te.lang.cce.vmul(bias, gamma)
|
||||
bias = te.lang.cce.vsub(beta, bias)
|
||||
bias_b = te.lang.cce.broadcast(bias, shape_x)
|
||||
res = te.lang.cce.vadd(res, bias_b)
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, str)
|
||||
def batchnorm_fold2(x, beta, gamma, batch_std, batch_mean, running_std, y, kernel_name="batchnorm_fold2"):
|
||||
"""_BatchNormFold2 op"""
|
||||
shape = x.get("shape")
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
data_format = x.get("format")
|
||||
ori_format = x.get("ori_format")
|
||||
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||
shape_c = gamma.get("shape")
|
||||
if gamma.get("format").upper() == "NCHW":
|
||||
shape_c = 1, gamma.get("shape")[0], 1, 1
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
beta_t = tvm.placeholder(shape_c, name="beta", dtype=inp_dtype)
|
||||
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
|
||||
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
|
||||
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||
|
||||
res = batchnorm_fold2_compute(x_t, beta_t, gamma_t, batch_std_t, batch_mean_t,
|
||||
running_std_t, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": [x_t, beta_t, gamma_t, batch_std_t, batch_mean_t, running_std_t, res]}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,126 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""_BatchNormFold2Grad op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnorm_fold2_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2_grad") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "dout_reduce", None, "required", None) \
|
||||
.input(2, "dout_x_reduce", None, "required", None) \
|
||||
.input(3, "gamma", None, "required", None) \
|
||||
.input(4, "batch_std", None, "required", None) \
|
||||
.input(5, "batch_mean", None, "required", None) \
|
||||
.input(6, "running_std", None, "required", None) \
|
||||
.output(0, "d_batch_std", True, "required", "all") \
|
||||
.output(1, "d_batch_mean", True, "required", "all") \
|
||||
.output(2, "d_gamma", True, "required", "all") \
|
||||
.output(3, "dx", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batchnorm_fold2_grad_op_info)
|
||||
def _batchnorm_fold2_grad_tbe():
|
||||
"""_BatchNormFold2Grad TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("batchnorm_fold2_grad")
|
||||
def batchnorm_fold2_grad_compute(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std,
|
||||
kernel_name="batchnorm_fold2_grad"):
|
||||
"""_BatchNormFold2Grad"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(dout.shape)
|
||||
|
||||
d_batch_std_1 = te.lang.cce.vmul(dout_reduce, batch_mean)
|
||||
d_batch_std_1 = te.lang.cce.vmul(d_batch_std_1, gamma)
|
||||
d_batch_std_2 = te.lang.cce.vmul(dout_x_reduce, running_std)
|
||||
d_batch_std = te.lang.cce.vsub(d_batch_std_1, d_batch_std_2)
|
||||
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
|
||||
d_batch_std = te.lang.cce.vdiv(d_batch_std, batch_std)
|
||||
|
||||
d_batch_mean = te.lang.cce.vmul(dout_reduce, gamma)
|
||||
d_batch_mean = te.lang.cce.vdiv(d_batch_mean, batch_std)
|
||||
d_batch_mean = te.lang.cce.vmuls(d_batch_mean, -1.)
|
||||
|
||||
d_gamma = te.lang.cce.vmul(dout_reduce, batch_mean)
|
||||
d_gamma = te.lang.cce.vdiv(d_gamma, batch_std)
|
||||
d_gamma = te.lang.cce.vmuls(d_gamma, -1.)
|
||||
|
||||
dx = te.lang.cce.vdiv(running_std, batch_std)
|
||||
dx = te.lang.cce.broadcast(dx, shape_x)
|
||||
dx = te.lang.cce.vmul(dx, dout)
|
||||
return [d_batch_std, d_batch_mean, d_gamma, dx]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, dict, str)
|
||||
def batchnorm_fold2_grad(dout, dout_reduce, dout_x_reduce, gamma, batch_std, batch_mean, running_std, d_batch_std,
|
||||
d_batch_mean, d_gamma, dx, kernel_name="batchnorm_fold2_grad"):
|
||||
"""_BatchNormFold2Grad op """
|
||||
shape = dout.get("shape")
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = dout.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
data_format = dout.get("format")
|
||||
ori_format = dout.get("ori_format")
|
||||
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||
shape_c = gamma.get("shape")
|
||||
if gamma.get("format").upper() == "NCHW":
|
||||
shape_c = 1, gamma.get("shape")[0], 1, 1
|
||||
|
||||
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||
dout_reduce_t = tvm.placeholder(shape_c, name="dout_reduce", dtype=inp_dtype)
|
||||
dout_x_reduce_t = tvm.placeholder(shape_c, name="dout_x_reduce", dtype=inp_dtype)
|
||||
gamma_t = tvm.placeholder(shape_c, name="gamma", dtype=inp_dtype)
|
||||
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||
batch_mean_t = tvm.placeholder(shape_c, name="batch_mean", dtype=inp_dtype)
|
||||
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||
|
||||
res_list = batchnorm_fold2_grad_compute(dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t,
|
||||
running_std_t, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
||||
tensor_list = [dout_t, dout_reduce_t, dout_x_reduce_t, gamma_t, batch_std_t, batch_mean_t, running_std_t] + list(
|
||||
res_list)
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,107 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""_BatchNormFold2GradReduce op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from te.platform.cce_build import build_config
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnorm_fold2_grad_reduce.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2_grad_reduce") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "x", None, "required", None) \
|
||||
.output(0, "dout_reduce", True, "required", "all") \
|
||||
.output(1, "dout_x_reduce", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batchnorm_fold2_grad_reduce_op_info)
|
||||
def _batchnorm_fold2_grad_reduce_tbe():
|
||||
"""_BatchNormFold2GradReduce TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("batchnorm_fold2_grad_reduce")
|
||||
def batchnorm_fold2_grad_reduce_compute(dout, x, dout_args, kernel_name="batchnorm_fold2_grad_reduce"):
|
||||
"""_BatchNormFold2GradReduce compute"""
|
||||
dtype = dout_args.get("dtype")
|
||||
dout_format = dout_args.get("format")
|
||||
ori_format = dout_args.get("ori_format")
|
||||
shape = dout_args.get("shape")
|
||||
|
||||
if dtype == "float16":
|
||||
dout = te.lang.cce.cast_to(dout, "float32")
|
||||
x = te.lang.cce.cast_to(x, "float32")
|
||||
|
||||
dout_x = te.lang.cce.vmul(dout, x)
|
||||
if dout_format == "NC1HWC0":
|
||||
axis = [0, 2, 3]
|
||||
dout_reduce, dout_x_reduce = te.lang.cce.tuple_sum([dout, dout_x], axis, True)
|
||||
else:
|
||||
axis = list(range(len(shape)))
|
||||
if ori_format == "NCHW":
|
||||
axis.pop(1)
|
||||
for _, i in enumerate(range(len(shape))):
|
||||
if shape[i] == 1 and i in axis:
|
||||
axis.remove(i)
|
||||
dout_reduce = te.lang.cce.sum(dout, axis, False)
|
||||
dout_x_reduce = te.lang.cce.sum(dout_x, axis, False)
|
||||
return [dout_reduce, dout_x_reduce]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, str)
|
||||
def batchnorm_fold2_grad_reduce(dout, x, dout_reduce, dout_x_reduce, kernel_name="batchnorm_fold2_grad_reduce"):
|
||||
"""_BatchNormFold2GradReduce op"""
|
||||
shape = x.get("shape")
|
||||
x_format = x.get("format")
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
dout_t = tvm.placeholder(shape, name="dout", dtype=inp_dtype)
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
|
||||
res_list = batchnorm_fold2_grad_reduce_compute(dout_t, x_t, dout, kernel_name)
|
||||
|
||||
if x_format == "NC1HWC0":
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
tensor_list = [dout_t, x_t] + list(res_list)
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
||||
return
|
||||
from impl.bn_training_reduce import bn_training_reduce_schedule_nd
|
||||
sch, tensor_list = bn_training_reduce_schedule_nd(res_list)
|
||||
with build_config:
|
||||
tvm.build(sch, tensor_list, "cce", name=kernel_name)
|
@ -0,0 +1,124 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""_BatchNormFoldGrad op"""
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
|
||||
batch_norm_op_info = TBERegOp("BatchNormFoldGradD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("batchnorm_fold_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("epsilon", "optional", "float", "all") \
|
||||
.attr("is_training", "optional", "bool", "all") \
|
||||
.attr("freeze_bn", "optional", "int", "all") \
|
||||
.input(0, "d_batch_mean", False, "required", "all") \
|
||||
.input(1, "d_batch_std", False, "required", "all") \
|
||||
.input(2, "x", False, "required", "all") \
|
||||
.input(3, "batch_mean", False, "required", "all") \
|
||||
.input(4, "batch_std", False, "required", "all") \
|
||||
.output(0, "dx", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(batch_norm_op_info)
|
||||
def _batchnorm_fold_grad_tbe():
|
||||
"""_BatchNormFoldGrad TBE register"""
|
||||
return
|
||||
|
||||
|
||||
def _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std):
|
||||
"""_batchnorm_fold_grad_compute """
|
||||
shape_x = te.lang.cce.util.shape_to_list(data_x.shape)
|
||||
normal_size = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
|
||||
d_batch_mean_broad = te.lang.cce.broadcast(d_batch_mean, shape_x)
|
||||
d_batch_std_broad = te.lang.cce.broadcast(d_batch_std, shape_x)
|
||||
batch_mean_broad = te.lang.cce.broadcast(batch_mean, shape_x)
|
||||
batch_std_broad = te.lang.cce.broadcast(batch_std, shape_x)
|
||||
|
||||
dx = te.lang.cce.vsub(data_x, batch_mean_broad)
|
||||
dx = te.lang.cce.vmul(dx, d_batch_std_broad)
|
||||
dx = te.lang.cce.vdiv(dx, batch_std_broad)
|
||||
dx = te.lang.cce.vadd(dx, d_batch_mean_broad)
|
||||
dx = te.lang.cce.vmuls(dx, tvm.const(1. / normal_size, dtype=dx.dtype))
|
||||
return [dx]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict,
|
||||
float, bool, int, str)
|
||||
def batchnorm_fold_grad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, dx,
|
||||
epsilon=1e-5, is_training=True, freeze_bn=0, kernel_name="batchnorm_fold_grad"):
|
||||
"""batchnorm_fold_grad op """
|
||||
util.check_kernel_name(kernel_name)
|
||||
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
|
||||
util.check_shape_rule(iv.get("shape"))
|
||||
util.check_tensor_shape_size(iv.get("shape"))
|
||||
check_tuple = ("float16", "float32")
|
||||
for iv in (d_batch_mean, d_batch_std, x, batch_mean, batch_std):
|
||||
util.check_dtype_rule(iv.get("dtype").lower(), check_tuple)
|
||||
|
||||
shape_x = x.get("shape")
|
||||
dtype_x = x.get("dtype")
|
||||
format_data = x.get("format").upper()
|
||||
if format_data not in ("NCHW", "NC1HWC0"):
|
||||
raise RuntimeError("Format of input only support 4D and 5HD")
|
||||
|
||||
shape_mean = d_batch_mean.get("shape")
|
||||
dtype_mean = d_batch_mean.get("dtype").lower()
|
||||
if format_data == "NC1HWC0":
|
||||
if len(shape_x) != 5:
|
||||
raise RuntimeError("batchnorm_fold only support shape 5D"
|
||||
"when input format is NC1HWC0")
|
||||
shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
|
||||
elif format_data == "NCHW":
|
||||
if len(shape_x) < 2 or len(shape_x) > 4:
|
||||
raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
|
||||
if shape_x[1] != shape_mean[0]:
|
||||
raise RuntimeError("data_format is NCHW, shape_bias must"
|
||||
"be equal to the second axis of shape_x")
|
||||
shape_mean = (1, shape_x[1],)
|
||||
for _ in range(2, len(shape_x)):
|
||||
shape_mean = shape_mean + (1,)
|
||||
|
||||
d_batch_mean = tvm.placeholder(shape_mean, name="d_batch_mean", dtype=dtype_mean)
|
||||
d_batch_std = tvm.placeholder(shape_mean, name="d_batch_std", dtype=dtype_mean)
|
||||
data_x = tvm.placeholder(shape_x, name="data_x", dtype=dtype_x.lower())
|
||||
batch_mean = tvm.placeholder(shape_mean, name="batch_mean", dtype=dtype_mean)
|
||||
batch_std = tvm.placeholder(shape_mean, name="batch_std", dtype=dtype_mean)
|
||||
|
||||
res = _batchnorm_fold_grad_compute(d_batch_mean, d_batch_std, data_x, batch_mean, batch_std)
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
tensor_list = [d_batch_mean, d_batch_std, data_x, batch_mean, batch_std] + res
|
||||
config = {"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CorrectionMul op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
correction_mul_op_info = TBERegOp("CorrectionMul") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("correction_mul.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "batch_std", None, "required", None) \
|
||||
.input(2, "running_std", None, "required", None) \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(correction_mul_op_info)
|
||||
def _correction_mul_tbe():
|
||||
"""CorrectionMul TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("correction_mul")
|
||||
def correction_mul_compute(x, batch_std, running_std, kernel_name="correction_mul"):
|
||||
"""CorrectionMul compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||
factor = te.lang.cce.vdiv(batch_std, running_std)
|
||||
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||
res = te.lang.cce.vmul(x, factor_b)
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, int, str)
|
||||
def correction_mul(x, batch_std, running_std, y, channel, kernel_name="correction_mul"):
|
||||
"""CorrectionMul op"""
|
||||
shape = x.get("shape")
|
||||
data_format = x.get("format")
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape)
|
||||
util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
|
||||
check_list = ["float16", "float32"]
|
||||
inp_dtype = x.get("dtype").lower()
|
||||
if not inp_dtype in check_list:
|
||||
raise RuntimeError("Dtype of input only support float16, float32")
|
||||
|
||||
# shape = util.shape_refine(shape)
|
||||
x_t = tvm.placeholder(shape, name="x", dtype=inp_dtype)
|
||||
shape_c = [1] * len(shape)
|
||||
shape_c[channel] = batch_std.get("ori_shape")[0]
|
||||
if data_format == "NC1HWC0" and channel == 1:
|
||||
shape_c = batch_std.get("shape")
|
||||
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype)
|
||||
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype)
|
||||
res = correction_mul_compute(x_t, batch_std_t, running_std_t, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": [x_t, batch_std_t, running_std_t, res]}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,134 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CorrectionMul op"""
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
|
||||
correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("correction_mul_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul_grad") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "x", None, "required", None) \
|
||||
.input(2, "batch_std", None, "required", None) \
|
||||
.input(3, "running_std", None, "required", None) \
|
||||
.output(0, "dx", True, "required", "all") \
|
||||
.output(1, "d_batch_std", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(correction_mul_grad_op_info)
|
||||
def _correction_mul_grad_tbe():
|
||||
"""CorrectionMulGrad TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("correction_mul_grad")
|
||||
def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_format, kernel_name="correction_mul"):
|
||||
"""CorrectionMulGrad compute"""
|
||||
shape_x = te.lang.cce.util.shape_to_list(x.shape)
|
||||
factor = te.lang.cce.vdiv(batch_std, running_std)
|
||||
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
||||
dx = te.lang.cce.vmul(dout, factor_b)
|
||||
mul_data = te.lang.cce.vmul(dout, x)
|
||||
if channel == 0:
|
||||
if data_format == "NCHW":
|
||||
axis = [1, 2, 3]
|
||||
else:
|
||||
axis = [1, 2, 3, 4]
|
||||
else:
|
||||
axis = [2, 3]
|
||||
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
|
||||
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
|
||||
return [dx, d_batch_std]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
|
||||
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
|
||||
"""CorrectionMulGrad op"""
|
||||
shape_dout = dout.get("shape")
|
||||
shape_x = dout.get("shape")
|
||||
|
||||
dtype_dout = dout.get("dtype")
|
||||
dtype_x = x.get("dtype")
|
||||
dtype_batch_std = batch_std.get("dtype")
|
||||
dtype_running_std = running_std.get("dtype")
|
||||
|
||||
inp_dtype_dout = dtype_dout.lower()
|
||||
inp_dtype_x = dtype_x.lower()
|
||||
inp_dtype_batch_std = dtype_batch_std.lower()
|
||||
inp_dtype_running_std = dtype_running_std.lower()
|
||||
|
||||
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
|
||||
util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
|
||||
util.check_dtype_rule(inp_dtype_batch_std, ("float32",))
|
||||
util.check_dtype_rule(inp_dtype_running_std, ("float32",))
|
||||
util.compare_tensor_dict_key(dout, x, "dtype")
|
||||
util.compare_tensor_dict_key(dout, x, "shape")
|
||||
util.compare_tensor_dict_key(dx, x, "shape")
|
||||
util.compare_tensor_dict_key(batch_std, running_std, "shape")
|
||||
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
|
||||
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(shape_x)
|
||||
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
|
||||
|
||||
data_format = dout.get("format")
|
||||
ori_format = dout.get("format")
|
||||
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
||||
raise RuntimeError("Un supported data format {}".format(data_format))
|
||||
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
||||
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
||||
|
||||
shape_c = [1] * len(shape_x)
|
||||
shape_c[channel] = batch_std.get("ori_shape")[0]
|
||||
if data_format == "NC1HWC0" and channel == 1:
|
||||
shape_c = batch_std.get("shape")
|
||||
|
||||
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
|
||||
x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x)
|
||||
batch_std_t = tvm.placeholder(shape_c, name="batch_std", dtype=inp_dtype_batch_std)
|
||||
running_std_t = tvm.placeholder(shape_c, name="running_std", dtype=inp_dtype_running_std)
|
||||
res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t, running_std_t, channel, data_format, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
||||
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,146 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMax op"""
|
||||
|
||||
from functools import reduce as functools_reduce
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_vars_ema.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_vars_ema") \
|
||||
.partial_flag(True) \
|
||||
.attr("ema", "optional", "bool", "all") \
|
||||
.attr("ema_decay", "optional", "float", "all") \
|
||||
.attr("symmetric", "optional", "bool", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.attr("training", "optional", "bool", "all") \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("quant_delay", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "min", None, "required", None) \
|
||||
.input(2, "max", None, "required", None) \
|
||||
.output(0, "y", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_op_info)
|
||||
def _fake_quant_tbe():
|
||||
"""FakeQuantWithMinMax TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_with_min_max_vars_ema")
|
||||
def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max,
|
||||
kernel_name="correction_mul"):
|
||||
"""FakeQuantWithMinMax"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
|
||||
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
|
||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||
|
||||
# CalNudge(NudgeMinMax)
|
||||
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
||||
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
|
||||
# Nudge zero point
|
||||
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
|
||||
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
|
||||
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
|
||||
|
||||
# boradcast to shape
|
||||
nudge_min = te.lang.cce.broadcast(nudge_min, shape, x.dtype)
|
||||
nudge_max = te.lang.cce.broadcast(nudge_max, shape, x.dtype)
|
||||
scale = te.lang.cce.broadcast(scale, shape, x.dtype)
|
||||
|
||||
# FakeQuant
|
||||
input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x))
|
||||
nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale),
|
||||
0.5))
|
||||
res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
||||
def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
|
||||
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
|
||||
kernel_name="fake_quant"):
|
||||
"""FakeQuantWithMinMax"""
|
||||
input_shape = x.get("shape")
|
||||
input_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
min_shape = util.scalar2tensor_one(min_shape)
|
||||
max_shape = util.scalar2tensor_one(max_shape)
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(input_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||
util.check_tensor_shape_size(input_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
||||
check_list = ["float32", "float16"]
|
||||
x_dtype = input_dtype.lower()
|
||||
min_dtype = min_dtype.lower()
|
||||
max_dtype = max_dtype.lower()
|
||||
util.check_dtype_rule(x_dtype, check_list)
|
||||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y,
|
||||
quant_min, quant_max, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
tensor_list = [input_data, min_data, max_data, res]
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,156 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxGrad op"""
|
||||
|
||||
from functools import reduce as functools_reduce
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
SHAPE_SIZE_LIMIT = 2147483648
|
||||
D_TYPE = 'float32'
|
||||
|
||||
fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("quant_delay", "optional", "int", "all") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "x", None, "required", None) \
|
||||
.input(2, "min", None, "required", None) \
|
||||
.input(3, "max", None, "required", None) \
|
||||
.output(0, "dx", True, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
def _less_compare_float32(data_x, data_y):
|
||||
"""_less_compare_float32 compute"""
|
||||
shape_inputs = te.lang.cce.util.shape_to_list(data_x.shape)
|
||||
min_value = tvm.const(2 ** (-126), dtype=D_TYPE)
|
||||
max_value = tvm.const(2 ** 62, dtype=D_TYPE)
|
||||
factor_value = tvm.const(2 ** 2, dtype=D_TYPE)
|
||||
data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE)
|
||||
min_value_tensor = te.lang.cce.vadds(data_zero, min_value)
|
||||
|
||||
res_sub = te.lang.cce.vsub(data_y, data_x)
|
||||
res_min = te.lang.cce.vmin(res_sub, min_value_tensor)
|
||||
res_max = te.lang.cce.vmax(res_min, data_zero)
|
||||
|
||||
res_max_mul = te.lang.cce.vmuls(res_max, max_value)
|
||||
res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value)
|
||||
res = te.lang.cce.vmuls(res_max_mul_max, factor_value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@op_info_register(fake_quant_grad_op_info)
|
||||
def _fake_quant_grad_tbe():
|
||||
"""FakeQuantWithMinMaxGrad TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_with_min_max_grad")
|
||||
def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
|
||||
kernel_name="fake_quant_with_min_max_grad"):
|
||||
"""FakeQuantWithMinMaxGrad"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
quant_min = tvm.const(quant_min, x.dtype)
|
||||
quant_max = tvm.const(quant_max, x.dtype)
|
||||
quant_min = te.lang.cce.broadcast(quant_min, shape_min)
|
||||
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
|
||||
|
||||
# CalNudge(NudgeMinMax)
|
||||
scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
||||
zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale))
|
||||
# Nudge zero point
|
||||
nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min)))
|
||||
nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale)
|
||||
nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale)
|
||||
nudge_min = te.lang.cce.broadcast(nudge_min, shape)
|
||||
nudge_max = te.lang.cce.broadcast(nudge_max, shape)
|
||||
|
||||
bool_over_min = _less_compare_float32(nudge_min, x)
|
||||
bool_less_max = _less_compare_float32(x, nudge_max)
|
||||
bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max)
|
||||
res = te.lang.cce.vmul(dout, bool_between)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, int, int, str)
|
||||
def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, num_bits, quant_delay,
|
||||
kernel_name="fake_quant_with_min_max_grad"):
|
||||
"""FakeQuantWithMinMaxGrad"""
|
||||
input_shape = x.get("shape")
|
||||
input_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
min_shape = util.scalar2tensor_one(min_shape)
|
||||
max_shape = util.scalar2tensor_one(max_shape)
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(input_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||
util.check_tensor_shape_size(input_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
||||
check_list = ["float32", 'float16']
|
||||
x_dtype = input_dtype.lower()
|
||||
min_dtype = min_dtype.lower()
|
||||
max_dtype = max_dtype.lower()
|
||||
util.check_dtype_rule(x_dtype, check_list)
|
||||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
|
||||
quant_max, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
||||
tensor_list = [dout_data, input_data, min_data, max_data, res]
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
@ -0,0 +1,137 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""FakeQuantWithMinMaxUpdate op"""
|
||||
from functools import reduce as functools_reduce
|
||||
import te.lang.cce
|
||||
from te import tvm
|
||||
from te.platform.fusion_manager import fusion_manager
|
||||
from topi import generic
|
||||
from topi.cce import util
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
|
||||
fake_quant_update5d_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("fake_quant_with_min_max_update5d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("fake_quant_with_min_max_update") \
|
||||
.partial_flag(True) \
|
||||
.attr("ema", "optional", "bool", "all") \
|
||||
.attr("ema_decay", "optional", "float", "all") \
|
||||
.attr("symmetric", "optional", "bool", "all") \
|
||||
.attr("narrow_range", "optional", "bool", "all") \
|
||||
.attr("training", "optional", "bool", "all") \
|
||||
.attr("num_bits", "optional", "int", "all") \
|
||||
.attr("quant_delay", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "min", None, "required", None) \
|
||||
.input(2, "max", None, "required", None) \
|
||||
.output(0, "min_up", True, "required", "all") \
|
||||
.output(1, "max_up", True, "required", "all") \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(fake_quant_update5d_op_info)
|
||||
def _fake_quant_update5d_tbe():
|
||||
"""_FakeQuantWithMinMaxUpdate5D TBE register"""
|
||||
return
|
||||
|
||||
|
||||
@fusion_manager.register("fake_quant_with_min_max_update")
|
||||
def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training,
|
||||
kernel_name="fake_quant_update"):
|
||||
"""FakeQuantWithMinMaxUpdate compute"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||
if not ema:
|
||||
ema_decay = 0.0
|
||||
if training:
|
||||
# CalMinMax
|
||||
axis = tuple(range(len(shape)))
|
||||
x_min = te.lang.cce.reduce_min(x, axis=axis)
|
||||
x_max = te.lang.cce.reduce_max(x, axis=axis)
|
||||
x_min = te.lang.cce.broadcast(x_min, shape_min)
|
||||
x_max = te.lang.cce.broadcast(x_max, shape_min)
|
||||
min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay)))
|
||||
max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay)))
|
||||
min_val = te.lang.cce.vmins(min_val, 0)
|
||||
max_val = te.lang.cce.vmaxs(max_val, 0)
|
||||
|
||||
return [min_val, max_val]
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str)
|
||||
def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
|
||||
ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay,
|
||||
kernel_name="fake_quant_update"):
|
||||
"""FakeQuantWithMinMax op"""
|
||||
input_shape = x.get("shape")
|
||||
input_dtype = x.get("dtype")
|
||||
min_shape = min_val.get("ori_shape")
|
||||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
min_shape = util.scalar2tensor_one(min_shape)
|
||||
max_shape = util.scalar2tensor_one(max_shape)
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(input_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, 1)
|
||||
util.check_shape_rule(max_shape, 1, 1, 1)
|
||||
util.check_tensor_shape_size(input_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
||||
check_list = ["float32", "float16"]
|
||||
x_dtype = input_dtype.lower()
|
||||
min_dtype = min_dtype.lower()
|
||||
max_dtype = max_dtype.lower()
|
||||
util.check_dtype_rule(x_dtype, check_list)
|
||||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, quant_min, quant_max, training, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
||||
tensor_list = [input_data, min_data, max_data] + list(res_list)
|
||||
config = {"print_ir": False,
|
||||
"name": kernel_name,
|
||||
"tensor_list": tensor_list}
|
||||
|
||||
te.lang.cce.cce_build_code(sch, config)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue