code review fix for buffer fusion

pull/2047/head
huanghui 5 years ago
parent 9442516f56
commit 4acb61d59d

@ -346,7 +346,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_before.ir";
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
auto fusion_id_allocator = std::make_shared<FusionIdAllocator>();
@ -372,7 +373,8 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_after.ir";
std::string file_path =
save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
}

@ -34,16 +34,22 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(relu_input);
auto add = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add);
auto tuple_getitem = add->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
auto getitem = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2));
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size());

@ -34,12 +34,17 @@ void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const A
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(relu_input);
auto getitem = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2));
output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size());

@ -35,6 +35,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input);
} else {
@ -43,6 +44,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
return;

@ -36,6 +36,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;

@ -35,6 +35,7 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
auto conv = cnode->input(1);
MS_EXCEPTION_IF_NULL(conv);
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);

@ -35,6 +35,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input);
} else {
@ -43,6 +44,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
return;

@ -44,6 +44,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
break;
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;

@ -35,6 +35,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
while (CheckEltWiseNode(manager.get(), eltwise_input)) {
(void)record.insert(eltwise_input);
if (record.size() == MAX_ELTWISE_SIZE) {
@ -57,6 +58,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;

@ -25,6 +25,7 @@ namespace mindspore {
namespace opt {
bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
@ -38,6 +39,7 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt
bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
@ -51,6 +53,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const A
bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}

@ -55,6 +55,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) {
MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
}

@ -35,6 +35,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);

@ -45,6 +45,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
break;
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;

@ -44,6 +44,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
break;
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;

@ -45,6 +45,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
write_input = input_cnode->input(1);
}
MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
fusion_id_allocator->HasFusionIdAttr(write_input)) {
return;
@ -57,6 +58,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
(void)record.insert(write_input);
auto conv_input = conv_cnode->input(1);
MS_EXCEPTION_IF_NULL(conv_input);
if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) ||
fusion_id_allocator->HasFusionIdAttr(conv_input)) {
return;

@ -206,6 +206,7 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto nodes = TopoSort(kernel_graph->get_return());
for (auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
@ -231,6 +232,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
auto fusion_info = buffer_fusion_info.second;
for (const auto &node : fusion_info.anf_nodes) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
@ -253,6 +255,14 @@ bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
auto getitem2 = node2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1);
MS_EXCEPTION_IF_NULL(getitem2);
if (getitem1->size() < kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
<< getitem1->DebugString() << "]";
}
if (getitem2->size() < kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1["
<< getitem2->DebugString() << "]";
}
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
return output_idx1 < output_idx2;
@ -285,6 +295,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
for (auto getitem : tuple_getitem_nodes) {
MS_EXCEPTION_IF_NULL(getitem);
auto getitem_ptr = getitem->cast<CNodePtr>();
auto input2 = getitem_ptr->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2));
@ -313,6 +324,7 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
MS_EXCEPTION_IF_NULL(manager);
for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
auto output = outputs_list[idx];
MS_EXCEPTION_IF_NULL(output);
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
auto real_output = AnfAlgo::VisitKernel(output, 0);
auto output_cnode = output->cast<CNodePtr>();
@ -393,6 +405,7 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos,
int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr,
session::KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list,
buffer_fusion_info.anf_nodes, kernel_graph);

