|
|
|
@ -20,6 +20,8 @@
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
|
#include "schema/model_generated.h"
|
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
|
#include "src/runtime/runtime_api.h"
|
|
|
|
|
#include "src/runtime/thread_pool.h"
|
|
|
|
|
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
|
|
|
using mindspore::lite::KernelRegistrar;
|
|
|
|
@ -42,12 +44,7 @@ int ConcatCPUKernel::Init() {
|
|
|
|
|
|
|
|
|
|
int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); }
|
|
|
|
|
|
|
|
|
|
int ConcatCPUKernel::Run() {
|
|
|
|
|
auto prepare_ret = Prepare();
|
|
|
|
|
if (prepare_ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
|
|
|
|
return prepare_ret;
|
|
|
|
|
}
|
|
|
|
|
int ConcatCPUKernel::DoConcat(int task_id) {
|
|
|
|
|
auto input_num = in_tensors_.size();
|
|
|
|
|
std::vector<void *> inputs_addr(input_num, nullptr);
|
|
|
|
|
std::vector<int *> inputs_output_shape(input_num + 1, nullptr);
|
|
|
|
@ -63,7 +60,27 @@ int ConcatCPUKernel::Run() {
|
|
|
|
|
auto output_addr = out_tensors_.at(0)->MutableData();
|
|
|
|
|
|
|
|
|
|
Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(),
|
|
|
|
|
output_shape.size(), output_addr);
|
|
|
|
|
output_shape.size(), output_addr, task_id, thread_count_);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ConcatsRun(void *cdata, int task_id) {
|
|
|
|
|
auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata);
|
|
|
|
|
auto error_code = concat_kernel->DoConcat(task_id);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ConcatCPUKernel::Run() {
|
|
|
|
|
auto prepare_ret = Prepare();
|
|
|
|
|
if (prepare_ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
|
|
|
|
return prepare_ret;
|
|
|
|
|
}
|
|
|
|
|
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConcatsRun, this, thread_count_);
|
|
|
|
|
return error_code;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|
|