pull/4320/head
xuanyue 5 years ago
parent d7795494e0
commit 3a07af9633

@ -106,7 +106,7 @@ if (BUILD_CONVERTER)
include_directories(${TOP_DIR}/third_party/protobuf/build/include) include_directories(${TOP_DIR}/third_party/protobuf/build/include)
link_directories(${TOP_DIR}/third_party/protobuf/build/lib) link_directories(${TOP_DIR}/third_party/protobuf/build/lib)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
add_subdirectory(src/common/anf_exporter) add_subdirectory(src/common/anf_importer)
endif() endif()
if (BUILD_DEVICE) if (BUILD_DEVICE)

@ -25,7 +25,7 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "mindspore/core/ir/primitive.h" #include "mindspore/core/ir/primitive.h"
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" // #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h"
#include "src/ir/primitive_t_value.h" #include "src/ir/primitive_t_value.h"
#include "src/ir/tensor.h" #include "src/ir/tensor.h"
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
@ -148,27 +148,27 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
node->name = cnode->fullname_with_scope(); node->name = cnode->fullname_with_scope();
node->nodeType = schema::NodeType_CNode; node->nodeType = schema::NodeType_CNode;
// populate primitive // populate primitive
if (primitive != nullptr) { // if (primitive != nullptr) {
primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); // primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_ASSERT(primitive != nullptr); // MS_ASSERT(primitive != nullptr);
std::string opType = primitive->name(); // std::string opType = primitive->name();
auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType);
if (nodeParser == nullptr) { // if (nodeParser == nullptr) {
MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType;
return nullptr; // return nullptr;
} // }
std::vector<schema::TensorT *> outputs; // std::vector<schema::TensorT *> outputs;
if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { // if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) {
auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); // auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract());
outputs.resize(abstract_cnode->size()); // outputs.resize(abstract_cnode->size());
} // }
//
nodeParser->Parse(cnode, node.get(), &outputs); // nodeParser->Parse(cnode, node.get(), &outputs);
SetOpInputNode(cnode, metaGraphT.get(), node.get()); // SetOpInputNode(cnode, metaGraphT.get(), node.get());
SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get());
metaGraphT->nodes.emplace_back(std::move(node)); // metaGraphT->nodes.emplace_back(std::move(node));
continue; // continue;
} // }
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) { if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; MS_LOG(ERROR) << "PrimitiveT_value is nullptr";

@ -1,7 +1,7 @@
file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.cc *.cc
) )
add_library(anf_exporter_mid OBJECT list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc)
add_library(anf_importer_mid OBJECT
${ANF_SRC_LIST} ${ANF_SRC_LIST}
) )

