|
|
|
@ -99,7 +99,7 @@ def cus_cube_matmul_right_mul(tik_instance, input_x1, input_x2, input_x3,
|
|
|
|
|
"""cus_cube_matmul_right_mul"""
|
|
|
|
|
diag_size = 128
|
|
|
|
|
ko, mo, _, _ = input_x1.shape
|
|
|
|
|
no, ko, ki, _ = input_x2.shape
|
|
|
|
|
no, ko, _, _ = input_x2.shape
|
|
|
|
|
c0 = input_x1.shape[-1]
|
|
|
|
|
diag_outer = diag_size // c0
|
|
|
|
|
if [input_x1.shape[-1], input_x1.shape[-2], input_x2.shape[-1], input_x2.shape[-2]] != [c0, c0, c0, c0]:
|
|
|
|
|