diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index f6157e19fd..19737f823f 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -258,7 +258,8 @@ union PrimitiveType { SmoothL1LossGrad, SigmoidCrossEntropyWithLogits, SigmoidCrossEntropyWithLogitsGrad, - Reciprocal + Reciprocal, + Merge, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 5b4a714c5e..9dea38d0af 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1222,4 +1222,7 @@ table SigmoidCrossEntropyWithLogitsGrad { } table Reciprocal { -} \ No newline at end of file +} + +table Merge { +} diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 0a4e4217c5..0ea9b0d442 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "mindspore/lite/src/executor.h" -#include "nnacl/pack.h" +#include "src/executor.h" +#include #include "include/errorcode.h" namespace mindspore::lite { @@ -26,7 +26,7 @@ int Executor::CheckInputs(const std::vector &in_tensors) { return RET_ERROR; } if (inTensor->data_c() == nullptr) { - MS_LOG(ERROR) << "Graph input tensor data is nullptr"; + MS_LOG(ERROR) << "Graph input tensor data is nullptr " << in_tensors; return RET_ERROR; } auto shape = inTensor->shape(); @@ -49,7 +49,52 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ MS_LOG(ERROR) << "CheckInputs failed"; return ret; } - kernel::LiteKernelUtil::InitTensorRefCount(kernels); + std::queue kernel_queue; + for (auto kernel : kernels) { + if (kernel->IsReady()) { + kernel_queue.push(kernel); + } + } + while (!kernel_queue.empty()) { + auto cur_kernel = kernel_queue.front(); + kernel_queue.pop(); + MS_ASSERT(nullptr != cur_kernel); + ret = cur_kernel->PreProcess(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "PreProcess kernel failed, name: " << cur_kernel->name(); + return ret; + } + ret = cur_kernel->Run(before, after); + if (RET_OK != ret) { + MS_LOG(ERROR) << "run kernel failed, name: " << cur_kernel->name(); + return ret; + } + ret = cur_kernel->PostProcess(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "PostProcess kernel failed, name: " << cur_kernel->name(); + return ret; + } + for (auto &out_kernel : cur_kernel->out_kernels()) { + if (out_kernel->IsReady()) { + kernel_queue.push(out_kernel); + } + } + } + return RET_OK; +} + +int CpuExecutor::Run(std::vector &in_tensors, std::vector &out_tensors, + std::vector &kernels, Allocator *allocator, const KernelCallBack &before, + const KernelCallBack &after) { + MS_ASSERT(nullptr != allocator); + // not check input for merge. too hard + if (kernels.front()->Type() != schema::PrimitiveType_Merge) { + auto ret = this->CheckInputs(in_tensors); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CheckInputs failed"; + return ret; + } + } #ifdef SUPPORT_TRAIN for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them out_tensor->set_ref_count(out_tensor->ref_count() + 1); @@ -57,7 +102,7 @@ int Executor::Run(std::vector &in_tensors, std::vector &out_ #endif for (auto *kernel : kernels) { MS_ASSERT(nullptr != kernel); - ret = kernel->PreProcess(); + auto ret = kernel->PreProcess(); if (RET_OK != ret) { MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name(); return ret; diff --git a/mindspore/lite/src/executor.h b/mindspore/lite/src/executor.h index 220361c9b8..0fc59bea93 100644 --- a/mindspore/lite/src/executor.h +++ b/mindspore/lite/src/executor.h @@ -37,5 +37,16 @@ class Executor { protected: static int CheckInputs(const std::vector &in_tensors); }; + +class CpuExecutor : public Executor { + public: + CpuExecutor() = default; + virtual ~CpuExecutor() = default; + + int Run(std::vector &in_tensors, std::vector &out_tensors, + std::vector &kernels, Allocator *allocator = nullptr, + const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; +}; + } // namespace mindspore::lite #endif diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index 4f931d8fdc..184ed79628 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -62,7 +62,7 @@ InnerContext::~InnerContext() { } } -int InnerContext::IsValid() { +int InnerContext::IsValid() const { if (this->device_list_.empty()) { MS_LOG(ERROR) << "Device list is empty."; return RET_NOT_SUPPORT; @@ -86,33 +86,33 @@ int InnerContext::IsValid() { return RET_OK; } -bool InnerContext::IsCpuFloat16Enabled() { +bool InnerContext::IsCpuFloat16Enabled() const { if (!IsCpuEnabled()) { return false; } return GetCpuInfo().enable_float16_; } -bool InnerContext::IsGpuFloat16Enabled() { +bool InnerContext::IsGpuFloat16Enabled() const { if (!IsGpuEnabled()) { return false; } return GetGpuInfo().enable_float16_; } -bool InnerContext::IsCpuEnabled() { +bool InnerContext::IsCpuEnabled() const { return this->device_list_.end() != std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); } -bool InnerContext::IsGpuEnabled() { +bool InnerContext::IsGpuEnabled() const { return this->device_list_.end() != std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); } -bool InnerContext::IsNpuEnabled() { +bool InnerContext::IsNpuEnabled() const { #ifdef SUPPORT_NPU return this->device_list_.end() != std::find_if(this->device_list_.begin(), this->device_list_.end(), @@ -123,7 +123,7 @@ bool InnerContext::IsNpuEnabled() { #endif } -CpuDeviceInfo InnerContext::GetCpuInfo() { +CpuDeviceInfo InnerContext::GetCpuInfo() const { auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_CPU; }); if (iter == this->device_list_.end()) { @@ -133,7 +133,7 @@ CpuDeviceInfo InnerContext::GetCpuInfo() { } } -GpuDeviceInfo InnerContext::GetGpuInfo() { +GpuDeviceInfo InnerContext::GetGpuInfo() const { auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { return device.device_type_ == DT_GPU; }); if (iter == this->device_list_.end()) { diff --git a/mindspore/lite/src/inner_context.h b/mindspore/lite/src/inner_context.h index 115943b62c..7a19826d3c 100644 --- a/mindspore/lite/src/inner_context.h +++ b/mindspore/lite/src/inner_context.h @@ -33,23 +33,23 @@ struct InnerContext : public Context { int Init(); - bool IsCpuFloat16Enabled(); + bool IsCpuFloat16Enabled() const; - bool IsGpuFloat16Enabled(); + bool IsGpuFloat16Enabled() const; - bool IsCpuEnabled(); + bool IsCpuEnabled() const; - bool IsGpuEnabled(); + bool IsGpuEnabled() const; - bool IsNpuEnabled(); + bool IsNpuEnabled() const; - CpuDeviceInfo GetCpuInfo(); + CpuDeviceInfo GetCpuInfo() const; - GpuDeviceInfo GetGpuInfo(); + GpuDeviceInfo GetGpuInfo() const; NpuDeviceInfo GetNpuInfo() const; - int IsValid(); + int IsValid() const; virtual ~InnerContext(); }; diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 9ffbcab6cc..f553f71a39 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -41,9 +41,21 @@ void LiteKernel::FreeWorkspace() { workspace_ = nullptr; } -void LiteKernel::InitOutTensorRefCount() { +bool LiteKernel::IsReady() { + return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { + return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; + }); +} + +void LiteKernel::InitOutTensorInitRefCount() { for (auto *tensor : this->out_tensors_) { - tensor->set_ref_count(this->out_kernels_.size()); + int init_ref_count = 0; + for (auto *post_kernel : this->out_kernels_) { + init_ref_count += + std::count_if(post_kernel->in_tensors_.begin(), post_kernel->in_tensors_.end(), + [&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; }); + } + tensor->set_init_ref_count(init_ref_count); } } @@ -61,15 +73,20 @@ int LiteKernel::DecOutTensorRefCount() { return 0; } -int LiteKernel::FreeWorkTensor() const { - for (auto input_kernel : this->in_kernels()) { - MS_ASSERT(input_kernel != nullptr); - if (input_kernel->is_model_output()) { +int LiteKernel::FreeInWorkTensor() const { + for (auto &in_tensor : this->in_tensors_) { + MS_ASSERT(in_tensor != nullptr); + if (in_tensor->IsConst()) { continue; } - auto ret = input_kernel->DecOutTensorRefCount(); - if (0 != ret) { - MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; + MS_ASSERT(in_tensor->ref_count() > 0); + in_tensor->set_ref_count(in_tensor->ref_count() - 1); + if (in_tensor->ref_count() <= 0) { + auto ret = in_tensor->FreeData(); + if (0 != ret) { + MS_LOG(ERROR) << "Free tensor data failed"; + return ret; + } } } return RET_OK; @@ -91,15 +108,12 @@ int LiteKernel::PreProcess() { } } - auto outputs = this->out_tensors(); - for (auto *output : outputs) { + for (auto *output : this->out_tensors()) { MS_ASSERT(output != nullptr); - if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast(sizeof(int64_t))) { MS_LOG(ERROR) << "The size of output tensor is too big"; return RET_ERROR; } - auto ret = output->MallocData(); if (ret != RET_OK) { MS_LOG(ERROR) << "MallocData failed"; @@ -109,6 +123,28 @@ int LiteKernel::PreProcess() { return RET_OK; } +int LiteKernel::PostProcess() { +#ifdef SUPPORT_TRAIN + for (auto input_kernel : this->in_kernels()) { + MS_ASSERT(input_kernel != nullptr); + if (input_kernel->is_model_output()) { + continue; + } + auto ret = input_kernel->DecOutTensorRefCount(); + if (0 != ret) { + MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; + } + } + return RET_OK; +#else + for (auto *output : this->out_tensors()) { + MS_ASSERT(output != nullptr); + output->ResetRefCount(); + } + return FreeInWorkTensor(); +#endif +} + int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { if (before != nullptr) { if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_), @@ -153,6 +189,28 @@ std::string LiteKernel::ToString() const { return oss.str(); } +void LiteKernel::FindInoutKernels(const std::vector &scope_kernels) { + // clean io kernels + this->in_kernels_.clear(); + this->out_kernels_.clear(); + // find io kernels + for (auto *scope_kernel : scope_kernels) { + if (scope_kernel == this) { + continue; + } + for (auto *tensor : this->in_tensors_) { + if (lite::IsContain(scope_kernel->out_tensors(), tensor)) { + this->AddInKernel(scope_kernel); + } + } + for (auto *tensor : this->out_tensors_) { + if (lite::IsContain(scope_kernel->in_tensors(), tensor)) { + this->AddOutKernel(scope_kernel); + } + } + } +} + std::vector LiteKernelUtil::SubgraphInputKernels( const std::vector &kernels) { std::vector input_kernels; @@ -202,7 +260,7 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect if (outer_in_kernels.empty()) { for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { if (!in_kernel_in_tensor->IsConst()) { - if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { + if (!IsContain(input_tensors, in_kernel_in_tensor)) { input_tensors.push_back(in_kernel_in_tensor); } } @@ -219,7 +277,7 @@ std::vector LiteKernelUtil::SubgraphInputTensors(const std::vect auto outer_in_kernel_out_tensors_iter = std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { - if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { + if (!IsContain(input_tensors, in_kernel_in_tensor)) { input_tensors.emplace_back(in_kernel_in_tensor); } } @@ -237,7 +295,7 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec auto &out_kernel_out_tensors = output_kernel->out_tensors(); if (outer_out_kernels.empty()) { for (auto out_kernel_out_tensor : out_kernel_out_tensors) { - if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + if (!IsContain(output_tensors, out_kernel_out_tensor)) { output_tensors.push_back(out_kernel_out_tensor); } } @@ -253,7 +311,7 @@ std::vector LiteKernelUtil::SubgraphOutputTensors(const std::vec auto outer_out_kernel_in_tensors_iter = std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { - if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { + if (!IsContain(output_tensors, out_kernel_out_tensor)) { output_tensors.emplace_back(out_kernel_out_tensor); } } @@ -299,33 +357,9 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector *ke return RET_OK; } -void LiteKernelUtil::InitIOKernels(std::vector &kernels) { - for (auto *kernel : kernels) { - // clean io kernels - kernel->set_in_kernels({}); - kernel->set_out_kernels({}); - // find io kernels - for (auto *search_kernel : kernels) { - if (search_kernel == kernel) { - continue; - } - for (auto *tensor : kernel->in_tensors()) { - if (lite::IsContain(search_kernel->out_tensors(), tensor)) { - kernel->AddInKernel(search_kernel); - } - } - for (auto *tensor : kernel->out_tensors()) { - if (lite::IsContain(search_kernel->in_tensors(), tensor)) { - kernel->AddOutKernel(search_kernel); - } - } - } - } -} - -void LiteKernelUtil::InitTensorRefCount(std::vector &kernels) { +void LiteKernelUtil::InitTensorInitRefCount(std::vector &kernels) { for (auto *kernel : kernels) { - kernel->InitOutTensorRefCount(); + kernel->InitOutTensorInitRefCount(); } } diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index ffb659fa15..46a1ce791c 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -87,10 +87,12 @@ class LiteKernel { virtual int Run(const KernelCallBack &before, const KernelCallBack &after); // called after Run - virtual int PostProcess() { return FreeWorkTensor(); } + virtual int PostProcess(); virtual int ReSize() { return mindspore::lite::RET_ERROR; } + virtual void FindInoutKernels(const std::vector &scope_kernels); + virtual int Init() { return mindspore::lite::RET_ERROR; } std::string name() const { return this->name_; } @@ -154,11 +156,13 @@ class LiteKernel { const std::vector &out_kernels() const { return this->out_kernels_; } - void InitOutTensorRefCount(); + virtual bool IsReady(); + + virtual void InitOutTensorInitRefCount(); int DecOutTensorRefCount(); - int FreeWorkTensor() const; + virtual int FreeInWorkTensor() const; KernelKey desc() const { return desc_; } @@ -203,8 +207,6 @@ typedef LiteKernel *(*KernelCreator)(const std::vector &inputs, class LiteKernelUtil { public: - static void InitIOKernels(std::vector &kernels); - static std::vector SubgraphInputKernels(const std::vector &kernels); static std::vector SubgraphOutputKernels(const std::vector &kernels); @@ -215,7 +217,7 @@ class LiteKernelUtil { static int TopologicalSortKernels(std::vector *kernels); - static void InitTensorRefCount(std::vector &kernels); + static void InitTensorInitRefCount(std::vector &kernels); static int SetInput(LiteKernel &kernelMod, const std::vector &inputs); }; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 80913a8431..2166cc80b0 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -295,6 +295,21 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) { } } +void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) { + MS_ASSERT(model != nullptr); + auto graph_out_size = model->sub_graphs_.front()->output_indices_.size(); + for (size_t i = 0; i < graph_out_size; ++i) { + size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i]; + MS_ASSERT(graph_out_index < this->tensors_.size()); + auto *out_tensor = this->tensors_.at(graph_out_index); + if (out_tensor == nullptr) { + MS_LOG(ERROR) << "out_tensor is null!"; + return; + } + out_tensor->set_init_ref_count(out_tensor->init_ref_count() + 1); + } +} + void LiteSession::InitGraphInOutTensors(const lite::Model *model) { InitGraphInputTensors(model); InitGraphInputMSTensors(); @@ -303,6 +318,7 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) { InitGraphOutputNodeMap(model); InitGraphOutputTensorNames(model); InitGraphOutputTensorMap(model); + AdjustModelOutputTensorInitRefCount(model); } int LiteSession::CompileGraph(Model *model) { @@ -334,12 +350,9 @@ int LiteSession::CompileGraph(Model *model) { is_running_.store(false); return ret; } - - InitGraphInOutTensors(model); - // scheduler kernels - Scheduler scheduler(context_); - ret = scheduler.Schedule(model, &tensors_, &kernels_); + Scheduler scheduler(context_, model, tensors_); + ret = scheduler.Schedule(&kernels_); if (ret != RET_OK) { MS_LOG(ERROR) << "Schedule kernels failed: " << ret; is_running_.store(false); @@ -353,6 +366,7 @@ int LiteSession::CompileGraph(Model *model) { } } #endif + InitGraphInOutTensors(model); ret = executor_->Prepare(this->kernels_); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare executor failed: " << ret; @@ -563,6 +577,32 @@ void LiteSession::ResetInputsShape(const std::vector> &dims) { } } +int LiteSession::ReSizeKernels(const std::vector &kernels) { + bool infer_shape_interrupt = false; + for (auto kernel : kernels) { + if (kernel == nullptr) { + MS_LOG(ERROR) << "input kernel is nullptr!"; + return RET_ERROR; + } + if (kernel->subgraph_type() == kernel::kNotSubGraph) { + MS_LOG(ERROR) << "All node in graph should be sub_graph"; + return RET_ERROR; + } + auto sub_graph = reinterpret_cast(kernel); + auto ret = sub_graph->ReSize(infer_shape_interrupt); + if (ret == RET_INFER_INVALID) { + MS_LOG(INFO) << "InferShape is interrupted"; + infer_shape_interrupt = true; + continue; + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; + return RET_ERROR; + } + } + return RET_OK; +} + int LiteSession::Resize(const std::vector &inputs, const std::vector> &dims) { bool expected = false; @@ -581,11 +621,10 @@ int LiteSession::Resize(const std::vector &inputs return ret; } - Scheduler scheduler(context_); - ret = scheduler.ReSizeKernels(kernels_); + ret = ReSizeKernels(kernels_); if (ret != RET_OK) { ResetInputsShape(old_dims); - auto resize_ret = scheduler.ReSizeKernels(kernels_); + auto resize_ret = ReSizeKernels(kernels_); if (resize_ret != RET_OK) { MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret; } diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 55f953d124..d0d3ec1bfd 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -92,10 +92,14 @@ class LiteSession : public session::LiteSession { void InitGraphOutputTensorMap(const lite::Model *model); + void AdjustModelOutputTensorInitRefCount(const lite::Model *model); + int ResizeInputs(const std::vector &inputs, const std::vector> &dims); int PrepareKernels(); + static int ReSizeKernels(const std::vector &kernels); + private: void ResetInputsShape(const std::vector> &dims); diff --git a/mindspore/lite/src/ops/merge.cc b/mindspore/lite/src/ops/merge.cc new file mode 100644 index 0000000000..e9895164cb --- /dev/null +++ b/mindspore/lite/src/ops/merge.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/merge.h" +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int Merge::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Merge; + } + if (this->primitive_->value.type != schema::PrimitiveType_Merge) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + this->primitive_->value.value = new (std::nothrow) schema::MergeT(); + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + } + PopulaterQuantParam(prim, inputs); + return RET_OK; +} + +#else +int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Merge(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Merge return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateMerge(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator); +#endif + +int Merge::InferShape(std::vector inputs_, std::vector outputs_) { + MS_ASSERT(outputs_.size() == 1); + MS_ASSERT(inputs_.size() == 2); + outputs_[0]->set_data_type(inputs_[0]->data_type()); + + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/merge.h b/mindspore/lite/src/ops/merge.h new file mode 100644 index 0000000000..446fc76e09 --- /dev/null +++ b/mindspore/lite/src/ops/merge.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { + +class Merge : public PrimitiveC { + public: + Merge() = default; + ~Merge() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Merge, PrimitiveC); + explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_ diff --git a/mindspore/lite/src/ops/partial.cc b/mindspore/lite/src/ops/partial.cc new file mode 100644 index 0000000000..deb4d80b20 --- /dev/null +++ b/mindspore/lite/src/ops/partial.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/partial.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE + +int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; } + +int Partial::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Partial; + } + if (this->primitive_->value.type != schema::PrimitiveType_Partial) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::PartialT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} + +#else + +int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Partial(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Partial return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); } + +PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator); + +#endif + +int Partial::InferShape(std::vector inputs_, std::vector outputs_) { return RET_OK; } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/partial.h b/mindspore/lite/src/ops/partial.h new file mode 100644 index 0000000000..6ef3e70255 --- /dev/null +++ b/mindspore/lite/src/ops/partial.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ + +#include +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Partial : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Partial, PrimitiveC); + Partial() = default; + explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + +#else + Partial() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; + int GetSubGraphIndex() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_ diff --git a/mindspore/lite/src/ops/populate/merge_populate.cc b/mindspore/lite/src/ops/populate/merge_populate.cc new file mode 100644 index 0000000000..ec23291934 --- /dev/null +++ b/mindspore/lite/src/ops/populate/merge_populate.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { + +OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *merge_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + if (merge_parameter == nullptr) { + MS_LOG(ERROR) << "malloc Merge parameter failed."; + return nullptr; + } + memset(merge_parameter, 0, sizeof(OpParameter)); + merge_parameter->type_ = primitive->Type(); + return reinterpret_cast(merge_parameter); +} +Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/partial_populate.cc b/mindspore/lite/src/ops/populate/partial_populate.cc new file mode 100644 index 0000000000..300f5e2827 --- /dev/null +++ b/mindspore/lite/src/ops/populate/partial_populate.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/partial.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { +typedef struct PartialParameter { + OpParameter op_parameter_; + int sub_graph_index_; +} PartialParameter; + +OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) { + PartialParameter *partial_parameter = reinterpret_cast(malloc(sizeof(PartialParameter))); + if (partial_parameter == nullptr) { + MS_LOG(ERROR) << "malloc partial parameter failed."; + return nullptr; + } + memset(partial_parameter, 0, sizeof(PartialParameter)); + partial_parameter->op_parameter_.type_ = primitive->Type(); + + auto param = reinterpret_cast(const_cast(primitive)); + partial_parameter->sub_graph_index_ = param->GetSubGraphIndex(); + + return reinterpret_cast(partial_parameter); +} +Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/switch_populate.cc b/mindspore/lite/src/ops/populate/switch_populate.cc new file mode 100644 index 0000000000..c895b9ae6c --- /dev/null +++ b/mindspore/lite/src/ops/populate/switch_populate.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/switch.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { +OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *switch_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + if (switch_parameter == nullptr) { + MS_LOG(ERROR) << "malloc SwitchParameter failed."; + return nullptr; + } + memset(switch_parameter, 0, sizeof(OpParameter)); + switch_parameter->type_ = primitive->Type(); + + return reinterpret_cast(switch_parameter); +} +Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 206ef8b1bf..c838262204 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -155,6 +155,9 @@ #include "src/ops/tensorlistsetitem.h" #include "src/ops/tensorlistreserve.h" #include "src/ops/tensorliststack.h" +#include "src/ops/merge.h" +#include "src/ops/switch.h" +#include "src/ops/partial.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -925,7 +928,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) TensorListReserve(primitive); case schema::PrimitiveType_TensorListStack: return new (std::nothrow) TensorListStack(primitive); - + case schema::PrimitiveType_Switch: + return new (std::nothrow) Switch(primitive); + case schema::PrimitiveType_Merge: + return new (std::nothrow) Merge(primitive); + case schema::PrimitiveType_Partial: + return new (std::nothrow) Partial(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/src/ops/switch.cc b/mindspore/lite/src/ops/switch.cc new file mode 100644 index 0000000000..70277c2d27 --- /dev/null +++ b/mindspore/lite/src/ops/switch.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/switch.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Switch::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Switch; + } + if (this->primitive_->value.type != schema::PrimitiveType_Switch) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SwitchT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Switch(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Switch return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSwitch(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator); +#endif + +int Switch::InferShape(std::vector inputs_, std::vector outputs_) { return RET_OK; } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/switch.h b/mindspore/lite/src/ops/switch.h new file mode 100644 index 0000000000..c52d43c7d3 --- /dev/null +++ b/mindspore/lite/src/ops/switch.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ +#define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ + +#include +#include +#include +#include + +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Switch : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Switch, PrimitiveC); + Switch() = default; + explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + +#else + Switch() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/merge.cc b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc new file mode 100644 index 0000000000..153428b941 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/merge.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/merge.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Merge; + +namespace mindspore::kernel { +// if one of input of merge is const-tensor, merge is always ready, this will cause error. +bool MergeCPUKernel::IsReady() { + MS_ASSERT(in_tensors().size() == 2); + return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) { + return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1; + }); +} + +int MergeCPUKernel::Init() { return RET_OK; } + +int MergeCPUKernel::ReSize() { return RET_ERROR; } + +int MergeCPUKernel::Run() { + MS_ASSERT(in_tensors_.size() == 2); + MS_ASSERT(out_tensors_.size() == 1); + auto out_data = out_tensors_.front()->data_c(); + MS_ASSERT(out_data != nullptr); + for (size_t i = 0; i < in_tensors().size(); i++) { + if (in_tensors()[i]->data_c() != nullptr) { + auto in_data = in_tensors_[i]->data_c(); + MS_ASSERT(in_data != nullptr); + memcpy(out_data, in_data, in_tensors_[i]->Size()); + } + } + return RET_OK; +} + +kernel::LiteKernel *CpuMergeKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is nullptr"; + return nullptr; + } + if (desc.type != PrimitiveType_Merge) { + MS_LOG(ERROR) << "type in desc is not Merge"; + free(parameter); + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "ctx is nullptr"; + free(parameter); + return nullptr; + } + + auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + free(parameter); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/merge.h b/mindspore/lite/src/runtime/kernel/arm/base/merge.h new file mode 100644 index 0000000000..be81359768 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/merge.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +typedef struct MergeParameter { + OpParameter op_parameter_; +} MergeParameter; + +class MergeCPUKernel : public LiteKernel { + public: + MergeCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + merge_param_ = reinterpret_cast(op_parameter_); + } + ~MergeCPUKernel() override {} + bool IsReady() override; + int Init() override; + int ReSize() override; + int Run() override; + + private: + MergeParameter *merge_param_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.cc b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc new file mode 100644 index 0000000000..c56583a172 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/base/switch.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Switch; + +namespace mindspore::kernel { +int SwitchCPUKernel::PostProcess() { + auto bool_tensor = in_tensors_.front(); + MS_ASSERT(bool_tensor != nullptr); + MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); + MS_ASSERT(bool_tensor->shape().size() == 1); + MS_ASSERT(bool_tensor->shape().front() == 1); + auto *active = static_cast(bool_tensor->data_c()); + if (active == nullptr) { + MS_LOG(ERROR) << "data of bool tensor is nullptr"; + return lite::RET_NULL_PTR; + } + size_t in_index = 1; + size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); + while (in_index < in_tensors_.size()) { + in_index++; + auto out_tensor = out_tensors_.at(out_index++); + out_tensor->ResetRefCount(); + } + return FreeInWorkTensor(); +} + +int SwitchCPUKernel::Init() { return RET_OK; } + +int SwitchCPUKernel::ReSize() { return RET_ERROR; } + +// inputs: bool*1 data*n +// output: true-data*n, false-data*n +int SwitchCPUKernel::Run() { + MS_ASSERT(in_tensors_.size() >= 2); + MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size()); + auto bool_tensor = in_tensors_.front(); + MS_ASSERT(bool_tensor != nullptr); + MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); + MS_ASSERT(bool_tensor->shape().size() == 1); + MS_ASSERT(bool_tensor->shape().front() == 1); + auto active = static_cast(bool_tensor->data_c()); + if (active == nullptr) { + MS_LOG(ERROR) << "data of bool tensor is nullptr"; + return lite::RET_NULL_PTR; + } + size_t in_index = 1; + size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2); + while (in_index < in_tensors_.size()) { + auto in_tensor = in_tensors_.at(in_index++); + auto out_tensor = out_tensors_.at(out_index++); + MS_ASSERT(in_tensor != nullptr); + MS_ASSERT(out_tensor != nullptr); + auto input = reinterpret_cast(in_tensor->data_c()); + auto output = reinterpret_cast(out_tensor->data_c()); + MS_ASSERT(in_tensor->Size() == out_tensor->Size()); + if (input == nullptr || output == nullptr) { + MS_LOG(ERROR) << "input tensor or output tensor have not been malloced"; + return lite::RET_NULL_PTR; + } + memcpy(output, input, in_tensor->Size()); + } + return RET_OK; +} + +kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "parameter is nullptr"; + return nullptr; + } + if (desc.type != PrimitiveType_Switch) { + MS_LOG(ERROR) << "type in desc is not Switch"; + free(parameter); + return nullptr; + } + if (ctx == nullptr) { + MS_LOG(ERROR) << "ctx is nullptr"; + free(parameter); + return nullptr; + } + auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_; + free(parameter); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/switch.h b/mindspore/lite/src/runtime/kernel/arm/base/switch.h new file mode 100644 index 0000000000..7e7530a088 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/base/switch.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +typedef struct SwitchParameter { + OpParameter op_parameter_; +} SwitchParameter; + +class SwitchCPUKernel : public LiteKernel { + public: + SwitchCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + switch_param_ = reinterpret_cast(op_parameter_); + } + ~SwitchCPUKernel() override = default; + int PostProcess() override; + int Init() override; + int ReSize() override; + int Run() override; + + private: + SwitchParameter *switch_param_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index 84c93fe2bc..d8346112ea 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -71,6 +71,14 @@ int OpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { return RET_OK; } +int OpenCLKernel::PostProcess() { + for (auto *output : this->out_tensors()) { + MS_ASSERT(output != nullptr); + output->ResetRefCount(); + } + return FreeInWorkTensor(); +} + std::vector OpenCLKernel::GenerateTuningParam() { size_t ndim = global_size_.size(); std::vector tuning_params = {}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 5f536a1f2f..ee9d927676 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -164,6 +164,7 @@ class OpenCLKernel : public LiteKernel { int Prepare() override { return RET_OK; } int PreProcess() override { return RET_ERROR; } + int PostProcess() override; int ReSize() override { return RET_ERROR; } int Run() override { return RET_ERROR; } diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc index aca0c32e3b..941a25623e 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -36,7 +36,6 @@ int OpenCLExecutor::RunOrTune(std::vector &inputs, std::vectorSetProfiling(true); } - kernel::LiteKernelUtil::InitTensorRefCount(kernels); for (auto *kernel : kernels) { MS_ASSERT(kernel); CallBackParam callbackParam; @@ -82,6 +81,11 @@ int OpenCLExecutor::RunOrTune(std::vector &inputs, std::vectorname(); return ret; } + ret = kernel->PostProcess(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "PostProcess kernel failed, name: " << kernel->name(); + return ret; + } if (profiling_tmp) { MS_LOG(INFO) << "OpenCl kernel " << kernel->name() << "(" << kernel->type_str() << ") execute time is: " << op_kernel->GetProfilingTimeMs() << "ms"; @@ -92,13 +96,6 @@ int OpenCLExecutor::RunOrTune(std::vector &inputs, std::vectorname(); } } - for (auto input_kernel : kernel->in_kernels()) { - MS_ASSERT(input_kernel); - ret = input_kernel->DecOutTensorRefCount(); - if (ret != RET_OK) { - MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << kernel->name() << " failed"; - } - } } opencl_runtime_ins->SetProfiling(profiling_tmp); return ret; diff --git a/mindspore/lite/src/runtime/parallel_executor.cc b/mindspore/lite/src/runtime/parallel_executor.cc index 7f1c80aa9a..056005f48d 100644 --- a/mindspore/lite/src/runtime/parallel_executor.cc +++ b/mindspore/lite/src/runtime/parallel_executor.cc @@ -40,9 +40,9 @@ static int RunKernel(void *data, int index) { return 0; } - ret = kernel->FreeWorkTensor(); + ret = kernel->FreeInWorkTensor(); if (RET_OK != ret) { - MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << kernel->name(); + MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << kernel->name(); return ret; } return 0; @@ -62,7 +62,7 @@ int ParallelExecutor::Run(std::vector &in_tensors, std::vectorin_kernels().empty()) { @@ -96,9 +96,9 @@ int ParallelExecutor::Run(std::vector &in_tensors, std::vectorFreeWorkTensor(); + auto ret = completed->FreeInWorkTensor(); if (RET_OK != ret) { - MS_LOG(ERROR) << "FreeWorkTensor failed, name: " << completed->name(); + MS_LOG(ERROR) << "FreeInWorkTensor failed, name: " << completed->name(); return ret; } } diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 9e17dcb206..9ec55d58a5 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -19,6 +19,7 @@ #include #include #include +#include "src/ops/partial.h" #include "include/errorcode.h" #include "src/common/graph_util.h" #include "src/common/utils.h" @@ -36,152 +37,255 @@ namespace mindspore::lite { using kernel::KERNEL_ARCH::kCPU; using kernel::KERNEL_ARCH::kGPU; using kernel::KERNEL_ARCH::kNPU; +constexpr int kMainSubGraphIndex = 0; -int Scheduler::Schedule(const lite::Model *model, std::vector *tensors, - std::vector *kernels) { - int ret = InferShape(model, tensors); +int Scheduler::Schedule(std::vector *dst_kernels) { + if (src_model_ == nullptr) { + MS_LOG(ERROR) << "Input model is nullptr"; + return RET_PARAM_INVALID; + } + if (src_model_->sub_graphs_.empty()) { + MS_LOG(ERROR) << "Model should have a subgraph at least"; + return RET_PARAM_INVALID; + } + + this->graph_output_node_indexes_ = GetGraphOutputNodes(src_model_); + bool infer_shape_interrupt = false; + auto ret = InferSubGraphShape(kMainSubGraphIndex, &infer_shape_interrupt); if (ret != RET_OK) { MS_LOG(ERROR) << "op infer shape failed."; return ret; } - ret = BuildKernels(model, tensors, kernels); + ret = ScheduleSubGraphToKernels(kMainSubGraphIndex, dst_kernels); if (ret != RET_OK) { - MS_LOG(ERROR) << "init op to kernel failed."; + MS_LOG(ERROR) << "Schedule main subgraph to kernels failed."; return ret; } - - kernel::LiteKernelUtil::InitIOKernels(*kernels); - - ret = ConstructSubGraphs(kernels); + FindAllInoutKernels(*dst_kernels); + ret = ConstructSubGraphs(dst_kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "ConstructSubGraphs failed."; return ret; } - - kernel::LiteKernelUtil::InitIOKernels(*kernels); - + FindAllInoutKernels(*dst_kernels); + kernel::LiteKernelUtil::InitTensorInitRefCount(*dst_kernels); MS_LOG(DEBUG) << "schedule kernels success."; return RET_OK; } -int Scheduler::ReSizeKernels(const std::vector &kernels) { - bool infer_shape_interrupt = false; - for (auto kernel : kernels) { - if (kernel == nullptr) { - MS_LOG(ERROR) << "input kernel is nullptr!"; - return RET_ERROR; - } - if (kernel->subgraph_type() == kernel::kNotSubGraph) { - MS_LOG(ERROR) << "All node in graph should be sub_graph"; - return RET_ERROR; - } - auto sub_graph = reinterpret_cast(kernel); - auto ret = sub_graph->ReSize(infer_shape_interrupt); - if (ret == RET_INFER_INVALID) { - MS_LOG(INFO) << "InferShape is interrupted"; - infer_shape_interrupt = true; - continue; - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; - return RET_ERROR; +void Scheduler::FindNodeInoutTensors(const lite::Model::Node &node, std::vector *inputs, + std::vector *outputs) { + MS_ASSERT(inputs != nullptr); + MS_ASSERT(outputs != nullptr); + auto in_size = node.input_indices_.size(); + inputs->reserve(in_size); + for (size_t j = 0; j < in_size; ++j) { + inputs->emplace_back(src_tensors_.at(node.input_indices_[j])); + } + auto out_size = node.output_indices_.size(); + outputs->reserve(out_size); + for (size_t j = 0; j < out_size; ++j) { + outputs->emplace_back(src_tensors_.at(node.output_indices_[j])); + } +} + +int Scheduler::InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { + MS_ASSERT(node != nullptr); + MS_ASSERT(infer_shape_interrupt != nullptr); + auto primitive = node->primitive_; + MS_ASSERT(primitive != nullptr); + if (primitive->Type() == schema::PrimitiveType_Partial) { + return InferPartialShape(node, infer_shape_interrupt); + } + std::vector inputs; + std::vector outputs; + FindNodeInoutTensors(*node, &inputs, &outputs); + bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { + auto shape = tensor->shape(); + return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); + }); + if (!infer_valid) { + *infer_shape_interrupt = true; + } + primitive->set_infer_flag(!(*infer_shape_interrupt)); + auto ret = primitive->InferShape(inputs, outputs); + if (ret == RET_OK) { + for (auto &output : outputs) { + if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast(sizeof(int64_t))) { + MS_LOG(ERROR) << "The size of output tensor is too big"; + return RET_ERROR; + } } } - return RET_OK; + return ret; } -int Scheduler::InferShape(const lite::Model *model, std::vector *tensors) { - MS_ASSERT(model != nullptr); - MS_ASSERT(tensors != nullptr); - bool infer_shape_interrupt = false; - uint32_t kernelCount = model->all_nodes_.size(); - for (uint32_t i = 0; i < kernelCount; ++i) { - auto node = model->all_nodes_[i]; +int Scheduler::InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt) { + MS_ASSERT(src_model_ != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(infer_shape_interrupt != nullptr); + auto primitive = node->primitive_; + MS_ASSERT(primitive != nullptr); + if (primitive->Type() != schema::PrimitiveType_Partial) { + MS_LOG(ERROR) << "Node is not a partial"; + return RET_PARAM_INVALID; + } + auto partial_primitive = reinterpret_cast(node->primitive_); + return InferSubGraphShape(partial_primitive->GetSubGraphIndex(), infer_shape_interrupt); +} + +int Scheduler::InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt) { + MS_ASSERT(infer_shape_interrupt != nullptr); + MS_ASSERT(src_model_ != nullptr); + MS_ASSERT(!src_model_->sub_graphs_.empty()); + MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); + auto subgraph = src_model_->sub_graphs_.at(subgraph_index); + for (auto node_index : subgraph->node_indices_) { + auto node = src_model_->all_nodes_[node_index]; MS_ASSERT(node != nullptr); - std::vector inputs; - std::vector outputs; - auto in_size = node->input_indices_.size(); - inputs.reserve(in_size); - for (size_t j = 0; j < in_size; ++j) { - inputs.emplace_back(tensors->at(node->input_indices_[j])); - } - auto out_size = node->output_indices_.size(); - outputs.reserve(out_size); - for (size_t j = 0; j < out_size; ++j) { - outputs.emplace_back(tensors->at(node->output_indices_[j])); - } auto *primitive = node->primitive_; if (primitive == nullptr) { MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; return RET_ERROR; } - bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { - auto shape = tensor->shape(); - return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); - }); - if (!infer_valid) { - infer_shape_interrupt = true; - } - primitive->set_infer_flag(!infer_shape_interrupt); - auto ret = primitive->InferShape(inputs, outputs); + auto ret = InferNodeShape(node, infer_shape_interrupt); if (ret == RET_INFER_INVALID) { - MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_ + MS_LOG(INFO) << "InferShape interrupted, name: " << node->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(primitive->Type())) - << "flag set to false."; + << ", set infer flag to false."; primitive->set_infer_flag(false); - infer_shape_interrupt = true; + *infer_shape_interrupt = true; } else if (ret != RET_OK) { MS_LOG(ERROR) << "InferShape failed, name: " << node->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(primitive->Type())); return RET_INFER_ERR; - } else { - for (auto &output : outputs) { - if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast(sizeof(int64_t))) { - MS_LOG(ERROR) << "The size of output tensor is too big"; - return RET_ERROR; - } - } } } - return RET_OK; } -int Scheduler::BuildKernels(const lite::Model *model, const std::vector *tensors, - std::vector *kernels) { - MS_ASSERT(model != nullptr); - MS_ASSERT(tensors != nullptr); - uint32_t kernelCount = model->all_nodes_.size(); - auto graph_output_node_indexes = GetGraphOutputNodes(model); - for (uint32_t i = 0; i < kernelCount; ++i) { - auto node = model->all_nodes_[i]; - MS_ASSERT(node != nullptr); - std::vector inputs; - std::vector outputs; - auto in_size = node->input_indices_.size(); - inputs.reserve(in_size); - for (size_t j = 0; j < in_size; ++j) { - inputs.emplace_back(tensors->at(node->input_indices_[j])); +kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in_tensors, + const std::vector &out_tensors, + const mindspore::lite::PrimitiveC *primitive, + const Model::Node *node) { + MS_ASSERT(primitive != nullptr); + TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); + kernel::KernelKey desc{kCPU, data_type, static_cast(primitive->Type())}; +#if SUPPORT_GPU + if (context_->IsGpuEnabled()) { + kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; + auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); + if (kernel != nullptr) { + MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; + return kernel; + } else { + MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " + << node->name_; + } + } +#endif +#if SUPPORT_NPU + if (context_->IsNpuEnabled()) { + kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; + auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); + if (kernel != nullptr) { + MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; + return kernel; + } else { + MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " + << node->name_; } - auto out_size = node->output_indices_.size(); - outputs.reserve(out_size); - for (size_t j = 0; j < out_size; ++j) { - outputs.emplace_back(tensors->at(node->output_indices_[j])); + } +#endif + if (mindspore::lite::IsSupportFloat16() && + ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { + kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; + auto *kernel = + KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); + if (kernel != nullptr) { + MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " + << node->name_; + return kernel; } + } + if (data_type == kNumberTypeFloat16) { + MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; + desc.data_type = kNumberTypeFloat32; + } + auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); + if (kernel != nullptr) { + return kernel; + } + return nullptr; +} + +kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *src_node) { + MS_ASSERT(src_model_ != nullptr); + MS_ASSERT(src_node != nullptr); + auto *primitive = src_node->primitive_; + MS_ASSERT(primitive != nullptr); + if (primitive->Type() != schema::PrimitiveType_Partial) { + return nullptr; + } + auto partial_primitive = reinterpret_cast(primitive); + auto sub_graph_index = partial_primitive->GetSubGraphIndex(); + std::vector sub_kernels; + auto ret = ScheduleSubGraphToKernels(sub_graph_index, &sub_kernels); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Schedule partial failed, name: " << src_node->name_; + return nullptr; + } + auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(sub_kernels.front()); + // for kernel::LiteKernelUtil::SubgraphInputTensors in CreateSubGraphKernel + FindAllInoutKernels(sub_kernels); + auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); + subgraph->set_name("subgraph_" + src_node->name_); + return subgraph; +} + +kernel::LiteKernel *Scheduler::ScheduleNodeToKernel(const lite::Model::Node *src_node) { + auto *primitive = src_node->primitive_; + MS_ASSERT(primitive != nullptr); + std::vector inputs; + std::vector outputs; + FindNodeInoutTensors(*src_node, &inputs, &outputs); + auto *kernel = this->FindBackendKernel(inputs, outputs, primitive, src_node); + if (kernel == nullptr) { + MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << src_node->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(primitive->Type())); + return nullptr; + } + SetKernelTensorDataType(kernel); + kernel->set_name(src_node->name_); + return kernel; +} + +int Scheduler::ScheduleSubGraphToKernels(size_t subgraph_index, std::vector *dst_kernels) { + MS_ASSERT(src_model_ != nullptr); + MS_ASSERT(!src_model_->sub_graphs_.empty()); + MS_ASSERT(src_model_->sub_graphs_.size() > subgraph_index); + MS_ASSERT(dst_kernels != nullptr); + MS_ASSERT(dst_kernels->empty()); + auto subgraph = src_model_->sub_graphs_.at(subgraph_index); + for (auto node_index : subgraph->node_indices_) { + auto node = src_model_->all_nodes_[node_index]; + MS_ASSERT(node != nullptr); auto *primitive = node->primitive_; MS_ASSERT(primitive != nullptr); - auto *kernel = this->ScheduleNode(inputs, outputs, primitive, node); + kernel::LiteKernel *kernel = nullptr; + if (primitive->Type() == schema::PrimitiveType_Partial) { // sub_graph + kernel = SchedulePartialToKernel(node); + } else { // kernel + kernel = ScheduleNodeToKernel(node); + } if (kernel == nullptr) { - MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << node->name_ << ", type: " + MS_LOG(ERROR) << "FindBackendKernel return nullptr, name: " << node->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(primitive->Type())); return RET_ERROR; } - SetKernelTensorDataType(kernel); - kernel->set_name(node->name_); - kernel->set_is_model_output(IsContain(graph_output_node_indexes, size_t(i))); - kernels->emplace_back(kernel); + kernel->set_is_model_output(IsContain(graph_output_node_indexes_, size_t(node_index))); + dst_kernels->emplace_back(kernel); } - return RET_OK; } @@ -190,6 +294,11 @@ std::vector Scheduler::FindAllSubGraphKernels( MS_ASSERT(head_kernel != nullptr); MS_ASSERT(sinked_kernel_map != nullptr); std::vector sub_kernels; + if (head_kernel->Type() == schema::PrimitiveType_Switch || head_kernel->Type() == schema::PrimitiveType_Merge) { + (*sinked_kernel_map)[head_kernel] = true; + sub_kernels.emplace_back(head_kernel); + return sub_kernels; + } std::queue kernel_queue; kernel_queue.emplace(head_kernel); auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); @@ -200,6 +309,10 @@ std::vector Scheduler::FindAllSubGraphKernels( sub_kernels.emplace_back(cur_kernel); auto post_kernels = cur_kernel->out_kernels(); for (auto post_kernel : post_kernels) { + if (post_kernel->subgraph_type() != kernel::kNotSubGraph || post_kernel->Type() == schema::PrimitiveType_Merge || + post_kernel->Type() == schema::PrimitiveType_Switch) { + continue; + } if (cur_sub_graph_type == mindspore::lite::Scheduler::GetKernelSubGraphType(post_kernel)) { auto post_kernel_inputs = post_kernel->in_kernels(); if (std::all_of(post_kernel_inputs.begin(), post_kernel_inputs.end(), @@ -215,28 +328,41 @@ std::vector Scheduler::FindAllSubGraphKernels( int Scheduler::ConstructSubGraphs(std::vector *kernels) { auto old_kernels = *kernels; kernels->clear(); - std::map is_kernel_sinked; + std::map is_kernel_finish; for (auto kernel : old_kernels) { - is_kernel_sinked[kernel] = false; + is_kernel_finish[kernel] = false; } while (true) { auto head_kernel_iter = std::find_if(old_kernels.begin(), old_kernels.end(), [&](const kernel::LiteKernel *kernel) { auto kernel_inputs = kernel->in_kernels(); - return !is_kernel_sinked[kernel] && - std::all_of(kernel_inputs.begin(), kernel_inputs.end(), - [&](kernel::LiteKernel *kernel) { return is_kernel_sinked[kernel]; }); + if (is_kernel_finish[kernel]) { + return false; + } + // when merge is removed, this if is removed automatically + if (kernel->Type() == schema::PrimitiveType_Merge) { + MS_ASSERT(kernel->in_kernels().size() == 2); + return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]); + } else { + return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), + [&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); + } }); if (head_kernel_iter == old_kernels.end()) { break; } auto head_kernel = *head_kernel_iter; + if (head_kernel->subgraph_type() != kernel::kNotSubGraph) { + is_kernel_finish[head_kernel] = true; + kernels->emplace_back(head_kernel); + continue; + } if (head_kernel->desc().arch == mindspore::kernel::kAPU) { MS_LOG(ERROR) << "Not support APU now"; return RET_NOT_SUPPORT; } auto cur_sub_graph_type = mindspore::lite::Scheduler::GetKernelSubGraphType(head_kernel); - auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_sinked); + auto sub_kernels = FindAllSubGraphKernels(head_kernel, &is_kernel_finish); auto subgraph = CreateSubGraphKernel(sub_kernels, cur_sub_graph_type); if (subgraph == nullptr) { MS_LOG(ERROR) << "Create SubGraphKernel failed"; @@ -296,60 +422,6 @@ kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector &in_tensors, - const std::vector &out_tensors, - const mindspore::lite::PrimitiveC *primitive, const Model::Node *node) { - MS_ASSERT(primitive != nullptr); - TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors); - kernel::KernelKey desc{kCPU, data_type, static_cast(primitive->Type())}; -#if SUPPORT_NPU - if (context_->IsNpuEnabled()) { - kernel::KernelKey npu_desc{kNPU, desc.data_type, desc.type}; - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, npu_desc); - if (kernel != nullptr) { - MS_LOG(DEBUG) << "Get npu op success: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " << node->name_; - return kernel; - } else { - MS_LOG(DEBUG) << "Get npu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(npu_desc.type) << " " - << node->name_; - } - } -#endif -#if SUPPORT_GPU - if (context_->IsGpuEnabled()) { - kernel::KernelKey gpu_desc{kGPU, desc.data_type, desc.type}; - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, gpu_desc); - if (kernel != nullptr) { - MS_LOG(DEBUG) << "Get gpu op success: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " << node->name_; - return kernel; - } else { - MS_LOG(DEBUG) << "Get gpu op failed, scheduler to cpu: " << schema::EnumNamePrimitiveType(gpu_desc.type) << " " - << node->name_; - } - } -#endif - if (mindspore::lite::IsSupportFloat16() && - ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { - kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; - auto *kernel = - KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); - if (kernel != nullptr) { - MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " " - << node->name_; - return kernel; - } - } - if (data_type == kNumberTypeFloat16) { - MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; - desc.data_type = kNumberTypeFloat32; - } - auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); - if (kernel != nullptr) { - return kernel; - } - return nullptr; -} - TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors) { for (const auto &tensor : in_tensors) { auto dtype = tensor->data_type(); @@ -411,4 +483,11 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(const kernel::LiteKernel * } return kernel::kNotSubGraph; } + +void Scheduler::FindAllInoutKernels(const std::vector &kernels) { + for (auto *kernel : kernels) { + MS_ASSERT(kernel != nullptr); + kernel->FindInoutKernels(kernels); + } +} } // namespace mindspore::lite diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 15275ac8e9..755ce8fcb4 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_SCHEDULER_H_ #define MINDSPORE_LITE_SRC_SCHEDULER_H_ +#include #include #include #include "src/sub_graph_kernel.h" @@ -27,30 +28,47 @@ namespace mindspore::lite { class Scheduler { public: - explicit Scheduler(const InnerContext *ctx) { context_ = const_cast(ctx); } + Scheduler(const InnerContext *ctx, Model *src_model, std::vector src_tensors) + : context_(ctx), src_model_(src_model), src_tensors_(std::move(src_tensors)) {} ~Scheduler() = default; - int Schedule(const lite::Model *model, std::vector *tensors, std::vector *kernels); - - static int ReSizeKernels(const std::vector &kernels); - - protected: - kernel::LiteKernel *ScheduleNode(const std::vector &in_tensors, const std::vector &out_tensors, - const mindspore::lite::PrimitiveC *primitive, const Model::Node *cnode); - - int BuildKernels(const lite::Model *model, const std::vector *tensors, - std::vector *kernels); - - static int InferShape(const lite::Model *model, std::vector *tensors); - + int Schedule(std::vector *dst_kernels); + + private: + void FindNodeInoutTensors(const lite::Model::Node &node, std::vector *inputs, + std::vector *outputs); + // infer shape for a partial node + int InferPartialShape(const lite::Model::Node *node, bool *infer_shape_interrupt); + // infer shape for a node + int InferNodeShape(const lite::Model::Node *node, bool *infer_shape_interrupt); + // infer shape for a subgraph + int InferSubGraphShape(size_t subgraph_index, bool *infer_shape_interrupt); + + // schedule a node to kernel according to context and kernels registered + kernel::LiteKernel *FindBackendKernel(const std::vector &in_tensors, + const std::vector &out_tensors, + const mindspore::lite::PrimitiveC *primitive, const Model::Node *node); + // schedule a partial node to a subgraph_kernel + kernel::LiteKernel *SchedulePartialToKernel(const lite::Model::Node *src_node); + // schedule a node to a kernel + kernel::LiteKernel *ScheduleNodeToKernel(const lite::Model::Node *src_node); + // schedule a Model::SubGraph into a vector of kernel and subgraph_kernel + int ScheduleSubGraphToKernels(size_t subgraph_index, std::vector *dst_kernels); + + // find in_kernels_ and out_kernels of kernel, sub_graph and nodes_ in sub_graph + static void FindAllInoutKernels(const std::vector &kernels); + + // vector --> vector int ConstructSubGraphs(std::vector *kernels); + // create subgraph_kernel from a vector of kernel kernel::SubGraphKernel *CreateSubGraphKernel(const std::vector &kernels, kernel::SubGraphType type); std::vector FindAllSubGraphKernels( kernel::LiteKernel *head_kernel, std::map *sinked_kernel_map); + // other methods static TypeId GetFirstFp32Fp16OrInt8Type(const std::vector &in_tensors); static void SetKernelTensorDataType(kernel::LiteKernel *kernel); @@ -58,7 +76,10 @@ class Scheduler { static kernel::SubGraphType GetKernelSubGraphType(const kernel::LiteKernel *kernel); protected: - InnerContext *context_ = nullptr; + const InnerContext *context_ = nullptr; + Model *src_model_ = nullptr; + std::vector src_tensors_; + std::vector graph_output_node_indexes_; }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index 2ba75b73ca..4f2de9a1f4 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -149,6 +149,12 @@ int SubGraphKernel::ReSize(bool is_interrupt) { return RET_OK; } +void SubGraphKernel::InitOutTensorInitRefCount() { + for (auto *node : nodes_) { + node->InitOutTensorInitRefCount(); + } +} + int CpuSubGraph::Prepare() { auto ret = SubGraphKernel::Prepare(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index 29ca117977..8bcbd4d289 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -84,6 +84,8 @@ class SubGraphKernel : public LiteKernel { int ReSize(bool is_interrupt); + void InitOutTensorInitRefCount() override; + std::string ToString() const override; std::vector nodes() { return this->nodes_; } @@ -104,11 +106,10 @@ class CpuSubGraph : public SubGraphKernel { const std::vector &nodes, const lite::InnerContext *ctx) : SubGraphKernel(inputs, outputs, in_kernels, out_kernels, nodes, ctx) { subgraph_type_ = kCpuFP32SubGraph; - this->executor_ = new (std::nothrow) mindspore::lite::Executor; + this->executor_ = new (std::nothrow) mindspore::lite::CpuExecutor; } ~CpuSubGraph() override { delete this->executor_; } - int Prepare() override; int Init() override { return SubGraphKernel::Init(); } int PreProcess() override { return SubGraphKernel::PreProcess(); } diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index b97bd3c9d7..ed6f3b4620 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -110,8 +110,14 @@ class Tensor : public mindspore::tensor::MSTensor { size_t ref_count() const { return this->ref_count_; } + size_t init_ref_count() const { return this->init_ref_count_; } + void set_ref_count(size_t ref_count) { this->ref_count_ = ref_count; } + void set_init_ref_count(size_t ref_count) { this->init_ref_count_ = ref_count; } + + void ResetRefCount() { this->ref_count_ = this->init_ref_count_; } + void DecRefCount() { this->ref_count_--; } std::string ToString() const; @@ -156,6 +162,8 @@ class Tensor : public mindspore::tensor::MSTensor { schema::Format format_; Category category_; size_t ref_count_ = 0; + size_t init_ref_count_ = 0; + size_t ready_count_ = 0; std::vector quant_params_; std::vector quant_clusters_; mindspore::lite::Allocator *allocator_ = nullptr; diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index e6e3c80d51..bdb1c2f3cb 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -128,7 +128,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a return lite::RET_NULL_PTR; } auto run_kernel = (train_mode_) ? train_kernels_ : inference_kernels_; - lite::Executor executor; + lite::CpuExecutor executor; if (before == nullptr && after == nullptr) { return executor.Run(this->inputs_, this->outputs_, run_kernel, this->context_->allocator.get()); } else { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 928e360bad..e961df4ecb 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -261,6 +261,8 @@ if (ENABLE_CONVERTER) set(TEST_SRC ${TEST_SRC} ${TEST_DIR}/st/converter_test.cc + ${TEST_DIR}/st/control_flow_test.cc + ${TEST_DIR}/st/sub_graph_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc ${TEST_DIR}/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -300,7 +302,7 @@ endif () add_executable(lite-test ${TEST_SRC}) - +add_dependencies(lite-test fbs_src) target_link_libraries(lite-test dl ${GTEST_LIBRARY}) if (PLATFORM_ARM64) target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) @@ -321,6 +323,7 @@ if (SUPPORT_NPU) target_link_libraries(lite-test npu_kernel_mid) endif () if (ENABLE_CONVERTER) + add_dependencies(lite-test fbs_inner_src) target_link_libraries(lite-test anf_importer_mid anf_exporter_mid diff --git a/mindspore/lite/test/models_tflite_posttraining.cfg b/mindspore/lite/test/models_tflite_posttraining.cfg index c17917cbfd..5e3710533f 100644 --- a/mindspore/lite/test/models_tflite_posttraining.cfg +++ b/mindspore/lite/test/models_tflite_posttraining.cfg @@ -1,3 +1,3 @@ mobilenet.tflite 0.5 -transformer_20200831_encoder_fp32.tflite 68 +transformer_20200831_encoder_fp32.tflite 69 transformer_20200831_decoder_fp32.tflite 35 diff --git a/mindspore/lite/test/st/control_flow_test.cc b/mindspore/lite/test/st/control_flow_test.cc new file mode 100644 index 0000000000..4d5d7fce22 --- /dev/null +++ b/mindspore/lite/test/st/control_flow_test.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "mindspore/lite/include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/lite_session.h" +#include "include/version.h" + +namespace mindspore { +class ControlFlowTest : public mindspore::CommonTest { + public: + ControlFlowTest() {} +}; + +TEST_F(ControlFlowTest, TestMergeWhileModel) { + // make graph + auto meta_graph = std::make_shared(); + MS_LOG(DEBUG) << "make subgraph"; + meta_graph->name = "graph"; + meta_graph->version = lite::Version(); + meta_graph->inputIndex = {0}; + meta_graph->outputIndex = {9}; + // subgraph 0 : main graph + auto sub_graph_0 = std::make_unique(); + sub_graph_0->name = "main_graph"; + + // subgraph 1 : cond graph + auto sub_graph_1 = std::make_unique(); + sub_graph_1->name = "cond_graph"; + + // subgraph 2: body graph + auto sub_graph_2 = std::make_unique(); + sub_graph_2->name = "body_graph"; + + MS_LOG(DEBUG) << "make subgraph"; + + // subgraph 0: node 0 before-add-1 + auto sub_graph_0_node_0 = std::make_unique(); + sub_graph_0_node_0->inputIndex = {0, 1}; + sub_graph_0_node_0->outputIndex = {2}; + sub_graph_0_node_0->primitive = std::make_unique(); + sub_graph_0_node_0->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_0_node_0 = new schema::AddT; + primitive_sub_graph_0_node_0->activationType = schema::ActivationType_NO_ACTIVATION; + sub_graph_0_node_0->primitive->value.value = primitive_sub_graph_0_node_0; + sub_graph_0_node_0->name = "before_Add_1"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_0)); + sub_graph_0->nodeIndices.push_back(0); + MS_LOG(DEBUG) << "node 0"; + + // subgraph 0: node 1 before-add-1 + auto sub_graph_0_node_1 = std::make_unique(); + sub_graph_0_node_1->inputIndex = {2, 3}; + sub_graph_0_node_1->outputIndex = {4}; + sub_graph_0_node_1->primitive = std::make_unique(); + sub_graph_0_node_1->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_0_node_1 = new schema::AddT; + primitive_sub_graph_0_node_1->activationType = schema::ActivationType_NO_ACTIVATION; + sub_graph_0_node_1->primitive->value.value = primitive_sub_graph_0_node_1; + sub_graph_0_node_1->name = "before_Add_2"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_1)); + sub_graph_0->nodeIndices.push_back(1); + MS_LOG(DEBUG) << "node 1"; + + // subgraph 0: node 2 merge + auto sub_graph_0_node_2 = std::make_unique(); + sub_graph_0_node_2->inputIndex = {4, 17}; + sub_graph_0_node_2->outputIndex = {16}; + sub_graph_0_node_2->primitive = std::make_unique(); + sub_graph_0_node_2->primitive->value.type = schema::PrimitiveType_Merge; + auto primitive_sub_graph_0_node_2 = new schema::MergeT; + sub_graph_0_node_2->primitive->value.value = primitive_sub_graph_0_node_2; + sub_graph_0_node_2->name = "merge"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_2)); + sub_graph_0->nodeIndices.push_back(2); + MS_LOG(DEBUG) << "node 2"; + + // subgraph 0: node 3 partial cond subGraph + auto sub_graph_0_node_3 = std::make_unique(); + sub_graph_0_node_3->inputIndex = {16}; + sub_graph_0_node_3->outputIndex = {5}; // 5 : bool + sub_graph_0_node_3->primitive = std::make_unique(); + sub_graph_0_node_3->primitive->value.type = schema::PrimitiveType_Partial; + auto primitive_sub_graph_0_node_3 = new schema::PartialT; + primitive_sub_graph_0_node_3->subGraphIndex = 1; + sub_graph_0_node_3->primitive->value.value = primitive_sub_graph_0_node_3; + sub_graph_0_node_3->name = "Partial_cond"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_3)); + sub_graph_0->nodeIndices.push_back(3); + MS_LOG(DEBUG) << "node 2"; + + // subgraph 0: node 4 switch + auto sub_graph_0_node_4 = std::make_unique(); + sub_graph_0_node_4->inputIndex = {5, 16}; // 5 : bool; 16 data + sub_graph_0_node_4->outputIndex = {6, 7}; + sub_graph_0_node_4->primitive = std::make_unique(); + sub_graph_0_node_4->primitive->value.type = schema::PrimitiveType_Switch; + auto primitive_sub_graph_0_node_4 = new schema::SwitchT; + sub_graph_0_node_4->primitive->value.value = primitive_sub_graph_0_node_4; + sub_graph_0_node_4->name = "Switch"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_4)); + sub_graph_0->nodeIndices.push_back(4); + MS_LOG(DEBUG) << "node 4"; + + // subgraph 0: node 5 partial body subgraph + auto sub_graph_0_node_5 = std::make_unique(); + sub_graph_0_node_5->inputIndex = {6}; + sub_graph_0_node_5->outputIndex = {17}; + sub_graph_0_node_5->primitive = std::make_unique(); + sub_graph_0_node_5->primitive->value.type = schema::PrimitiveType_Partial; + auto primitive_sub_graph_0_node_5 = new schema::PartialT; + primitive_sub_graph_0_node_5->subGraphIndex = 2; + sub_graph_0_node_5->primitive->value.value = primitive_sub_graph_0_node_5; + sub_graph_0_node_5->name = "Partial_body"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_5)); + sub_graph_0->nodeIndices.push_back(5); + MS_LOG(DEBUG) << "node 5"; + + // subgraph 0: node 6 add-after + auto sub_graph_0_node_6 = std::make_unique(); + sub_graph_0_node_6->inputIndex = {7, 8}; + sub_graph_0_node_6->outputIndex = {9}; + sub_graph_0_node_6->primitive = std::make_unique(); + sub_graph_0_node_6->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_0_node_6 = new schema::AddT; + sub_graph_0_node_6->primitive->value.value = primitive_sub_graph_0_node_6; + sub_graph_0_node_6->name = "Add-after"; + meta_graph->nodes.emplace_back(std::move(sub_graph_0_node_6)); + sub_graph_0->nodeIndices.push_back(6); + MS_LOG(DEBUG) << "node 6"; + + sub_graph_0->inputIndices = {0}; + sub_graph_0->outputIndices = {9}; + sub_graph_0->tensorIndices = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17}; + + meta_graph->subGraph.push_back(std::move(sub_graph_0)); + + // subgraph 1 ; node:0 add cond + auto sub_graph_1_node_0 = std::make_unique(); + sub_graph_1_node_0->inputIndex = {16, 10}; + sub_graph_1_node_0->outputIndex = {11}; + sub_graph_1_node_0->primitive = std::make_unique(); + sub_graph_1_node_0->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_1_node_0 = new schema::AddT; + sub_graph_1_node_0->primitive->value.value = primitive_sub_graph_1_node_0; + sub_graph_1_node_0->name = "cond_add"; + meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_0)); + sub_graph_1->nodeIndices.push_back(7); + MS_LOG(DEBUG) << "node 6"; + + // subgraph 1 ; node:1 Less cond + auto sub_graph_1_node_1 = std::make_unique(); + sub_graph_1_node_1->inputIndex = {11, 12}; + sub_graph_1_node_1->outputIndex = {5}; + sub_graph_1_node_1->primitive = std::make_unique(); + sub_graph_1_node_1->primitive->value.type = schema::PrimitiveType_Less; + auto primitive_sub_graph_1_node_1 = new schema::LessT; + sub_graph_1_node_1->primitive->value.value = primitive_sub_graph_1_node_1; + sub_graph_1_node_1->name = "cond_Less"; + meta_graph->nodes.emplace_back(std::move(sub_graph_1_node_1)); + sub_graph_1->nodeIndices.push_back(8); + MS_LOG(DEBUG) << "node 7"; + + sub_graph_1->inputIndices = {16}; + sub_graph_1->outputIndices = {5}; + sub_graph_1->tensorIndices = {16, 10, 11, 12, 5}; + meta_graph->subGraph.push_back(std::move(sub_graph_1)); + + // subgraph 2 ; node:0 body add-1 + auto sub_graph_2_node_0 = std::make_unique(); + sub_graph_2_node_0->inputIndex = {6, 13}; + sub_graph_2_node_0->outputIndex = {14}; + sub_graph_2_node_0->primitive = std::make_unique(); + sub_graph_2_node_0->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_2_node_0 = new schema::AddT; + sub_graph_2_node_0->primitive->value.value = primitive_sub_graph_2_node_0; + sub_graph_2_node_0->name = "body_add_1"; + meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_0)); + sub_graph_2->nodeIndices.push_back(9); + MS_LOG(DEBUG) << "node 8"; + + // subgraph 2 ; node:1 body add-2 + auto sub_graph_2_node_1 = std::make_unique(); + sub_graph_2_node_1->inputIndex = {14, 15}; + sub_graph_2_node_1->outputIndex = {17}; + sub_graph_2_node_1->primitive = std::make_unique(); + sub_graph_2_node_1->primitive->value.type = schema::PrimitiveType_Add; + auto primitive_sub_graph_2_node_1 = new schema::AddT; + sub_graph_2_node_1->primitive->value.value = primitive_sub_graph_2_node_1; + sub_graph_2_node_1->name = "body_add_2"; + meta_graph->nodes.emplace_back(std::move(sub_graph_2_node_1)); + sub_graph_2->nodeIndices.push_back(10); + MS_LOG(DEBUG) << "node 9"; + + sub_graph_2->inputIndices = {6}; + sub_graph_2->outputIndices = {17}; + sub_graph_2->tensorIndices = {13, 14, 15, 6, 17}; + + meta_graph->subGraph.push_back(std::move(sub_graph_2)); + + // ------- tensor --------- + // tensor: 0 before-add input0
+ auto tensor_0 = std::make_unique(); + tensor_0->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_0->format = schema::Format_NHWC; + tensor_0->dataType = TypeId::kNumberTypeFloat32; + tensor_0->dims = {1}; + tensor_0->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_0)); + MS_LOG(DEBUG) << "tensor 0"; + + // tensor: 1 before-add input1 + auto tensor_1 = std::make_unique(); + tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_1->format = schema::Format_NHWC; + tensor_1->dataType = TypeId::kNumberTypeFloat32; + tensor_1->dims = {1}; + tensor_1->data.resize(sizeof(float) * 1); + float input1_data[] = {1}; + memcpy(tensor_1->data.data(), input1_data, sizeof(float) * 1); + tensor_1->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_1)); + MS_LOG(DEBUG) << "tensor 1"; + + // tensor: 2 before-add output/partial input + auto tensor_2 = std::make_unique(); + tensor_2->nodeType = schema::NodeType::NodeType_Parameter; + tensor_2->format = schema::Format_NHWC; + tensor_2->dataType = TypeId::kNumberTypeFloat32; + tensor_2->dims = {1}; + tensor_2->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_2)); + MS_LOG(DEBUG) << "tensor 2"; + + // tensor: 3 before-add input1 + auto tensor_3 = std::make_unique(); + tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_3->format = schema::Format_NHWC; + tensor_3->dataType = TypeId::kNumberTypeFloat32; + tensor_3->dims = {1}; + tensor_3->data.resize(sizeof(float) * 1); + float tensor_3_data[] = {1}; + memcpy(tensor_3->data.data(), tensor_3_data, sizeof(float) * 1); + tensor_3->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_3)); + MS_LOG(DEBUG) << "tensor 3"; + + auto tensor_4 = std::make_unique(); + tensor_4->nodeType = schema::NodeType::NodeType_Parameter; + tensor_4->format = schema::Format_NHWC; + tensor_4->dataType = TypeId::kNumberTypeFloat32; + tensor_4->dims = {1}; + tensor_4->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_4)); + MS_LOG(DEBUG) << "tensor 4"; + + // tensor :5 partial output + auto tensor_5 = std::make_unique(); + tensor_5->nodeType = schema::NodeType::NodeType_Parameter; + tensor_5->format = schema::Format_NHWC; + tensor_5->dataType = TypeId::kNumberTypeBool; + tensor_5->dims = {1}; + tensor_5->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_5)); + MS_LOG(DEBUG) << "tensor_4"; + + // tensor: 6 switch true output + auto tensor_6 = std::make_unique(); + tensor_6->nodeType = schema::NodeType::NodeType_Parameter; + tensor_6->format = schema::Format_NHWC; + tensor_6->dataType = TypeId::kNumberTypeFloat32; + tensor_6->dims = {1}; + tensor_6->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_6)); + MS_LOG(DEBUG) << "tensor 6"; + + // tensor: 5 switch False output + auto tensor_7 = std::make_unique(); + tensor_7->nodeType = schema::NodeType::NodeType_Parameter; + tensor_7->format = schema::Format_NHWC; + tensor_7->dataType = TypeId::kNumberTypeFloat32; + tensor_7->dims = {1}; + tensor_7->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_7)); + MS_LOG(DEBUG) << "tensor_7"; + + // tensor: 6 body-add input ,other input is switch true output + auto tensor_8 = std::make_unique(); + tensor_8->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_8->format = schema::Format_NHWC; + tensor_8->dataType = TypeId::kNumberTypeFloat32; + tensor_8->dims = {1}; + tensor_8->data.resize(sizeof(float) * 1); + float tensor_8_data[] = {10}; + memcpy(tensor_8->data.data(), tensor_8_data, sizeof(float) * 1); + tensor_8->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_8)); + MS_LOG(DEBUG) << "tensor_8"; + + auto tensor_9 = std::make_unique(); + tensor_9->nodeType = schema::NodeType::NodeType_Parameter; + tensor_9->format = schema::Format_NHWC; + tensor_9->dataType = TypeId::kNumberTypeFloat32; + tensor_9->dims = {1}; + tensor_9->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_9)); + MS_LOG(DEBUG) << "tensor_9"; + + // tensor: 7 after-add input ,other input is switch false output + auto tensor_10 = std::make_unique(); + tensor_10->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_10->format = schema::Format_NHWC; + tensor_10->dataType = TypeId::kNumberTypeFloat32; + tensor_10->dims = {1}; + tensor_10->data.resize(sizeof(float) * 1); + float tensor_10_data[] = {1}; + memcpy(tensor_10->data.data(), tensor_10_data, sizeof(float) * 1); + tensor_10->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_10)); + MS_LOG(DEBUG) << "tensor_10"; + + // tensor: 8 main graph output + auto tensor_11 = std::make_unique(); + tensor_11->nodeType = schema::NodeType::NodeType_Parameter; + tensor_11->format = schema::Format_NHWC; + tensor_11->dataType = TypeId::kNumberTypeFloat32; + tensor_11->dims = {1}; + tensor_11->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_11)); + MS_LOG(DEBUG) << "tensor 11"; + + // tensor: 9 cond-Less input, other input is tensor 2 + auto tensor_12 = std::make_unique(); + tensor_12->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_12->format = schema::Format_NHWC; + tensor_12->dataType = TypeId::kNumberTypeFloat32; + tensor_12->dims = {1}; + tensor_12->data.resize(sizeof(float) * 1); + float tensor_12_data[] = {10}; + memcpy(tensor_12->data.data(), tensor_12_data, sizeof(float) * 1); + tensor_12->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_12)); + MS_LOG(DEBUG) << "tensor_12"; + + auto tensor_13 = std::make_unique(); + tensor_13->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_13->format = schema::Format_NHWC; + tensor_13->dataType = TypeId::kNumberTypeFloat32; + tensor_13->dims = {1}; + tensor_13->data.resize(sizeof(float) * 1); + float tensor_13_data[] = {1}; + memcpy(tensor_13->data.data(), tensor_13_data, sizeof(float) * 1); + tensor_13->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_13)); + MS_LOG(DEBUG) << "tensor_13"; + + auto tensor_14 = std::make_unique(); + tensor_14->nodeType = schema::NodeType::NodeType_Parameter; + tensor_14->format = schema::Format_NHWC; + tensor_14->dataType = TypeId::kNumberTypeFloat32; + tensor_14->dims = {1}; + tensor_14->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_14)); + MS_LOG(DEBUG) << "tensor 14"; + + auto tensor_15 = std::make_unique(); + tensor_15->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_15->format = schema::Format_NHWC; + tensor_15->dataType = TypeId::kNumberTypeFloat32; + tensor_15->dims = {1}; + tensor_15->data.resize(sizeof(float) * 1); + float tensor_15_data[] = {1}; + memcpy(tensor_15->data.data(), tensor_15_data, sizeof(float) * 1); + tensor_15->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_15)); + MS_LOG(DEBUG) << "tensor_15"; + + auto tensor_16 = std::make_unique(); + tensor_16->nodeType = schema::NodeType::NodeType_Parameter; + tensor_16->format = schema::Format_NHWC; + tensor_16->dataType = TypeId::kNumberTypeFloat32; + tensor_16->dims = {1}; + tensor_16->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_16)); + MS_LOG(DEBUG) << "tensor_16"; + + auto tensor_17 = std::make_unique(); + tensor_17->nodeType = schema::NodeType::NodeType_Parameter; + tensor_17->format = schema::Format_NHWC; + tensor_17->dataType = TypeId::kNumberTypeFloat32; + tensor_17->dims = {1}; + tensor_17->offset = -1; + meta_graph->allTensors.emplace_back(std::move(tensor_17)); + MS_LOG(DEBUG) << "tensor_17"; + // ----------------------------------------------------------------------- + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph.get()); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + size_t size = builder.GetSize(); + const char *content = reinterpret_cast(builder.GetBufferPointer()); + + auto model = std::shared_ptr(lite::Model::Import(content, size)); + ASSERT_NE(model, nullptr); + lite::Context context; + context.thread_num_ = 2; + auto &cpu_device_ctx = context.device_list_[0]; + cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; + cpu_device_ctx.device_info_.cpu_device_info_.enable_float16_ = false; + auto session = std::shared_ptr(session::LiteSession::CreateSession(&context)); + ASSERT_NE(session, nullptr); + auto ret = session->CompileGraph(model.get()); + ASSERT_EQ(ret, lite::RET_OK); + model->Free(); + auto inputs = session->GetInputs(); + ASSERT_EQ(inputs.size(), 1); + auto input = inputs.front(); + ASSERT_NE(input, nullptr); + ASSERT_EQ(input->data_type(), kNumberTypeFloat32); + ASSERT_EQ(input->shape().size(), 1); + ASSERT_EQ(input->shape().at(0), 1); + auto in_data = reinterpret_cast(input->MutableData()); + ASSERT_NE(in_data, nullptr); + in_data[0] = 1; + ret = session->RunGraph(); + ASSERT_EQ(ret, lite::RET_OK); + auto outputs = session->GetOutputs(); + ASSERT_EQ(outputs.size(), 1); + auto output = outputs.begin()->second; + ASSERT_NE(output, nullptr); + ASSERT_EQ(output->data_type(), kNumberTypeFloat32); + ASSERT_EQ(output->shape().size(), 1); + ASSERT_EQ(output->shape().at(0), 1); + auto out_data = reinterpret_cast(output->MutableData()); + ASSERT_NE(out_data, nullptr); + ASSERT_EQ(out_data[0], 19); +} +} // namespace mindspore diff --git a/mindspore/lite/test/st/sub_graph_test.cc b/mindspore/lite/test/st/sub_graph_test.cc new file mode 100644 index 0000000000..fbe29e4f7b --- /dev/null +++ b/mindspore/lite/test/st/sub_graph_test.cc @@ -0,0 +1,217 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "mindspore/lite/include/model.h" +#include "common/common_test.h" +#include "include/lite_session.h" +#include "include/context.h" +#include "include/model.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/lite_session.h" +#include "src/runtime/parallel_executor.h" +#include "tools/common/storage.h" +#include "include/version.h" + +namespace mindspore { +class SubGraphTest : public mindspore::CommonTest { + public: + SubGraphTest() {} +}; + +TEST_F(SubGraphTest, RecursiveSubGraphTest) { + // add0 partial1 2 3 tensor0 1 2 + auto add_0 = std::make_unique(); + add_0->inputIndex = {0, 1}; + add_0->outputIndex = {2}; + add_0->primitive = std::make_unique(); + add_0->primitive->value.type = schema::PrimitiveType_Add; + auto add_0_prim = new schema::AddT; + add_0_prim->activationType = schema::ActivationType_NO_ACTIVATION; + add_0->primitive->value.value = add_0_prim; + add_0->name = "Add0"; + auto partial_1 = std::make_unique(); + partial_1->inputIndex = {2}; + partial_1->outputIndex = {7}; + partial_1->primitive = std::make_unique(); + partial_1->primitive->value.type = schema::PrimitiveType_Partial; + auto partial_1_prim = new schema::PartialT; + partial_1_prim->subGraphIndex = 1; + partial_1->primitive->value.value = partial_1_prim; + partial_1->name = "Partial1"; + auto partial_2 = std::make_unique(); + partial_2->inputIndex = {2}; + partial_2->outputIndex = {7}; + partial_2->primitive = std::make_unique(); + partial_2->primitive->value.type = schema::PrimitiveType_Partial; + auto partial_2_prim = new schema::PartialT; + partial_2_prim->subGraphIndex = 2; + partial_2->primitive->value.value = partial_2_prim; + partial_2->name = "Partial2"; + auto partial_3 = std::make_unique(); + partial_3->inputIndex = {4, 6}; + partial_3->outputIndex = {7}; + partial_3->primitive = std::make_unique(); + partial_3->primitive->value.type = schema::PrimitiveType_Partial; + auto partial_3_prim = new schema::PartialT; + partial_3_prim->subGraphIndex = 3; + partial_3->primitive->value.value = partial_3_prim; + partial_3->name = "Partial3"; + auto tensor_0 = std::make_unique(); + tensor_0->nodeType = schema::NodeType::NodeType_Parameter; + tensor_0->format = schema::Format_NHWC; + tensor_0->dataType = TypeId::kNumberTypeFloat32; + tensor_0->dims = {1, 2}; + auto tensor_1 = std::make_unique(); + tensor_1->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_1->format = schema::Format_NHWC; + tensor_1->dataType = TypeId::kNumberTypeFloat32; + tensor_1->dims = {1, 2}; + auto tensor_2 = std::make_unique(); + tensor_2->nodeType = schema::NodeType::NodeType_Parameter; + tensor_2->format = schema::Format_NHWC; + tensor_2->dataType = TypeId::kNumberTypeFloat32; + auto sub_graph_0 = std::make_unique(); + sub_graph_0->name = "main_graph"; + sub_graph_0->inputIndices = {0}; + sub_graph_0->outputIndices = {7}; + sub_graph_0->nodeIndices = {0, 1, 2}; + sub_graph_0->tensorIndices = {0, 1, 2, 7}; + // add1 tensor3 4 + auto add_1 = std::make_unique(); + add_1->inputIndex = {2, 3}; + add_1->outputIndex = {4}; + add_1->primitive = std::make_unique(); + add_1->primitive->value.type = schema::PrimitiveType_Add; + auto add_1_prim = new schema::AddT; + add_1_prim->activationType = schema::ActivationType_NO_ACTIVATION; + add_1->primitive->value.value = add_1_prim; + add_1->name = "Add1"; + auto tensor_3 = std::make_unique(); + tensor_3->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_3->format = schema::Format_NHWC; + tensor_3->dataType = TypeId::kNumberTypeFloat32; + tensor_3->dims = {1, 2}; + auto tensor_4 = std::make_unique(); + tensor_4->nodeType = schema::NodeType::NodeType_Parameter; + tensor_4->format = schema::Format_NHWC; + tensor_4->dataType = TypeId::kNumberTypeFloat32; + auto sub_graph_1 = std::make_unique(); + sub_graph_1->name = "sub_graph_1"; + sub_graph_1->inputIndices = {2}; + sub_graph_1->outputIndices = {7}; + sub_graph_1->nodeIndices = {4, 3}; + sub_graph_1->tensorIndices = {2, 3, 4, 7}; + // add2 tensor5 6 + auto add_2 = std::make_unique(); + add_2->inputIndex = {2, 5}; + add_2->outputIndex = {6}; + add_2->primitive = std::make_unique(); + add_2->primitive->value.type = schema::PrimitiveType_Add; + auto add_2_prim = new schema::AddT; + add_2_prim->activationType = schema::ActivationType_NO_ACTIVATION; + add_2->primitive->value.value = add_2_prim; + add_2->name = "Add2"; + auto tensor_5 = std::make_unique(); + tensor_5->nodeType = schema::NodeType::NodeType_ValueNode; + tensor_5->format = schema::Format_NHWC; + tensor_5->dataType = TypeId::kNumberTypeFloat32; + tensor_5->dims = {1, 2}; + auto tensor_6 = std::make_unique(); + tensor_6->nodeType = schema::NodeType::NodeType_Parameter; + tensor_6->format = schema::Format_NHWC; + tensor_6->dataType = TypeId::kNumberTypeFloat32; + auto sub_graph_2 = std::make_unique(); + sub_graph_2->name = "sub_graph_2"; + sub_graph_2->inputIndices = {2}; + sub_graph_2->outputIndices = {7}; + sub_graph_2->nodeIndices = {5, 3}; + sub_graph_2->tensorIndices = {2, 5, 6, 7}; + // add3 tensor7 + auto add_3 = std::make_unique(); + add_3->inputIndex = {4, 6}; + add_3->outputIndex = {7}; + add_3->primitive = std::make_unique(); + add_3->primitive->value.type = schema::PrimitiveType_Add; + auto add_3_prim = new schema::AddT; + add_3_prim->activationType = schema::ActivationType_NO_ACTIVATION; + add_3->primitive->value.value = add_3_prim; + add_3->name = "Add3"; + auto tensor_7 = std::make_unique(); + tensor_7->nodeType = schema::NodeType::NodeType_Parameter; + tensor_7->format = schema::Format_NHWC; + tensor_7->dataType = TypeId::kNumberTypeFloat32; + auto sub_graph_3 = std::make_unique(); + sub_graph_3->name = "sub_graph_3"; + sub_graph_3->inputIndices = {4, 6}; + sub_graph_3->outputIndices = {7}; + sub_graph_3->nodeIndices = {6}; + sub_graph_3->tensorIndices = {4, 6, 7}; + + // make graph + auto meta_graph = std::make_shared(); + meta_graph->name = "graph"; + meta_graph->nodes.emplace_back(std::move(add_0)); + meta_graph->nodes.emplace_back(std::move(partial_1)); + meta_graph->nodes.emplace_back(std::move(partial_2)); + meta_graph->nodes.emplace_back(std::move(partial_3)); + meta_graph->nodes.emplace_back(std::move(add_1)); + meta_graph->nodes.emplace_back(std::move(add_2)); + meta_graph->nodes.emplace_back(std::move(add_3)); + meta_graph->allTensors.emplace_back(std::move(tensor_0)); + meta_graph->allTensors.emplace_back(std::move(tensor_1)); + meta_graph->allTensors.emplace_back(std::move(tensor_2)); + meta_graph->allTensors.emplace_back(std::move(tensor_3)); + meta_graph->allTensors.emplace_back(std::move(tensor_4)); + meta_graph->allTensors.emplace_back(std::move(tensor_5)); + meta_graph->allTensors.emplace_back(std::move(tensor_6)); + meta_graph->allTensors.emplace_back(std::move(tensor_7)); + meta_graph->subGraph.emplace_back(std::move(sub_graph_0)); + meta_graph->subGraph.emplace_back(std::move(sub_graph_1)); + meta_graph->subGraph.emplace_back(std::move(sub_graph_2)); + meta_graph->subGraph.emplace_back(std::move(sub_graph_3)); + meta_graph->version = lite::Version(); + // ----------------------------------------------------------------------- + lite::Storage::Save(*meta_graph, + "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph"); + // ----------------------------------------------------------------------- + size_t size = 0; + char *graph_buf = lite::ReadFile( + "/mnt/data/workspace/OpenAI/Huawei/mindspore/mindspore/lite/my_test/models/recursive_subgraph.ms", &size); + ASSERT_NE(graph_buf, nullptr); + + auto model = std::shared_ptr(lite::Model::Import(graph_buf, size)); + ASSERT_NE(model, nullptr); + delete[](graph_buf); + lite::Context context; + auto &cpu_device_ctx = context.device_list_[0]; + cpu_device_ctx.device_info_.cpu_device_info_.cpu_bind_mode_ = lite::MID_CPU; + context.thread_num_ = 2; + auto session = std::shared_ptr(lite::LiteSession::CreateSession(&context)); + ASSERT_NE(session, nullptr); + auto ret = session->CompileGraph(model.get()); + ASSERT_EQ(ret, lite::RET_OK); + auto inputs = session->GetInputs(); + for (auto *input : inputs) { + (void)input->MutableData(); + } + ret = session->RunGraph(); + ASSERT_EQ(ret, lite::RET_OK); +} +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index bb48b0b2f7..16085c79bf 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -142,7 +142,6 @@ add_executable(converter_lite ${KERNEL_SRC} ${LITE_SRC} ) -add_dependencies(converter_lite tflite_fbs_src) add_dependencies(converter_lite fbs_src) add_dependencies(converter_lite fbs_inner_src) diff --git a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt index 69614b1977..682bd2c4a0 100644 --- a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt +++ b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt @@ -5,4 +5,5 @@ set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID add_library(tflite_parser_mid OBJECT ${TFLITE_SRC_LIST} ) +add_dependencies(tflite_parser_mid tflite_fbs_src) target_link_libraries(tflite_parser_mid mindspore::flatbuffers)