parent
9442516f56
commit
4acb61d59d
@ -1,157 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h"
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
#include "session/anf_runtime_algorithm.h"
|
|
||||||
#include "device/kernel_info.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
namespace {
|
|
||||||
constexpr size_t kBn2AddReluOutputNum = 4;
|
|
||||||
enum Bn2AddReluOutput {
|
|
||||||
kBn2AddReluOutput = 0,
|
|
||||||
kBn2AddReluRunningMean,
|
|
||||||
kBn2AddReluRunningVariance,
|
|
||||||
kBn2AddReluSaveInvVariance,
|
|
||||||
};
|
|
||||||
|
|
||||||
std::tuple<CNodePtr, CNodePtr, CNodePtr, CNodePtr> GetUsedCNode(const AnfNodePtr &node) {
|
|
||||||
auto relu_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kReluInputNum);
|
|
||||||
MS_EXCEPTION_IF_NULL(relu_cnode);
|
|
||||||
auto add_cnode = CheckAnfNodeIfCNodeAndInputSize(relu_cnode->input(1), kAddInputNum);
|
|
||||||
MS_EXCEPTION_IF_NULL(add_cnode);
|
|
||||||
auto add_input1_cnode = CheckAnfNodeIfCNodeAndInputSize(add_cnode->input(1), kTupleGetitemInputNum);
|
|
||||||
MS_EXCEPTION_IF_NULL(add_input1_cnode);
|
|
||||||
auto bn_cnode = CheckAnfNodeIfCNodeAndInputSize(add_input1_cnode->input(1), kBnInputNum);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
|
||||||
auto conv_cnode = CheckAnfNodeIfCNodeAndInputSize(bn_cnode->input(kX), kConvInputNum);
|
|
||||||
|
|
||||||
return std::make_tuple(conv_cnode, bn_cnode, add_cnode, relu_cnode);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateOutputsOfBn2AddRelu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
|
|
||||||
const CNodePtr &bn_node, const CNodePtr &add_node, const CNodePtr &relu_node,
|
|
||||||
std::vector<AnfNodePtr> *bn2_add_relu_outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(add_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(relu_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_node);
|
|
||||||
auto prim = std::make_shared<Primitive>(kBN2AddReluOpName);
|
|
||||||
std::vector<AnfNodePtr> bn2_add_relu_inputs = {NewValueNode(prim)};
|
|
||||||
// The inputs of bn2_add_relu are from the outputs of conv_bn1, the 2nd input of add, and the 2nd to 5th inputs of bn
|
|
||||||
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_add_relu_inputs));
|
|
||||||
bn2_add_relu_inputs.push_back(add_node->input(2));
|
|
||||||
for (size_t i = kX + 1; i <= kVariance; i++) {
|
|
||||||
bn2_add_relu_inputs.push_back(bn_node->input(i));
|
|
||||||
}
|
|
||||||
auto bn2_add_relu_cnode = func_graph->NewCNode(bn2_add_relu_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn2_add_relu_cnode);
|
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
||||||
bn2_add_relu_cnode->set_kernel_info(kernel_info);
|
|
||||||
|
|
||||||
// Set attr for bn2_add_relu
|
|
||||||
AnfAlgo::CopyNodeAttrs(bn_node, bn2_add_relu_cnode);
|
|
||||||
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_add_relu_cnode);
|
|
||||||
|
|
||||||
// Set abstract of bn2_add_relu
|
|
||||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_node->abstract());
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
|
|
||||||
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "Abstract tuple size of FusedBatchNorm must be " << kBnOutputNum << ", but it is "
|
|
||||||
<< bn_abstract_tuple->elements().size();
|
|
||||||
}
|
|
||||||
auto relu_abstract = relu_node->abstract();
|
|
||||||
MS_EXCEPTION_IF_NULL(relu_abstract);
|
|
||||||
// The abstracts of node bn2_add_relu are from the some abstracts of bn and relu nodes.
|
|
||||||
AbstractBasePtrList bn2_add_relu_abstract_list{relu_abstract, bn_abstract_tuple->elements()[kRunningMean],
|
|
||||||
bn_abstract_tuple->elements()[kRunningVariance],
|
|
||||||
bn_abstract_tuple->elements()[kSaveInvVariance]};
|
|
||||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(bn2_add_relu_abstract_list);
|
|
||||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
|
||||||
bn2_add_relu_cnode->set_abstract(abstract_tuple);
|
|
||||||
|
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, bn2_add_relu_cnode, kBn2AddReluOutputNum, bn2_add_relu_outputs);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const BaseRef ConvBnAddReluFusion::DefinePattern() const {
|
|
||||||
VarPtr X = std::make_shared<Var>();
|
|
||||||
MS_EXCEPTION_IF_NULL(X);
|
|
||||||
VarPtr W = std::make_shared<Var>();
|
|
||||||
MS_EXCEPTION_IF_NULL(W);
|
|
||||||
VarPtr Ys = std::make_shared<SeqVar>();
|
|
||||||
MS_EXCEPTION_IF_NULL(Ys);
|
|
||||||
VarPtr Zs = std::make_shared<SeqVar>();
|
|
||||||
MS_EXCEPTION_IF_NULL(Zs);
|
|
||||||
|
|
||||||
return VectorRef(
|
|
||||||
{prim::kPrimRelu,
|
|
||||||
PatternListType(
|
|
||||||
{prim::kPrimTensorAdd,
|
|
||||||
PatternListType({prim::kPrimTupleGetItem,
|
|
||||||
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Ys}), Zs}),
|
|
||||||
W}),
|
|
||||||
X})});
|
|
||||||
}
|
|
||||||
|
|
||||||
const AnfNodePtr ConvBnAddReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
||||||
const EquivPtr &) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
auto manager = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
CNodePtr conv_cnode = nullptr;
|
|
||||||
CNodePtr bn_cnode = nullptr;
|
|
||||||
CNodePtr add_cnode = nullptr;
|
|
||||||
CNodePtr relu_cnode = nullptr;
|
|
||||||
std::tie(conv_cnode, bn_cnode, add_cnode, relu_cnode) = GetUsedCNode(node);
|
|
||||||
// Create conv_bn1 node and get outputs of conv_bn1
|
|
||||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
|
||||||
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
|
|
||||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
|
|
||||||
<< conv_bn1_outputs.size();
|
|
||||||
}
|
|
||||||
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by others
|
|
||||||
(void)manager->Replace(conv_cnode, conv_bn1_outputs[kData]);
|
|
||||||
|
|
||||||
// Create bn2_add_relu node and get outputs of bn2_add_relu
|
|
||||||
std::vector<AnfNodePtr> bn2_add_relu_outputs;
|
|
||||||
CreateOutputsOfBn2AddRelu(func_graph, conv_bn1_outputs, bn_cnode, add_cnode, relu_cnode, &bn2_add_relu_outputs);
|
|
||||||
if (bn2_add_relu_outputs.size() != kBn2AddReluOutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output size of node bn2_add_relu must be " << kBn2AddReluOutputNum << ", but it is "
|
|
||||||
<< bn2_add_relu_outputs.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a make_tuple to replace the bn node here, the outputs are from node bn2_add_relu and conv_bn1.
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
|
||||||
bn2_add_relu_outputs[kBn2AddReluOutput],
|
|
||||||
bn2_add_relu_outputs[kBn2AddReluRunningMean],
|
|
||||||
bn2_add_relu_outputs[kBn2AddReluRunningVariance],
|
|
||||||
conv_bn1_outputs[kMean],
|
|
||||||
bn2_add_relu_outputs[kBn2AddReluSaveInvVariance]};
|
|
||||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
||||||
(void)manager->Replace(bn_cnode, make_tuple);
|
|
||||||
return bn2_add_relu_outputs[kBn2AddReluOutput];
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
@ -1,34 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
|
||||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
|
||||||
|
|
||||||
#include "pre_activate/common/optimizer.h"
|
|
||||||
#include "pre_activate/common/helper.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class ConvBnAddReluFusion : public PatternProcessPass {
|
|
||||||
public:
|
|
||||||
explicit ConvBnAddReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_add_relu_fusion", multigraph) {}
|
|
||||||
~ConvBnAddReluFusion() override = default;
|
|
||||||
const BaseRef DefinePattern() const override;
|
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
|
@ -1,93 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
#include "session/anf_runtime_algorithm.h"
|
|
||||||
#include "device/kernel_info.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
const BaseRef ConvBnFusion::DefinePattern() const {
|
|
||||||
VarPtr Xs = std::make_shared<SeqVar>();
|
|
||||||
MS_EXCEPTION_IF_NULL(Xs);
|
|
||||||
VarPtr Ys = std::make_shared<SeqVar>();
|
|
||||||
MS_EXCEPTION_IF_NULL(Ys);
|
|
||||||
return VectorRef({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys});
|
|
||||||
}
|
|
||||||
|
|
||||||
const AnfNodePtr ConvBnFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (!node->isa<CNode>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The bn node is expected to be a cnode";
|
|
||||||
}
|
|
||||||
auto bn_cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
|
||||||
if (bn_cnode->inputs().size() < kVariance + 1) {
|
|
||||||
auto op_name = AnfAlgo::GetCNodeName(bn_cnode);
|
|
||||||
MS_LOG(EXCEPTION) << "op[" << op_name << "] has less than " << kVariance + 1 << " inputs.";
|
|
||||||
}
|
|
||||||
AnfNodePtr conv_node = bn_cnode->input(kX);
|
|
||||||
MS_EXCEPTION_IF_NULL(conv_node);
|
|
||||||
if (!conv_node->isa<CNode>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The conv node is expected to be a cnode";
|
|
||||||
}
|
|
||||||
auto conv_cnode = conv_node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
|
||||||
auto manager = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
// Create conv_bn1 node and get outputs of conv_bn1
|
|
||||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
|
||||||
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
|
|
||||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
|
|
||||||
<< conv_bn1_outputs.size();
|
|
||||||
}
|
|
||||||
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by other
|
|
||||||
(void)manager->Replace(conv_node, conv_bn1_outputs[kData]);
|
|
||||||
|
|
||||||
// Create bn2 node and get outputs of bn2
|
|
||||||
std::vector<AnfNodePtr> bn2_outputs;
|
|
||||||
std::vector<AnfNodePtr> bn1_outputs = {conv_bn1_outputs[2], conv_bn1_outputs[1]};
|
|
||||||
CreateOutputsOfFusedBn2(func_graph, bn1_outputs, bn_cnode, &bn2_outputs);
|
|
||||||
if (bn2_outputs.size() != kBN2OutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output size of node fusedbn2 must be " << kBN2OutputNum << ", but it is "
|
|
||||||
<< bn2_outputs.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create bn3 node and get outputs of bn3
|
|
||||||
std::vector<AnfNodePtr> bn3_outputs;
|
|
||||||
CreateOutputsOfFusedBn3(func_graph, conv_bn1_outputs[0], bn1_outputs, bn2_outputs, bn_cnode, &bn3_outputs);
|
|
||||||
|
|
||||||
if (bn3_outputs.size() != kBN3OutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "The output size of node fusedbn3 must be " << kBN3OutputNum << ", but it is "
|
|
||||||
<< bn3_outputs.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a make_tuple to replace the bn node here, the outputs are from node bn2 and conv_bn1.
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
|
||||||
bn3_outputs[0],
|
|
||||||
bn2_outputs[1],
|
|
||||||
bn2_outputs[2],
|
|
||||||
conv_bn1_outputs[2],
|
|
||||||
bn2_outputs[0]};
|
|
||||||
|
|
||||||
return func_graph->NewCNode(make_tuple_inputs);
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
@ -1,34 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
|
||||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
|
||||||
|
|
||||||
#include "pre_activate/common/optimizer.h"
|
|
||||||
#include "pre_activate/common/helper.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class ConvBnFusion : public PatternProcessPass {
|
|
||||||
public:
|
|
||||||
explicit ConvBnFusion(bool multigraph = true) : PatternProcessPass("conv_bn_fusion", multigraph) {}
|
|
||||||
~ConvBnFusion() override = default;
|
|
||||||
const BaseRef DefinePattern() const override;
|
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
|
@ -1,140 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h"
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "utils/utils.h"
|
|
||||||
#include "session/anf_runtime_algorithm.h"
|
|
||||||
#include "common/utils.h"
|
|
||||||
#include "device/kernel_info.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
namespace {
|
|
||||||
std::tuple<CNodePtr, CNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
auto relu_node = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(relu_node);
|
|
||||||
if (relu_node->inputs().size() < kReluInputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "relu has wrong input size";
|
|
||||||
}
|
|
||||||
auto tuple_getitem_anf = relu_node->input(1);
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_getitem_anf);
|
|
||||||
auto tuple_getitem = tuple_getitem_anf->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
||||||
if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "tuple getitem has wrong input size";
|
|
||||||
}
|
|
||||||
auto bn_node_anf = tuple_getitem->input(1);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_node_anf);
|
|
||||||
auto bn_node = bn_node_anf->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_node);
|
|
||||||
if (bn_node->inputs().size() < kBnInputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "bn_node has wrong input size";
|
|
||||||
}
|
|
||||||
auto conv_node_anf = bn_node->input(1);
|
|
||||||
MS_EXCEPTION_IF_NULL(conv_node_anf);
|
|
||||||
CNodePtr conv_node = conv_node_anf->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(conv_node);
|
|
||||||
return std::make_tuple(bn_node, bn_node, conv_node);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CreateOutputsOfBn2Relu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
|
|
||||||
const CNodePtr &bn_node, const CNodePtr &relu_node,
|
|
||||||
std::vector<AnfNodePtr> *bn2_relu_outputs) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn_node);
|
|
||||||
MS_EXCEPTION_IF_NULL(relu_node);
|
|
||||||
// The inputs of bn2_relu are from the outputs of conv_bn1 and the 2nd to 5th inputs of bn
|
|
||||||
std::vector<AnfNodePtr> bn2_relu_inputs = {NewValueNode(std::make_shared<Primitive>(kBN2ReLUOpName))};
|
|
||||||
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_relu_inputs));
|
|
||||||
for (size_t i = 2; i <= 5; i++) {
|
|
||||||
bn2_relu_inputs.push_back(bn_node->input(i));
|
|
||||||
}
|
|
||||||
auto bn2_relu = func_graph->NewCNode(bn2_relu_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(bn2_relu);
|
|
||||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
||||||
bn2_relu->set_kernel_info(kernel_info);
|
|
||||||
auto types = {AnfAlgo::GetOutputInferDataType(relu_node, 0), AnfAlgo::GetOutputInferDataType(bn_node, 1),
|
|
||||||
AnfAlgo::GetOutputInferDataType(bn_node, 2), AnfAlgo::GetOutputInferDataType(bn_node, 4)};
|
|
||||||
auto shapes = {AnfAlgo::GetOutputInferShape(relu_node, 0), AnfAlgo::GetOutputInferShape(bn_node, 1),
|
|
||||||
AnfAlgo::GetOutputInferShape(bn_node, 2), AnfAlgo::GetOutputInferShape(bn_node, 4)};
|
|
||||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn2_relu.get());
|
|
||||||
// Set attr for bn2_add_relu
|
|
||||||
AnfAlgo::CopyNodeAttrs(bn_node, bn2_relu);
|
|
||||||
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_relu);
|
|
||||||
|
|
||||||
CreateMultipleOutputsOfAnfNode(func_graph, bn2_relu, kBn2ReluOutputNum, bn2_relu_outputs);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const BaseRef ConvBnReluFusion::DefinePattern() const {
|
|
||||||
VarPtr Xs = std::make_shared<SeqVar>();
|
|
||||||
VarPtr Ys = std::make_shared<SeqVar>();
|
|
||||||
VarPtr Z = std::make_shared<Var>();
|
|
||||||
MS_EXCEPTION_IF_NULL(Xs);
|
|
||||||
MS_EXCEPTION_IF_NULL(Ys);
|
|
||||||
MS_EXCEPTION_IF_NULL(Z);
|
|
||||||
return VectorRef(
|
|
||||||
{prim::kPrimRelu,
|
|
||||||
PatternListType({prim::kPrimTupleGetItem,
|
|
||||||
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}), Z})});
|
|
||||||
}
|
|
||||||
|
|
||||||
const AnfNodePtr ConvBnReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
||||||
const EquivPtr &) const {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
|
|
||||||
CNodePtr relu_node = nullptr;
|
|
||||||
CNodePtr bn_node = nullptr;
|
|
||||||
CNodePtr conv_node = nullptr;
|
|
||||||
std::tie(relu_node, bn_node, conv_node) = GetPrevNodes(node);
|
|
||||||
|
|
||||||
auto manager = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
|
||||||
CreateOutputsOfConvBn1(func_graph, conv_node, bn_node, &conv_bn1_outputs);
|
|
||||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "conv_bn1 outputs has wrong size: " << conv_bn1_outputs.size();
|
|
||||||
}
|
|
||||||
(void)manager->Replace(conv_node, conv_bn1_outputs[0]);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> bn2_relu_outputs;
|
|
||||||
CreateOutputsOfBn2Relu(func_graph, conv_bn1_outputs, bn_node, relu_node, &bn2_relu_outputs);
|
|
||||||
if (bn2_relu_outputs.size() != kBn2ReluOutputNum) {
|
|
||||||
MS_LOG(EXCEPTION) << "bn2_relu outputs has wrong size: " << bn2_relu_outputs.size();
|
|
||||||
}
|
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
|
||||||
bn2_relu_outputs[0],
|
|
||||||
bn2_relu_outputs[1],
|
|
||||||
bn2_relu_outputs[2],
|
|
||||||
conv_bn1_outputs[2],
|
|
||||||
bn2_relu_outputs[3]};
|
|
||||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
|
||||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
||||||
(void)manager->Replace(bn_node, make_tuple);
|
|
||||||
return bn2_relu_outputs[0];
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
@ -1,33 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
|
||||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
|
||||||
|
|
||||||
#include "pre_activate/common/optimizer.h"
|
|
||||||
#include "pre_activate/common/helper.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class ConvBnReluFusion : public PatternProcessPass {
|
|
||||||
public:
|
|
||||||
explicit ConvBnReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_relu_fusion", multigraph) {}
|
|
||||||
~ConvBnReluFusion() override = default;
|
|
||||||
const BaseRef DefinePattern() const override;
|
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
|
@ -1,77 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "common/backend_common_test.h"
|
|
||||||
#include "operator/ops.h"
|
|
||||||
#include "debug/anf_ir_dump.h"
|
|
||||||
#include "common/py_func_graph_fetcher.h"
|
|
||||||
#include "pre_activate/common/optimizer.h"
|
|
||||||
#include "pre_activate/common/pass_manager.h"
|
|
||||||
#include "session/anf_runtime_algorithm.h"
|
|
||||||
#include "device/kernel_info.h"
|
|
||||||
|
|
||||||
#define private public
|
|
||||||
#define protected public
|
|
||||||
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
|
|
||||||
#undef private
|
|
||||||
#undef protected
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
|
||||||
|
|
||||||
class TestHWConvBnFusion : public BackendCommon {
|
|
||||||
public:
|
|
||||||
TestHWConvBnFusion() : getPyFun_("gtest_input.pre_activate.ir_fusion_test", true) {}
|
|
||||||
~TestHWConvBnFusion() override = default;
|
|
||||||
|
|
||||||
UT::PyFuncGraphFetcher getPyFun_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(TestHWConvBnFusion, test_conv_bn_fusion) {
|
|
||||||
/*
|
|
||||||
* def before(x, y):
|
|
||||||
* conv_output = conv(x, y)
|
|
||||||
* bn_output = bn(conv_output)
|
|
||||||
* item0 = tuple_getitem(bn_output, 0)
|
|
||||||
* item1 = tuple_getitem(bn_output, 3)
|
|
||||||
* item2 = tuple_getitem(bn_output, 4)
|
|
||||||
* res = make_tuple(item0, item1, item2)
|
|
||||||
* return res
|
|
||||||
*/
|
|
||||||
getPyFun_.SetDoResolve(true);
|
|
||||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "before");
|
|
||||||
std::vector<int> shp_x{32, 3, 224, 224};
|
|
||||||
std::vector<int> shp_w{64, 3, 7, 7};
|
|
||||||
std::vector<int> shp_b{64};
|
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
|
||||||
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w);
|
|
||||||
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
|
|
||||||
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
|
|
||||||
auto fg = GetKernelGraph(g, args_spec_list);
|
|
||||||
|
|
||||||
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
||||||
auto pass_manager = std::make_shared<opt::PassManager>();
|
|
||||||
auto conv_bn_fusion_pass = std::make_shared<opt::ConvBnFusion>();
|
|
||||||
pass_manager->AddPass(conv_bn_fusion_pass);
|
|
||||||
graph_optimizer->AddPassManager(pass_manager);
|
|
||||||
auto new_g = graph_optimizer->Optimize(fg);
|
|
||||||
|
|
||||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "after");
|
|
||||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue