|
|
|
@ -33,14 +33,13 @@ __all__ = ["CusBatchMatMul",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusBatchMatMul(PrimitiveWithInfer):
|
|
|
|
|
"""CusBatchMatMul definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b` in batch.
|
|
|
|
|
|
|
|
|
|
The rank of input tensors must be `3`.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If
|
|
|
|
|
- **input_x** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`.
|
|
|
|
|
- **input_y** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(N, D, D)`. If
|
|
|
|
|
`transpose_b` is True.
|
|
|
|
|
|
|
|
|
@ -73,7 +72,6 @@ class CusBatchMatMul(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusCholeskyTrsm(PrimitiveWithInfer):
|
|
|
|
|
"""CusCholeskyTrsm definition"""
|
|
|
|
|
"""
|
|
|
|
|
L * LT = A.
|
|
|
|
|
LT * (LT)^-1 = I.
|
|
|
|
@ -112,7 +110,6 @@ class CusCholeskyTrsm(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusFusedAbsMax1(PrimitiveWithInfer):
|
|
|
|
|
"""CusFusedAbsMax1 definition"""
|
|
|
|
|
"""
|
|
|
|
|
Compute the abs max of Tensor input.
|
|
|
|
|
|
|
|
|
@ -154,7 +151,6 @@ class CusFusedAbsMax1(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusImg2Col(PrimitiveWithInfer):
|
|
|
|
|
"""CusImg2Col definition"""
|
|
|
|
|
"""
|
|
|
|
|
Img2col the feature map and the result in reorganized in NC1HWC0.
|
|
|
|
|
|
|
|
|
@ -203,7 +199,6 @@ class CusImg2Col(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatMulCube definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b`.
|
|
|
|
|
|
|
|
|
@ -242,7 +237,6 @@ class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatMulCubeFraczRightMul definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b` and muls the result by scalar `c`.
|
|
|
|
|
|
|
|
|
@ -283,7 +277,6 @@ class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatMulCube(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatMulCube definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b`.
|
|
|
|
|
|
|
|
|
@ -342,7 +335,6 @@ class CusMatMulCube(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatrixCombine(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatrixCombine definition"""
|
|
|
|
|
"""
|
|
|
|
|
move the batch matrix to result matrix diag part.
|
|
|
|
|
The rank of input tensors must be `3`.
|
|
|
|
@ -381,7 +373,6 @@ class CusMatrixCombine(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusTranspose02314(PrimitiveWithInfer):
|
|
|
|
|
"""CusTranspose02314 definition"""
|
|
|
|
|
"""
|
|
|
|
|
Permute input tensor with perm (0, 2, 3, 1, 4)
|
|
|
|
|
|
|
|
|
@ -423,7 +414,6 @@ class CusTranspose02314(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatMulCubeDenseRight(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatMulCubeDenseRight definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b`.
|
|
|
|
|
|
|
|
|
@ -464,7 +454,6 @@ class CusMatMulCubeDenseRight(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
|
|
|
|
|
"""CusMatMulCubeFraczLeftCast definition"""
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b`.
|
|
|
|
|
|
|
|
|
|