fix single-batchnorm-fission && softmax-grad-ext-fusion pass

pull/1996/head
huanghui 5 years ago
parent 62110becf8
commit 88eec2b894

@ -128,7 +128,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
std::string convert_format;
size_t counter = 0;
const size_t counter = 0;
for (const auto &iter : format_counter) {
if (counter < iter.second) {
convert_format = iter.first;

@ -129,18 +129,22 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() < kBatchNormRealInputNum + 1) {
if (cnode->size() < kBatchNormRealInputNum + 1) {
MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum
<< ". The node should not be changed";
return nullptr;
}
if (!GetBoolAttr(cnode, kAttrIsTraining)) {
MS_LOG(INFO) << "is training should be true if do fusion";
return nullptr;
}
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
std::vector<AnfNodePtr> bn_training_reduce_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,

@ -58,6 +58,10 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const
auto input1 = GetAnfNodeByVar(equiv, input1_);
auto input2 = GetAnfNodeByVar(equiv, input2_);
auto sum = GetAnfNodeByVar(equiv, sum_var_);
if (!GetBoolAttr(sum, kAttrKeepDims)) {
MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion";
return nullptr;
}
auto prim = std::make_shared<Primitive>(kSoftmaxGradExtOpName);
auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2});

@ -722,5 +722,16 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
MS_EXCEPTION_IF_NULL(value_node2);
return GetValue<int>(value_node1->value()) < GetValue<int>(value_node2->value());
}
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(INFO) << "node is not a cnode";
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
}
} // namespace opt
} // namespace mindspore

@ -180,6 +180,9 @@ AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node);
// Compare tuple getitem's index, return bool[n1's index < n2's index]
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);
// Get attr which is bool from cnode
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_

@ -17,7 +17,7 @@ from mindspore.ops import operations as P
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
BatchNorm = P.BatchNorm()
BatchNorm = P.BatchNorm(is_training=True)
BNTrainingReduce = Primitive('BNTrainingReduce')
BNTrainingUpdateV3 = Primitive('BNTrainingUpdateV3')

@ -16,7 +16,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P
Mul = P.Mul()
ReduceSum = P.ReduceSum()
ReduceSum = P.ReduceSum(keep_dims=True)
Sub = P.Sub()
SoftmaxGradExt = Primitive('SoftmaxGradExt')
MakeTuple = Primitive('make_tuple')

Loading…
Cancel
Save