pull/4320/head
xuanyue 5 years ago
parent d7795494e0
commit 3a07af9633

@ -106,7 +106,7 @@ if (BUILD_CONVERTER)
include_directories(${TOP_DIR}/third_party/protobuf/build/include)
link_directories(${TOP_DIR}/third_party/protobuf/build/lib)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
add_subdirectory(src/common/anf_exporter)
add_subdirectory(src/common/anf_importer)
endif()
if (BUILD_DEVICE)

@ -25,7 +25,7 @@
#include "abstract/abstract_value.h"
#include "base/core_ops.h"
#include "mindspore/core/ir/primitive.h"
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
// #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/ir/primitive_t_value.h"
#include "src/ir/tensor.h"
#include "src/param_value_lite.h"
@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
node->name = cnode->fullname_with_scope();
node->nodeType = schema::NodeType_CNode;
// populate primitive
if (primitive != nullptr) {
primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(primitive != nullptr);
std::string opType = primitive->name();
auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (nodeParser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr;
}
std::vector<schema::TensorT *> outputs;
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
outputs.resize(abstract_cnode->size());
}
nodeParser->Parse(cnode, node.get(), &outputs);
SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
metaGraphT->nodes.emplace_back(std::move(node));
continue;
}
// if (primitive != nullptr) {
// primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
// MS_ASSERT(primitive != nullptr);
// std::string opType = primitive->name();
// auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
// if (nodeParser == nullptr) {
// MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
// return nullptr;
// }
// std::vector<schema::TensorT *> outputs;
// if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
// auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
// outputs.resize(abstract_cnode->size());
// }
//
// nodeParser->Parse(cnode, node.get(), &outputs);
// SetOpInputNode(cnode, metaGraphT.get(), node.get());
// SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
// metaGraphT->nodes.emplace_back(std::move(node));
// continue;
// }
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";

@ -1,7 +1,7 @@
file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.cc
)
add_library(anf_exporter_mid OBJECT
list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc)
add_library(anf_importer_mid OBJECT
${ANF_SRC_LIST}
)

@ -13,33 +13,33 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h"
#include "src/common/anf_importer/anf_populater/anf_activation_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ActivationT>();
if (p->name() == "ReLU") {
if (prim->name() == "ReLU") {
attr->type = schema::ActivationType_RELU;
} else if (p->name() == "Sigmoid") {
} else if (prim->name() == "Sigmoid") {
attr->type = schema::ActivationType_SIGMOID;
} else if (p->name() == "ReLU6") {
} else if (prim->name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Activation;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Activation;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater());
} // namespace mindspore::lite

@ -16,14 +16,15 @@
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfActivationPopulater : public AnfNodePopulater {
public:
AnfActivationPopulater() = default;
~AnfActivationPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,25 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h"
#include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FusedBatchNormT>();
attr->epsilon = GetValue<float>(p->GetAttr("epsilon"));
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
node->primitive->value.value = attr.release();
attr->epsilon = GetValue<float>(prim->GetAttr("epsilon"));
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser());
AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H
#define MINDSPORE_ANF_BATCHNORM_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfBatchnormParser : public AnfNodePopulater {
class AnfBatchnormPopulater : public AnfNodePopulater {
public:
AnfBatchnormParser() = default;
~AnfBatchnormParser() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
AnfBatchnormPopulater() = default;
~AnfBatchnormPopulater() override = default;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,25 +13,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h"
#include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::BiasAddT>();
attr->axis = {0};
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_BiasAdd;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_BiasAdd;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater());
AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_BIASADD_PARSER_H
#define MINDSPORE_ANF_BIASADD_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfBiasAddPopulater : public AnfNodePopulater {
public:
AnfBiasAddPopulater() = default;
~AnfBiasAddPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -16,30 +16,27 @@
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h"
#include "src/common/anf_importer/anf_populater/anf_concat_populater.h"
#include <string>
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ConcatT>();
auto prim_axis = GetValue<int>(p->GetAttr("axis"));
auto prim_axis = GetValue<int>(prim->GetAttr("axis"));
attr->axis = prim_axis;
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Concat;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Concat;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater());
AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater());
} // namespace mindspore::lite