@ -1,157 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h"
#include <memory>
#include <vector>
#include <algorithm>
#include <string>
#include <tuple>
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kBn2AddReluOutputNum = 4;
enum Bn2AddReluOutput {
kBn2AddReluOutput = 0,
kBn2AddReluRunningMean,
kBn2AddReluRunningVariance,
kBn2AddReluSaveInvVariance,
};
std::tuple<CNodePtr, CNodePtr, CNodePtr, CNodePtr> GetUsedCNode(const AnfNodePtr &node) {
auto relu_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kReluInputNum);
MS_EXCEPTION_IF_NULL(relu_cnode);
auto add_cnode = CheckAnfNodeIfCNodeAndInputSize(relu_cnode->input(1), kAddInputNum);
MS_EXCEPTION_IF_NULL(add_cnode);
auto add_input1_cnode = CheckAnfNodeIfCNodeAndInputSize(add_cnode->input(1), kTupleGetitemInputNum);
MS_EXCEPTION_IF_NULL(add_input1_cnode);
auto bn_cnode = CheckAnfNodeIfCNodeAndInputSize(add_input1_cnode->input(1), kBnInputNum);
MS_EXCEPTION_IF_NULL(bn_cnode);
auto conv_cnode = CheckAnfNodeIfCNodeAndInputSize(bn_cnode->input(kX), kConvInputNum);
return std::make_tuple(conv_cnode, bn_cnode, add_cnode, relu_cnode);
}
void CreateOutputsOfBn2AddRelu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
const CNodePtr &bn_node, const CNodePtr &add_node, const CNodePtr &relu_node,
std::vector<AnfNodePtr> *bn2_add_relu_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(add_node);
MS_EXCEPTION_IF_NULL(relu_node);
MS_EXCEPTION_IF_NULL(bn_node);
auto prim = std::make_shared<Primitive>(kBN2AddReluOpName);
std::vector<AnfNodePtr> bn2_add_relu_inputs = {NewValueNode(prim)};
// The inputs of bn2_add_relu are from the outputs of conv_bn1, the 2nd input of add, and the 2nd to 5th inputs of bn
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_add_relu_inputs));
bn2_add_relu_inputs.push_back(add_node->input(2));
for (size_t i = kX + 1; i <= kVariance; i++) {
bn2_add_relu_inputs.push_back(bn_node->input(i));
}
auto bn2_add_relu_cnode = func_graph->NewCNode(bn2_add_relu_inputs);
MS_EXCEPTION_IF_NULL(bn2_add_relu_cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
bn2_add_relu_cnode->set_kernel_info(kernel_info);
// Set attr for bn2_add_relu
AnfAlgo::CopyNodeAttrs(bn_node, bn2_add_relu_cnode);
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_add_relu_cnode);
// Set abstract of bn2_add_relu
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_node->abstract());
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
MS_LOG(EXCEPTION) << "Abstract tuple size of FusedBatchNorm must be " << kBnOutputNum << ", but it is "
<< bn_abstract_tuple->elements().size();
}
auto relu_abstract = relu_node->abstract();
MS_EXCEPTION_IF_NULL(relu_abstract);
// The abstracts of node bn2_add_relu are from the some abstracts of bn and relu nodes.
AbstractBasePtrList bn2_add_relu_abstract_list{relu_abstract, bn_abstract_tuple->elements()[kRunningMean],
bn_abstract_tuple->elements()[kRunningVariance],
bn_abstract_tuple->elements()[kSaveInvVariance]};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(bn2_add_relu_abstract_list);
MS_EXCEPTION_IF_NULL(abstract_tuple);
bn2_add_relu_cnode->set_abstract(abstract_tuple);
CreateMultipleOutputsOfAnfNode(func_graph, bn2_add_relu_cnode, kBn2AddReluOutputNum, bn2_add_relu_outputs);
}
} // namespace
const BaseRef ConvBnAddReluFusion::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VarPtr W = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(W);
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Ys);
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Zs);
return VectorRef(
{prim::kPrimRelu,
PatternListType(
{prim::kPrimTensorAdd,
PatternListType({prim::kPrimTupleGetItem,
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Ys}), Zs}),
W}),
X})});
}
const AnfNodePtr ConvBnAddReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
CNodePtr conv_cnode = nullptr;
CNodePtr bn_cnode = nullptr;
CNodePtr add_cnode = nullptr;
CNodePtr relu_cnode = nullptr;
std::tie(conv_cnode, bn_cnode, add_cnode, relu_cnode) = GetUsedCNode(node);
// Create conv_bn1 node and get outputs of conv_bn1
std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
<< conv_bn1_outputs.size();
}
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by others
(void)manager->Replace(conv_cnode, conv_bn1_outputs[kData]);
// Create bn2_add_relu node and get outputs of bn2_add_relu
std::vector<AnfNodePtr> bn2_add_relu_outputs;
CreateOutputsOfBn2AddRelu(func_graph, conv_bn1_outputs, bn_cnode, add_cnode, relu_cnode, &bn2_add_relu_outputs);
if (bn2_add_relu_outputs.size() != kBn2AddReluOutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn2_add_relu must be " << kBn2AddReluOutputNum << ", but it is "
<< bn2_add_relu_outputs.size();
}
// Create a make_tuple to replace the bn node here, the outputs are from node bn2_add_relu and conv_bn1.
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn2_add_relu_outputs[kBn2AddReluOutput],
bn2_add_relu_outputs[kBn2AddReluRunningMean],
bn2_add_relu_outputs[kBn2AddReluRunningVariance],
conv_bn1_outputs[kMean],
bn2_add_relu_outputs[kBn2AddReluSaveInvVariance]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
(void)manager->Replace(bn_cnode, make_tuple);
return bn2_add_relu_outputs[kBn2AddReluOutput];
}
} // namespace opt
} // namespace mindspore

