!14211 [MSLITE] Support conv1d.

From: @wang_shaocong
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by:
pull/14211/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5588a398be

@ -35,7 +35,7 @@ static std::map<std::string, int64_t> DataFormatToEnumMap = {
{"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK},
{"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC},
{"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
{"NCDHW", Format::NCDHW},
{"NCDHW", Format::NCDHW}, {"NWC", Format::NWC}, {"NCW", Format::NCW},
};
static std::map<int64_t, std::string> DataFormatToStrMap = {
@ -44,7 +44,7 @@ static std::map<int64_t, std::string> DataFormatToStrMap = {
{Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"},
{Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"},
{Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
{Format::NCDHW, "NCDHW"},
{Format::NCDHW, "NCDHW"}, {Format::NWC, "NWC"}, {Format::NCW, "NCW"},
};
static std::map<std::string, int64_t> ReductionToEnumMap = {

@ -62,7 +62,9 @@ enum Format : int64_t {
NC4 = 12,
NC4HW4 = 13,
NUM_OF_FORMAT = 14,
NCDHW = 15
NCDHW = 15,
NWC = 16,
NCW = 17
};
enum ActivationType : int64_t {
NO_ACTIVATION = 0,

@ -53,7 +53,9 @@ enum Format : int {
NC4,
NC4HW4,
NUM_OF_FORMAT,
NCDHW
NCDHW,
NWC,
NCW
}
enum ActivationType : byte {

@ -246,8 +246,11 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/squeeze_fusion.cc
${LITE_DIR}/tools/optimizer/graph/conv1d_inout_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/conv1d_weight_expanding_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/group_depthwise_op_convert_pass.cc
${LITE_DIR}/tools/optimizer/graph/tflite_inputs_adjust_pass.cc

@ -56,8 +56,11 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/gelu_fusion.cc
../optimizer/fusion/tf_gelu_fusion.cc
../optimizer/fusion/onnx_gelu_fusion.cc
../optimizer/fusion/squeeze_fusion.cc
../optimizer/graph/conv1d_inout_adjust_pass.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/conv1d_weight_expanding_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/group_depthwise_op_convert_pass.cc
../optimizer/graph/tflite_inputs_adjust_pass.cc

@ -39,10 +39,13 @@
#include "tools/optimizer/graph/primitive_adjust_pass.h"
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
#include "tools/optimizer/fusion/squeeze_fusion.h"
#include "tools/optimizer/graph/conv1d_inout_adjust_pass.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/conv1d_weight_expanding_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/group_depthwise_op_convert_pass.h"
#include "tools/optimizer/graph/tflite_inputs_adjust_pass.h"
@ -131,6 +134,8 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto conv1d_weight_expanding_pass = std::make_shared<opt::Conv1DWeightExpandingPass>();
graph_pm->AddPass(conv1d_weight_expanding_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
@ -198,6 +203,15 @@ int AnfTransform::RunAdjustPass(const FuncGraphPtr &old_graph, const converter::
}
}
int AnfTransform::AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto conv1d_pm = std::make_shared<opt::PassManager>("conv1d adjust pass manager", true);
conv1d_pm->AddPass(std::make_shared<opt::Conv1DInOutAdjustPass>());
conv1d_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
optimizer->AddPassManager(conv1d_pm);
return RET_OK;
}
int AnfTransform::RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto primitive_adjust_pass = std::make_shared<opt::PrimitiveAdjustPass>();
primitive_adjust_pass->SetFmkType(config->fmk);
@ -341,6 +355,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return nullptr;
}
status = AddConv1DAdjustPass(optimizer, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Add conv1d adjust pass failed.";
return nullptr;
}
auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);

@ -52,6 +52,8 @@ class AnfTransform {
static int RunPrecedingPass(const FuncGraphPtr &old_graph, const converter::Flags &config);
static int AddConv1DAdjustPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer, const converter::Flags *config);
static int RunAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunMindirAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

@ -23,45 +23,87 @@
namespace mindspore::lite {
STATUS ParseVecAttr(const onnx::NodeProto &onnx_node, std::vector<int64_t> *kernels, std::vector<int64_t> *strides,
std::vector<int64_t> *dilation, std::vector<int64_t> *pads) {
std::vector<int64_t> *dilation, std::vector<int64_t> *pads, bool *conv1d) {
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "dilations") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
switch (onnx_node_attr.ints().size()) {
case 1:
*conv1d = true;
dilation->push_back(1);
dilation->push_back(onnx_node_attr.ints(0));
break;
case 2:
dilation->push_back(onnx_node_attr.ints(0));
dilation->push_back(onnx_node_attr.ints(1));
break;
default:
MS_LOG(ERROR) << "dilations size " << onnx_node_attr.ints().size() << " is not 1 or 2";
return RET_ERROR;
}
dilation->push_back(onnx_node_attr.ints(0));
dilation->push_back(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernels") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
switch (onnx_node_attr.ints().size()) {
case 1:
*conv1d = true;
kernels->push_back(1);
kernels->push_back(onnx_node_attr.ints(0));
break;
case 2:
kernels->push_back(onnx_node_attr.ints(0));
kernels->push_back(onnx_node_attr.ints(1));
break;
default:
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 1 or 2";
return RET_ERROR;
}
kernels->push_back(onnx_node_attr.ints(0));
kernels->push_back(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "kernel_shape") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
switch (onnx_node_attr.ints().size()) {
case 1:
*conv1d = true;
kernels->push_back(1);
kernels->push_back(onnx_node_attr.ints(0));
break;
case 2:
kernels->push_back(onnx_node_attr.ints(0));
kernels->push_back(onnx_node_attr.ints(1));
break;
default:
MS_LOG(ERROR) << "kernel_shape size " << onnx_node_attr.ints().size() << " is not 1 or 2";
return RET_ERROR;
}
kernels->push_back(onnx_node_attr.ints(0));
kernels->push_back(onnx_node_attr.ints(1));
} else if (onnx_node_attr.name() == "pads") {
if (onnx_node_attr.ints().size() != 4) {
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 4";
return RET_ERROR;
switch (onnx_node_attr.ints().size()) {
case 2:
*conv1d = true;
pads->push_back(0);
pads->push_back(0);
pads->push_back(onnx_node_attr.ints(0));
pads->push_back(onnx_node_attr.ints(1));
break;
case 4:
pads->push_back(onnx_node_attr.ints(0));
pads->push_back(onnx_node_attr.ints(2));
pads->push_back(onnx_node_attr.ints(1));
pads->push_back(onnx_node_attr.ints(3));
break;
default:
MS_LOG(ERROR) << "pads size " << onnx_node_attr.ints().size() << " is not 2 or 4";
return RET_ERROR;
}
pads->push_back(onnx_node_attr.ints(0));
pads->push_back(onnx_node_attr.ints(2));
pads->push_back(onnx_node_attr.ints(1));
pads->push_back(onnx_node_attr.ints(3));
} else if (onnx_node_attr.name() == "strides") {
if (onnx_node_attr.ints().size() != 2) {
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 2";
return RET_ERROR;
switch (onnx_node_attr.ints().size()) {
case 1:
*conv1d = true;
strides->push_back(1);
strides->push_back(onnx_node_attr.ints(0));
break;
case 2:
strides->push_back(onnx_node_attr.ints(0));
strides->push_back(onnx_node_attr.ints(1));
break;
default:
MS_LOG(ERROR) << "strides size " << onnx_node_attr.ints().size() << " is not 1 or 2";
return RET_ERROR;
}
strides->push_back(onnx_node_attr.ints(0));
strides->push_back(onnx_node_attr.ints(1));
}
}
return RET_OK;
@ -140,9 +182,13 @@ ops::PrimitiveC *OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const
prim->set_pad_mode(pad_mode);
prim->set_group(group);
if (ParseVecAttr(onnx_node, &kernels, &strides, &dilation, &pads) != RET_OK) {
bool conv1d = false;
if (ParseVecAttr(onnx_node, &kernels, &strides, &dilation, &pads, &conv1d) != RET_OK) {
return nullptr;
}
if (conv1d) {
prim->set_format(mindspore::Format::NCW);
}
if (dilation.empty()) {
prim->set_dilation({1, 1});
} else {

@ -34,7 +34,7 @@ class OnnxConvParser : public OnnxNodeParser {
};
STATUS ParseVecAttr(const onnx::NodeProto &onnx_node, std::vector<int64_t> *kernels, std::vector<int64_t> *strides,
std::vector<int64_t> *dilation, std::vector<int64_t> *pads);
std::vector<int64_t> *dilation, std::vector<int64_t> *pads, bool *conv1d);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONV_PARSER_H

@ -49,9 +49,13 @@ ops::PrimitiveC *OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, con
prim->set_group(group);
prim->set_pad_mode(pad_mode);
if (ParseVecAttr(onnx_node, &kernel, &stride, &dilate, &pads) != RET_OK) {
bool conv1d = false;
if (ParseVecAttr(onnx_node, &kernel, &stride, &dilate, &pads, &conv1d) != RET_OK) {
return nullptr;
}
if (conv1d) {
prim->set_format(mindspore::Format::NCW);
}
if (!dilate.empty()) {
prim->set_dilation(dilate);
}

@ -583,6 +583,14 @@ bool IsQuantNode(const BaseRef &n) {
return false;
}
bool IsSqueezeNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
return CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimSqueeze) ||
CheckPrimitiveType(utils::cast<AnfNodePtr>(n), prim::kPrimUnsqueeze);
}
return false;
}
bool CheckIsAllInputsParam(const AnfNodePtr &node) {
if (node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);

@ -78,6 +78,8 @@ bool IsQuantNode(const BaseRef &n);
bool IsActivationNode(const BaseRef &n);
bool IsSqueezeNode(const BaseRef &n);
bool CheckIsAllInputsParam(const AnfNodePtr &node);
size_t GetOutputTensorNum(const AnfNodePtr &node);

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/squeeze_fusion.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
const BaseRef SqueezeFusion::DefinePattern() const {
auto squeeze_var = std::make_shared<CondVar>(IsSqueezeNode);
auto act_var = std::make_shared<CondVar>(IsActivationNode);
VectorRef act_ref = VectorRef({act_var, squeeze_var});
auto unsqueeze_var = std::make_shared<CondVar>(IsSqueezeNode);
return VectorRef({unsqueeze_var, act_ref});
}
const AnfNodePtr SqueezeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &unsqueeze_node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(unsqueeze_node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto act_node = unsqueeze_node->cast<CNodePtr>()->input(1);
if (CheckIfCNodeIsNull(act_node->cast<CNodePtr>()) != lite::RET_OK) {
return nullptr;
}
auto squeeze_node = act_node->cast<CNodePtr>()->input(1);
if (CheckIfCNodeIsNull(squeeze_node->cast<CNodePtr>()) != lite::RET_OK) {
return nullptr;
}
auto pre_node = squeeze_node->cast<CNodePtr>()->input(1);
if (CheckIfCNodeIsNull(pre_node->cast<CNodePtr>()) != lite::RET_OK) {
return nullptr;
}
if (GetCNodePrimitive(unsqueeze_node)->GetAttr("axis") == GetCNodePrimitive(unsqueeze_node)->GetAttr("axis")) {
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(unsqueeze_node, act_node);
manager->Replace(squeeze_node, pre_node);
return pre_node;
}
return nullptr;
}
} // namespace mindspore::opt

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace opt {
class SqueezeFusion : public PatternProcessPass {
public:
explicit SqueezeFusion(bool multigraph = true, const std::string &name = "squeeze_fusion")
: PatternProcessPass(name, multigraph) {}
~SqueezeFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/conv1d_inout_adjust_pass.h"
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "mindspore/lite/include/errorcode.h"
#include "ops/conv2d.h"
#include "ops/squeeze.h"
#include "ops/unsqueeze.h"
#include "ops/primitive_c.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
CNodePtr Conv1DInOutAdjustPass::NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
const std::vector<int64_t> &axis) {
auto unsqueeze_prim = std::make_shared<ops::Unsqueeze>();
if (unsqueeze_prim == nullptr) {
MS_LOG(ERROR) << "create unsqueeze failed.";
return nullptr;
}
unsqueeze_prim->set_attr("axis", MakeValue(axis));
ValueNodePtr value_node = NewValueNode(unsqueeze_prim);
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
auto unsqueeze = func_graph->NewCNode(op_inputs);
unsqueeze->set_fullname_with_scope(input_node->fullname_with_scope() + "_unsqueeze");
return unsqueeze;
}
CNodePtr Conv1DInOutAdjustPass::NewSqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
const std::vector<int64_t> &axis) {
auto squeeze_prim = std::make_shared<ops::Squeeze>();
if (squeeze_prim == nullptr) {
MS_LOG(ERROR) << "create squeeze failed.";
return nullptr;
}
squeeze_prim->set_attr("axis", MakeValue(axis));
ValueNodePtr value_node = NewValueNode(squeeze_prim);
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
auto squeeze = func_graph->NewCNode(op_inputs);
squeeze->set_fullname_with_scope(input_node->fullname_with_scope() + "_squeeze");
return squeeze;
}
bool Conv1DInOutAdjustPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
if (!CheckPrimitiveType(cnode, prim::kPrimConv2D) && !CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
continue;
}
auto conv2d_node = GetValueNode<std::shared_ptr<mindspore::ops::Conv2D>>(cnode->input(0));
if (conv2d_node == nullptr) {
MS_LOG(ERROR) << "conv2d is nullptr.";
return false;
}
if (conv2d_node->GetAttr(ops::kFormat) == nullptr) {
MS_LOG(ERROR) << "The format of conv2d is nullptr.";
return false;
}
std::vector<int64_t> axis;
switch (conv2d_node->get_format()) {
case mindspore::Format::NWC:
axis = {1};
break;
case mindspore::Format::NCW:
axis = {2};
break;
default:
continue;
}
auto input_node = cnode->input(1);
auto unsqueeze = NewUnsqueezeOpNode(func_graph, input_node, axis);
if (unsqueeze == nullptr) {
MS_LOG(ERROR) << "New unsqueeze node failed.";
return false;
}
manager->Replace(input_node, unsqueeze);
auto squeeze = NewSqueezeOpNode(func_graph, cnode, axis);
if (squeeze == nullptr) {
MS_LOG(ERROR) << "New squeeze node failed.";
return false;
}
manager->Replace(cnode, squeeze);
}
return true;
}
} // namespace mindspore::opt

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_flags.h"
namespace mindspore::opt {
class Conv1DInOutAdjustPass : public Pass {
public:
Conv1DInOutAdjustPass() : Pass("conv1d_inout_adjust_pass") {}
~Conv1DInOutAdjustPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
CNodePtr NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
const std::vector<int64_t> &axis);
CNodePtr NewSqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
const std::vector<int64_t> &axis);
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/conv1d_weight_expanding_pass.h"
#include <memory>
#include <algorithm>
#include <vector>
namespace mindspore::opt {
namespace {
constexpr size_t kTripleNum = 3;
constexpr size_t kConvWeightIndex = 2;
} // namespace
lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const ParamValueLitePtr &tensor) {
if (tensor == nullptr) {
return lite::RET_NULL_PTR;
}
auto shape = tensor->tensor_shape();
std::vector<int> new_shape(shape);
switch (tensor->format()) {
case schema::Format_NCHW:
case schema::Format_KCHW:
new_shape.insert(new_shape.begin() + 2, 1);
break;
case schema::Format_NHWC:
case schema::Format_KHWC:
new_shape.insert(new_shape.begin() + 1, 1);
break;
default:
MS_LOG(ERROR) << "Unsupported format.";
return RET_ERROR;
}
tensor->set_tensor_shape(new_shape);
return RET_OK;
}
bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
if (!CheckPrimitiveType(node, prim::kPrimConv2D) && !CheckPrimitiveType(node, prim::kPrimConv2DFusion)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto weight_value = GetLiteParamValue(weight_node);
if (weight_value == nullptr) {
MS_LOG(ERROR) << "weight node must be param value.";
return false;
}
// expand weight tensor to 4 dimensions.
if (weight_value->tensor_shape().size() == kTripleNum) {
auto status = ExpandFilterShape(weight_value);
if (status != RET_OK) {
MS_LOG(ERROR) << "Expand filter shape failed.";
return false;
}
}
}
return RET_OK;
}
} // namespace mindspore::opt

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::ParamValueLitePtr;
namespace mindspore::opt {
class Conv1DWeightExpandingPass : public Pass {
public:
Conv1DWeightExpandingPass() : Pass("conv1d_weight_expanding_pass") {}
~Conv1DWeightExpandingPass() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
lite::STATUS ExpandFilterShape(const ParamValueLitePtr &tensor);
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_

@ -24,7 +24,7 @@
namespace mindspore::opt {
class OnnxPadAdjustPass : public Pass {
public:
OnnxPadAdjustPass() : Pass("onnx_pad_adjust") {}
OnnxPadAdjustPass() : Pass("onnx_pad_adjust_pass") {}
~OnnxPadAdjustPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;

Loading…
Cancel
Save