@ -143,6 +143,7 @@ def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b):
def _get_bias(shape_bias):
"""_get_bias"""
bias_length = shape_bias[0]
shb = []
if bias_length % 16 == 0:
@ -32,6 +32,7 @@ cus_matrix_combine_op_info = TBERegOp("CusMatrixCombine") \
@op_info_register(cus_matrix_combine_op_info)
def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"):
"""CusMatrixCombine"""
input_x_shape = input_x.get("shape")
output_shape = output.get("shape")
split_dim = 128