@ -1,34 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
class ConvBnAddReluFusion : public PatternProcessPass {
public:
explicit ConvBnAddReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_add_relu_fusion", multigraph) {}
~ConvBnAddReluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_

@ -1,93 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
#include <memory>
#include <vector>
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
namespace mindspore {
namespace opt {
const BaseRef ConvBnFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Ys);
return VectorRef({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys});
}
const AnfNodePtr ConvBnFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The bn node is expected to be a cnode";
}
auto bn_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_cnode);
if (bn_cnode->inputs().size() < kVariance + 1) {
auto op_name = AnfAlgo::GetCNodeName(bn_cnode);
MS_LOG(EXCEPTION) << "op[" << op_name << "] has less than " << kVariance + 1 << " inputs.";
}
AnfNodePtr conv_node = bn_cnode->input(kX);
MS_EXCEPTION_IF_NULL(conv_node);
if (!conv_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The conv node is expected to be a cnode";
}
auto conv_cnode = conv_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_cnode);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
// Create conv_bn1 node and get outputs of conv_bn1
std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
<< conv_bn1_outputs.size();
}
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by other
(void)manager->Replace(conv_node, conv_bn1_outputs[kData]);
// Create bn2 node and get outputs of bn2
std::vector<AnfNodePtr> bn2_outputs;
std::vector<AnfNodePtr> bn1_outputs = {conv_bn1_outputs[2], conv_bn1_outputs[1]};
CreateOutputsOfFusedBn2(func_graph, bn1_outputs, bn_cnode, &bn2_outputs);
if (bn2_outputs.size() != kBN2OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node fusedbn2 must be " << kBN2OutputNum << ", but it is "
<< bn2_outputs.size();
}
// Create bn3 node and get outputs of bn3
std::vector<AnfNodePtr> bn3_outputs;
CreateOutputsOfFusedBn3(func_graph, conv_bn1_outputs[0], bn1_outputs, bn2_outputs, bn_cnode, &bn3_outputs);
if (bn3_outputs.size() != kBN3OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node fusedbn3 must be " << kBN3OutputNum << ", but it is "
<< bn3_outputs.size();
}
// Return a make_tuple to replace the bn node here, the outputs are from node bn2 and conv_bn1.
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn3_outputs[0],
bn2_outputs[1],
bn2_outputs[2],
conv_bn1_outputs[2],
bn2_outputs[0]};
return func_graph->NewCNode(make_tuple_inputs);
}
} // namespace opt
} // namespace mindspore

@ -1,34 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
class ConvBnFusion : public PatternProcessPass {
public:
explicit ConvBnFusion(bool multigraph = true) : PatternProcessPass("conv_bn_fusion", multigraph) {}
~ConvBnFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_

@ -1,140 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h"
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <tuple>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
#include "device/kernel_info.h"
namespace mindspore {
namespace opt {
namespace {
std::tuple<CNodePtr, CNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto relu_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(relu_node);
if (relu_node->inputs().size() < kReluInputNum) {
MS_LOG(EXCEPTION) << "relu has wrong input size";
}
auto tuple_getitem_anf = relu_node->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem_anf);
auto tuple_getitem = tuple_getitem_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) {
MS_LOG(EXCEPTION) << "tuple getitem has wrong input size";
}
auto bn_node_anf = tuple_getitem->input(1);
MS_EXCEPTION_IF_NULL(bn_node_anf);
auto bn_node = bn_node_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(bn_node);
if (bn_node->inputs().size() < kBnInputNum) {
MS_LOG(EXCEPTION) << "bn_node has wrong input size";
}
auto conv_node_anf = bn_node->input(1);
MS_EXCEPTION_IF_NULL(conv_node_anf);
CNodePtr conv_node = conv_node_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_node);
return std::make_tuple(bn_node, bn_node, conv_node);
}
void CreateOutputsOfBn2Relu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
const CNodePtr &bn_node, const CNodePtr &relu_node,
std::vector<AnfNodePtr> *bn2_relu_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_node);
MS_EXCEPTION_IF_NULL(relu_node);
// The inputs of bn2_relu are from the outputs of conv_bn1 and the 2nd to 5th inputs of bn
std::vector<AnfNodePtr> bn2_relu_inputs = {NewValueNode(std::make_shared<Primitive>(kBN2ReLUOpName))};
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_relu_inputs));
for (size_t i = 2; i <= 5; i++) {
bn2_relu_inputs.push_back(bn_node->input(i));
}
auto bn2_relu = func_graph->NewCNode(bn2_relu_inputs);
MS_EXCEPTION_IF_NULL(bn2_relu);
auto kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(kernel_info);
bn2_relu->set_kernel_info(kernel_info);
auto types = {AnfAlgo::GetOutputInferDataType(relu_node, 0), AnfAlgo::GetOutputInferDataType(bn_node, 1),
AnfAlgo::GetOutputInferDataType(bn_node, 2), AnfAlgo::GetOutputInferDataType(bn_node, 4)};
auto shapes = {AnfAlgo::GetOutputInferShape(relu_node, 0), AnfAlgo::GetOutputInferShape(bn_node, 1),
AnfAlgo::GetOutputInferShape(bn_node, 2), AnfAlgo::GetOutputInferShape(bn_node, 4)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn2_relu.get());
// Set attr for bn2_add_relu
AnfAlgo::CopyNodeAttrs(bn_node, bn2_relu);
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_relu);
CreateMultipleOutputsOfAnfNode(func_graph, bn2_relu, kBn2ReluOutputNum, bn2_relu_outputs);
}
} // namespace
const BaseRef ConvBnReluFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Z = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Z);
return VectorRef(
{prim::kPrimRelu,
PatternListType({prim::kPrimTupleGetItem,
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}), Z})});
}
const AnfNodePtr ConvBnReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
CNodePtr relu_node = nullptr;
CNodePtr bn_node = nullptr;
CNodePtr conv_node = nullptr;
std::tie(relu_node, bn_node, conv_node) = GetPrevNodes(node);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> conv_bn1_outputs;
CreateOutputsOfConvBn1(func_graph, conv_node, bn_node, &conv_bn1_outputs);
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
MS_LOG(EXCEPTION) << "conv_bn1 outputs has wrong size: " << conv_bn1_outputs.size();
}
(void)manager->Replace(conv_node, conv_bn1_outputs[0]);
std::vector<AnfNodePtr> bn2_relu_outputs;
CreateOutputsOfBn2Relu(func_graph, conv_bn1_outputs, bn_node, relu_node, &bn2_relu_outputs);
if (bn2_relu_outputs.size() != kBn2ReluOutputNum) {
MS_LOG(EXCEPTION) << "bn2_relu outputs has wrong size: " << bn2_relu_outputs.size();
}
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
bn2_relu_outputs[0],
bn2_relu_outputs[1],
bn2_relu_outputs[2],
conv_bn1_outputs[2],
bn2_relu_outputs[3]};
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
(void)manager->Replace(bn_node, make_tuple);
return bn2_relu_outputs[0];
}
} // namespace opt
} // namespace mindspore