@ -13,33 +13,33 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_activation_populater.h" #include "src/common/anf_importer/anf_populater/anf_activation_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfActivationPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfActivationPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ActivationT>(); auto attr = std::make_unique<schema::ActivationT>();
if (p->name() == "ReLU") { if (prim->name() == "ReLU") {
attr->type = schema::ActivationType_RELU; attr->type = schema::ActivationType_RELU;
} else if (p->name() == "Sigmoid") { } else if (prim->name() == "Sigmoid") {
attr->type = schema::ActivationType_SIGMOID; attr->type = schema::ActivationType_SIGMOID;
} else if (p->name() == "ReLU6") { } else if (prim->name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6; attr->type = schema::ActivationType_RELU6;
} }
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_Activation;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_Activation; MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.value = attr.release(); primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfReLUParser("ReLU", new AnfActivationPopulater()); AnfNodePopulaterRegistrar anfReLUPopulater("ReLU", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfReLU6Parser("ReLU6", new AnfActivationPopulater()); AnfNodePopulaterRegistrar anfReLU6Populater("ReLU6", new AnfActivationPopulater());
AnfNodePopulaterRegistrar anfSigmoidParser("Sigmoid", new AnfActivationPopulater()); AnfNodePopulaterRegistrar anfSigmoidPopulater("Sigmoid", new AnfActivationPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -16,14 +16,15 @@
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfActivationPopulater : public AnfNodePopulater { class AnfActivationPopulater : public AnfNodePopulater {
public: public:
AnfActivationPopulater() = default; AnfActivationPopulater() = default;
~AnfActivationPopulater() override = default; ~AnfActivationPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,25 +13,24 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_batchnorm_populater.h" #include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfBatchnormParser::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfBatchnormPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FusedBatchNormT>(); auto attr = std::make_unique<schema::FusedBatchNormT>();
attr->epsilon = GetValue<float>(p->GetAttr("epsilon")); attr->epsilon = GetValue<float>(prim->GetAttr("epsilon"));
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
node->nodeType = schema::NodeType_CNode; primitive->value.value = attr.release();
node->primitive = std::make_unique<schema::PrimitiveT>(); MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.type = schema::PrimitiveType_FusedBatchNorm; primitiveTValuePtr->SetPrimitiveT(primitive.release());
node->primitive->value.value = attr.release();
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfBatchnormParser("BatchNorm", new AnfBatchnormParser()); AnfNodePopulaterRegistrar anfBatchnormPopulater("BatchNorm", new AnfBatchnormPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H
#define MINDSPORE_ANF_BATCHNORM_PARSER_H #define MINDSPORE_ANF_BATCHNORM_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfBatchnormParser : public AnfNodePopulater { class AnfBatchnormPopulater : public AnfNodePopulater {
public: public:
AnfBatchnormParser() = default; AnfBatchnormPopulater() = default;
~AnfBatchnormParser() override = default; ~AnfBatchnormPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,25 +13,25 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_biasadd_populater.h" #include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfBiasAddPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfBiasAddPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::BiasAddT>(); auto attr = std::make_unique<schema::BiasAddT>();
attr->axis = {0}; attr->axis = {0};
primitive->value.type = schema::PrimitiveType_BiasAdd;
node->nodeType = schema::NodeType_CNode; primitive->value.value = attr.release();
node->primitive = std::make_unique<schema::PrimitiveT>(); MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.type = schema::PrimitiveType_BiasAdd; primitiveTValuePtr->SetPrimitiveT(primitive.release());
node->primitive->value.value = attr.release();
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfBiasAddParser("BiasAdd", new AnfBiasAddPopulater()); AnfNodePopulaterRegistrar anfBiasAddPopulater("BiasAdd", new AnfBiasAddPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_BIASADD_PARSER_H #ifndef MINDSPORE_ANF_BIASADD_PARSER_H
#define MINDSPORE_ANF_BIASADD_PARSER_H #define MINDSPORE_ANF_BIASADD_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfBiasAddPopulater : public AnfNodePopulater { class AnfBiasAddPopulater : public AnfNodePopulater {
public: public:
AnfBiasAddPopulater() = default; AnfBiasAddPopulater() = default;
~AnfBiasAddPopulater() override = default; ~AnfBiasAddPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -16,30 +16,27 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_concat_populater.h" #include "src/common/anf_importer/anf_populater/anf_concat_populater.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfConcatPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfConcatPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::ConcatT>(); auto attr = std::make_unique<schema::ConcatT>();
auto prim_axis = GetValue<int>(prim->GetAttr("axis"));
auto prim_axis = GetValue<int>(p->GetAttr("axis"));
attr->axis = prim_axis; attr->axis = prim_axis;
primitive->value.type = schema::PrimitiveType_Concat;
node->nodeType = schema::NodeType_CNode; primitive->value.value = attr.release();
node->primitive = std::make_unique<schema::PrimitiveT>(); MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.type = schema::PrimitiveType_Concat; primitiveTValuePtr->SetPrimitiveT(primitive.release());
node->primitive->value.value = attr.release();
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfConcatParser("Concat", new AnfConcatPopulater()); AnfNodePopulaterRegistrar anfConcatPopulater("Concat", new AnfConcatPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -18,14 +18,15 @@
#ifndef MINDSPORE_ANF_CONCAT_PARSER_H #ifndef MINDSPORE_ANF_CONCAT_PARSER_H
#define MINDSPORE_ANF_CONCAT_PARSER_H #define MINDSPORE_ANF_CONCAT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfConcatPopulater : public AnfNodePopulater { class AnfConcatPopulater : public AnfNodePopulater {
public: public:
AnfConcatPopulater() = default; AnfConcatPopulater() = default;
~AnfConcatPopulater() override = default; ~AnfConcatPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtrr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -16,23 +16,22 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_conv_populater.h" #include "src/common/anf_importer/anf_populater/anf_conv_populater.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); int group = GetValue<int>(prim->GetAttr("group"));
int group = GetValue<int>(p->GetAttr("group")); auto primitive = std::make_unique<schema::PrimitiveT>();
if (group > 1) { if (group > 1) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>(); auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format")); auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
} else if (format == "NHWC") { } else if (format == "NHWC") {
@ -40,25 +39,25 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID; attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") { } else if (pad_mode == "same") {
@ -67,14 +66,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive->value.value = attr.release();
} else { } else {
auto attr = std::make_unique<schema::Conv2DT>(); auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group; attr->group = group;
auto format = GetValue<std::string>(p->GetAttr("data_format")); auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
} else if (format == "NHWC") { } else if (format == "NHWC") {
@ -82,27 +79,27 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pad_list")); auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pad_list"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
attr->channelOut = GetValue<int>(p->GetAttr("out_channel")); attr->channelOut = GetValue<int>(prim->GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID; attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") { } else if (pad_mode == "same") {
@ -110,12 +107,12 @@ int mindspore::lite::AnfConvPopulater::Parse(mindspore::CNodePtr cnodePtr, schem
} else { } else {
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.type = schema::PrimitiveType_Conv2D;
node->primitive->value.type = schema::PrimitiveType_Conv2D; primitive->value.value = attr.release();
node->primitive->value.value = attr.release();
} }
MS_ASSERT(primitiveTValuePtr != nullptr);
primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfConvPopulater("Conv2D", new AnfConvPopulater());
AnfNodePopulaterRegistrar anfConvParser("Conv2D", new AnfConvPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -18,14 +18,15 @@
#ifndef MINDSPORE_ANF_CONV_PARSER_H #ifndef MINDSPORE_ANF_CONV_PARSER_H
#define MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfConvPopulater : public AnfNodePopulater { class AnfConvPopulater : public AnfNodePopulater {
public: public:
AnfConvPopulater() = default; AnfConvPopulater() = default;
~AnfConvPopulater() override = default; ~AnfConvPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,21 +13,21 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_depthwiseconv2d_populater.h" #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h"
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::DepthwiseConv2DT>(); auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(p->GetAttr("data_format")); auto format = GetValue<std::string>(prim->GetAttr("data_format"));
if (format == "NCHW") { if (format == "NCHW") {
attr->format = schema::Format_NCHW; attr->format = schema::Format_NCHW;
} else if (format == "NHWC") { } else if (format == "NHWC") {
@ -35,25 +35,25 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
} else { } else {
attr->format = schema::Format_NUM_OF_FORMAT; attr->format = schema::Format_NUM_OF_FORMAT;
} }
auto pad_list = GetValue<std::vector<int>>(p->GetAttr("pads")); auto pad_list = GetValue<std::vector<int>>(prim->GetAttr("pads"));
attr->padUp = pad_list[0]; attr->padUp = pad_list[0];
attr->padDown = pad_list[1]; attr->padDown = pad_list[1];
attr->padLeft = pad_list[2]; attr->padLeft = pad_list[2];
attr->padRight = pad_list[3]; attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(p->GetAttr("dilation")); auto dilation = GetValue<std::vector<int>>(prim->GetAttr("dilation"));
attr->dilateH = dilation[0]; attr->dilateH = dilation[0];
attr->dilateW = dilation[1]; attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(p->GetAttr("kernel_size")); auto kernel_size = GetValue<std::vector<int>>(prim->GetAttr("kernel_size"));
attr->kernelH = kernel_size[0]; attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1]; attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(p->GetAttr("stride")); auto stride = GetValue<std::vector<int>>(prim->GetAttr("stride"));
attr->strideH = stride[2]; attr->strideH = stride[2];
attr->strideW = stride[3]; attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(p->GetAttr("pad_mode")); auto pad_mode = GetValue<std::string>(prim->GetAttr("pad_mode"));
if (pad_mode == "valid") { if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID; attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") { } else if (pad_mode == "same") {
@ -62,11 +62,11 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
attr->padMode = schema::PadMode_NOTSET; attr->padMode = schema::PadMode_NOTSET;
} }
auto channel_multiplier = GetValue<int>(p->GetAttr("channel_multiplier")); auto channel_multiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
attr->channelMultiplier = channel_multiplier; attr->channelMultiplier = channel_multiplier;
MS_ASSERT(cnodePtr->size() == kAnfPopulaterThree); MS_ASSERT(inputs.size() == kAnfPopulaterThree);
auto inputNode = cnodePtr->input(kAnfPopulaterTwo); auto inputNode = inputs[kAnfPopulaterTwo];
MS_ASSERT(inputNode != nullptr); MS_ASSERT(inputNode != nullptr);
if (inputNode->isa<Parameter>()) { if (inputNode->isa<Parameter>()) {
auto paramNode = inputNode->cast<ParameterPtr>(); auto paramNode = inputNode->cast<ParameterPtr>();
@ -82,12 +82,12 @@ int mindspore::lite::AnfDepwiseconv2DPopulater::Parse(mindspore::CNodePtr cnodeP
} }
} }
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.value = attr.release(); primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfdepthwise2dParser("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater());
AnfNodePopulaterRegistrar anfdepthwise2dnativeParser("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfDepwiseconv2DPopulater : public AnfNodePopulater { class AnfDepwiseconv2DPopulater : public AnfNodePopulater {
public: public:
AnfDepwiseconv2DPopulater() = default; AnfDepwiseconv2DPopulater() = default;
~AnfDepwiseconv2DPopulater() override = default; ~AnfDepwiseconv2DPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_dequant_populater.h" #include "src/common/anf_importer/anf_populater/anf_dequant_populater.h"
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfDequantPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfDequantPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>(); auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.value = attr.release(); primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfDequantParser("Dequant", new AnfDequantPopulater()); AnfNodePopulaterRegistrar anfDequantPopulater("Dequant", new AnfDequantPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_DEQUANT_PARSER_H #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H
#define MINDSPORE_ANF_DEQUANT_PARSER_H #define MINDSPORE_ANF_DEQUANT_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfDequantPopulater : public AnfNodePopulater { class AnfDequantPopulater : public AnfNodePopulater {
public: public:
AnfDequantPopulater() = default; AnfDequantPopulater() = default;
~AnfDequantPopulater() override = default; ~AnfDequantPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,23 +13,24 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_flatten_populater.h" #include "src/common/anf_importer/anf_populater/anf_flatten_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfFlattenPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfFlattenPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::FlattenT>(); auto attr = std::make_unique<schema::FlattenT>();
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_Flatten;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_Flatten; MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.value = attr.release(); primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfFlattenParser("Flatten", new AnfFlattenPopulater()); AnfNodePopulaterRegistrar anfFlattenPopulater("Flatten", new AnfFlattenPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_FLATTEN_PARSER_H #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H
#define MINDSPORE_ANF_FLATTEN_PARSER_H #define MINDSPORE_ANF_FLATTEN_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfFlattenPopulater : public AnfNodePopulater { class AnfFlattenPopulater : public AnfNodePopulater {
public: public:
AnfFlattenPopulater() = default; AnfFlattenPopulater() = default;
~AnfFlattenPopulater() override = default; ~AnfFlattenPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,26 +13,26 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_matmul_populater.h" #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfMatmulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto p = GetCNodePrimitive(cnodePtr); auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MatMulT>(); auto attr = std::make_unique<schema::MatMulT>();
attr->transposeA = GetValue<bool>(p->GetAttr("transpose_a")); attr->transposeA = GetValue<bool>(prim->GetAttr("transpose_a"));
attr->transposeB = GetValue<bool>(p->GetAttr("transpose_b")); attr->transposeB = GetValue<bool>(prim->GetAttr("transpose_b"));
node->nodeType = schema::NodeType_CNode; primitive->value.type = schema::PrimitiveType_MatMul;
node->primitive = std::make_unique<schema::PrimitiveT>(); primitive->value.value = attr.release();
node->primitive->value.type = schema::PrimitiveType_MatMul; MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.value = attr.release(); primitiveTValuePtr->SetPrimitiveT(primitive.release());
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfMatmulParser("MatMul", new AnfMatmulPopulater()); AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_MATMUL_PARSER_H #ifndef MINDSPORE_ANF_MATMUL_PARSER_H
#define MINDSPORE_ANF_MATMUL_PARSER_H #define MINDSPORE_ANF_MATMUL_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfMatmulPopulater : public AnfNodePopulater { class AnfMatmulPopulater : public AnfNodePopulater {
public: public:
AnfMatmulPopulater() = default; AnfMatmulPopulater() = default;
~AnfMatmulPopulater() override = default; ~AnfMatmulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -13,23 +13,23 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_mul_populater.h" #include "src/common/anf_importer/anf_populater/anf_mul_populater.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
namespace mindspore::lite { namespace mindspore::lite {
int mindspore::lite::AnfMulPopulater::Parse(mindspore::CNodePtr cnodePtr, schema::CNodeT *node, int AnfMulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
std::vector<schema::TensorT *> *outputs) { const std::vector<AnfNodePtr> &inputs) {
auto primitive = std::make_unique<schema::PrimitiveT>();
auto attr = std::make_unique<schema::MulT>(); auto attr = std::make_unique<schema::MulT>();
primitive->value.type = schema::PrimitiveType_Mul;
node->nodeType = schema::NodeType_CNode; primitive->value.value = attr.release();
node->primitive = std::make_unique<schema::PrimitiveT>(); MS_ASSERT(primitiveTValuePtr != nullptr);
node->primitive->value.type = schema::PrimitiveType_Mul; primitiveTValuePtr->SetPrimitiveT(primitive.release());
node->primitive->value.value = attr.release();
return 0; return 0;
} }
AnfNodePopulaterRegistrar anfMulParser("Mul", new AnfMulPopulater()); AnfNodePopulaterRegistrar anfMulPopulater("Mul", new AnfMulPopulater());
} // namespace mindspore::lite } // namespace mindspore::lite

@ -15,14 +15,15 @@
*/ */
#ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H
#define MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
#include <vector> #include <vector>
namespace mindspore::lite { namespace mindspore::lite {
class AnfMulPopulater : public AnfNodePopulater { class AnfMulPopulater : public AnfNodePopulater {
public: public:
AnfMulPopulater() = default; AnfMulPopulater() = default;
~AnfMulPopulater() override = default; ~AnfMulPopulater() override = default;
int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) override; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) override;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

@ -14,6 +14,6 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/common/anf_exporter/anf_populater/anf_node_populater.h" #include "src/common/anf_importer/anf_populater/anf_node_populater.h"
namespace mindspore::lite {} // namespace mindspore::lite namespace mindspore::lite {} // namespace mindspore::lite

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "ir/anf.h" #include "ir/anf.h"
#include "src/ir/primitive_t_value.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
namespace mindspore::lite { namespace mindspore::lite {
constexpr int kAnfPopulaterOne = 1; constexpr int kAnfPopulaterOne = 1;
@ -28,7 +29,9 @@ class AnfNodePopulater {
public: public:
AnfNodePopulater() = default; AnfNodePopulater() = default;
virtual ~AnfNodePopulater() = default; virtual ~AnfNodePopulater() = default;
virtual int Parse(CNodePtr cnodePtr, schema::CNodeT *node, std::vector<schema::TensorT *> *outputs) = 0;
virtual int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr,
const std::vector<AnfNodePtr> &inputs) = 0;
}; };
} // namespace mindspore::lite } // namespace mindspore::lite

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save