add onnx loop if support

pull/10373/head
zhengjun10 4 years ago
parent 2645ed3c90
commit 401d42a103

@ -264,6 +264,7 @@ union PrimitiveType {
If,
GeLU,
Gru,
NonZero,
}
enum QuantType: int {

@ -236,7 +236,8 @@ union PrimitiveType {
LpNormalization,
DropoutGrad,
MaximumGrad,
MinimumGrad
MinimumGrad,
NonZero,
}
enum QuantType: int {

@ -1241,3 +1241,6 @@ table Merge {
table GeLU {
approximate : bool = false;
}
table NonZero {
}

@ -1143,3 +1143,6 @@ table LpNormalization {
axis : int;
p : int;
}
table NonZero {
}

@ -0,0 +1,124 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/nonzero.h"
#include <algorithm>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int NonZero::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_NonZero;
}
if (this->primitive_->value.type != schema::PrimitiveType_NonZero) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::NonZeroT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
PopulaterQuantParam(prim, inputs);
return RET_OK;
}
#else
int NonZero::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_NonZero();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_NonZero return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateNonZero(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonZero, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *NonZeroCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NonZero>(primitive); }
Registry NonZeroRegistry(schema::PrimitiveType_NonZero, NonZeroCreator);
#endif
template <typename T>
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) {
int input_count = inputs[0]->ElementsNum();
int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size();
(*out_shape)[0] = input_dim_size;
int nonzero_size = 0;
for (int i = 0; i < input_count; i++) {
if (static_cast<int>(data[i]) != 0) {
nonzero_size++;
}
}
if (nonzero_size == 0) {
*out_shape = {};
} else {
(*out_shape)[1] = nonzero_size / input_dim_size;
}
}
int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
MS_ASSERT(inputs_.size() == 1);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input->data_type());
output->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;
}
std::vector<int> out_shape;
if (inputs_.size() == kSingleNum) {
auto input_tensor = inputs_.at(0);
if (input_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
switch (input_tensor->data_type()) {
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(input_tensor->MutableData());
CalShape<float>(data, inputs_, &out_shape);
} break;
default: {
MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type();
return RET_INFER_ERR;
}
}
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return RET_INFER_ERR;
}
output->set_shape(out_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,45 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_NONZERO_H_
#define MINDSPORE_LITE_SRC_OPS_NONZERO_H_
#include <cmath>
#include <memory>
#include <set>
#include <vector>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class NonZero : public PrimitiveC {
public:
NonZero() = default;
~NonZero() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NonZero, PrimitiveC);
explicit NonZero(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_NONZERO_H_

@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/nonzero_fp32.h"
#include "include/errorcode.h"
#include "nnacl/op_base.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/tensor.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_NonZero;
namespace mindspore::kernel {
int NonZeroCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int NonZeroCPUKernel::ReSize() { return RET_OK; }
int NonZeroCPUKernel::Run() {
auto in_tensor = in_tensors_.front();
auto out_tensor = out_tensors_.front();
auto input_data = reinterpret_cast<float *>(in_tensor->MutableData());
auto output_data = reinterpret_cast<int *>(out_tensor->MutableData());
auto input_dim_size = in_tensor->shape().size();
if (out_tensor->shape().size() != 2) {
MS_LOG(ERROR) << "out tensor shape size must be equal to 2!";
return RET_ERROR;
}
auto non_zero_nums = out_tensor->shape()[1];
int non_zero_count = 0;
std::vector coordiate_values(in_tensor->shape().size(), 0);
for (int i = 0; i < in_tensor->ElementsNum(); i += 1) {
if (input_data[i] != 0) {
for (size_t j = 0; j < input_dim_size; j++) {
output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j];
}
non_zero_count++;
}
for (int idx = input_dim_size - 1; idx >= 0; --idx) {
if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) {
coordiate_values[idx] = coordiate_values[idx] + 1;
break;
}
coordiate_values[idx] = 0;
}
}
return RET_OK;
}
kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "Input context is nullptr!";
free(opParameter);
return nullptr;
}
if (ctx->thread_num_ == 0) {
MS_LOG(ERROR) << "context thread num is 0!";
free(opParameter);
return nullptr;
}
auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new NonZeroCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
#include <vector>
#include "src/lite_kernel.h"
namespace mindspore::kernel {
class NonZeroCPUKernel : public LiteKernel {
public:
NonZeroCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~NonZeroCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
protected:
int thread_count_ = 1;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_

@ -100,7 +100,8 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
}
}
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF) {
if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF ||
config->fmk == lite::converter::FmkType_ONNX) {
graph_pm->AddPass(std::make_shared<opt::WhilePass>());
graph_pm->AddPass(std::make_shared<opt::IfPass>());
}

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_if_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxIfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx IfParser";
auto attr = std::make_unique<schema::IfT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_If;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxIfParser("If", new OnnxIfParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxIfParser : public OnnxNodeParser {
public:
OnnxIfParser() : OnnxNodeParser("If") {}
~OnnxIfParser() override = default;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_loop_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxLoopParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx LoopParser";
auto attr = std::make_unique<schema::WhileT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_While;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxLoopParser("Loop", new OnnxLoopParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxLoopParser : public OnnxNodeParser {
public:
OnnxLoopParser() : OnnxNodeParser("Loop") {}
~OnnxLoopParser() override = default;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H

@ -54,29 +54,60 @@ class OnnxModelParser : public ModelParser {
private:
STATUS InitOriginModel(const std::string &model_file);
STATUS ConvertNodes();
STATUS ConvertConstTensors();
STATUS ConvertGraphInputs();
STATUS ConvertGraphOutputs();
STATUS BuildReturnNode(const std::vector<AnfNodePtr> &return_inputs);
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
const std::string &root_node_name);
STATUS ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name);
STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map);
STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *nodes_map);
STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map);
STATUS BuildReturnNode(const FuncGraphPtr &func_graph_ptr, const std::vector<AnfNodePtr> &return_inputs);
STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor);
STATUS BuildParameterNodeForQuantParam(void *data, const std::string &name, TypeId type);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c, const std::string &name);
STATUS BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
lite::PrimitiveC *primitive_c, std::string loop_name);
STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
lite::PrimitiveC *primitive_c);
STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c);
STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, lite::PrimitiveC *primitive_c,
const std::string &name);
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, lite::PrimitiveC *primitive_c);
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);
STATUS ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
const std::string &root_node_name);
STATUS AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
const std::string &loop_node_name, std::vector<AnfNodePtr> *body_graph_inputs,
int act_output_num);
STATUS BuildCondGraph(const FuncGraphPtr &cond_graph, const AnfNodePtr &root_while_node, int inputs_num,
const std::string &cond_graph_name);
STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
const std::string &subgrah_name, const std::string &if_node_name,
const std::string &root_node_name);
onnx::ModelProto onnx_model_;
onnx::GraphProto onnx_graph_;
std::unordered_map<std::string, AnfNodePtr> nodes_;
FuncGraphPtr func_graph_ptr_ = nullptr;
onnx::GraphProto onnx_root_graph_;
std::vector<FuncGraphPtr> all_subgraphs_;
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
FuncGraphPtr anf_root_graph_ = nullptr;
};
} // namespace lite
} // namespace mindspore

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/onnx/onnx_nonzero_parser.h"
#include <memory>
#include "tools/converter/parser/onnx/onnx_model_parser.h"
namespace mindspore {
namespace lite {
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node) {
MS_LOG(DEBUG) << "onnx NonZeroParser";
auto attr = std::make_unique<schema::NonZeroT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return nullptr;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "new primitive failed";
return nullptr;
}
primitive->value.type = schema::PrimitiveType_NonZero;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
}
OnnxNodeRegistrar g_onnxNonZeroParser("NonZero", new OnnxNonZeroParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
namespace mindspore {
namespace lite {
class OnnxNonZeroParser : public OnnxNodeParser {
public:
OnnxNonZeroParser() : OnnxNodeParser("NonZero") {}
~OnnxNonZeroParser() override = default;
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H

@ -96,6 +96,19 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) {
status = ReplaceIdentity(node, manager);
} else if (type == schema::PrimitiveType_TupleGetItem) {
status = ReplaceTupleGetItem(node, manager);
} else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)Run(sub_func_graph);
sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
if (sub_func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
(void)Run(sub_func_graph);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "remove identity pass is failed.";

@ -296,6 +296,45 @@ STATUS OnnxInputAdjustOpPass::AdjustStridedSlice(const FuncGraphPtr &func_graph,
return lite::RET_OK;
}
STATUS OnnxInputAdjustOpPass::AdjustResize(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto node = cnode->input(0);
MS_ASSERT(value_node != nullptr);
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "cnode input0 is not a valuenode.";
return lite::RET_ERROR;
}
MS_ASSERT(value_node->value() != nullptr);
auto primitive_c = value_node->value()->cast<PrimitiveCPtr>();
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "cnode has no primitive_c.";
return lite::RET_ERROR;
}
auto primitive = primitive_c->primitiveT();
if (primitive == nullptr) {
MS_LOG(ERROR) << "cnode has no schema::primitive.";
return lite::RET_ERROR;
}
if (primitive->value.type != schema::PrimitiveType_Resize) {
MS_LOG(DEBUG) << "cnode is not cast node.";
return RET_OK;
}
auto value = primitive->value.value;
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr.";
return lite::RET_ERROR;
}
auto attr = reinterpret_cast<schema::ResizeT *>(value);
if (cnode->inputs().size() > 3 &&
attr->coordinateTransformMode == schema::CoordinateTransformMode_TF_CROP_AND_RESIZE) {
auto new_resize_inputs = cnode->inputs();
new_resize_inputs.erase(new_resize_inputs.begin() + 1);
cnode->set_inputs(new_resize_inputs);
}
return lite::RET_OK;
}
STATUS OnnxInputAdjustOpPass::AdjustConvOrDeConv(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (!CheckInputs(cnode)) {

@ -40,6 +40,7 @@ class OnnxInputAdjustOpPass : public Pass {
STATUS AdjustConvOrDeConv(const CNodePtr &cnode);
STATUS AdjustTile(const CNodePtr &cnode);
STATUS AdjustCast(const CNodePtr &cnode);
STATUS AdjustResize(const CNodePtr &cnode);
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override;

Loading…
Cancel
Save