@ -18,14 +18,15 @@
#ifndef MINDSPORE_ANF_CONCAT_PARSER_H
#define MINDSPORE_ANF_CONCAT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfConcatPopulater : public AnfNodePopulater {
public:
AnfConcatPopulater() = default;
~AnfConcatPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -16,23 +16,22 @@
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h"
#include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
#include <string>
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int group = GetValue<int>(p->GetAttr("group"));
int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
int group = GetValue<int>(prim->GetAttr("group"));
auto primitive = std::make_unique<schema::PrimitiveT>();
if (group > 1) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
attr->padMode = schema::PadMode_NOTSET;
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
} else {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(p->GetAttr("out_channel"));
attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else {
attr->padMode = schema::PadMode_NOTSET;
}
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Conv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Conv2D;
primitive->value.value = attr.release();
}
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater());
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());
} // namespace mindspore::lite

@ -18,14 +18,15 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater {
public:
AnfConvPopulater() = default;
~AnfConvPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,21 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h"
#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format"));
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads"));
auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation"));
auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size"));
auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride"));
auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode"));
auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
attr->padMode = schema::PadMode_NOTSET;
}
auto channel_multiplier = GetValue<int>(p->GetAttr("channel_multiplier"));
auto channel_multiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
attr->channelMultiplier = channel_multiplier;
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo);
MS_ASSERT(inputs.size() == kAnfPopulaterThree);
auto inputNode = inputs[kAnfPopulaterTwo];
MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>();
@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
}
}
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
public:
AnfDepwiseconv2DPopulater() = default;
~AnfDepwiseconv2DPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h"
#include "src/common/anf_importer/anf_populater/anf_dequant_populater.h"
#include <vector>
#include <string>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater());
AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
#define MINDSPORE_ANF_DEQUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfDequantPopulater : public AnfNodePopulater {
public:
AnfDequantPopulater() = default;
~AnfDequantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h"
#include "src/common/anf_importer/anf_populater/anf_flatten_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FlattenT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Flatten;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater());
AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H
#define MINDSPORE_ANF_FLATTEN_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfFlattenPopulater : public AnfNodePopulater {
public:
AnfFlattenPopulater() = default;
~AnfFlattenPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,26 +13,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h"
#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
auto p = GetCNodePrimitive(cnodePtr);
int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MatMulT>();
attr->transposeA = GetValue<bool>(p->GetAttr("transpose_a"));
attr->transposeB = GetValue<bool>(p->GetAttr("transpose_b"));
attr->transposeA = GetValue<bool>(prim->GetAttr("transpose_a"));
attr->transposeB = GetValue<bool>(prim->GetAttr("transpose_b"));
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_MatMul;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater());
AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_MATMUL_PARSER_H
#define MINDSPORE_ANF_MATMUL_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfMatmulPopulater : public AnfNodePopulater {
public:
AnfMatmulPopulater() = default;
~AnfMatmulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -13,23 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h"
#include "src/common/anf_importer/anf_populater/anf_mul_populater.h"
#include <vector>
#include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore::lite {
int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node,
std::vector<schema::TensorT *> *outputs) {
int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MulT>();
node->nodeType = schema::NodeType_CNode;
node->primitive = std::make_unique<schema::PrimitiveT>();
node->primitive->value.type = schema::PrimitiveType_Mul;
node->primitive->value.value = attr.release();
primitive->value.type = schema::PrimitiveType_Mul;
primitive->value.value = attr.release();
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0;
}
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater());
AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater());
} // namespace mindspore::lite

@ -15,14 +15,15 @@
*/
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector>
namespace mindspore::lite {
class AnfMulPopulater : public AnfNodePopulater {
public:
AnfMulPopulater() = default;
~AnfMulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override;
int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
};
} // namespace mindspore::lite

@ -14,6 +14,6 @@
* limitations under the License.
*/
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h"
#include "src/common/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {} // namespace mindspore::lite

@ -19,6 +19,7 @@
#include <vector>
#include "ir/anf.h"
#include "src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
constexpr int kAnfPopulaterOne = 1;
@ -28,7 +29,9 @@ class AnfNodePopulater {
public:
AnfNodePopulater() = default;
virtual ~AnfNodePopulater() = default;
virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) = 0;
virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) = 0;
};
} // namespace mindspore::lite

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save