|
|
|
@ -46,7 +46,7 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace math {
|
|
|
|
|
|
|
|
|
|
struct MatDim {
|
|
|
|
|
struct MatDescriptor {
|
|
|
|
|
int64_t height_;
|
|
|
|
|
int64_t width_;
|
|
|
|
|
int64_t stride_{0};
|
|
|
|
@ -54,8 +54,8 @@ struct MatDim {
|
|
|
|
|
bool trans_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
|
|
|
|
|
bool trans);
|
|
|
|
|
extern MatDescriptor GetMatDim(const framework::DDim& tensor,
|
|
|
|
|
int num_flatten_cols, bool trans);
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext>
|
|
|
|
|
class Blas {
|
|
|
|
@ -102,26 +102,9 @@ class Blas {
|
|
|
|
|
int batchCount, int64_t strideA, int64_t strideB) const;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
|
|
|
|
|
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
|
|
|
|
|
framework::Tensor* mat_out, T beta) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
|
|
|
|
|
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
|
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
|
|
|
|
|
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
|
|
|
|
|
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
|
|
|
|
|
dim_a.width_, alpha, mat_a.data<T>(),
|
|
|
|
|
mat_b.data<T>(), beta, mat_out->data<T>());
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
|
|
|
|
|
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
|
|
|
|
|
this->template BatchedGEMM<T>(
|
|
|
|
|
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
|
|
|
|
|
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
|
|
|
|
|
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
|
|
|
|
|
dim_a.stride_, dim_b.stride_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
|
|
|
|
|
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
|
|
|
|
|
T alpha, framework::Tensor* mat_out, T beta) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const DeviceContext& context_;
|
|
|
|
|