|
|
@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
|
|
|
.input(2, "batch_std", None, "required", None) \
|
|
|
|
.input(2, "batch_std", None, "required", None) \
|
|
|
|
.input(3, "running_std", None, "required", None) \
|
|
|
|
.input(3, "running_std", None, "required", None) \
|
|
|
|
.output(0, "dx", True, "required", "all") \
|
|
|
|
.output(0, "dx", True, "required", "all") \
|
|
|
|
.output(1, "d_batch_std", True, "required", "all") \
|
|
|
|
.output(1, "mul_dx", True, "required", "all") \
|
|
|
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
|
|
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
|
|
|
DataType.F32_5HD, DataType.F32_5HD) \
|
|
|
|
DataType.F32_5HD, DataType.F32_5HD) \
|
|
|
|
.get_op_info()
|
|
|
|
.get_op_info()
|
|
|
@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f
|
|
|
|
factor = te.lang.cce.vdiv(batch_std, running_std)
|
|
|
|
factor = te.lang.cce.vdiv(batch_std, running_std)
|
|
|
|
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
|
|
|
factor_b = te.lang.cce.broadcast(factor, shape_x)
|
|
|
|
dx = te.lang.cce.vmul(dout, factor_b)
|
|
|
|
dx = te.lang.cce.vmul(dout, factor_b)
|
|
|
|
mul_data = te.lang.cce.vmul(dout, x)
|
|
|
|
mul_dx = te.lang.cce.vmul(dout, x)
|
|
|
|
if channel == 0:
|
|
|
|
running_std_b = te.lang.cce.broadcast(running_std, shape_x)
|
|
|
|
if data_format == "NCHW":
|
|
|
|
mul_dx = te.lang.cce.vdiv(mul_dx, running_std_b)
|
|
|
|
axis = [1, 2, 3]
|
|
|
|
return [dx, mul_dx]
|
|
|
|
else:
|
|
|
|
|
|
|
|
axis = [1, 2, 3, 4]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
axis = [2, 3]
|
|
|
|
|
|
|
|
red_data = te.lang.cce.sum(mul_data, axis, keepdims=True)
|
|
|
|
|
|
|
|
d_batch_std = te.lang.cce.vdiv(red_data, running_std)
|
|
|
|
|
|
|
|
return [dx, d_batch_std]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
|
|
|
|
@util.check_input_type(dict, dict, dict, dict, dict, dict, int, str)
|
|
|
|
def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channel, kernel_name="correction_mul_grad"):
|
|
|
|
def correction_mul_grad(dout, x, batch_std, running_std, dx, mul_dx, channel, kernel_name="correction_mul_grad"):
|
|
|
|
"""CorrectionMulGrad op"""
|
|
|
|
"""CorrectionMulGrad op"""
|
|
|
|
shape_dout = dout.get("shape")
|
|
|
|
shape_dout = dout.get("shape")
|
|
|
|
shape_x = dout.get("shape")
|
|
|
|
shape_x = dout.get("shape")
|
|
|
@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
|
|
|
|
util.compare_tensor_dict_key(dout, x, "shape")
|
|
|
|
util.compare_tensor_dict_key(dout, x, "shape")
|
|
|
|
util.compare_tensor_dict_key(dx, x, "shape")
|
|
|
|
util.compare_tensor_dict_key(dx, x, "shape")
|
|
|
|
util.compare_tensor_dict_key(batch_std, running_std, "shape")
|
|
|
|
util.compare_tensor_dict_key(batch_std, running_std, "shape")
|
|
|
|
util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")
|
|
|
|
util.compare_tensor_dict_key(dx, mul_dx, "shape")
|
|
|
|
|
|
|
|
|
|
|
|
util.check_kernel_name(kernel_name)
|
|
|
|
util.check_kernel_name(kernel_name)
|
|
|
|
util.check_shape_rule(shape_x)
|
|
|
|
util.check_shape_rule(shape_x)
|
|
|
@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
|
|
|
|
with tvm.target.cce():
|
|
|
|
with tvm.target.cce():
|
|
|
|
sch = generic.auto_schedule(res_list)
|
|
|
|
sch = generic.auto_schedule(res_list)
|
|
|
|
|
|
|
|
|
|
|
|
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
|
|
|
|
tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + res_list
|
|
|
|
|
|
|
|
config = {"print_ir": False,
|
|
|
|
|
|
|
|
"name": kernel_name,
|
|
|
|
|
|
|
|
"tensor_list": tensor_list}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
te.lang.cce.cce_build_code(sch, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
|
|
|
|
|
|
|
|
.fusion_type("OPAQUE") \
|
|
|
|
|
|
|
|
.async_flag(False) \
|
|
|
|
|
|
|
|
.binfile_name("correction_mul_grad_reduce.so") \
|
|
|
|
|
|
|
|
.compute_cost(10) \
|
|
|
|
|
|
|
|
.kernel_name("correction_mul_grad_reduce") \
|
|
|
|
|
|
|
|
.partial_flag(True) \
|
|
|
|
|
|
|
|
.op_pattern("formatAgnostic") \
|
|
|
|
|
|
|
|
.attr("channel_axis", "optional", "int", "all") \
|
|
|
|
|
|
|
|
.input(0, "dout", None, "required", None) \
|
|
|
|
|
|
|
|
.output(0, "d_batch_std", True, "required", "all") \
|
|
|
|
|
|
|
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
|
|
|
|
|
|
|
.get_op_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(correction_mul_grad_reduce_op_info)
|
|
|
|
|
|
|
|
def _correction_mul_grad_reduce_tbe():
|
|
|
|
|
|
|
|
"""CorrectionMulGradReduce TBE register"""
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@fusion_manager.register("correction_mul_grad_reduce")
|
|
|
|
|
|
|
|
def correction_mul_grad_reduce_compute(mul_dx, channel, data_format, kernel_name="correction_mul"):
|
|
|
|
|
|
|
|
"""CorrectionMulGradReduce compute"""
|
|
|
|
|
|
|
|
if channel == 0:
|
|
|
|
|
|
|
|
if data_format == "NCHW":
|
|
|
|
|
|
|
|
axis = [1, 2, 3]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
axis = [1, 2, 3, 4]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
axis = [2, 3]
|
|
|
|
|
|
|
|
d_batch_std = te.lang.cce.sum(mul_dx, axis, keepdims=True)
|
|
|
|
|
|
|
|
return d_batch_std
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@util.check_input_type(dict, dict, int, str)
|
|
|
|
|
|
|
|
def correction_mul_grad_reduce(mul_dx, d_batch_std, channel, kernel_name="correction_mul_grad_reduce"):
|
|
|
|
|
|
|
|
"""CorrectionMulGradReduce op"""
|
|
|
|
|
|
|
|
shape_dout = mul_dx.get("shape")
|
|
|
|
|
|
|
|
shape_x = mul_dx.get("shape")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype_dout = mul_dx.get("dtype")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inp_dtype_dout = dtype_dout.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util.check_kernel_name(kernel_name)
|
|
|
|
|
|
|
|
util.check_shape_rule(shape_x)
|
|
|
|
|
|
|
|
util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_format = mul_dx.get("format")
|
|
|
|
|
|
|
|
ori_format = mul_dx.get("format")
|
|
|
|
|
|
|
|
if data_format.upper() not in ("NC1HWC0", "NCHW"):
|
|
|
|
|
|
|
|
raise RuntimeError("Un supported data format {}".format(data_format))
|
|
|
|
|
|
|
|
if data_format.upper() == "NCHW" and ori_format != "NCHW":
|
|
|
|
|
|
|
|
raise RuntimeError("data_format(NCHW) must same as ori_format")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shape_c = [1] * len(shape_x)
|
|
|
|
|
|
|
|
shape_c[channel] = d_batch_std.get("ori_shape")[0]
|
|
|
|
|
|
|
|
if data_format == "NC1HWC0" and channel == 1:
|
|
|
|
|
|
|
|
shape_c = d_batch_std.get("shape")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
|
|
|
|
|
|
|
|
res = correction_mul_grad_reduce_compute(dout_t, channel, data_format, kernel_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with tvm.target.cce():
|
|
|
|
|
|
|
|
sch = generic.auto_schedule(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_list = [dout_t, res]
|
|
|
|
config = {"print_ir": False,
|
|
|
|
config = {"print_ir": False,
|
|
|
|
"name": kernel_name,
|
|
|
|
"name": kernel_name,
|
|
|
|
"tensor_list": tensor_list}
|
|
|
|
"tensor_list": tensor_list}
|
|
|
|