set resize bilinear thread 1

pull/7727/head
zhaozhenlong 4 years ago
parent c9d6a7880c
commit 87a7b42471

@ -33,6 +33,7 @@ int ResizeCPUKernel::Init() {
if (ret != RET_OK) {
return ret;
}
thread_num_ = context_->thread_num_;
if (!InferShapeDone()) {
return RET_OK;
}
@ -42,6 +43,7 @@ int ResizeCPUKernel::Init() {
int ResizeCPUKernel::ReSize() {
int ret = RET_OK;
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
thread_num_ = 1;
FreeTmpBuffer();
ret = MallocTmpBuffer();
if (ret != RET_OK) {
@ -95,7 +97,7 @@ int ResizeCPUKernel::MallocTmpBuffer() {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
}
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * context_->thread_num_));
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * thread_num_));
if (line_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc data failed";
return RET_NULL_PTR;
@ -166,7 +168,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
int n_h_begin, n_h_end;
int n = out_tensors_.at(0)->shape()[0];
int h = new_height_;
int unit = UP_DIV(n * h, context_->thread_num_);
int unit = UP_DIV(n * h, thread_num_);
n_h_begin = unit * task_id;
n_h_end = std::min(n_h_begin + unit, n * h);
int c = in_tensors_.at(0)->shape()[3];
@ -191,7 +193,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
}
}
ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
align_corners_, task_id, context_->thread_num_);
align_corners_, task_id, thread_num_);
break;
}
case schema::ResizeMethod_UNKNOW:
@ -204,7 +206,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
}
int ResizeCPUKernel::Run() {
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, context_->thread_num_);
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, thread_num_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]";
FreeTmpBuffer();

@ -41,6 +41,7 @@ class ResizeCPUKernel : public ResizeBaseCPUKernel {
void FreeTmpBuffer();
private:
int thread_num_;
int *y_tops_ = nullptr;
int *y_bottoms_ = nullptr;
int *x_lefts_ = nullptr;

Loading…
Cancel
Save