|
|
@ -29,7 +29,8 @@ template <>
|
|
|
|
struct CBlas<int8_t> {
|
|
|
|
struct CBlas<int8_t> {
|
|
|
|
template <typename... ARGS>
|
|
|
|
template <typename... ARGS>
|
|
|
|
static void VCOPY(ARGS... args) {
|
|
|
|
static void VCOPY(ARGS... args) {
|
|
|
|
PADDLE_THROW("Blas VCOPY don't support int8_t");
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"Blas VCOPY do not supported on CPU, please check your code"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -347,22 +348,47 @@ struct CBlas<double> {
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
struct CBlas<platform::float16> {
|
|
|
|
struct CBlas<platform::float16> {
|
|
|
|
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
|
|
|
|
static void GEMM(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 GEMM not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static void SMM_GEMM(...) {
|
|
|
|
static void SMM_GEMM(...) {
|
|
|
|
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 SMM_GEMM not supported on CPU, please check your code"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
|
|
|
static void VMUL(...) {
|
|
|
|
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
static void VSQUARE(...) {
|
|
|
|
"float16 VMUL not supported on CPU, please check your code"));
|
|
|
|
PADDLE_THROW("float16 VSQUARE not supported on CPU");
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
|
|
|
|
static void VEXP(...) {
|
|
|
|
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
|
|
|
"float16 VEXP not supported on CPU, please check your code"));
|
|
|
|
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
|
|
|
|
}
|
|
|
|
|
|
|
|
static void VSQUARE(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 VSQUARE not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
static void VPOW(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 VPOW not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
static void DOT(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 DOT not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
static void SCAL(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 SCAL not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
static void ASUM(...) {
|
|
|
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 ASUM not supported on CPU, please check your code"));
|
|
|
|
|
|
|
|
};
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
#ifdef PADDLE_WITH_MKLML
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
static void GEMM_BATCH(...) {
|
|
|
|
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
|
|
|
"float16 GEMM_BATCH not supported on CPU, please check your code"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
|
|
|
|
auto dim_a = mat_a.dims();
|
|
|
|
auto dim_a = mat_a.dims();
|
|
|
|
auto dim_b = mat_b.dims();
|
|
|
|
auto dim_b = mat_b.dims();
|
|
|
|
auto dim_out = mat_out->dims();
|
|
|
|
auto dim_out = mat_out->dims();
|
|
|
|
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
"The input and output of matmul be matrix");
|
|
|
|
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, true,
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(),
|
|
|
|
"The input and output of matmul should be matrix, the dim size must "
|
|
|
|
"The places of matrices must be same");
|
|
|
|
"be 2,"
|
|
|
|
|
|
|
|
"but received dim size input_a:%d, input_b:%d, output:%d",
|
|
|
|
|
|
|
|
dim_a.size(), dim_b.size(), dim_out.size()));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), true,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument("The places of matrices in the matmul "
|
|
|
|
|
|
|
|
"should be same, please check your "
|
|
|
|
|
|
|
|
"code."));
|
|
|
|
|
|
|
|
|
|
|
|
int M = dim_out[0];
|
|
|
|
int M = dim_out[0];
|
|
|
|
int N = dim_out[1];
|
|
|
|
int N = dim_out[1];
|
|
|
@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE_EQ(W1, H2);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
W1, H2,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The fisrt matrix width should be same as second matrix height,"
|
|
|
|
|
|
|
|
"but received fisrt matrix width %d"
|
|
|
|
|
|
|
|
", second matrix height %d",
|
|
|
|
|
|
|
|
W1, H2));
|
|
|
|
int ldc = W2 * head_number;
|
|
|
|
int ldc = W2 * head_number;
|
|
|
|
int sub_width = W1 / head_number;
|
|
|
|
int sub_width = W1 / head_number;
|
|
|
|
|
|
|
|
|
|
|
@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
|
|
|
|
const framework::Tensor &mat_b,
|
|
|
|
const framework::Tensor &mat_b,
|
|
|
|
const MatDescriptor &dim_b, T alpha,
|
|
|
|
const MatDescriptor &dim_b, T alpha,
|
|
|
|
framework::Tensor *mat_out, T beta) const {
|
|
|
|
framework::Tensor *mat_out, T beta) const {
|
|
|
|
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
dim_a.width_, dim_b.height_,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The fisrt matrix width should be same as second matrix height,"
|
|
|
|
|
|
|
|
"but received fisrt matrix width %d"
|
|
|
|
|
|
|
|
", second matrix height %d",
|
|
|
|
|
|
|
|
dim_a.width_, dim_b.height_));
|
|
|
|
|
|
|
|
|
|
|
|
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
|
|
|
|
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
|
|
|
@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
|
|
|
|
dim_a.width_, alpha, mat_a.data<T>(),
|
|
|
|
dim_a.width_, alpha, mat_a.data<T>(),
|
|
|
|
mat_b.data<T>(), beta, mat_out->data<T>());
|
|
|
|
mat_b.data<T>(), beta, mat_out->data<T>());
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0,
|
|
|
|
dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
|
|
|
|
|
|
|
|
dim_b.batch_size_ == 0,
|
|
|
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
"dim_a.batch_size should be equal to dim_b.batch_size, or "
|
|
|
|
"dim_a.batch_size should be equal to dim_b.batch_size, or "
|
|
|
|
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
|
|
|
|
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
|
|
|
|
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
|
|
|
|
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
|
|
|
|
dim_a.batch_size_, dim_b.batch_size_);
|
|
|
|
dim_a.batch_size_, dim_b.batch_size_));
|
|
|
|
this->template BatchedGEMM<T>(
|
|
|
|
this->template BatchedGEMM<T>(
|
|
|
|
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
|
|
|
|
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
|
|
|
|
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
|
|
|
|
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
|
|
|
@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
|
|
|
|
int head_number,
|
|
|
|
int head_number,
|
|
|
|
framework::Tensor *mat_out, T beta,
|
|
|
|
framework::Tensor *mat_out, T beta,
|
|
|
|
bool mat_b_split_vertical) const {
|
|
|
|
bool mat_b_split_vertical) const {
|
|
|
|
PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_GE(head_number, 1);
|
|
|
|
dim_a.width_ % head_number, 0,
|
|
|
|
PADDLE_ENFORCE_LE(head_number, dim_a.width_);
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The first input width must be some times the head number"
|
|
|
|
|
|
|
|
"but received first input width %d"
|
|
|
|
|
|
|
|
", head_number %d",
|
|
|
|
|
|
|
|
dim_a.width_, head_number));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(head_number, 1,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The head number should be greater equal 1,"
|
|
|
|
|
|
|
|
"but received head number %d",
|
|
|
|
|
|
|
|
head_number));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
|
|
|
head_number, dim_a.width_,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The head number should be less equal first input width,"
|
|
|
|
|
|
|
|
"but received first input width %d"
|
|
|
|
|
|
|
|
", head_number %d",
|
|
|
|
|
|
|
|
dim_a.width_, head_number));
|
|
|
|
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
|
|
|
|
|
|
|
|
if (mat_b_split_vertical) {
|
|
|
|
if (mat_b_split_vertical) {
|
|
|
|
PADDLE_ENFORCE_EQ(dim_b.height_, dim_a.width_ / head_number);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(dim_b.width_ % head_number, 0);
|
|
|
|
dim_b.height_, dim_a.width_ / head_number,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The second input height should be equal than first input width,"
|
|
|
|
|
|
|
|
"but received second input height %d, first input width %d",
|
|
|
|
|
|
|
|
dim_b.height_, dim_a.width_ / head_number));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
dim_a.width_ % head_number, 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The second input width should be some times the head number"
|
|
|
|
|
|
|
|
"but received second input width %d"
|
|
|
|
|
|
|
|
", head_number %d",
|
|
|
|
|
|
|
|
dim_b.width_, head_number));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
|
|
|
|
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
|
|
|
@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
|
|
|
|
mat_out->data<T>() + sub_matC_offset, ldc);
|
|
|
|
mat_out->data<T>() + sub_matC_offset, ldc);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
PADDLE_ENFORCE_EQ((dim_a.batch_size_ == dim_b.batch_size_ ||
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0),
|
|
|
|
(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
|
|
|
|
true);
|
|
|
|
dim_b.batch_size_ == 0),
|
|
|
|
|
|
|
|
true,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The first input batch size should be equal than second input,"
|
|
|
|
|
|
|
|
"either two input batch size is 0, but received first input batch "
|
|
|
|
|
|
|
|
"size"
|
|
|
|
|
|
|
|
" %d, second input batch size %d",
|
|
|
|
|
|
|
|
dim_a.batch_size_, dim_b.batch_size_));
|
|
|
|
|
|
|
|
|
|
|
|
this->template BatchedGEMMWithHead<T>(
|
|
|
|
this->template BatchedGEMMWithHead<T>(
|
|
|
|
transA, transB, dim_a.width_, dim_a.height_, dim_b.width_,
|
|
|
|
transA, transB, dim_a.width_, dim_a.height_, dim_b.width_,
|
|
|
|