diff --git a/mindspore/lite/src/executor.h b/mindspore/lite/src/executor.h index c1b073a8a0..d951ffb6e0 100644 --- a/mindspore/lite/src/executor.h +++ b/mindspore/lite/src/executor.h @@ -28,7 +28,7 @@ class Executor { Executor() = default; virtual ~Executor() = default; - virtual int Prepare(const std::vector &kernels) { return 0; } + virtual int Prepare(const std::vector &kernels) { return RET_OK; } virtual int Run(std::vector &in_tensors, std::vector &out_tensors, std::vector &kernels, Allocator *allocator = nullptr, diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 33f36c0283..8b328e2965 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -63,7 +63,7 @@ int SubGraphKernel::Prepare() { return mindspore::lite::RET_NULL_PTR; } auto ret = node->Prepare(); - if (ret == RET_OK) { + if (ret != RET_OK) { MS_LOG(ERROR) << "prepare node " << node->name() << " failed"; return ret; } @@ -180,6 +180,20 @@ int SubGraphKernel::ReSize(bool is_interrupt) { return RET_OK; } +int CpuSubGraph::Prepare() { + auto ret = SubGraphKernel::Prepare(); + if (ret != RET_OK) { + return ret; + } + for (auto node : nodes_) { + for (auto tensor : node->out_tensors()) { + MS_ASSERT(tensor != nullptr); + tensor->set_allocator(this->context_->allocator.get()); + } + } + return RET_OK; +} + int CpuFp32SubGraph::PreProcess() { return RET_OK; } int CpuFp16SubGraph::PreProcess() { diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 2d4607f2fe..077b8536e4 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -61,12 +61,34 @@ class SubGraphKernel : public LiteKernel { mindspore::lite::Executor *executor_ = nullptr; }; -class CpuFp32SubGraph : public SubGraphKernel { +class CpuSubGraph : public SubGraphKernel { + public: + explicit CpuSubGraph(const std::vector &inputs, const std::vector &outputs, + const std::vector &in_kernels, const std::vector &out_kernels, + const std::vector &nodes, const lite::InnerContext *ctx) + : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { + subgraph_type_ = kCpuFP32SubGraph; + this->executor_ = new mindspore::lite::Executor; + } + + ~CpuSubGraph() override = default; + + int Prepare() override; + int Init() override { return SubGraphKernel::Init(); } + int PreProcess() override { return SubGraphKernel::PreProcess(); } + int Run() override { return SubGraphKernel::Run(); } + int Run(const KernelCallBack &before, const KernelCallBack &after) override { + return SubGraphKernel::Run(before, after); + }; + int PostProcess() override { return mindspore::lite::RET_OK; } +}; + +class CpuFp32SubGraph : public CpuSubGraph { public: explicit CpuFp32SubGraph(const std::vector &inputs, const std::vector &outputs, const std::vector &in_kernels, const std::vector &out_kernels, const std::vector &nodes, const lite::InnerContext *ctx) - : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { + : CpuSubGraph(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { subgraph_type_ = kCpuFP32SubGraph; this->name_ = "CpuFP32SubGraph"; this->executor_ = new mindspore::lite::Executor; @@ -82,12 +104,12 @@ class CpuFp32SubGraph : public SubGraphKernel { int PostProcess() override { return mindspore::lite::RET_OK; } }; -class CpuFp16SubGraph : public SubGraphKernel { +class CpuFp16SubGraph : public CpuSubGraph { public: explicit CpuFp16SubGraph(const std::vector &inputs, const std::vector &outputs, const std::vector &in_kernels, const std::vector &out_kernels, const std::vector &nodes, const lite::InnerContext *ctx) - : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { + : CpuSubGraph(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { subgraph_type_ = kCpuFP16SubGraph; this->name_ = "CpuFP16SubGraph"; this->executor_ = new mindspore::lite::Executor; diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index 892289f10f..65ed534007 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -285,6 +285,51 @@ std::string Tensor::ToString() const { return oss.str(); } +int Tensor::MallocData(mindspore::lite::Allocator *allocator) { + if (nullptr != this->data_) { + return 0; + } + if (allocator != nullptr) { + allocator_ = allocator; + } + if (allocator_ == nullptr) { + this->data_ = malloc(this->Size()); + } else { + this->data_ = allocator_->Malloc(this->Size()); + } + if (nullptr == this->data_) { + MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size(); + return -1; + } + + return 0; +} + +int Tensor::FreeData() { + if (nullptr == this->data_) { + return 0; + } + if (nullptr == allocator_) { + free(this->data_); + this->data_ = nullptr; + } else { + allocator_->Free(this->data_); + this->data_ = nullptr; + } + return 0; +} + +void *Tensor::MutableData() { + if (this->data_ == nullptr) { + auto ret = this->MallocData(); + if (ret != 0) { + MS_LOG(WARNING) << "Malloc data failed"; + } + } + Prepare(); + return this->data_; +} + void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push_back(quant_arg); } std::vector Tensor::GetQuantParams() const { return this->quant_params_; } diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 97b2af97b2..7f219de192 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -147,50 +147,11 @@ class Tensor : public mindspore::tensor::MSTensor { void set_allocator(mindspore::lite::Allocator *allocator) { allocator_ = allocator; } - int MallocData(mindspore::lite::Allocator *allocator = nullptr) { - if (nullptr != this->data_) { - return 0; - } - if (allocator != nullptr) { - allocator_ = allocator; - } - if (allocator_ == nullptr) { - this->data_ = malloc(this->Size()); - } else { - this->data_ = allocator_->Malloc(this->Size()); - } - if (nullptr == this->data_) { - MS_LOG(ERROR) << "Malloc tensor data failed, size=" << this->Size(); - return -1; - } + int MallocData(mindspore::lite::Allocator *allocator = nullptr); - return 0; - } + int FreeData(); - int FreeData() { - if (nullptr == this->data_) { - return 0; - } - if (nullptr == allocator_) { - free(this->data_); - this->data_ = nullptr; - } else { - allocator_->Free(this->data_); - this->data_ = nullptr; - } - return 0; - } - - void *MutableData() override { - if (this->data_ == nullptr) { - auto ret = this->MallocData(); - if (ret != 0) { - MS_LOG(WARNING) << "Malloc data failed"; - } - } - Prepare(); - return this->data_; - } + void *MutableData() override; void *data_c() const { return data_; }