@ -1,33 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
class ConvBnReluFusion : public PatternProcessPass {
public:
explicit ConvBnReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_relu_fusion", multigraph) {}
~ConvBnReluFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_

@ -28,6 +28,7 @@
namespace mindspore {
namespace opt {
void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -37,7 +38,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
save_graphs_path = ".";
}
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_common_before.ir";
std::string file_path =
save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
auto optimizer = std::make_shared<GraphOptimizer>();
@ -51,7 +53,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
if (save_graphs) {
std::string file_path = save_graphs_path + "/" + "hwopt_common_after.ir";
std::string file_path =
save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir";
DumpIR(file_path, kernel_graph);
}
}

@ -45,6 +45,7 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
for (auto &nd : node_list) {
MS_EXCEPTION_IF_NULL(nd);
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
auto control_depend = nd->cast<CNodePtr>();
auto prior_node = control_depend->input(kControlDependPriorIndex);
@ -157,6 +158,7 @@ const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(func_graph);
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum);
MS_EXCEPTION_IF_NULL(transop_cnode);
auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum);
auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum);
MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1));
@ -545,14 +547,22 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
MS_EXCEPTION_IF_NULL(a_node);
MS_EXCEPTION_IF_NULL(b_node);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(a_value_node);
auto a_value = a_value_node->value();
MS_EXCEPTION_IF_NULL(a_value);
auto a_prim = a_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(a_prim);
auto b_value_node = b_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(b_value_node);
auto b_value = b_value_node->value();
MS_EXCEPTION_IF_NULL(b_value);
auto b_prim = b_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(b_prim);
return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {

@ -1,77 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "operator/ops.h"
#include "debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass_manager.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
class TestHWConvBnFusion : public BackendCommon {
public:
TestHWConvBnFusion() : getPyFun_("gtest_input.pre_activate.ir_fusion_test", true) {}
~TestHWConvBnFusion() override = default;
UT::PyFuncGraphFetcher getPyFun_;
};
TEST_F(TestHWConvBnFusion, test_conv_bn_fusion) {
/*
* def before(x, y):
* conv_output = conv(x, y)
* bn_output = bn(conv_output)
* item0 = tuple_getitem(bn_output, 0)
* item1 = tuple_getitem(bn_output, 3)
* item2 = tuple_getitem(bn_output, 4)
* res = make_tuple(item0, item1, item2)
* return res
*/
getPyFun_.SetDoResolve(true);
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "before");
std::vector<int> shp_x{32, 3, 224, 224};
std::vector<int> shp_w{64, 3, 7, 7};
std::vector<int> shp_b{64};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w);
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
auto fg = GetKernelGraph(g, args_spec_list);
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
auto pass_manager = std::make_shared<opt::PassManager>();
auto conv_bn_fusion_pass = std::make_shared<opt::ConvBnFusion>();
pass_manager->AddPass(conv_bn_fusion_pass);
graph_optimizer->AddPassManager(pass_manager);
auto new_g = graph_optimizer->Optimize(fg);
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
}
} // namespace opt
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save