|
|
|
@ -92,57 +92,8 @@ def _inner_matmul_new_1_64_32_64(tik_instance, dtype, input1, input1_index, inpu
|
|
|
|
|
matmul_hybrid_f_t_local_UB, 0, 1, 4, 0, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(cus_batchmatmul_op_info)
|
|
|
|
|
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"):
|
|
|
|
|
"""CusBatchMatMul"""
|
|
|
|
|
if util.get_product_version() == util.VERSION_MINI:
|
|
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
|
|
|
|
|
else:
|
|
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
|
|
|
|
|
x1_shape = input_x1.get("shape")
|
|
|
|
|
dtype = input_x1.get("dtype").lower()
|
|
|
|
|
x2_shape = input_x2.get("shape")
|
|
|
|
|
if dtype != input_x2.get("dtype").lower():
|
|
|
|
|
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % (
|
|
|
|
|
dtype, input_x2.get("dtype").lower()))
|
|
|
|
|
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
|
|
|
|
|
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
|
|
|
|
|
((36, 128, 128), (36, 128, 128), "float32", False, True),
|
|
|
|
|
((5, 128, 128), (5, 128, 128), "float32", False, True),
|
|
|
|
|
((18, 128, 128), (18, 128, 128), "float32", False, True),
|
|
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
|
|
|
|
((9, 128, 128), (9, 128, 128), "float32", False, True),
|
|
|
|
|
((1, 64, 64), (1, 64, 64), "float32", False, True),
|
|
|
|
|
((1, 128, 128), (1, 128, 128), "float32", False, True),
|
|
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
|
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
|
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
|
|
|
|
if input_shape not in support_shape:
|
|
|
|
|
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
|
|
|
|
|
|
|
|
|
|
# if not transpose_a and transpose_b:
|
|
|
|
|
batch, m, k = x1_shape
|
|
|
|
|
|
|
|
|
|
input1_shape = _get_flattern_shape(x1_shape)
|
|
|
|
|
input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm)
|
|
|
|
|
input2_shape = _get_flattern_shape(x2_shape)
|
|
|
|
|
input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm)
|
|
|
|
|
|
|
|
|
|
output_shape = x1_shape
|
|
|
|
|
res_shape = _get_flattern_shape(output_shape)
|
|
|
|
|
res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm)
|
|
|
|
|
|
|
|
|
|
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 2) as cc0:
|
|
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc1:
|
|
|
|
|
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
|
|
|
|
input2_index = block_idx * 32768 + cc0 * 16384
|
|
|
|
|
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
|
|
|
|
_inner_matmul_new(tik_instance, dtype,
|
|
|
|
|
input1, input1_index,
|
|
|
|
|
input2, input2_index,
|
|
|
|
|
res, res_index)
|
|
|
|
|
def process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res):
|
|
|
|
|
"""process input shape of 640"""
|
|
|
|
|
if input_shape == ((5, 128, 128), (5, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 30, block_num=30) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 11) as cc1_db:
|
|
|
|
@ -189,17 +140,9 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr
|
|
|
|
|
thread_idx * 128 + thread_idx2 * 64],
|
|
|
|
|
matmul_hybrid_f_t_local_UB, 0, 1, 8, 0, 0)
|
|
|
|
|
|
|
|
|
|
if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc0:
|
|
|
|
|
input1_index = block_idx * 16384 + cc0 * 128
|
|
|
|
|
input2_index = block_idx * 16384
|
|
|
|
|
res_index = block_idx * 16384 + cc0 * 128
|
|
|
|
|
_inner_matmul_new(tik_instance, dtype,
|
|
|
|
|
input1, input1_index,
|
|
|
|
|
input2, input2_index,
|
|
|
|
|
res, res_index)
|
|
|
|
|
|
|
|
|
|
def process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res):
|
|
|
|
|
"""process input shape of 1152"""
|
|
|
|
|
if input_shape == ((9, 128, 128), (9, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 27, block_num=27) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 42, thread_num=2) as cc0:
|
|
|
|
@ -219,6 +162,76 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr
|
|
|
|
|
input2, input2_index,
|
|
|
|
|
res, res_index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(cus_batchmatmul_op_info)
|
|
|
|
|
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"):
|
|
|
|
|
"""CusBatchMatMul"""
|
|
|
|
|
if util.get_product_version() == util.VERSION_MINI:
|
|
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
|
|
|
|
|
else:
|
|
|
|
|
tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
|
|
|
|
|
x1_shape = input_x1.get("shape")
|
|
|
|
|
dtype = input_x1.get("dtype").lower()
|
|
|
|
|
x2_shape = input_x2.get("shape")
|
|
|
|
|
if dtype != input_x2.get("dtype").lower():
|
|
|
|
|
raise RuntimeError("dtype of input_x1 and input_x2 must be same, but got %s vs %s" % (
|
|
|
|
|
dtype, input_x2.get("dtype").lower()))
|
|
|
|
|
input_shape = (tuple(x1_shape), tuple(x2_shape), dtype, transpose_a, transpose_b)
|
|
|
|
|
support_shape = [((8, 128, 128), (8, 128, 128), "float32", False, True),
|
|
|
|
|
((36, 128, 128), (36, 128, 128), "float32", False, True),
|
|
|
|
|
((5, 128, 128), (5, 128, 128), "float32", False, True),
|
|
|
|
|
((18, 128, 128), (18, 128, 128), "float32", False, True),
|
|
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
|
|
|
|
((9, 128, 128), (9, 128, 128), "float32", False, True),
|
|
|
|
|
((1, 64, 64), (1, 64, 64), "float32", False, True),
|
|
|
|
|
((1, 128, 128), (1, 128, 128), "float32", False, True),
|
|
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
|
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
|
|
|
|
((6, 128, 128), (6, 128, 128), "float32", False, True),
|
|
|
|
|
((24, 128, 128), (24, 128, 128), "float32", False, True),
|
|
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
|
|
|
|
if input_shape not in support_shape:
|
|
|
|
|
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
|
|
|
|
|
|
|
|
|
|
# if not transpose_a and transpose_b:
|
|
|
|
|
batch, m, k = x1_shape
|
|
|
|
|
|
|
|
|
|
input1_shape = _get_flattern_shape(x1_shape)
|
|
|
|
|
input1 = tik_instance.Tensor(dtype, input1_shape, name="input1", scope=tik.scope_gm)
|
|
|
|
|
input2_shape = _get_flattern_shape(x2_shape)
|
|
|
|
|
input2 = tik_instance.Tensor(dtype, input2_shape, name="input2", scope=tik.scope_gm)
|
|
|
|
|
|
|
|
|
|
output_shape = x1_shape
|
|
|
|
|
res_shape = _get_flattern_shape(output_shape)
|
|
|
|
|
res = tik_instance.Tensor(dtype, res_shape, name="res", scope=tik.scope_gm)
|
|
|
|
|
|
|
|
|
|
if input_shape == ((36, 128, 128), (36, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 2) as cc0:
|
|
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc1:
|
|
|
|
|
input1_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
|
|
|
|
input2_index = block_idx * 32768 + cc0 * 16384
|
|
|
|
|
res_index = block_idx * 32768 + cc0 * 16384 + cc1 * 128
|
|
|
|
|
_inner_matmul_new(tik_instance, dtype,
|
|
|
|
|
input1, input1_index,
|
|
|
|
|
input2, input2_index,
|
|
|
|
|
res, res_index)
|
|
|
|
|
|
|
|
|
|
process_input_shape_640(input_shape, tik_instance, dtype, input1, input2, res)
|
|
|
|
|
|
|
|
|
|
if input_shape == ((18, 128, 128), (18, 128, 128), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 18, block_num=18) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 128, thread_num=2) as cc0:
|
|
|
|
|
input1_index = block_idx * 16384 + cc0 * 128
|
|
|
|
|
input2_index = block_idx * 16384
|
|
|
|
|
res_index = block_idx * 16384 + cc0 * 128
|
|
|
|
|
_inner_matmul_new(tik_instance, dtype,
|
|
|
|
|
input1, input1_index,
|
|
|
|
|
input2, input2_index,
|
|
|
|
|
res, res_index)
|
|
|
|
|
|
|
|
|
|
process_input_shape_1152(input_shape, tik_instance, dtype, input1, input2, res)
|
|
|
|
|
|
|
|
|
|
if input_shape == ((1, 64, 64), (1, 64, 64), "float32", False, True):
|
|
|
|
|
with tik_instance.for_range(0, 32, block_num=32) as block_idx:
|
|
|
|
|
with tik_instance.for_range(0, 2, thread_num=2) as cc0:
|
|
|
|
@ -233,8 +246,10 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr
|
|
|
|
|
input_shape_list = [((1, 128, 128), (1, 128, 128), "float32", False, True),
|
|
|
|
|
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
|
|
|
|
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
|
|
|
|
((6, 128, 128), (6, 128, 128), "float32", False, True),
|
|
|
|
|
((8, 128, 128), (8, 128, 128), "float32", False, True),
|
|
|
|
|
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
|
|
|
|
((24, 128, 128), (24, 128, 128), "float32", False, True),
|
|
|
|
|
((32, 128, 128), (32, 128, 128), 'float32', False, True)
|
|
|
|
|
]
|
|
|
|
|
if input_shape in input_shape_list:
|
|
|
|
|