|
|
|
@ -111,6 +111,20 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto mat_dim_b =
|
|
|
|
|
math::CreateMatrixDescriptor(ColumnMatrixFromVector(y->dims()), 0,
|
|
|
|
|
context.Attr<bool>("transpose_Y"));
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
const auto &y_dims = y->dims();
|
|
|
|
|
if (x_dims.size() == 3 && y_dims.size() <= 2) {
|
|
|
|
|
// if transpose_X is true, the transpose cost much time
|
|
|
|
|
if (!context.Attr<bool>("transpose_X")) {
|
|
|
|
|
mat_dim_a.height_ *= mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_a.batch_size_ = 0;
|
|
|
|
|
} else {
|
|
|
|
|
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
mat_dim_a.width_, mat_dim_b.height_,
|
|
|
|
|
platform::errors::InvalidArgument("Shape mistake in matmul_op"));
|
|
|
|
@ -224,12 +238,26 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
|
|
|
|
|
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
|
|
|
|
|
const auto &a_dims = a.dims();
|
|
|
|
|
const auto &b_dims = b.dims();
|
|
|
|
|
if (a_dims.size() == 3 && b_dims.size() <= 2) {
|
|
|
|
|
// if transpose_X is true, the transpose cost much time
|
|
|
|
|
if (!context.Attr<bool>("transpose_X")) {
|
|
|
|
|
mat_dim_a.height_ *= mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_a.batch_size_ = 0;
|
|
|
|
|
} else {
|
|
|
|
|
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
mat_dim_a.width_, mat_dim_b.height_,
|
|
|
|
|
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
mat_dim_a.batch_size_, mat_dim_b.batch_size_,
|
|
|
|
|
platform::errors::InvalidArgument("Shape mistake in matmul_grad_op"));
|
|
|
|
|
|
|
|
|
|
T alpha = static_cast<T>(context.Attr<float>("alpha"));
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|