|
|
|
@ -71,8 +71,14 @@ static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims,
|
|
|
|
|
for (int i = 0; i < ndim; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1,
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) and Input(Y) has error dim."));
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) and Input(Y) has error dim."
|
|
|
|
|
"X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s],"
|
|
|
|
|
"or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1,"
|
|
|
|
|
"But received X_broadcast's shape[%s] = [%s]"
|
|
|
|
|
"received Y_broadcast's shape[%s] = [%s]",
|
|
|
|
|
i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i]));
|
|
|
|
|
if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) {
|
|
|
|
|
out_bd_dims[i] = 0;
|
|
|
|
|
} else {
|
|
|
|
@ -118,10 +124,13 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
const T* y_data = Y->data<T>();
|
|
|
|
|
|
|
|
|
|
if (x_ndim == 1 && y_ndim == 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(X->numel(), Y->numel(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"X's numbers is not equal to Y's numbers,"
|
|
|
|
|
"when X/Y's dims =1"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
X->numel(), Y->numel(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"X's numbers must be equal to Y's numbers,"
|
|
|
|
|
"when X/Y's dims =1. But received X has [%d] elements,"
|
|
|
|
|
"received Y has [%d] elements",
|
|
|
|
|
X->numel(), Y->numel()));
|
|
|
|
|
VLOG(3) << "MatMul's case 1";
|
|
|
|
|
Out->Resize({1});
|
|
|
|
|
Out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -140,13 +149,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
if (x_ndim == 1) {
|
|
|
|
|
const int N = X->numel();
|
|
|
|
|
if (trans_y) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y_dims[y_ndim - 1], N,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Y) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) has error dim."
|
|
|
|
|
"Y'dims[%d] must be equal to %d"
|
|
|
|
|
"But received Y'dims[%d] is %d",
|
|
|
|
|
y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y_dims[y_ndim - 2], N,
|
|
|
|
|
platform::errors::InvalidArgument("Input(Y) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) has error dim."
|
|
|
|
|
"Y'dims[%d] must be equal to %d"
|
|
|
|
|
"But received Y'dims[%d] is %d",
|
|
|
|
|
y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2]));
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::int64_t> out_dims(y_ndim - 1);
|
|
|
|
|
if (trans_y) {
|
|
|
|
@ -182,13 +197,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
if (y_ndim == 1) {
|
|
|
|
|
const int N = Y->numel();
|
|
|
|
|
if (trans_x) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[x_ndim - 2], N,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) has error dim."
|
|
|
|
|
"X'dims[%d] must be equal to %d"
|
|
|
|
|
"But received X'dims[%d] is %d",
|
|
|
|
|
x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[x_ndim - 1], N,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) has error dim."
|
|
|
|
|
"X'dims[%d] must be equal to %d"
|
|
|
|
|
"But received X'dims[%d] is %d",
|
|
|
|
|
x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1]));
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::int64_t> out_dims(x_ndim - 1);
|
|
|
|
|
if (trans_x) {
|
|
|
|
@ -225,11 +246,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
|
|
|
|
|
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
|
|
|
|
|
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
|
|
|
|
|
if (trans_y) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) has error dim."
|
|
|
|
|
"Y'dims[%d] must be equal to %d"
|
|
|
|
|
"But received Y'dims[%d] is %d",
|
|
|
|
|
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) has error dim."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Y) has error dim."
|
|
|
|
|
"Y'dims[%d] must be equal to %d"
|
|
|
|
|
"But received Y'dims[%d] is %d",
|
|
|
|
|
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
|
|
|
|
|
}
|
|
|
|
|
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
|
|
|
|
|
const int ndim = (std::max)(x_ndim, y_ndim);
|
|
|
|
|