|
|
|
@ -57,7 +57,7 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
|
|
|
|
|
params_->batch = batch;
|
|
|
|
|
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2];
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
if (params_->row_ == 1) {
|
|
|
|
|
if (params_->a_init_shape_ && params_->row_ == 1) {
|
|
|
|
|
is_vector_a_ = true;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
@ -134,7 +134,7 @@ int MatmulCPUKernel::InitBias() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MatmulCPUKernel::ReSize() {
|
|
|
|
|
if (params_->a_const_ == false || params_->a_has_shape_ == false) {
|
|
|
|
|
if (params_->a_const_ == false || params_->a_init_shape_ == false) {
|
|
|
|
|
if (a_pack_ptr_ != nullptr) {
|
|
|
|
|
free(a_pack_ptr_);
|
|
|
|
|
a_pack_ptr_ = nullptr;
|
|
|
|
@ -145,7 +145,7 @@ int MatmulCPUKernel::ReSize() {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params_->b_const_ == false || params_->b_has_shape_ == false) {
|
|
|
|
|
if (params_->b_const_ == false || params_->b_init_shape_ == false) {
|
|
|
|
|
if (b_pack_ptr_ != nullptr) {
|
|
|
|
|
free(b_pack_ptr_);
|
|
|
|
|
b_pack_ptr_ = nullptr;
|
|
|
|
@ -222,16 +222,16 @@ void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MatmulCPUKernel::Init() {
|
|
|
|
|
params_->a_has_shape_ = (in_tensors_[0]->shape().size() != 0);
|
|
|
|
|
params_->b_has_shape_ = (in_tensors_[1]->shape().size() != 0);
|
|
|
|
|
if (params_->a_has_shape_ == true) {
|
|
|
|
|
params_->a_init_shape_ = (in_tensors_[0]->shape().size() != 0);
|
|
|
|
|
params_->b_init_shape_ = (in_tensors_[1]->shape().size() != 0);
|
|
|
|
|
if (params_->a_init_shape_ == true) {
|
|
|
|
|
auto ret = MallocMatrixABuffer();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params_->b_has_shape_ == true) {
|
|
|
|
|
if (params_->b_init_shape_ == true) {
|
|
|
|
|
auto ret = MallocMatrixBBuffer();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed";
|
|
|
|
@ -300,7 +300,7 @@ int MatmulCPUKernel::Run() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params_->b_const_ == false || is_train()) {
|
|
|
|
|
if (is_vector_a_) {
|
|
|
|
|
if (is_vector_a_ && params_->b_transpose_) {
|
|
|
|
|
b_ptr_ = b_src;
|
|
|
|
|
} else {
|
|
|
|
|
InitMatrixB(b_src, b_pack_ptr_);
|
|
|
|
|