support control flow

pull/9481/head
hangangqiang 4 years ago
parent d9395e8b71
commit 6e10a6288a

@ -258,7 +258,8 @@ union PrimitiveType {
SmoothL1LossGrad,
SigmoidCrossEntropyWithLogits,
SigmoidCrossEntropyWithLogitsGrad,
Reciprocal
Reciprocal,
Merge,
}
enum QuantType: int {

@ -1222,4 +1222,7 @@ table SigmoidCrossEntropyWithLogitsGrad {
}
table Reciprocal {
}
}
table Merge {
}

@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "mindspore/lite/src/executor.h"
#include "nnacl/pack.h"
#include "src/executor.h"
#include <queue>
#include "include/errorcode.h"
namespace mindspore::lite {
@ -26,7 +26,7 @@ int Executor::CheckInputs(const std::vector<Tensor *> &in_tensors) {
return RET_ERROR;
}
if (inTensor->data_c() == nullptr) {
MS_LOG(ERROR) << "Graph input tensor data is nullptr";
MS_LOG(ERROR) << "Graph input tensor data is nullptr " << in_tensors;
return RET_ERROR;
}
auto shape = inTensor->shape();
@ -49,7 +49,52 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
std::queue<kernel::LiteKernel *> kernel_queue;
for (auto kernel : kernels) {
if (kernel->IsReady()) {
kernel_queue.push(kernel);
}
}
while (!kernel_queue.empty()) {
auto cur_kernel = kernel_queue.front();
kernel_queue.pop();
MS_ASSERT(nullptr != cur_kernel);
ret = cur_kernel->PreProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << cur_kernel->name();
return ret;
}
ret = cur_kernel->Run(before, after);
if (RET_OK != ret) {
MS_LOG(ERROR) << "run kernel failed, name: " << cur_kernel->name();
return ret;
}
ret = cur_kernel->PostProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PostProcess kernel failed, name: " << cur_kernel->name();
return ret;
}
for (auto &out_kernel : cur_kernel->out_kernels()) {
if (out_kernel->IsReady()) {
kernel_queue.push(out_kernel);
}
}
}
return RET_OK;
}
int CpuExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, const KernelCallBack &before,
const KernelCallBack &after) {
MS_ASSERT(nullptr != allocator);
// not check input for merge. too hard
if (kernels.front()->Type() != schema::PrimitiveType_Merge) {
auto ret = this->CheckInputs(in_tensors);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CheckInputs failed";
return ret;
}
}
#ifdef SUPPORT_TRAIN
for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->set_ref_count(out_tensor->ref_count() + 1);
@ -57,7 +102,7 @@ int Executor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_
#endif
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
ret = kernel->PreProcess();
auto ret = kernel->PreProcess();
if (RET_OK != ret) {
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << kernel->name();
return ret;

@ -37,5 +37,16 @@ class Executor {
protected:
static int CheckInputs(const std::vector<Tensor *> &in_tensors);
};
class CpuExecutor : public Executor {
public:
CpuExecutor() = default;
virtual ~CpuExecutor() = default;
int Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors,
std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator = nullptr,
const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override;
};
} // namespace mindspore::lite
#endif

