|
|
|
@ -33,6 +33,7 @@ int ResizeCPUKernel::Init() {
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
thread_num_ = context_->thread_num_;
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -42,6 +43,7 @@ int ResizeCPUKernel::Init() {
|
|
|
|
|
int ResizeCPUKernel::ReSize() {
|
|
|
|
|
int ret = RET_OK;
|
|
|
|
|
if (method_ == static_cast<int>(schema::ResizeMethod_LINEAR)) {
|
|
|
|
|
thread_num_ = 1;
|
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|
ret = MallocTmpBuffer();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
@ -95,7 +97,7 @@ int ResizeCPUKernel::MallocTmpBuffer() {
|
|
|
|
|
MS_LOG(ERROR) << "malloc data failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * context_->thread_num_));
|
|
|
|
|
line_buffer_ = reinterpret_cast<float *>(malloc(sizeof(float) * w * c * 2 * thread_num_));
|
|
|
|
|
if (line_buffer_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc data failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
@ -166,7 +168,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
|
|
|
|
int n_h_begin, n_h_end;
|
|
|
|
|
int n = out_tensors_.at(0)->shape()[0];
|
|
|
|
|
int h = new_height_;
|
|
|
|
|
int unit = UP_DIV(n * h, context_->thread_num_);
|
|
|
|
|
int unit = UP_DIV(n * h, thread_num_);
|
|
|
|
|
n_h_begin = unit * task_id;
|
|
|
|
|
n_h_end = std::min(n_h_begin + unit, n * h);
|
|
|
|
|
int c = in_tensors_.at(0)->shape()[3];
|
|
|
|
@ -191,7 +193,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ret = ResizeNearestNeighbor(input_data, output_data, input_shape.data(), out_tensors_[0]->shape().data(),
|
|
|
|
|
align_corners_, task_id, context_->thread_num_);
|
|
|
|
|
align_corners_, task_id, thread_num_);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case schema::ResizeMethod_UNKNOW:
|
|
|
|
@ -204,7 +206,7 @@ int ResizeCPUKernel::RunImpl(int task_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ResizeCPUKernel::Run() {
|
|
|
|
|
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, context_->thread_num_);
|
|
|
|
|
int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeImpl, this, thread_num_);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Resize run error, error_code[" << error_code << "]";
|
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|