|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/matmul_op.h"
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -41,10 +42,26 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input tensor X must be at least 1-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_GE(dim_y.size(), 1,
|
|
|
|
|
"Input tensor Y must be at least 1-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_LE(dim_x.size(), 3,
|
|
|
|
|
"Input tensor X must be at most 3-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_LE(dim_y.size(), 3,
|
|
|
|
|
"Input tensor Y must be at most 3-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_LE(dim_x.size(), 4,
|
|
|
|
|
"Input tensor X must be at most 4-dimensional.");
|
|
|
|
|
PADDLE_ENFORCE_LE(dim_y.size(), 4,
|
|
|
|
|
"Input tensor Y must be at most 4-dimensional.");
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_dim;
|
|
|
|
|
int64_t batch_count = 1;
|
|
|
|
|
if (dim_x.size() > 3) {
|
|
|
|
|
PADDLE_ENFORCE(dim_y.size() == dim_x.size(),
|
|
|
|
|
"The dimensions of X and Y must be the same, and both of "
|
|
|
|
|
"them should be 4-dimensional.");
|
|
|
|
|
for (int j = 0; j < dim_x.size() - 2; ++j) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dim_y[j] == dim_x[j],
|
|
|
|
|
"The dimensions of X and Y must be the same, and both of "
|
|
|
|
|
"them should be 4-dimensional.");
|
|
|
|
|
out_dim.push_back(dim_x[j]);
|
|
|
|
|
batch_count *= dim_x[j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
|
|
|
|
|
bool remove_initial_dim = false, remove_final_dim = false;
|
|
|
|
@ -70,7 +87,11 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
KX = transpose_x ? dim_x[1] : dim_x[2];
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
assert(false);
|
|
|
|
|
batchCountX = batch_count;
|
|
|
|
|
size_t mat_s = dim_x.size() - 2;
|
|
|
|
|
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
|
|
|
|
|
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch (dim_y.size()) {
|
|
|
|
@ -94,7 +115,10 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
N = transpose_y ? dim_y[1] : dim_y[2];
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
assert(false);
|
|
|
|
|
batchCountY = batch_count;
|
|
|
|
|
size_t mat_s = dim_y.size() - 2;
|
|
|
|
|
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
|
|
|
|
|
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -110,7 +134,11 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> dim_out;
|
|
|
|
|
if (batchCount) {
|
|
|
|
|
dim_out.push_back(batchCount);
|
|
|
|
|
if (dim_x.size() > 3) {
|
|
|
|
|
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
|
|
|
|
|
} else {
|
|
|
|
|
dim_out.push_back(batchCount);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!remove_initial_dim) {
|
|
|
|
|
dim_out.push_back(M);
|
|
|
|
|