@ -62,7 +62,7 @@ InnerContext::~InnerContext() {
}
}
int InnerContext::IsValid() {
int InnerContext::IsValid() const {
if (this->device_list_.empty()) {
MS_LOG(ERROR) << "Device list is empty.";
return RET_NOT_SUPPORT;
@ -86,33 +86,33 @@ int InnerContext::IsValid() {
return RET_OK;
}
bool InnerContext::IsCpuFloat16Enabled() {
bool InnerContext::IsCpuFloat16Enabled() const {
if (!IsCpuEnabled()) {
return false;
}
return GetCpuInfo().enable_float16_;
}
bool InnerContext::IsGpuFloat16Enabled() {
bool InnerContext::IsGpuFloat16Enabled() const {
if (!IsGpuEnabled()) {
return false;
}
return GetGpuInfo().enable_float16_;
}
bool InnerContext::IsCpuEnabled() {
bool InnerContext::IsCpuEnabled() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });
}
bool InnerContext::IsGpuEnabled() {
bool InnerContext::IsGpuEnabled() const {
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
}
bool InnerContext::IsNpuEnabled() {
bool InnerContext::IsNpuEnabled() const {
#ifdef SUPPORT_NPU
return this->device_list_.end() !=
std::find_if(this->device_list_.begin(), this->device_list_.end(),
@ -123,7 +123,7 @@ bool InnerContext::IsNpuEnabled() {
#endif
}
CpuDeviceInfo InnerContext::GetCpuInfo() {
CpuDeviceInfo InnerContext::GetCpuInfo() const {
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_CPU; });
if (iter == this->device_list_.end()) {
@ -133,7 +133,7 @@ CpuDeviceInfo InnerContext::GetCpuInfo() {
}
}
GpuDeviceInfo InnerContext::GetGpuInfo() {
GpuDeviceInfo InnerContext::GetGpuInfo() const {
auto iter = std::find_if(this->device_list_.begin(), this->device_list_.end(),
[](const DeviceContext &device) { return device.device_type_ == DT_GPU; });
if (iter == this->device_list_.end()) {

@ -33,23 +33,23 @@ struct InnerContext : public Context {
int Init();
bool IsCpuFloat16Enabled();
bool IsCpuFloat16Enabled() const;
bool IsGpuFloat16Enabled();
bool IsGpuFloat16Enabled() const;
bool IsCpuEnabled();
bool IsCpuEnabled() const;
bool IsGpuEnabled();
bool IsGpuEnabled() const;
bool IsNpuEnabled();
bool IsNpuEnabled() const;
CpuDeviceInfo GetCpuInfo();
CpuDeviceInfo GetCpuInfo() const;
GpuDeviceInfo GetGpuInfo();
GpuDeviceInfo GetGpuInfo() const;
NpuDeviceInfo GetNpuInfo() const;
int IsValid();
int IsValid() const;
virtual ~InnerContext();
};

@ -41,9 +41,21 @@ void LiteKernel::FreeWorkspace() {
workspace_ = nullptr;
}
void LiteKernel::InitOutTensorRefCount() {
bool LiteKernel::IsReady() {
return std::all_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) {
return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1;
});
}
void LiteKernel::InitOutTensorInitRefCount() {
for (auto *tensor : this->out_tensors_) {
tensor->set_ref_count(this->out_kernels_.size());
int init_ref_count = 0;
for (auto *post_kernel : this->out_kernels_) {
init_ref_count +=
std::count_if(post_kernel->in_tensors_.begin(), post_kernel->in_tensors_.end(),
[&tensor](const lite::Tensor *post_kernel_in_tensor) { return post_kernel_in_tensor == tensor; });
}
tensor->set_init_ref_count(init_ref_count);
}
}
@ -61,15 +73,20 @@ int LiteKernel::DecOutTensorRefCount() {
return 0;
}
int LiteKernel::FreeWorkTensor() const {
for (auto input_kernel : this->in_kernels()) {
MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
int LiteKernel::FreeInWorkTensor() const {
for (auto &in_tensor : this->in_tensors_) {
MS_ASSERT(in_tensor != nullptr);
if (in_tensor->IsConst()) {
continue;
}
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed";
MS_ASSERT(in_tensor->ref_count() > 0);
in_tensor->set_ref_count(in_tensor->ref_count() - 1);
if (in_tensor->ref_count() <= 0) {
auto ret = in_tensor->FreeData();
if (0 != ret) {
MS_LOG(ERROR) << "Free tensor data failed";
return ret;
}
}
}
return RET_OK;
@ -91,15 +108,12 @@ int LiteKernel::PreProcess() {
}
}
auto outputs = this->out_tensors();
for (auto *output : outputs) {
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
if (output->ElementsNum() >= MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
MS_LOG(ERROR) << "The size of output tensor is too big";
return RET_ERROR;
}
auto ret = output->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MallocData failed";
@ -109,6 +123,28 @@ int LiteKernel::PreProcess() {
return RET_OK;
}
int LiteKernel::PostProcess() {
#ifdef SUPPORT_TRAIN
for (auto input_kernel : this->in_kernels()) {
MS_ASSERT(input_kernel != nullptr);
if (input_kernel->is_model_output()) {
continue;
}
auto ret = input_kernel->DecOutTensorRefCount();
if (0 != ret) {
MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed";
}
}
return RET_OK;
#else
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
output->ResetRefCount();
}
return FreeInWorkTensor();
#endif
}
int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) {
if (before != nullptr) {
if (!before(TensorVectorCast(this->in_tensors_), TensorVectorCast(this->out_tensors_),
@ -153,6 +189,28 @@ std::string LiteKernel::ToString() const {
return oss.str();
}
void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels) {
// clean io kernels
this->in_kernels_.clear();
this->out_kernels_.clear();
// find io kernels
for (auto *scope_kernel : scope_kernels) {
if (scope_kernel == this) {
continue;
}
for (auto *tensor : this->in_tensors_) {
if (lite::IsContain(scope_kernel->out_tensors(), tensor)) {
this->AddInKernel(scope_kernel);
}
}
for (auto *tensor : this->out_tensors_) {
if (lite::IsContain(scope_kernel->in_tensors(), tensor)) {
this->AddOutKernel(scope_kernel);
}
}
}
}
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels(
const std::vector<kernel::LiteKernel *> &kernels) {
std::vector<kernel::LiteKernel *> input_kernels;
@ -202,7 +260,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
if (outer_in_kernels.empty()) {
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) {
if (!in_kernel_in_tensor->IsConst()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
if (!IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.push_back(in_kernel_in_tensor);
}
}
@ -219,7 +277,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect
auto outer_in_kernel_out_tensors_iter =
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor);
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) {
if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) {
if (!IsContain(input_tensors, in_kernel_in_tensor)) {
input_tensors.emplace_back(in_kernel_in_tensor);
}
}
@ -237,7 +295,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto &out_kernel_out_tensors = output_kernel->out_tensors();
if (outer_out_kernels.empty()) {
for (auto out_kernel_out_tensor : out_kernel_out_tensors) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
if (!IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.push_back(out_kernel_out_tensor);
}
}
@ -253,7 +311,7 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec
auto outer_out_kernel_in_tensors_iter =
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor);
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) {
if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) {
if (!IsContain(output_tensors, out_kernel_out_tensor)) {
output_tensors.emplace_back(out_kernel_out_tensor);
}
}
@ -299,33 +357,9 @@ int LiteKernelUtil::TopologicalSortKernels(std::vector<kernel::LiteKernel *> *ke
return RET_OK;
}
void LiteKernelUtil::InitIOKernels(std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
// clean io kernels
kernel->set_in_kernels({});
kernel->set_out_kernels({});
// find io kernels
for (auto *search_kernel : kernels) {
if (search_kernel == kernel) {
continue;
}
for (auto *tensor : kernel->in_tensors()) {
if (lite::IsContain(search_kernel->out_tensors(), tensor)) {
kernel->AddInKernel(search_kernel);
}
}
for (auto *tensor : kernel->out_tensors()) {
if (lite::IsContain(search_kernel->in_tensors(), tensor)) {
kernel->AddOutKernel(search_kernel);
}
}
}
}
}
void LiteKernelUtil::InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels) {
void LiteKernelUtil::InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels) {
for (auto *kernel : kernels) {
kernel->InitOutTensorRefCount();
kernel->InitOutTensorInitRefCount();
}
}

@ -87,10 +87,12 @@ class LiteKernel {
virtual int Run(const KernelCallBack &before, const KernelCallBack &after);
// called after Run
virtual int PostProcess() { return FreeWorkTensor(); }
virtual int PostProcess();
virtual int ReSize() { return mindspore::lite::RET_ERROR; }
virtual void FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope_kernels);
virtual int Init() { return mindspore::lite::RET_ERROR; }
std::string name() const { return this->name_; }
@ -154,11 +156,13 @@ class LiteKernel {
const std::vector<LiteKernel *> &out_kernels() const { return this->out_kernels_; }
void InitOutTensorRefCount();
virtual bool IsReady();
virtual void InitOutTensorInitRefCount();
int DecOutTensorRefCount();
int FreeWorkTensor() const;
virtual int FreeInWorkTensor() const;
KernelKey desc() const { return desc_; }
@ -203,8 +207,6 @@ typedef LiteKernel *(*KernelCreator)(const std::vector<lite::Tensor *> &inputs,
class LiteKernelUtil {
public:
static void InitIOKernels(std::vector<kernel::LiteKernel *> &kernels);
static std::vector<kernel::LiteKernel *> SubgraphInputKernels(const std::vector<kernel::LiteKernel *> &kernels);
static std::vector<kernel::LiteKernel *> SubgraphOutputKernels(const std::vector<kernel::LiteKernel *> &kernels);
@ -215,7 +217,7 @@ class LiteKernelUtil {
static int TopologicalSortKernels(std::vector<kernel::LiteKernel *> *kernels);
static void InitTensorRefCount(std::vector<kernel::LiteKernel *> &kernels);
static void InitTensorInitRefCount(std::vector<kernel::LiteKernel *> &kernels);
static int SetInput(LiteKernel &kernelMod, const std::vector<lite::Tensor *> &inputs);
};

@ -295,6 +295,21 @@ void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
}
}
void LiteSession::AdjustModelOutputTensorInitRefCount(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "out_tensor is null!";
return;
}
out_tensor->set_init_ref_count(out_tensor->init_ref_count() + 1);
}
}
void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphInputTensors(model);
InitGraphInputMSTensors();
@ -303,6 +318,7 @@ void LiteSession::InitGraphInOutTensors(const lite::Model *model) {
InitGraphOutputNodeMap(model);
InitGraphOutputTensorNames(model);
InitGraphOutputTensorMap(model);
AdjustModelOutputTensorInitRefCount(model);
}
int LiteSession::CompileGraph(Model *model) {
@ -334,12 +350,9 @@ int LiteSession::CompileGraph(Model *model) {
is_running_.store(false);
return ret;
}
InitGraphInOutTensors(model);
// scheduler kernels
Scheduler scheduler(context_);
ret = scheduler.Schedule(model, &tensors_, &kernels_);
Scheduler scheduler(context_, model, tensors_);
ret = scheduler.Schedule(&kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Schedule kernels failed: " << ret;
is_running_.store(false);
@ -353,6 +366,7 @@ int LiteSession::CompileGraph(Model *model) {
}
}
#endif
InitGraphInOutTensors(model);
ret = executor_->Prepare(this->kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare executor failed: " << ret;
@ -563,6 +577,32 @@ void LiteSession::ResetInputsShape(const std::vector<std::vector<int>> &dims) {
}
}
int LiteSession::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
bool infer_shape_interrupt = false;
for (auto kernel : kernels) {
if (kernel == nullptr) {
MS_LOG(ERROR) << "input kernel is nullptr!";
return RET_ERROR;
}
if (kernel->subgraph_type() == kernel::kNotSubGraph) {
MS_LOG(ERROR) << "All node in graph should be sub_graph";
return RET_ERROR;
}
auto sub_graph = reinterpret_cast<kernel::SubGraphKernel *>(kernel);
auto ret = sub_graph->ReSize(infer_shape_interrupt);
if (ret == RET_INFER_INVALID) {
MS_LOG(INFO) << "InferShape is interrupted";
infer_shape_interrupt = true;
continue;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed";
return RET_ERROR;
}
}
return RET_OK;
}
int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<std::vector<int>> &dims) {
bool expected = false;
@ -581,11 +621,10 @@ int LiteSession::Resize(const std::vector<mindspore::tensor::MSTensor *> &inputs
return ret;
}
Scheduler scheduler(context_);
ret = scheduler.ReSizeKernels(kernels_);
ret = ReSizeKernels(kernels_);
if (ret != RET_OK) {
ResetInputsShape(old_dims);
auto resize_ret = scheduler.ReSizeKernels(kernels_);
auto resize_ret = ReSizeKernels(kernels_);
if (resize_ret != RET_OK) {
MS_LOG(ERROR) << "restore kernel size fail!ret: " << resize_ret;
}

@ -92,10 +92,14 @@ class LiteSession : public session::LiteSession {
void InitGraphOutputTensorMap(const lite::Model *model);
void AdjustModelOutputTensorInitRefCount(const lite::Model *model);
int ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims);
int PrepareKernels();
static int ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels);
private:
void ResetInputsShape(const std::vector<std::vector<int>> &dims);

@ -0,0 +1,78 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/merge.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Merge::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Merge;
}
if (this->primitive_->value.type != schema::PrimitiveType_Merge) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::MergeT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else
int Merge::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Merge();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Merge return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateMerge(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Merge, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *MergeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Merge>(primitive); }
Registry MergeRegistry(schema::PrimitiveType_Merge, MergeCreator);
#endif
int Merge::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(outputs_.size() == 1);
MS_ASSERT(inputs_.size() == 2);
outputs_[0]->set_data_type(inputs_[0]->data_type());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
#define LITE_MINDSPORE_LITE_C_OPS_MERGE_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Merge : public PrimitiveC {
public:
Merge() = default;
~Merge() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Merge, PrimitiveC);
explicit Merge(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_MERGE_H_

@ -0,0 +1,83 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/partial.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Partial::GetSubGraphIndex() const { return this->primitive_->value.AsPartial()->subGraphIndex; }
int Partial::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Partial;
}
if (this->primitive_->value.type != schema::PrimitiveType_Partial) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::PartialT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Partial::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Partial();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Partial return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreatePartial(*fbb, attr->subGraphIndex());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Partial, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Partial::GetSubGraphIndex() const { return this->primitive_->value_as_Partial()->subGraphIndex(); }
PrimitiveC *PartialCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Partial>(primitive); }
Registry PartialRegistry(schema::PrimitiveType_Partial, PartialCreator);
#endif
int Partial::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
} // namespace lite
} // namespace mindspore

@ -0,0 +1,48 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
#define LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Partial : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Partial, PrimitiveC);
Partial() = default;
explicit Partial(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Partial() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetSubGraphIndex() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_PARTIAL_H_

@ -0,0 +1,35 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateMergeParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *merge_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (merge_parameter == nullptr) {
MS_LOG(ERROR) << "malloc Merge parameter failed.";
return nullptr;
}
memset(merge_parameter, 0, sizeof(OpParameter));
merge_parameter->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(merge_parameter);
}
Registry MergeParameterRegistry(schema::PrimitiveType_Merge, PopulateMergeParameter);
} // namespace lite
} // namespace mindspore

@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/partial.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
typedef struct PartialParameter {
OpParameter op_parameter_;
int sub_graph_index_;
} PartialParameter;
OpParameter *PopulatePartialParameter(const mindspore::lite::PrimitiveC *primitive) {
PartialParameter *partial_parameter = reinterpret_cast<PartialParameter *>(malloc(sizeof(PartialParameter)));
if (partial_parameter == nullptr) {
MS_LOG(ERROR) << "malloc partial parameter failed.";
return nullptr;
}
memset(partial_parameter, 0, sizeof(PartialParameter));
partial_parameter->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Partial *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
partial_parameter->sub_graph_index_ = param->GetSubGraphIndex();
return reinterpret_cast<OpParameter *>(partial_parameter);
}
Registry PartialParameterRegistry(schema::PrimitiveType_Partial, PopulatePartialParameter);
} // namespace lite
} // namespace mindspore

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/switch.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateSwitchParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *switch_parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (switch_parameter == nullptr) {
MS_LOG(ERROR) << "malloc SwitchParameter failed.";
return nullptr;
}
memset(switch_parameter, 0, sizeof(OpParameter));
switch_parameter->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(switch_parameter);
}
Registry SwitchParameterRegistry(schema::PrimitiveType_Switch, PopulateSwitchParameter);
} // namespace lite
} // namespace mindspore

