!3687 keep lite import with mindspore export

Merge pull request !3687 from yankai10/merge123
pull/3687/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 68666bd35e

@ -0,0 +1,35 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_reshape_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfReshapePopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto attr = std::make_unique<schema::FlattenT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.value = attr.release();
return 0;
}
AnfNodePopulaterRegistrar anfReshapeParser("Reshape", new AnfReshapePopulater());
} // namespace mindspore::lite

@ -0,0 +1,30 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_ANF_RESHAPE_PARSER_H
#define MINDSPORE_ANF_RESHAPE_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfReshapePopulater : public AnfNodePopulater {
public:
AnfReshapePopulater() = default;
~AnfReshapePopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_ANF_RESHAPE_PARSER_H

@ -47,6 +47,7 @@ class AnfImporterFromProtobuf : public AnfImporter {
bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
#if 0
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
@ -76,6 +77,30 @@ class AnfImporterFromProtobuf : public AnfImporter {
const onnx::TensorProto &attr_tensor);
std::unordered_map<std::string, abstract::AbstractTensorPtr>
GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
#endif
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const CNodePtr &cnode_ptr);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name,
const onnx::TensorProto &attr_tensor);
bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name,
const onnx::TensorProto &attr_tensor);
bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor);
abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto);
private:
std::string producer_name_;

Loading…
Cancel
Save