diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2663323898..5ef406aec6 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -480,6 +480,8 @@ table DeConv2D { dilateH: int; hasBias: bool = false; // DEPRECATED activationType: ActivationType = 0; + outputPaddingW: int; + outputPaddingH: int; } table DeConv2DGradFilter { diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index d6878602ed..fb3e757596 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -47,6 +47,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value.AsDeConv2D()- int DeConv2D::GetDilateW() const { return this->primitive_->value.AsDeConv2D()->dilateW; } int DeConv2D::GetDilateH() const { return this->primitive_->value.AsDeConv2D()->dilateH; } int DeConv2D::GetActivationType() const { return this->primitive_->value.AsDeConv2D()->activationType; } +int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value.AsDeConv2D()->outputPaddingW; } +int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value.AsDeConv2D()->outputPaddingH; } void DeConv2D::SetFormat(int format) { this->primitive_->value.AsDeConv2D()->format = (schema::Format)format; } void DeConv2D::SetGroup(int group) { this->primitive_->value.AsDeConv2D()->group = group; } @@ -295,6 +297,8 @@ int DeConv2D::GetPadRight() const { return this->primitive_->value_as_DeConv2D() int DeConv2D::GetDilateW() const { return this->primitive_->value_as_DeConv2D()->dilateW(); } int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()->dilateH(); } int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); } +int DeConv2D::GetOutputPaddingW() const { return this->primitive_->value_as_DeConv2D()->outputPaddingW(); } +int DeConv2D::GetOutputPaddingH() const { return this->primitive_->value_as_DeConv2D()->outputPaddingH(); } PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); @@ -347,6 +351,8 @@ int DeConv2D::InferShape(std::vector inputs_, std::vector out_shape = {output_n, output_h, output_w, output_c}; output->set_shape(out_shape); diff --git a/mindspore/lite/src/ops/deconv2d.h b/mindspore/lite/src/ops/deconv2d.h index 011ab1b4db..5ffe92b83a 100644 --- a/mindspore/lite/src/ops/deconv2d.h +++ b/mindspore/lite/src/ops/deconv2d.h @@ -71,7 +71,8 @@ class DeConv2D : public PrimitiveC { int GetDilateW() const; int GetDilateH() const; int GetActivationType() const; - + int GetOutputPaddingW() const; + int GetOutputPaddingH() const; int PadUp() const { return this->pad_u_; } int PadDown() const { return this->pad_d_; } int PadLeft() const { return this->pad_l_; } diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 7e4fe5d11e..28407c2d75 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -59,62 +59,8 @@ AnfTransform::AnfTransform() = default; AnfTransform::~AnfTransform() = default; -FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { - MS_ASSERT(nullptr != old_graph); - if (config == nullptr) { - MS_LOG(ERROR) << "config should be specified"; - return nullptr; - } - if (old_graph->has_flag("HasTransformed")) { - old_graph->set_flag("HasTransformed", false); - return old_graph; - } - auto optimizer = std::make_shared(); +int AnfTransform::AddFusionPass(const std::shared_ptr &optimizer, const converter::Flags *config) { auto fusion_pm = std::make_shared("anf fusion pass manager", false); - auto graph_pm = std::make_shared("anf graph pass manager", true); - auto convert_pm = std::make_shared("anf graph convert pass manager", true); - - if (config->fmk == converter::FmkType_MS) { - auto mindir_adjust_pass = std::make_shared(); - mindir_adjust_pass->SetFmkType(config->fmk); - mindir_adjust_pass->SetQuantType(config->quantType); - if (!mindir_adjust_pass->Run(old_graph)) { - MS_LOG(ERROR) << "mindir adjust failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } - auto mindir_inputs_adjust_pass = std::make_shared(); - if (!mindir_inputs_adjust_pass->Run(old_graph)) { - MS_LOG(ERROR) << "mindir inputs adjust failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } - } - - // onnx pre adjustment - if (config->fmk == converter::FmkType_ONNX) { - auto onnx_adjust_pass = std::make_shared(); - if (!onnx_adjust_pass->Run(old_graph)) { - MS_LOG(ERROR) << "onnx adjust failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } - } - - if (config->fmk == lite::converter::FmkType_TF) { - auto functionalize_control_op_pass = std::make_shared(); - if (!functionalize_control_op_pass->Run(old_graph)) { - MS_LOG(ERROR) << "functionalize control op pass failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; - } - } - - if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || - config->fmk == lite::converter::FmkType_ONNX) { - graph_pm->AddPass(std::make_shared()); - graph_pm->AddPass(std::make_shared()); - } // for now - training is not supporting fuse operations if (!config->trainModel) { @@ -137,26 +83,11 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); } - auto weight_format_hardcode_pass = std::make_shared(); - weight_format_hardcode_pass->SetFmkType(config->fmk); - weight_format_hardcode_pass->SetQuantType(config->quantType); - graph_pm->AddPass(weight_format_hardcode_pass); - auto weight_format_transform_pass = std::make_shared(); - weight_format_transform_pass->SetFmkType(config->fmk); - weight_format_transform_pass->SetQuantType(config->quantType); - graph_pm->AddPass(weight_format_transform_pass); - auto infershape_pass = std::make_shared(); - infershape_pass->SetFmkType(config->fmk); - graph_pm->AddPass(infershape_pass); - auto slice_prepose_pass = std::make_shared(); - slice_prepose_pass->SetFmkType(config->fmk); - graph_pm->AddPass(slice_prepose_pass); - if (config->fmk == lite::converter::FmkType_MS) { auto remove_unused_cast_pass = std::make_shared(); if (remove_unused_cast_pass == nullptr) { MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified"; - return nullptr; + return RET_ERROR; } remove_unused_cast_pass->SetFmkType(config->fmk); fusion_pm->AddPass(remove_unused_cast_pass); @@ -165,11 +96,55 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap auto remove_unused_transpose_pass = std::make_shared(); if (remove_unused_transpose_pass == nullptr) { MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass should be specified"; - return nullptr; + return RET_ERROR; } remove_unused_transpose_pass->SetFmkType(config->fmk); fusion_pm->AddPass(remove_unused_transpose_pass); } + fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(fusion_pm); + return RET_OK; +} + +int AnfTransform::AddGraphPass(const std::shared_ptr &optimizer, const converter::Flags *config) { + auto graph_pm = std::make_shared("anf graph pass manager", true); + if (config->fmk == lite::converter::FmkType_TFLITE || config->fmk == lite::converter::FmkType_TF || + config->fmk == lite::converter::FmkType_ONNX) { + graph_pm->AddPass(std::make_shared()); + graph_pm->AddPass(std::make_shared()); + } + auto weight_format_hardcode_pass = std::make_shared(); + weight_format_hardcode_pass->SetFmkType(config->fmk); + weight_format_hardcode_pass->SetQuantType(config->quantType); + graph_pm->AddPass(weight_format_hardcode_pass); + auto weight_format_transform_pass = std::make_shared(); + weight_format_transform_pass->SetFmkType(config->fmk); + weight_format_transform_pass->SetQuantType(config->quantType); + graph_pm->AddPass(weight_format_transform_pass); + auto infershape_pass = std::make_shared(); + infershape_pass->SetFmkType(config->fmk); + graph_pm->AddPass(infershape_pass); + auto slice_prepose_pass = std::make_shared(); + slice_prepose_pass->SetFmkType(config->fmk); + graph_pm->AddPass(slice_prepose_pass); + optimizer->AddPassManager(graph_pm); + return RET_OK; +} + +int AnfTransform::AddConvertPass(const std::shared_ptr &optimizer, + const converter::Flags *config) { + auto convert_pm = std::make_shared("anf graph convert pass manager", true); + convert_pm->AddPass(std::make_shared()); + if (config->fmk == lite::converter::FmkType_TFLITE) { + convert_pm->AddPass(std::make_shared()); + convert_pm->AddPass(std::make_shared()); + } + optimizer->AddPassManager(convert_pm); + return RET_OK; +} + +int AnfTransform::AddConstFoldPass(const std::shared_ptr &optimizer, + const converter::Flags *config) { auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); if (!config->trainModel) { auto inne_context_ptr = std::make_shared(); @@ -179,47 +154,90 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap auto update_conv2d_param_pass = std::make_shared(); update_conv2d_param_pass->SetFmkType(config->fmk); const_fold_pm->AddPass(update_conv2d_param_pass); - fusion_pm->AddPass(std::make_shared()); - convert_pm->AddPass(std::make_shared()); - if (config->fmk == lite::converter::FmkType_TFLITE) { - convert_pm->AddPass(std::make_shared()); - convert_pm->AddPass(std::make_shared()); - } optimizer->AddPassManager(const_fold_pm); - optimizer->AddPassManager(convert_pm); - optimizer->AddPassManager(fusion_pm); - optimizer->AddPassManager(graph_pm); - auto new_graph = optimizer->Optimize(old_graph); - if (new_graph == nullptr) { - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); - return nullptr; + return RET_OK; +} + +int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { + switch (config->fmk) { + case converter::FmkType_MS: + return RunMindirAdjustPass(old_graph, config); + case converter::FmkType_ONNX: + return RunOnnxAdjustPass(old_graph, config); + case converter::FmkType_TF: + return RunTFAdjustPass(old_graph, config); + default: + return RET_OK; + } +} + +int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { + auto mindir_adjust_pass = std::make_shared(); + mindir_adjust_pass->SetFmkType(config->fmk); + mindir_adjust_pass->SetQuantType(config->quantType); + if (!mindir_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "mindir adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; + } + auto mindir_inputs_adjust_pass = std::make_shared(); + if (!mindir_inputs_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "mindir inputs adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; + } + return RET_OK; +} + +int AnfTransform::RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { + // onnx pre adjustment + auto onnx_adjust_pass = std::make_shared(); + if (!onnx_adjust_pass->Run(old_graph)) { + MS_LOG(ERROR) << "onnx adjust failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; + } + return RET_OK; +} + +int AnfTransform::RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) { + auto functionalize_control_op_pass = std::make_shared(); + if (!functionalize_control_op_pass->Run(old_graph)) { + MS_LOG(ERROR) << "functionalize control op pass failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); + return RET_ERROR; } + return RET_OK; +} + +int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, + const FuncGraphPtr &new_graph) { // quant if (config->quantType == schema::QuantType_PostTraining) { if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { MS_LOG(ERROR) << "bitNum must be valid pos num."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; + return RET_ERROR; } this->mQuantizer = std::make_unique(new_graph, config->configFile, std::stoi(config->bitNum)); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); - return nullptr; + return RET_ERROR; } } else if (config->quantType == schema::QuantType_WeightQuant) { if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { MS_LOG(ERROR) << "weight quant input param error"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; + return RET_ERROR; } this->mQuantizer = std::make_unique(new_graph, config->configFile, config->quantWeightSize, config->quantWeightChannel, config->bitNum); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); - return nullptr; + return RET_ERROR; } } if (mQuantizer != nullptr) { @@ -228,9 +246,65 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap if (status != RET_OK) { MS_LOG(ERROR) << "Quant failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return RET_ERROR; } } + return RET_OK; +} + +FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { + MS_ASSERT(nullptr != old_graph); + if (config == nullptr) { + MS_LOG(ERROR) << "config should be specified"; + return nullptr; + } + if (old_graph->has_flag("HasTransformed")) { + old_graph->set_flag("HasTransformed", false); + return old_graph; + } + + auto status = RunAdjustPass(old_graph, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Run Adjust pass failed."; + return nullptr; + } + + auto optimizer = std::make_shared(); + + status = AddConstFoldPass(optimizer, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add const fold pass failed."; + return nullptr; + } + + status = AddConvertPass(optimizer, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add convert pass failed."; + return nullptr; + } + + status = AddFusionPass(optimizer, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add fusion pass failed."; + return nullptr; + } + status = AddGraphPass(optimizer, config); + if (status != RET_OK) { + MS_LOG(ERROR) << "Add graph pass failed."; + return nullptr; + } + + auto new_graph = optimizer->Optimize(old_graph); + if (new_graph == nullptr) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); + return nullptr; + } + + status = DoQuantize(old_graph, config, new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Do Quantize failed."; + return nullptr; + } return new_graph; } diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 38e0d30b43..e4de7d5d3d 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -19,6 +19,7 @@ #include #include +#include "backend/optimizer/common/optimizer.h" #include "schema/inner/model_generated.h" #include "tools/common/storage.h" #include "tools/converter/converter_flags.h" @@ -39,6 +40,24 @@ class AnfTransform { std::vector *vnodes); FuncGraphPtr TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr); std::unique_ptr mQuantizer = nullptr; + + int AddFusionPass(const std::shared_ptr &optimizer, const converter::Flags *config); + + int AddGraphPass(const std::shared_ptr &optimizer, const converter::Flags *config); + + int AddConvertPass(const std::shared_ptr &optimizer, const converter::Flags *config); + + int AddConstFoldPass(const std::shared_ptr &optimizer, const converter::Flags *config); + + int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + + int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + + int RunOnnxAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + + int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); + + int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index d7cbb3b192..e44948df62 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -53,42 +53,36 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return nullptr; - } - +int OnnxDeConvParser::ParseParameters(const onnx::NodeProto &onnx_node, + const std::unique_ptr &attr) { attr->padMode = schema::PadMode_NOTSET; attr->group = 1; attr->strideW = 1; attr->strideH = 1; attr->dilateW = 1; attr->dilateH = 1; + for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "group") { attr->group = static_cast(onnx_node_attr.i()); } else if (onnx_node_attr.name() == "dilations") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2"; - return nullptr; + return RET_ERROR; } attr->dilateH = static_cast(onnx_node_attr.ints(0)); attr->dilateW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernels") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return nullptr; + return RET_ERROR; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); } else if (onnx_node_attr.name() == "kernel_shape") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2"; - return nullptr; + return RET_ERROR; } attr->kernelH = static_cast(onnx_node_attr.ints(0)); attr->kernelW = static_cast(onnx_node_attr.ints(1)); @@ -97,7 +91,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o } else if (onnx_node_attr.name() == "pads") { if (onnx_node_attr.ints().size() != 4) { MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4"; - return nullptr; + return RET_ERROR; } attr->padUp = static_cast(onnx_node_attr.ints(0)); attr->padLeft = static_cast(onnx_node_attr.ints(1)); @@ -106,7 +100,7 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o } else if (onnx_node_attr.name() == "strides") { if (onnx_node_attr.ints().size() != 2) { MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2"; - return nullptr; + return RET_ERROR; } attr->strideH = static_cast(onnx_node_attr.ints(0)); attr->strideW = static_cast(onnx_node_attr.ints(1)); @@ -115,13 +109,30 @@ lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &o attr->format = schema::Format::Format_NHWC; } else { MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s().c_str(); - return nullptr; + return RET_ERROR; } } else if (onnx_node_attr.name() == "output_padding") { - MS_LOG(ERROR) << "output_padding param hasn't been supported"; - return nullptr; + attr->outputPaddingH = static_cast(onnx_node_attr.ints(0)); + attr->outputPaddingW = static_cast(onnx_node_attr.ints(1)); } } + return RET_OK; +} + +lite::PrimitiveC *OnnxDeConvParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph, + const onnx::NodeProto &onnx_node) { + MS_LOG(DEBUG) << "onnx DeConvParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return nullptr; + } + + auto status = ParseParameters(onnx_node, attr); + if (status != RET_OK) { + MS_LOG(ERROR) << "Parse parameters failed."; + return nullptr; + } const auto &onnx_conv_weight = onnx_node.input(1); auto node_iter = diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h index 2b83c223cf..a5f6c582cc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.h @@ -32,6 +32,8 @@ class OnnxDeConvParser : public OnnxNodeParser { private: bool ParseGroupDeConvolution(const std::unique_ptr &attr, schema::PrimitiveT *primitive); + + int ParseParameters(const onnx::NodeProto &onnx_node, const std::unique_ptr &attr); }; } // namespace lite } // namespace mindspore