MatMul operator (#4856)
* initial matmul operator Similar to np.matmul, but also has transpose_X and transpose_Y flags, and only supports tensors from rank 1 to 3 inclusive. For GPU, uses cublas?gemmStridedBatched. For CPU, uses cblas_?gemm_batch if available via MKL; otherwise a simple serial implementation that loops over the batch dimension is employed for now.revert-4814-Add_sequence_project_op
parent
fd96914d23
commit
164898277c
@ -0,0 +1,124 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
// Implements the logic of numpy matmul:
|
||||
// https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
|
||||
//
|
||||
// but allowing also for a, b to be transposed
|
||||
//
|
||||
// Both a & b can be 1- to 3-dimensional. Higher rank tensors are not supported
|
||||
// yet.
|
||||
template <typename Place, typename T>
|
||||
class MatMulFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& a, bool trans_a,
|
||||
const framework::Tensor& b, bool trans_b, T alpha,
|
||||
framework::Tensor* out, T beta) {
|
||||
auto dim_a = a.dims();
|
||||
auto dim_b = b.dims();
|
||||
|
||||
PADDLE_ENFORCE(a.place() == b.place() && b.place() == out->place(),
|
||||
"Tensors must all be in the same place.");
|
||||
PADDLE_ENFORCE_GE(dim_a.size(), 1,
|
||||
"Input tensor a must be at least 1-dimensional.");
|
||||
PADDLE_ENFORCE_GE(dim_b.size(), 1,
|
||||
"Input tensor b must be at least 1-dimensional.");
|
||||
PADDLE_ENFORCE_LE(dim_a.size(), 3,
|
||||
"Input tensor a must be at most 3-dimensional.");
|
||||
PADDLE_ENFORCE_LE(dim_b.size(), 3,
|
||||
"Input tensor b must be at most 3-dimensional.");
|
||||
|
||||
int M = 0, N = 0, kA = 0, kB = 0, batchCountA = 0, batchCountB = 0,
|
||||
strideA = 0, strideB = 0;
|
||||
|
||||
switch (dim_a.size()) {
|
||||
case 1:
|
||||
// similar to np.matmul:
|
||||
// prepend dimension 1 (no transpose) or append dimension 1 (transpose)
|
||||
M = trans_a ? dim_a[0] : 1;
|
||||
kA = trans_a ? 1 : dim_a[0];
|
||||
break;
|
||||
case 2:
|
||||
M = trans_a ? dim_a[1] : dim_a[0];
|
||||
kA = trans_a ? dim_a[0] : dim_a[1];
|
||||
break;
|
||||
case 3:
|
||||
batchCountA = dim_a[0];
|
||||
M = trans_a ? dim_a[2] : dim_a[1];
|
||||
kA = trans_a ? dim_a[1] : dim_a[2];
|
||||
strideA = M * kA;
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
switch (dim_b.size()) {
|
||||
case 1:
|
||||
// similar to np.matmul:
|
||||
// append dimension 1 (no transpose) or prepend dimension 1 (transpose)
|
||||
kB = trans_b ? 1 : dim_b[0];
|
||||
N = trans_b ? dim_b[0] : 1;
|
||||
break;
|
||||
case 2:
|
||||
kB = trans_b ? dim_b[1] : dim_b[0];
|
||||
N = trans_b ? dim_b[0] : dim_b[1];
|
||||
break;
|
||||
case 3:
|
||||
batchCountB = dim_b[0];
|
||||
kB = trans_b ? dim_b[2] : dim_b[1];
|
||||
N = trans_b ? dim_b[1] : dim_b[2];
|
||||
strideB = kB * N;
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
kA, kB,
|
||||
"First matrix's width must be equal with second matrix's height.");
|
||||
if (batchCountA && batchCountB) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
batchCountA, batchCountB,
|
||||
"When input tensors a and b are both batched, they must have the "
|
||||
"same batch dimension.");
|
||||
}
|
||||
int batchCount = std::max(batchCountA, batchCountB);
|
||||
|
||||
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
|
||||
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
|
||||
|
||||
if (!batchCount) {
|
||||
// regular matrix multiplication
|
||||
gemm<Place, T>(context, transA, transB, M, N, kA, alpha, a.data<T>(),
|
||||
b.data<T>(), beta, out->data<T>());
|
||||
} else {
|
||||
// batched matrix multiplication
|
||||
batched_gemm<Place, T>(context, transA, transB, M, N, kA, alpha,
|
||||
a.data<T>(), b.data<T>(), beta, out->data<T>(),
|
||||
batchCount, strideA, strideB);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,208 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/matmul_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class MatMulOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* context) const override {
|
||||
PADDLE_ENFORCE(context->HasInput("X"),
|
||||
"Input(X) of MatMulOp should not be null.");
|
||||
PADDLE_ENFORCE(context->HasInput("Y"),
|
||||
"Input(Y) of MatMulOp should not be null.");
|
||||
PADDLE_ENFORCE(context->HasOutput("Out"),
|
||||
"Output(Out) of MatMulOp should not be null.");
|
||||
|
||||
auto dim_x = context->GetInputDim("X");
|
||||
auto dim_y = context->GetInputDim("Y");
|
||||
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
|
||||
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
|
||||
|
||||
PADDLE_ENFORCE_GE(dim_x.size(), 1,
|
||||
"Input tensor X must be at least 1-dimensional.");
|
||||
PADDLE_ENFORCE_GE(dim_y.size(), 1,
|
||||
"Input tensor Y must be at least 1-dimensional.");
|
||||
PADDLE_ENFORCE_LE(dim_x.size(), 3,
|
||||
"Input tensor X must be at most 3-dimensional.");
|
||||
PADDLE_ENFORCE_LE(dim_y.size(), 3,
|
||||
"Input tensor Y must be at most 3-dimensional.");
|
||||
|
||||
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
|
||||
bool remove_initial_dim = false, remove_final_dim = false;
|
||||
|
||||
switch (dim_x.size()) {
|
||||
case 1:
|
||||
if (transpose_x) {
|
||||
M = dim_x[0];
|
||||
KX = 1;
|
||||
} else {
|
||||
M = 1;
|
||||
KX = dim_x[0];
|
||||
remove_initial_dim = true;
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
M = transpose_x ? dim_x[1] : dim_x[0];
|
||||
KX = transpose_x ? dim_x[0] : dim_x[1];
|
||||
break;
|
||||
case 3:
|
||||
batchCountX = dim_x[0];
|
||||
M = transpose_x ? dim_x[2] : dim_x[1];
|
||||
KX = transpose_x ? dim_x[1] : dim_x[2];
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
switch (dim_y.size()) {
|
||||
case 1:
|
||||
if (transpose_y) {
|
||||
N = dim_y[0];
|
||||
KY = 1;
|
||||
} else {
|
||||
N = 1;
|
||||
KY = dim_y[0];
|
||||
remove_final_dim = true;
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
KY = transpose_y ? dim_y[1] : dim_y[0];
|
||||
N = transpose_y ? dim_y[0] : dim_y[1];
|
||||
break;
|
||||
case 3:
|
||||
batchCountY = dim_y[0];
|
||||
KY = transpose_y ? dim_y[2] : dim_y[1];
|
||||
N = transpose_y ? dim_y[1] : dim_y[2];
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
KX, KY,
|
||||
"First matrix's width must be equal with second matrix's height.");
|
||||
if (batchCountX && batchCountY) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
batchCountX, batchCountY,
|
||||
"When Input(X) and Input(Y) are both three dimensional, they "
|
||||
"must have the same batch dimension.");
|
||||
}
|
||||
int batchCount = std::max(batchCountX, batchCountY);
|
||||
|
||||
std::vector<int64_t> dim_out;
|
||||
if (batchCount) {
|
||||
dim_out.push_back(batchCount);
|
||||
}
|
||||
if (!remove_initial_dim) {
|
||||
dim_out.push_back(M);
|
||||
}
|
||||
if (!remove_final_dim) {
|
||||
dim_out.push_back(N);
|
||||
}
|
||||
if (dim_out.size() == 0) {
|
||||
// We don't support 0-dimensional Tensors (scalars), so instead
|
||||
// treat the output as a Tensor of shape (1, ) in this case.
|
||||
dim_out.push_back(1);
|
||||
}
|
||||
context->SetOutputDim("Out", framework::make_ddim(dim_out));
|
||||
context->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MatMulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The first input of MatMul op");
|
||||
AddInput("Y", "The second input of MatMul op");
|
||||
AddOutput("Out", "The output of MatMul op");
|
||||
AddAttr<bool>("transpose_X",
|
||||
R"DOC(If true, use the transpose of `X`.
|
||||
)DOC")
|
||||
.SetDefault(false);
|
||||
AddAttr<bool>("transpose_Y",
|
||||
R"DOC(If true, use the transpose of `Y`.
|
||||
)DOC")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
The MatMul operator is used to perform (batched) matrix multiplication
|
||||
over the last two dimensions of the input tensors `X` and `Y`.
|
||||
|
||||
If a transpose flag is specified, the last two dimensions of the
|
||||
tensor are transposed. If the tensor is rank-1 of shape [D], then
|
||||
for `X` it is treated as [1, D] in nontransposed form and as [D, 1]
|
||||
in transposed form, whereas for `Y` it is the opposite: It is treated
|
||||
as [D, 1] in nontransposed form and as [1, D] in transposed form.
|
||||
|
||||
Examples without transpose:
|
||||
- X: [K], Y: [K] => Out: [1]
|
||||
- X: [K], Y: [K, N] => Out: [N]
|
||||
- X: [B, M, K], Y: [K] => Out: [B, M]
|
||||
- X: [M, K], Y: [B, K, N] => Out: [B, M, N]
|
||||
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
|
||||
|
||||
The behavior is designed to be similar to the `numpy.matmul` function.
|
||||
The differences are:
|
||||
- Currently only rank 1 to rank 3 input tensors are supported.
|
||||
- We add `transpose_X` and `transpose_Y` flags.
|
||||
|
||||
Both the input `X` and `Y` can carry the LoD (Level of Details) information,
|
||||
or not. But the output only shares the LoD with input `X`.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class MatMulOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* context) const override {
|
||||
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
|
||||
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto x_dims = context->GetInputDim("X");
|
||||
auto y_dims = context->GetInputDim("Y");
|
||||
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
|
||||
if (context->HasOutput(x_grad_name)) {
|
||||
context->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (context->HasOutput(y_grad_name)) {
|
||||
context->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(matmul, ops::MatMulOp, ops::MatMulOpMaker, matmul_grad,
|
||||
ops::MatMulOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(matmul,
|
||||
ops::MatMulKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
matmul_grad, ops::MatMulGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/matmul_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(matmul,
|
||||
ops::MatMulKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
matmul_grad, ops::MatMulGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,228 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
You may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/matmul.h"
|
||||
#include "paddle/operators/transpose_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace matmul_detail {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
using framework::make_ddim;
|
||||
using framework::vectorize;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class MatMulKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor& x = *context.Input<Tensor>("X");
|
||||
const Tensor& y = *context.Input<Tensor>("Y");
|
||||
Tensor* out = context.Output<Tensor>("Out");
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
bool transpose_x = context.Attr<bool>("transpose_X");
|
||||
bool transpose_y = context.Attr<bool>("transpose_Y");
|
||||
|
||||
math::MatMulFunctor<Place, T>()(context.device_context(), x, transpose_x, y,
|
||||
transpose_y, T(1), out, T(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline Tensor Reshape(const Tensor& input, const DDim& dims) {
|
||||
Tensor output;
|
||||
output.ShareDataWith<T>(input);
|
||||
output.Resize(dims);
|
||||
return output;
|
||||
}
|
||||
|
||||
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
|
||||
// Identity op if the tensor is not of rank 3.
|
||||
template <typename T>
|
||||
Tensor CombineBatchAndM(const Tensor& input) {
|
||||
Tensor output;
|
||||
output.ShareDataWith<T>(input);
|
||||
auto in_dims = input.dims();
|
||||
if (in_dims.size() == 3) {
|
||||
std::vector<int64_t> out_dims = {in_dims[0] * in_dims[1], in_dims[2]};
|
||||
output.Resize(make_ddim(out_dims));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
|
||||
// (Warning: This requires transposing data and writes into new memory.)
|
||||
// Identity op if the tensor is not of rank 3.
|
||||
template <typename Place, typename T>
|
||||
Tensor CombineBatchAndN(const framework::ExecutionContext& context,
|
||||
const Tensor& input) {
|
||||
Tensor output;
|
||||
auto in_dims = input.dims();
|
||||
if (in_dims.size() == 3) {
|
||||
output.Resize(in_dims);
|
||||
output.mutable_data<T>(context.GetPlace());
|
||||
EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2});
|
||||
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
|
||||
output.Resize(make_ddim(out_dims));
|
||||
} else {
|
||||
output.ShareDataWith<T>(input);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
// Using dimensional constraints on matrix multiplication, it is
|
||||
// straight-forward to check the following table for when X and Y
|
||||
// are both matrices.
|
||||
//
|
||||
// transpose_X | False | True | False | True
|
||||
// transpose_Y | False | False | True | True
|
||||
// -----------+----------+----------+----------+-----------
|
||||
// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T
|
||||
// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T
|
||||
//
|
||||
// When X is a vector of size K, we treat it instead as a matrix of shape
|
||||
// (1, K). Similarly, when Y is a vector of size K, we treat it instead as
|
||||
// a matrix of shape (K, 1).
|
||||
//
|
||||
// When X and Y are both 3-dimensional tensors, then the first dimension
|
||||
// the batch dimension can be ignored and the exact same formulas apply
|
||||
// as for two matrices.
|
||||
//
|
||||
// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end
|
||||
// up with formulas like
|
||||
//
|
||||
// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj}
|
||||
//
|
||||
// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N
|
||||
// to X: (P * M) x K, dOut: (P * M) x N.
|
||||
template <typename Place, typename T>
|
||||
class MatMulGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor& x = *context.Input<Tensor>("X");
|
||||
const Tensor& y = *context.Input<Tensor>("Y");
|
||||
const Tensor& dout = *context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
Tensor* dx = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
Tensor* dy = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
bool transpose_x = context.Attr<bool>("transpose_X");
|
||||
bool transpose_y = context.Attr<bool>("transpose_Y");
|
||||
|
||||
std::vector<int64_t> x_dims = vectorize(x.dims());
|
||||
std::vector<int64_t> y_dims = vectorize(y.dims());
|
||||
|
||||
// If X is a vector, reshape it to a matrix.
|
||||
if (x_dims.size() == 1) {
|
||||
x_dims.insert(x_dims.begin(), 1);
|
||||
}
|
||||
|
||||
// If Y is a vector, reshape it to a matrix.
|
||||
if (y_dims.size() == 1) {
|
||||
y_dims.push_back(1);
|
||||
}
|
||||
|
||||
// Fix the dOut dimensions.
|
||||
int M = 0, N = 0, batchCountX = 0, batchCountY = 0;
|
||||
|
||||
switch (x_dims.size()) {
|
||||
case 2:
|
||||
M = transpose_x ? x_dims[1] : x_dims[0];
|
||||
break;
|
||||
case 3:
|
||||
batchCountX = x_dims[0];
|
||||
M = transpose_x ? x_dims[2] : x_dims[1];
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
switch (y_dims.size()) {
|
||||
case 2:
|
||||
N = transpose_y ? y_dims[0] : y_dims[1];
|
||||
break;
|
||||
case 3:
|
||||
batchCountY = y_dims[0];
|
||||
N = transpose_y ? y_dims[1] : y_dims[2];
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
if (batchCountX && batchCountY) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
batchCountX, batchCountY,
|
||||
"When Input(X) and Input(Y) are both three dimensional, they "
|
||||
"must have the same batch dimension.");
|
||||
}
|
||||
int batchCount = std::max(batchCountX, batchCountY);
|
||||
std::vector<int64_t> dout_dims = {M, N};
|
||||
if (batchCount) {
|
||||
dout_dims.insert(dout_dims.begin(), batchCount);
|
||||
}
|
||||
Tensor X = Reshape<T>(x, make_ddim(x_dims));
|
||||
Tensor Y = Reshape<T>(y, make_ddim(y_dims));
|
||||
Tensor dOut = Reshape<T>(dout, make_ddim(dout_dims));
|
||||
|
||||
if (dx) {
|
||||
dx->mutable_data<T>(context.GetPlace());
|
||||
const Tensor& dOut_for_dX =
|
||||
(x_dims.size() == 2 && y_dims.size() == 3)
|
||||
? CombineBatchAndN<Place, T>(context, dOut)
|
||||
: dOut;
|
||||
if (x_dims.size() == 2 && y_dims.size() == 3) {
|
||||
Y = transpose_y ? CombineBatchAndM<T>(Y)
|
||||
: CombineBatchAndN<Place, T>(context, Y);
|
||||
}
|
||||
if (transpose_x) {
|
||||
math::MatMulFunctor<Place, T>()(context.device_context(), Y,
|
||||
transpose_y, dOut_for_dX, transpose_x,
|
||||
T(1), dx, T(0));
|
||||
} else {
|
||||
math::MatMulFunctor<Place, T>()(context.device_context(), dOut_for_dX,
|
||||
transpose_x, Y, !transpose_y, T(1), dx,
|
||||
T(0));
|
||||
}
|
||||
}
|
||||
|
||||
if (dy) {
|
||||
dy->mutable_data<T>(context.GetPlace());
|
||||
const Tensor& dOut_for_dY = (y_dims.size() == 2 && x_dims.size() == 3)
|
||||
? CombineBatchAndM<T>(dOut)
|
||||
: dOut;
|
||||
if (y_dims.size() == 2 && x_dims.size() == 3) {
|
||||
X = transpose_x ? CombineBatchAndN<Place, T>(context, X)
|
||||
: CombineBatchAndM<T>(X);
|
||||
dOut = CombineBatchAndM<T>(dOut);
|
||||
}
|
||||
if (transpose_y) {
|
||||
math::MatMulFunctor<Place, T>()(context.device_context(), dOut_for_dY,
|
||||
transpose_y, X, transpose_x, T(1), dy,
|
||||
T(0));
|
||||
} else {
|
||||
math::MatMulFunctor<Place, T>()(context.device_context(), X,
|
||||
!transpose_x, dOut_for_dY, transpose_y,
|
||||
T(1), dy, T(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace matmul_detail
|
||||
|
||||
using matmul_detail::MatMulKernel;
|
||||
using matmul_detail::MatMulGradKernel;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,119 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y):
|
||||
BATCH_SIZE = 2
|
||||
M = 3
|
||||
N = 4
|
||||
K = 5
|
||||
if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y):
|
||||
K = 1
|
||||
if dim_X == 1:
|
||||
if transpose_X:
|
||||
shape_X = [M]
|
||||
else:
|
||||
shape_X = [K]
|
||||
if dim_Y == 1:
|
||||
if transpose_Y:
|
||||
shape_Y = [N]
|
||||
else:
|
||||
shape_Y = [K]
|
||||
if dim_X >= 2:
|
||||
if transpose_X:
|
||||
shape_X = [K, M]
|
||||
else:
|
||||
shape_X = [M, K]
|
||||
if dim_X == 3:
|
||||
shape_X = [BATCH_SIZE] + shape_X
|
||||
if dim_Y >= 2:
|
||||
if transpose_Y:
|
||||
shape_Y = [N, K]
|
||||
else:
|
||||
shape_Y = [K, N]
|
||||
if dim_Y == 3:
|
||||
shape_Y = [BATCH_SIZE] + shape_Y
|
||||
return shape_X, shape_Y
|
||||
|
||||
|
||||
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
|
||||
"""Reference forward implementation using np.matmul."""
|
||||
# np.matmul does not support the transpose flags, so we manually
|
||||
# transpose X and Y appropriately.
|
||||
if transpose_X:
|
||||
if X.ndim == 1:
|
||||
X = X.reshape((X.size, 1))
|
||||
elif X.ndim == 2:
|
||||
X = X.T
|
||||
elif X.ndim == 3:
|
||||
X = np.transpose(X, (0, 2, 1))
|
||||
else:
|
||||
raise ValueError('X must have between 1 and 3 dimensions')
|
||||
if transpose_Y:
|
||||
if Y.ndim == 1:
|
||||
Y = Y.reshape((1, Y.size))
|
||||
elif Y.ndim == 2:
|
||||
Y = Y.T
|
||||
elif Y.ndim == 3:
|
||||
Y = np.transpose(Y, (0, 2, 1))
|
||||
else:
|
||||
raise ValueError('Y must have between 1 and 3 dimensions')
|
||||
Out = np.matmul(X, Y)
|
||||
if not Out.shape:
|
||||
# We do not support 0-dimensional Tensors (scalars). So where
|
||||
# np.matmul outputs a scalar, we must convert to a Tensor of
|
||||
# shape (1, ) instead.
|
||||
# Everywhere else, we are compatible with np.matmul.
|
||||
Out = np.array([Out], dtype="float32")
|
||||
return Out
|
||||
|
||||
|
||||
class Generator(object):
|
||||
def setUp(self):
|
||||
self.op_type = "matmul"
|
||||
X = np.random.random(self.shape_X).astype("float32")
|
||||
Y = np.random.random(self.shape_Y).astype("float32")
|
||||
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y)
|
||||
self.inputs = {'X': X, 'Y': Y}
|
||||
self.attrs = {
|
||||
'transpose_X': self.transpose_X,
|
||||
'transpose_Y': self.transpose_Y
|
||||
}
|
||||
self.outputs = {'Out': Out}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output(atol=1e-2)
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
|
||||
|
||||
def test_check_grad_ignore_x(self):
|
||||
self.check_grad(
|
||||
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ignore_y(self):
|
||||
self.check_grad(
|
||||
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
# Generate test cases for all possibilities
|
||||
for dim_X in [1, 2, 3]:
|
||||
for dim_Y in [1, 2, 3]:
|
||||
for transpose_X in [False, True]:
|
||||
for transpose_Y in [False, True]:
|
||||
test_name = (
|
||||
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
|
||||
dim_X, dim_Y, transpose_X, transpose_Y))
|
||||
shape_X, shape_Y = generate_compatible_shapes(
|
||||
dim_X, dim_Y, transpose_X, transpose_Y)
|
||||
test_class = type(test_name, (Generator, OpTest), {
|
||||
'shape_X': shape_X,
|
||||
'shape_Y': shape_Y,
|
||||
'transpose_X': transpose_X,
|
||||
'transpose_Y': transpose_Y,
|
||||
})
|
||||
globals()[test_name] = test_class
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue