|
|
|
@ -72,8 +72,21 @@ class MatMulKernel : public framework::OpKernel<T> {
|
|
|
|
|
ColumnMatrixFromVector(y.dims()), 0, context.Attr<bool>("transpose_Y"));
|
|
|
|
|
auto scale = static_cast<T>(context.Attr<float>("alpha"));
|
|
|
|
|
|
|
|
|
|
int head_number = 1;
|
|
|
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
head_number = context.Attr<int>("head_number");
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = x.dims();
|
|
|
|
|
const auto &y_dims = y.dims();
|
|
|
|
|
if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) {
|
|
|
|
|
// the transpose_X must be false, if is true, the transpose cost much time
|
|
|
|
|
if (!context.Attr<bool>("transpose_X")) {
|
|
|
|
|
mat_dim_a.height_ *= mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_a.batch_size_ = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
int head_number = context.Attr<int>("head_number");
|
|
|
|
|
bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_);
|
|
|
|
|
|
|
|
|
|
if (head_number > 1) {
|
|
|
|
@ -210,6 +223,19 @@ class MatMulGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(context);
|
|
|
|
|
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
|
|
|
|
|
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
|
|
|
|
|
|
|
|
|
|
int head_number = 1;
|
|
|
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
|
|
|
|
|
head_number = context.Attr<int>("head_number");
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) {
|
|
|
|
|
// the transpose_X must be false, if is true, the transpose cost much time
|
|
|
|
|
if (!trans_a) {
|
|
|
|
|
mat_dim_a.height_ *= mat_dim_a.batch_size_;
|
|
|
|
|
mat_dim_a.batch_size_ = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
blas.MatMul(a, mat_dim_a, b, mat_dim_b,
|
|
|
|
|
static_cast<T>(context.Attr<float>("alpha")), out, T(0));
|
|
|
|
|
}
|
|
|
|
|