@ -155,6 +155,9 @@
#include "src/ops/tensorlistsetitem.h"
#include "src/ops/tensorlistreserve.h"
#include "src/ops/tensorliststack.h"
#include "src/ops/merge.h"
#include "src/ops/switch.h"
#include "src/ops/partial.h"
#ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h"
@ -925,7 +928,12 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) TensorListReserve(primitive);
case schema::PrimitiveType_TensorListStack:
return new (std::nothrow) TensorListStack(primitive);
case schema::PrimitiveType_Switch:
return new (std::nothrow) Switch(primitive);
case schema::PrimitiveType_Merge:
return new (std::nothrow) Merge(primitive);
case schema::PrimitiveType_Partial:
return new (std::nothrow) Partial(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive);

@ -0,0 +1,75 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/switch.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Switch::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Switch;
}
if (this->primitive_->value.type != schema::PrimitiveType_Switch) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SwitchT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Switch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Switch();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Switch return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateSwitch(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Switch, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *SwitchCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Switch>(primitive); }
Registry SwitchRegistry(schema::PrimitiveType_Switch, SwitchCreator);
#endif
int Switch::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { return RET_OK; }
} // namespace lite
} // namespace mindspore

@ -0,0 +1,47 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
#define LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Switch : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Switch, PrimitiveC);
Switch() = default;
explicit Switch(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Switch() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_SWITCH_H_

@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/merge.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Merge;
namespace mindspore::kernel {
// if one of input of merge is const-tensor, merge is always ready, this will cause error.
bool MergeCPUKernel::IsReady() {
MS_ASSERT(in_tensors().size() == 2);
return std::any_of(this->in_tensors().begin(), this->in_tensors().end(), [&](lite::Tensor *kernel_in_tensor) {
return kernel_in_tensor->IsConst() || kernel_in_tensor->ref_count() >= 1;
});
}
int MergeCPUKernel::Init() { return RET_OK; }
int MergeCPUKernel::ReSize() { return RET_ERROR; }
int MergeCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() == 2);
MS_ASSERT(out_tensors_.size() == 1);
auto out_data = out_tensors_.front()->data_c();
MS_ASSERT(out_data != nullptr);
for (size_t i = 0; i < in_tensors().size(); i++) {
if (in_tensors()[i]->data_c() != nullptr) {
auto in_data = in_tensors_[i]->data_c();
MS_ASSERT(in_data != nullptr);
memcpy(out_data, in_data, in_tensors_[i]->Size());
}
}
return RET_OK;
}
kernel::LiteKernel *CpuMergeKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr";
return nullptr;
}
if (desc.type != PrimitiveType_Merge) {
MS_LOG(ERROR) << "type in desc is not Merge";
free(parameter);
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "ctx is nullptr";
free(parameter);
return nullptr;
}
auto *kernel = new (std::nothrow) MergeCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, CpuMergeKernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, CpuMergeKernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
typedef struct MergeParameter {
OpParameter op_parameter_;
} MergeParameter;
class MergeCPUKernel : public LiteKernel {
public:
MergeCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
merge_param_ = reinterpret_cast<MergeParameter *>(op_parameter_);
}
~MergeCPUKernel() override {}
bool IsReady() override;
int Init() override;
int ReSize() override;
int Run() override;
private:
MergeParameter *merge_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_MERGE_H_

