parent
9442516f56
commit
4acb61d59d
@ -1,157 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/ir_fusion/conv_bn_add_relu_fusion.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kBn2AddReluOutputNum = 4;
|
||||
enum Bn2AddReluOutput {
|
||||
kBn2AddReluOutput = 0,
|
||||
kBn2AddReluRunningMean,
|
||||
kBn2AddReluRunningVariance,
|
||||
kBn2AddReluSaveInvVariance,
|
||||
};
|
||||
|
||||
std::tuple<CNodePtr, CNodePtr, CNodePtr, CNodePtr> GetUsedCNode(const AnfNodePtr &node) {
|
||||
auto relu_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kReluInputNum);
|
||||
MS_EXCEPTION_IF_NULL(relu_cnode);
|
||||
auto add_cnode = CheckAnfNodeIfCNodeAndInputSize(relu_cnode->input(1), kAddInputNum);
|
||||
MS_EXCEPTION_IF_NULL(add_cnode);
|
||||
auto add_input1_cnode = CheckAnfNodeIfCNodeAndInputSize(add_cnode->input(1), kTupleGetitemInputNum);
|
||||
MS_EXCEPTION_IF_NULL(add_input1_cnode);
|
||||
auto bn_cnode = CheckAnfNodeIfCNodeAndInputSize(add_input1_cnode->input(1), kBnInputNum);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
auto conv_cnode = CheckAnfNodeIfCNodeAndInputSize(bn_cnode->input(kX), kConvInputNum);
|
||||
|
||||
return std::make_tuple(conv_cnode, bn_cnode, add_cnode, relu_cnode);
|
||||
}
|
||||
|
||||
void CreateOutputsOfBn2AddRelu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
|
||||
const CNodePtr &bn_node, const CNodePtr &add_node, const CNodePtr &relu_node,
|
||||
std::vector<AnfNodePtr> *bn2_add_relu_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(add_node);
|
||||
MS_EXCEPTION_IF_NULL(relu_node);
|
||||
MS_EXCEPTION_IF_NULL(bn_node);
|
||||
auto prim = std::make_shared<Primitive>(kBN2AddReluOpName);
|
||||
std::vector<AnfNodePtr> bn2_add_relu_inputs = {NewValueNode(prim)};
|
||||
// The inputs of bn2_add_relu are from the outputs of conv_bn1, the 2nd input of add, and the 2nd to 5th inputs of bn
|
||||
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_add_relu_inputs));
|
||||
bn2_add_relu_inputs.push_back(add_node->input(2));
|
||||
for (size_t i = kX + 1; i <= kVariance; i++) {
|
||||
bn2_add_relu_inputs.push_back(bn_node->input(i));
|
||||
}
|
||||
auto bn2_add_relu_cnode = func_graph->NewCNode(bn2_add_relu_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn2_add_relu_cnode);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
bn2_add_relu_cnode->set_kernel_info(kernel_info);
|
||||
|
||||
// Set attr for bn2_add_relu
|
||||
AnfAlgo::CopyNodeAttrs(bn_node, bn2_add_relu_cnode);
|
||||
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_add_relu_cnode);
|
||||
|
||||
// Set abstract of bn2_add_relu
|
||||
auto bn_abstract_tuple = dyn_cast<abstract::AbstractTuple>(bn_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
|
||||
if (bn_abstract_tuple->elements().size() != kBnOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "Abstract tuple size of FusedBatchNorm must be " << kBnOutputNum << ", but it is "
|
||||
<< bn_abstract_tuple->elements().size();
|
||||
}
|
||||
auto relu_abstract = relu_node->abstract();
|
||||
MS_EXCEPTION_IF_NULL(relu_abstract);
|
||||
// The abstracts of node bn2_add_relu are from the some abstracts of bn and relu nodes.
|
||||
AbstractBasePtrList bn2_add_relu_abstract_list{relu_abstract, bn_abstract_tuple->elements()[kRunningMean],
|
||||
bn_abstract_tuple->elements()[kRunningVariance],
|
||||
bn_abstract_tuple->elements()[kSaveInvVariance]};
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(bn2_add_relu_abstract_list);
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
bn2_add_relu_cnode->set_abstract(abstract_tuple);
|
||||
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, bn2_add_relu_cnode, kBn2AddReluOutputNum, bn2_add_relu_outputs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConvBnAddReluFusion::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VarPtr W = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(W);
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
VarPtr Zs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Zs);
|
||||
|
||||
return VectorRef(
|
||||
{prim::kPrimRelu,
|
||||
PatternListType(
|
||||
{prim::kPrimTensorAdd,
|
||||
PatternListType({prim::kPrimTupleGetItem,
|
||||
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Ys}), Zs}),
|
||||
W}),
|
||||
X})});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvBnAddReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
CNodePtr conv_cnode = nullptr;
|
||||
CNodePtr bn_cnode = nullptr;
|
||||
CNodePtr add_cnode = nullptr;
|
||||
CNodePtr relu_cnode = nullptr;
|
||||
std::tie(conv_cnode, bn_cnode, add_cnode, relu_cnode) = GetUsedCNode(node);
|
||||
// Create conv_bn1 node and get outputs of conv_bn1
|
||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
||||
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
|
||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
|
||||
<< conv_bn1_outputs.size();
|
||||
}
|
||||
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by others
|
||||
(void)manager->Replace(conv_cnode, conv_bn1_outputs[kData]);
|
||||
|
||||
// Create bn2_add_relu node and get outputs of bn2_add_relu
|
||||
std::vector<AnfNodePtr> bn2_add_relu_outputs;
|
||||
CreateOutputsOfBn2AddRelu(func_graph, conv_bn1_outputs, bn_cnode, add_cnode, relu_cnode, &bn2_add_relu_outputs);
|
||||
if (bn2_add_relu_outputs.size() != kBn2AddReluOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node bn2_add_relu must be " << kBn2AddReluOutputNum << ", but it is "
|
||||
<< bn2_add_relu_outputs.size();
|
||||
}
|
||||
|
||||
// Create a make_tuple to replace the bn node here, the outputs are from node bn2_add_relu and conv_bn1.
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
||||
bn2_add_relu_outputs[kBn2AddReluOutput],
|
||||
bn2_add_relu_outputs[kBn2AddReluRunningMean],
|
||||
bn2_add_relu_outputs[kBn2AddReluRunningVariance],
|
||||
conv_bn1_outputs[kMean],
|
||||
bn2_add_relu_outputs[kBn2AddReluSaveInvVariance]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
(void)manager->Replace(bn_cnode, make_tuple);
|
||||
return bn2_add_relu_outputs[kBn2AddReluOutput];
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvBnAddReluFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvBnAddReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_add_relu_fusion", multigraph) {}
|
||||
~ConvBnAddReluFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_ADD_RELU_FUSION_H_
|
@ -1,93 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ConvBnFusion::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
return VectorRef({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvBnFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The bn node is expected to be a cnode";
|
||||
}
|
||||
auto bn_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (bn_cnode->inputs().size() < kVariance + 1) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(bn_cnode);
|
||||
MS_LOG(EXCEPTION) << "op[" << op_name << "] has less than " << kVariance + 1 << " inputs.";
|
||||
}
|
||||
AnfNodePtr conv_node = bn_cnode->input(kX);
|
||||
MS_EXCEPTION_IF_NULL(conv_node);
|
||||
if (!conv_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The conv node is expected to be a cnode";
|
||||
}
|
||||
auto conv_cnode = conv_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// Create conv_bn1 node and get outputs of conv_bn1
|
||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
||||
CreateOutputsOfConvBn1(func_graph, conv_cnode, bn_cnode, &conv_bn1_outputs);
|
||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node conv_bn1 must be " << kConvBn1OutputNum << ", but it is "
|
||||
<< conv_bn1_outputs.size();
|
||||
}
|
||||
// Replace conv_node with the output 0 of conv_bn1 directly because the conv node may be used as input by other
|
||||
(void)manager->Replace(conv_node, conv_bn1_outputs[kData]);
|
||||
|
||||
// Create bn2 node and get outputs of bn2
|
||||
std::vector<AnfNodePtr> bn2_outputs;
|
||||
std::vector<AnfNodePtr> bn1_outputs = {conv_bn1_outputs[2], conv_bn1_outputs[1]};
|
||||
CreateOutputsOfFusedBn2(func_graph, bn1_outputs, bn_cnode, &bn2_outputs);
|
||||
if (bn2_outputs.size() != kBN2OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node fusedbn2 must be " << kBN2OutputNum << ", but it is "
|
||||
<< bn2_outputs.size();
|
||||
}
|
||||
|
||||
// Create bn3 node and get outputs of bn3
|
||||
std::vector<AnfNodePtr> bn3_outputs;
|
||||
CreateOutputsOfFusedBn3(func_graph, conv_bn1_outputs[0], bn1_outputs, bn2_outputs, bn_cnode, &bn3_outputs);
|
||||
|
||||
if (bn3_outputs.size() != kBN3OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node fusedbn3 must be " << kBN3OutputNum << ", but it is "
|
||||
<< bn3_outputs.size();
|
||||
}
|
||||
|
||||
// Return a make_tuple to replace the bn node here, the outputs are from node bn2 and conv_bn1.
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
||||
bn3_outputs[0],
|
||||
bn2_outputs[1],
|
||||
bn2_outputs[2],
|
||||
conv_bn1_outputs[2],
|
||||
bn2_outputs[0]};
|
||||
|
||||
return func_graph->NewCNode(make_tuple_inputs);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,34 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvBnFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvBnFusion(bool multigraph = true) : PatternProcessPass("conv_bn_fusion", multigraph) {}
|
||||
~ConvBnFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_FUSION_H_
|
@ -1,140 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/conv_bn_relu_fusion.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "common/utils.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
std::tuple<CNodePtr, CNodePtr, CNodePtr> GetPrevNodes(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto relu_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(relu_node);
|
||||
if (relu_node->inputs().size() < kReluInputNum) {
|
||||
MS_LOG(EXCEPTION) << "relu has wrong input size";
|
||||
}
|
||||
auto tuple_getitem_anf = relu_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem_anf);
|
||||
auto tuple_getitem = tuple_getitem_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) {
|
||||
MS_LOG(EXCEPTION) << "tuple getitem has wrong input size";
|
||||
}
|
||||
auto bn_node_anf = tuple_getitem->input(1);
|
||||
MS_EXCEPTION_IF_NULL(bn_node_anf);
|
||||
auto bn_node = bn_node_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(bn_node);
|
||||
if (bn_node->inputs().size() < kBnInputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn_node has wrong input size";
|
||||
}
|
||||
auto conv_node_anf = bn_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(conv_node_anf);
|
||||
CNodePtr conv_node = conv_node_anf->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_node);
|
||||
return std::make_tuple(bn_node, bn_node, conv_node);
|
||||
}
|
||||
|
||||
void CreateOutputsOfBn2Relu(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &conv_bn1_outputs,
|
||||
const CNodePtr &bn_node, const CNodePtr &relu_node,
|
||||
std::vector<AnfNodePtr> *bn2_relu_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_node);
|
||||
MS_EXCEPTION_IF_NULL(relu_node);
|
||||
// The inputs of bn2_relu are from the outputs of conv_bn1 and the 2nd to 5th inputs of bn
|
||||
std::vector<AnfNodePtr> bn2_relu_inputs = {NewValueNode(std::make_shared<Primitive>(kBN2ReLUOpName))};
|
||||
(void)std::copy(conv_bn1_outputs.begin(), conv_bn1_outputs.end(), std::back_inserter(bn2_relu_inputs));
|
||||
for (size_t i = 2; i <= 5; i++) {
|
||||
bn2_relu_inputs.push_back(bn_node->input(i));
|
||||
}
|
||||
auto bn2_relu = func_graph->NewCNode(bn2_relu_inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn2_relu);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
bn2_relu->set_kernel_info(kernel_info);
|
||||
auto types = {AnfAlgo::GetOutputInferDataType(relu_node, 0), AnfAlgo::GetOutputInferDataType(bn_node, 1),
|
||||
AnfAlgo::GetOutputInferDataType(bn_node, 2), AnfAlgo::GetOutputInferDataType(bn_node, 4)};
|
||||
auto shapes = {AnfAlgo::GetOutputInferShape(relu_node, 0), AnfAlgo::GetOutputInferShape(bn_node, 1),
|
||||
AnfAlgo::GetOutputInferShape(bn_node, 2), AnfAlgo::GetOutputInferShape(bn_node, 4)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn2_relu.get());
|
||||
// Set attr for bn2_add_relu
|
||||
AnfAlgo::CopyNodeAttrs(bn_node, bn2_relu);
|
||||
AnfAlgo::CopyNodeAttr("epsilon", "eps", bn_node, bn2_relu);
|
||||
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, bn2_relu, kBn2ReluOutputNum, bn2_relu_outputs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ConvBnReluFusion::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
VarPtr Z = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
MS_EXCEPTION_IF_NULL(Z);
|
||||
return VectorRef(
|
||||
{prim::kPrimRelu,
|
||||
PatternListType({prim::kPrimTupleGetItem,
|
||||
PatternListType({prim::kPrimFusedBatchNorm, PatternListType({prim::kPrimConv2D, Xs}), Ys}), Z})});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvBnReluFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
CNodePtr relu_node = nullptr;
|
||||
CNodePtr bn_node = nullptr;
|
||||
CNodePtr conv_node = nullptr;
|
||||
std::tie(relu_node, bn_node, conv_node) = GetPrevNodes(node);
|
||||
|
||||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
std::vector<AnfNodePtr> conv_bn1_outputs;
|
||||
CreateOutputsOfConvBn1(func_graph, conv_node, bn_node, &conv_bn1_outputs);
|
||||
if (conv_bn1_outputs.size() != kConvBn1OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "conv_bn1 outputs has wrong size: " << conv_bn1_outputs.size();
|
||||
}
|
||||
(void)manager->Replace(conv_node, conv_bn1_outputs[0]);
|
||||
|
||||
std::vector<AnfNodePtr> bn2_relu_outputs;
|
||||
CreateOutputsOfBn2Relu(func_graph, conv_bn1_outputs, bn_node, relu_node, &bn2_relu_outputs);
|
||||
if (bn2_relu_outputs.size() != kBn2ReluOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn2_relu outputs has wrong size: " << bn2_relu_outputs.size();
|
||||
}
|
||||
std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple),
|
||||
bn2_relu_outputs[0],
|
||||
bn2_relu_outputs[1],
|
||||
bn2_relu_outputs[2],
|
||||
conv_bn1_outputs[2],
|
||||
bn2_relu_outputs[3]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
(void)manager->Replace(bn_node, make_tuple);
|
||||
return bn2_relu_outputs[0];
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -1,33 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
||||
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvBnReluFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvBnReluFusion(bool multigraph = true) : PatternProcessPass("conv_bn_relu_fusion", multigraph) {}
|
||||
~ConvBnReluFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONV_BN_RELU_FUSION_H_
|
@ -1,77 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/backend_common_test.h"
|
||||
#include "operator/ops.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/common/pass_manager.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "pre_activate/ascend/ir_fusion/conv_bn_fusion.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
|
||||
|
||||
class TestHWConvBnFusion : public BackendCommon {
|
||||
public:
|
||||
TestHWConvBnFusion() : getPyFun_("gtest_input.pre_activate.ir_fusion_test", true) {}
|
||||
~TestHWConvBnFusion() override = default;
|
||||
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
};
|
||||
|
||||
TEST_F(TestHWConvBnFusion, test_conv_bn_fusion) {
|
||||
/*
|
||||
* def before(x, y):
|
||||
* conv_output = conv(x, y)
|
||||
* bn_output = bn(conv_output)
|
||||
* item0 = tuple_getitem(bn_output, 0)
|
||||
* item1 = tuple_getitem(bn_output, 3)
|
||||
* item2 = tuple_getitem(bn_output, 4)
|
||||
* res = make_tuple(item0, item1, item2)
|
||||
* return res
|
||||
*/
|
||||
getPyFun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "before");
|
||||
std::vector<int> shp_x{32, 3, 224, 224};
|
||||
std::vector<int> shp_w{64, 3, 7, 7};
|
||||
std::vector<int> shp_b{64};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
auto w_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_w);
|
||||
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, w_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pass_manager = std::make_shared<opt::PassManager>();
|
||||
auto conv_bn_fusion_pass = std::make_shared<opt::ConvBnFusion>();
|
||||
pass_manager->AddPass(conv_bn_fusion_pass);
|
||||
graph_optimizer->AddPassManager(pass_manager);
|
||||
auto new_g = graph_optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_conv_bn_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue