|
|
|
@ -56,7 +56,7 @@ void MatmulFp32BaseCPUKernel::InitParameter() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
|
|
|
|
if (params_->row_ == 1 && params_->b_const_ == false) {
|
|
|
|
|
if (params_->row_ == 1) {
|
|
|
|
|
vec_matmul_ = true;
|
|
|
|
|
}
|
|
|
|
|
params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_);
|
|
|
|
@ -238,10 +238,15 @@ int MatmulFp32BaseCPUKernel::Init() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params_->b_const_ == true) {
|
|
|
|
|
if (RET_OK != InitBufferB()) {
|
|
|
|
|
/* copy origin b data, pack in resize
|
|
|
|
|
* pack after a infershape done */
|
|
|
|
|
auto b_tensor = in_tensors_[1];
|
|
|
|
|
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_ * params_->deep_ * sizeof(float)));
|
|
|
|
|
if (src_b_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Matmul fp16 malloc src_b_ failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()));
|
|
|
|
|
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->col_ * params_->deep_ * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -249,6 +254,15 @@ int MatmulFp32BaseCPUKernel::Init() {
|
|
|
|
|
int MatmulFp32BaseCPUKernel::ReSize() {
|
|
|
|
|
ResizeParameter();
|
|
|
|
|
|
|
|
|
|
if (params_->b_const_ == true && src_b_ != nullptr) {
|
|
|
|
|
if (RET_OK != InitBufferB()) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
InitMatrixB(src_b_);
|
|
|
|
|
free(src_b_);
|
|
|
|
|
src_b_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
|
|
|
|
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|