@ -0,0 +1,115 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/switch.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Switch;
namespace mindspore::kernel {
int SwitchCPUKernel::PostProcess() {
auto bool_tensor = in_tensors_.front();
MS_ASSERT(bool_tensor != nullptr);
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
MS_ASSERT(bool_tensor->shape().size() == 1);
MS_ASSERT(bool_tensor->shape().front() == 1);
auto *active = static_cast<bool *>(bool_tensor->data_c());
if (active == nullptr) {
MS_LOG(ERROR) << "data of bool tensor is nullptr";
return lite::RET_NULL_PTR;
}
size_t in_index = 1;
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
while (in_index < in_tensors_.size()) {
in_index++;
auto out_tensor = out_tensors_.at(out_index++);
out_tensor->ResetRefCount();
}
return FreeInWorkTensor();
}
int SwitchCPUKernel::Init() { return RET_OK; }
int SwitchCPUKernel::ReSize() { return RET_ERROR; }
// inputs: bool*1 data*n
// output: true-data*n, false-data*n
int SwitchCPUKernel::Run() {
MS_ASSERT(in_tensors_.size() >= 2);
MS_ASSERT(out_tensors_.size() == 2 * in_tensors_.size());
auto bool_tensor = in_tensors_.front();
MS_ASSERT(bool_tensor != nullptr);
MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
MS_ASSERT(bool_tensor->shape().size() == 1);
MS_ASSERT(bool_tensor->shape().front() == 1);
auto active = static_cast<bool *>(bool_tensor->data_c());
if (active == nullptr) {
MS_LOG(ERROR) << "data of bool tensor is nullptr";
return lite::RET_NULL_PTR;
}
size_t in_index = 1;
size_t out_index = (*active) ? 0 : (out_tensors_.size() / 2);
while (in_index < in_tensors_.size()) {
auto in_tensor = in_tensors_.at(in_index++);
auto out_tensor = out_tensors_.at(out_index++);
MS_ASSERT(in_tensor != nullptr);
MS_ASSERT(out_tensor != nullptr);
auto input = reinterpret_cast<float *>(in_tensor->data_c());
auto output = reinterpret_cast<float *>(out_tensor->data_c());
MS_ASSERT(in_tensor->Size() == out_tensor->Size());
if (input == nullptr || output == nullptr) {
MS_LOG(ERROR) << "input tensor or output tensor have not been malloced";
return lite::RET_NULL_PTR;
}
memcpy(output, input, in_tensor->Size());
}
return RET_OK;
}
kernel::LiteKernel *CpuSwitchKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr";
return nullptr;
}
if (desc.type != PrimitiveType_Switch) {
MS_LOG(ERROR) << "type in desc is not Switch";
free(parameter);
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "ctx is nullptr";
free(parameter);
return nullptr;
}
auto *kernel = new (std::nothrow) SwitchCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, name: " << parameter->name_;
free(parameter);
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, CpuSwitchKernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, CpuSwitchKernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
typedef struct SwitchParameter {
OpParameter op_parameter_;
} SwitchParameter;
class SwitchCPUKernel : public LiteKernel {
public:
SwitchCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
switch_param_ = reinterpret_cast<SwitchParameter *>(op_parameter_);
}
~SwitchCPUKernel() override = default;
int PostProcess() override;
int Init() override;
int ReSize() override;
int Run() override;
private:
SwitchParameter *switch_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SWITCH_H_

@ -71,6 +71,14 @@ int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
return RET_OK;
}
int OpenCLKernel::PostProcess() {
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
output->ResetRefCount();
}
return FreeInWorkTensor();
}
std::vector<BaseTuningParameter> OpenCLKernel::GenerateTuningParam() {
size_t ndim = global_size_.size();
std::vector<BaseTuningParameter> tuning_params = {};

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save