|
|
@ -25,7 +25,7 @@ namespace operators {
|
|
|
|
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
|
|
|
|
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
|
|
|
|
* original x_dim is returned.
|
|
|
|
* original x_dim is returned.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
|
|
|
|
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
|
|
|
|
if (x_dim.size() > 1) {
|
|
|
|
if (x_dim.size() > 1) {
|
|
|
|
return x_dim;
|
|
|
|
return x_dim;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
|
|
|
|
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
|
|
|
|
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
|
|
|
|
* original y_dim is returned.
|
|
|
|
* original y_dim is returned.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
|
|
|
|
static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
|
|
|
|
if (y_dim.size() > 1) {
|
|
|
|
if (y_dim.size() > 1) {
|
|
|
|
return y_dim;
|
|
|
|
return y_dim;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class MatMulKernel : public framework::OpKernel<T> {
|
|
|
|
class MatMulKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
auto& x =
|
|
|
|
auto &x =
|
|
|
|
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
|
|
|
|
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
|
|
|
|
auto& y =
|
|
|
|
auto &y =
|
|
|
|
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
|
|
|
|
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
|
|
|
|
auto* out = context.Output<framework::Tensor>("Out");
|
|
|
|
auto *out = context.Output<framework::Tensor>("Out");
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
|
|
|
|
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
|
|
|
|
// Identity op if the tensor is not of rank 3.
|
|
|
|
// Identity op if the tensor is not of rank 3.
|
|
|
|
static framework::Tensor FoldInitDims(const framework::Tensor& input) {
|
|
|
|
static framework::Tensor FoldInitDims(const framework::Tensor &input) {
|
|
|
|
auto output = input;
|
|
|
|
auto output = input;
|
|
|
|
auto in_dims = input.dims();
|
|
|
|
auto in_dims = input.dims();
|
|
|
|
if (in_dims.size() == 3) {
|
|
|
|
if (in_dims.size() == 3) {
|
|
|
@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
|
|
|
|
// (Warning: This requires transposing data and writes into new memory.)
|
|
|
|
// (Warning: This requires transposing data and writes into new memory.)
|
|
|
|
// Identity op if the tensor is not of rank 3.
|
|
|
|
// Identity op if the tensor is not of rank 3.
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
|
|
|
|
static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context,
|
|
|
|
const framework::Tensor& input) {
|
|
|
|
const framework::Tensor &input) {
|
|
|
|
auto in_dims = input.dims();
|
|
|
|
auto in_dims = input.dims();
|
|
|
|
if (in_dims.size() != 3) {
|
|
|
|
if (in_dims.size() != 3) {
|
|
|
|
return input;
|
|
|
|
return input;
|
|
|
@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
|
|
|
|
* If transposed, `H,W` will be swapped.
|
|
|
|
* If transposed, `H,W` will be swapped.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
static void ReshapeTensorIntoMatrixSequence(
|
|
|
|
static void ReshapeTensorIntoMatrixSequence(
|
|
|
|
framework::Tensor* x, const math::MatDescriptor& descriptor) {
|
|
|
|
framework::Tensor *x, const math::MatDescriptor &descriptor) {
|
|
|
|
int64_t h, w;
|
|
|
|
int64_t h, w;
|
|
|
|
h = descriptor.height_;
|
|
|
|
h = descriptor.height_;
|
|
|
|
w = descriptor.width_;
|
|
|
|
w = descriptor.width_;
|
|
|
@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
|
|
|
|
* If any of `X` and `Y` has batch size BatchSize, the out will have the
|
|
|
|
* If any of `X` and `Y` has batch size BatchSize, the out will have the
|
|
|
|
* BatchSize.
|
|
|
|
* BatchSize.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
|
|
|
|
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
|
|
|
|
framework::Tensor* y,
|
|
|
|
framework::Tensor *y,
|
|
|
|
framework::Tensor* out, bool trans_x,
|
|
|
|
framework::Tensor *out, bool trans_x,
|
|
|
|
bool trans_y) {
|
|
|
|
bool trans_y) {
|
|
|
|
auto x_dim = RowMatrixFromVector(x->dims());
|
|
|
|
auto x_dim = RowMatrixFromVector(x->dims());
|
|
|
|
auto y_dim = ColumnMatrixFromVector(y->dims());
|
|
|
|
auto y_dim = ColumnMatrixFromVector(y->dims());
|
|
|
@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void MatMul(const framework::ExecutionContext& context,
|
|
|
|
void MatMul(const framework::ExecutionContext &context,
|
|
|
|
const framework::Tensor& a, bool trans_a,
|
|
|
|
const framework::Tensor &a, bool trans_a,
|
|
|
|
const framework::Tensor& b, bool trans_b,
|
|
|
|
const framework::Tensor &b, bool trans_b,
|
|
|
|
framework::Tensor* out) const {
|
|
|
|
framework::Tensor *out) const {
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
|
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
|
|
|
|
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
|
|
|
@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
|
|
|
|
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void CalcInputGrad(const framework::ExecutionContext& context,
|
|
|
|
void CalcInputGrad(const framework::ExecutionContext &context,
|
|
|
|
const framework::Tensor& a, bool trans_a,
|
|
|
|
const framework::Tensor &a, bool trans_a,
|
|
|
|
bool is_fold_init_dims_a, const framework::Tensor& b,
|
|
|
|
bool is_fold_init_dims_a, const framework::Tensor &b,
|
|
|
|
bool trans_b, bool is_fold_init_dims_b,
|
|
|
|
bool trans_b, bool is_fold_init_dims_b,
|
|
|
|
framework::Tensor* out) const {
|
|
|
|
framework::Tensor *out) const {
|
|
|
|
if (out == nullptr) return;
|
|
|
|
if (out == nullptr) return;
|
|
|
|
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
|
|
|
|
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
|
|
|
|
out->dims().size() == 2;
|
|
|
|
out->dims().size() == 2;
|
|
|
|
if (!need_combine) {
|
|
|
|
if (!need_combine) {
|
|
|
|
MatMul(context, a, trans_a, b, trans_b, out);
|
|
|
|
MatMul(context, a, trans_a, b, trans_b, out);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
auto& ctx = context.template device_context<DeviceContext>();
|
|
|
|
auto &ctx = context.template device_context<DeviceContext>();
|
|
|
|
MatMul(context, is_fold_init_dims_a
|
|
|
|
MatMul(context, is_fold_init_dims_a
|
|
|
|
? FoldInitDims(a)
|
|
|
|
? FoldInitDims(a)
|
|
|
|
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
|
|
|
|
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
|
|
|
@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
auto x = *context.Input<framework::Tensor>("X");
|
|
|
|
auto x = *context.Input<framework::Tensor>("X");
|
|
|
|
auto y = *context.Input<framework::Tensor>("Y");
|
|
|
|
auto y = *context.Input<framework::Tensor>("Y");
|
|
|
|
auto dout =
|
|
|
|
auto dout =
|
|
|
|
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
|
|
|
|
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
|
|
|
|
bool transpose_x = context.Attr<bool>("transpose_X");
|
|
|
|
bool transpose_x = context.Attr<bool>("transpose_X");
|
|
|
|
bool transpose_y = context.Attr<bool>("transpose_Y");
|
|
|
|
bool transpose_y = context.Attr<bool>("transpose_Y");
|
|
|
|
|
|
|
|
|
|
|
@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContext* context) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *context) const override {
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"),
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"),
|
|
|
|
"Input(X) of MatMulOp should not be null.");
|
|
|
|
"Input(X) of MatMulOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(context->HasInput("Y"),
|
|
|
|
PADDLE_ENFORCE(context->HasInput("Y"),
|
|
|
@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContext* context) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *context) const override {
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
|
|
|
|
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
|
|
|
|
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
|
|
|
@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
auto* retv = new framework::OpDesc();
|
|
|
|
auto *retv = new framework::OpDesc();
|
|
|
|
retv->SetType("matmul_grad");
|
|
|
|
retv->SetType("matmul_grad");
|
|
|
|
retv->SetInput("X", Input("X"));
|
|
|
|
retv->SetInput("X", Input("X"));
|
|
|
|
retv->SetInput("Y", Input("Y"));
|
|
|
|
retv->SetInput("Y", Input("Y"));
|
|
|
@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
|
|
|
|
ops::MatMulOpGradMaker);
|
|
|
|
ops::MatMulOpGradMaker);
|
|
|
|
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
|
|
|
|
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
matmul_grad,
|
|
|
|
matmul_grad,
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
matmul_grad,
|
|
|
|
matmul_grad,
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|