!14211 [MSLITE] Support conv1d.
From: @wang_shaocong Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by:pull/14211/MERGE
commit
5588a398be
@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/squeeze_fusion.h"
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
const BaseRef SqueezeFusion::DefinePattern() const {
|
||||
auto squeeze_var = std::make_shared<CondVar>(IsSqueezeNode);
|
||||
auto act_var = std::make_shared<CondVar>(IsActivationNode);
|
||||
VectorRef act_ref = VectorRef({act_var, squeeze_var});
|
||||
auto unsqueeze_var = std::make_shared<CondVar>(IsSqueezeNode);
|
||||
return VectorRef({unsqueeze_var, act_ref});
|
||||
}
|
||||
|
||||
const AnfNodePtr SqueezeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &unsqueeze_node,
|
||||
const EquivPtr &) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(unsqueeze_node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto act_node = unsqueeze_node->cast<CNodePtr>()->input(1);
|
||||
if (CheckIfCNodeIsNull(act_node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto squeeze_node = act_node->cast<CNodePtr>()->input(1);
|
||||
if (CheckIfCNodeIsNull(squeeze_node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
auto pre_node = squeeze_node->cast<CNodePtr>()->input(1);
|
||||
if (CheckIfCNodeIsNull(pre_node->cast<CNodePtr>()) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (GetCNodePrimitive(unsqueeze_node)->GetAttr("axis") == GetCNodePrimitive(unsqueeze_node)->GetAttr("axis")) {
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
manager->Replace(unsqueeze_node, act_node);
|
||||
manager->Replace(squeeze_node, pre_node);
|
||||
return pre_node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SqueezeFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit SqueezeFusion(bool multigraph = true, const std::string &name = "squeeze_fusion")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~SqueezeFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_SQUEEZE_FUSION_H_
|
@ -0,0 +1,106 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/conv1d_inout_adjust_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include "ops/conv2d.h"
|
||||
#include "ops/squeeze.h"
|
||||
#include "ops/unsqueeze.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
CNodePtr Conv1DInOutAdjustPass::NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
|
||||
const std::vector<int64_t> &axis) {
|
||||
auto unsqueeze_prim = std::make_shared<ops::Unsqueeze>();
|
||||
if (unsqueeze_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "create unsqueeze failed.";
|
||||
return nullptr;
|
||||
}
|
||||
unsqueeze_prim->set_attr("axis", MakeValue(axis));
|
||||
ValueNodePtr value_node = NewValueNode(unsqueeze_prim);
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
|
||||
auto unsqueeze = func_graph->NewCNode(op_inputs);
|
||||
unsqueeze->set_fullname_with_scope(input_node->fullname_with_scope() + "_unsqueeze");
|
||||
return unsqueeze;
|
||||
}
|
||||
|
||||
CNodePtr Conv1DInOutAdjustPass::NewSqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
|
||||
const std::vector<int64_t> &axis) {
|
||||
auto squeeze_prim = std::make_shared<ops::Squeeze>();
|
||||
if (squeeze_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "create squeeze failed.";
|
||||
return nullptr;
|
||||
}
|
||||
squeeze_prim->set_attr("axis", MakeValue(axis));
|
||||
ValueNodePtr value_node = NewValueNode(squeeze_prim);
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
|
||||
auto squeeze = func_graph->NewCNode(op_inputs);
|
||||
squeeze->set_fullname_with_scope(input_node->fullname_with_scope() + "_squeeze");
|
||||
return squeeze;
|
||||
}
|
||||
|
||||
bool Conv1DInOutAdjustPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
if (!CheckPrimitiveType(cnode, prim::kPrimConv2D) && !CheckPrimitiveType(cnode, prim::kPrimConv2DFusion)) {
|
||||
continue;
|
||||
}
|
||||
auto conv2d_node = GetValueNode<std::shared_ptr<mindspore::ops::Conv2D>>(cnode->input(0));
|
||||
if (conv2d_node == nullptr) {
|
||||
MS_LOG(ERROR) << "conv2d is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (conv2d_node->GetAttr(ops::kFormat) == nullptr) {
|
||||
MS_LOG(ERROR) << "The format of conv2d is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> axis;
|
||||
switch (conv2d_node->get_format()) {
|
||||
case mindspore::Format::NWC:
|
||||
axis = {1};
|
||||
break;
|
||||
case mindspore::Format::NCW:
|
||||
axis = {2};
|
||||
break;
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(1);
|
||||
auto unsqueeze = NewUnsqueezeOpNode(func_graph, input_node, axis);
|
||||
if (unsqueeze == nullptr) {
|
||||
MS_LOG(ERROR) << "New unsqueeze node failed.";
|
||||
return false;
|
||||
}
|
||||
manager->Replace(input_node, unsqueeze);
|
||||
auto squeeze = NewSqueezeOpNode(func_graph, cnode, axis);
|
||||
if (squeeze == nullptr) {
|
||||
MS_LOG(ERROR) << "New squeeze node failed.";
|
||||
return false;
|
||||
}
|
||||
manager->Replace(cnode, squeeze);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class Conv1DInOutAdjustPass : public Pass {
|
||||
public:
|
||||
Conv1DInOutAdjustPass() : Pass("conv1d_inout_adjust_pass") {}
|
||||
~Conv1DInOutAdjustPass() override = default;
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
CNodePtr NewUnsqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
|
||||
const std::vector<int64_t> &axis);
|
||||
CNodePtr NewSqueezeOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr input_node,
|
||||
const std::vector<int64_t> &axis);
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CONV1D_INOUT_ADJUST_PASS_H_
|
@ -0,0 +1,80 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/conv1d_weight_expanding_pass.h"
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kTripleNum = 3;
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
} // namespace
|
||||
lite::STATUS Conv1DWeightExpandingPass::ExpandFilterShape(const ParamValueLitePtr &tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return lite::RET_NULL_PTR;
|
||||
}
|
||||
auto shape = tensor->tensor_shape();
|
||||
std::vector<int> new_shape(shape);
|
||||
switch (tensor->format()) {
|
||||
case schema::Format_NCHW:
|
||||
case schema::Format_KCHW:
|
||||
new_shape.insert(new_shape.begin() + 2, 1);
|
||||
break;
|
||||
case schema::Format_NHWC:
|
||||
case schema::Format_KHWC:
|
||||
new_shape.insert(new_shape.begin() + 1, 1);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported format.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tensor->set_tensor_shape(new_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool Conv1DWeightExpandingPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveType(node, prim::kPrimConv2D) && !CheckPrimitiveType(node, prim::kPrimConv2DFusion)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
auto weight_value = GetLiteParamValue(weight_node);
|
||||
if (weight_value == nullptr) {
|
||||
MS_LOG(ERROR) << "weight node must be param value.";
|
||||
return false;
|
||||
}
|
||||
// expand weight tensor to 4 dimensions.
|
||||
if (weight_value->tensor_shape().size() == kTripleNum) {
|
||||
auto status = ExpandFilterShape(weight_value);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Expand filter shape failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::opt
|
@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_
|
||||
#include <string>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::ParamValueLitePtr;
|
||||
namespace mindspore::opt {
|
||||
class Conv1DWeightExpandingPass : public Pass {
|
||||
public:
|
||||
Conv1DWeightExpandingPass() : Pass("conv1d_weight_expanding_pass") {}
|
||||
~Conv1DWeightExpandingPass() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
lite::STATUS ExpandFilterShape(const ParamValueLitePtr &tensor);
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV1D_WEIGHT_EXPANDING_PASS_H_
|
Loading…
Reference in new issue