|
|
|
@ -129,18 +129,22 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::vector<AnfNodePtr> bn_outputs;
|
|
|
|
|
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
|
|
|
|
|
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() < kBatchNormRealInputNum + 1) {
|
|
|
|
|
if (cnode->size() < kBatchNormRealInputNum + 1) {
|
|
|
|
|
MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum
|
|
|
|
|
<< ". The node should not be changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (!GetBoolAttr(cnode, kAttrIsTraining)) {
|
|
|
|
|
MS_LOG(INFO) << "is training should be true if do fusion";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> bn_outputs;
|
|
|
|
|
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
|
|
|
|
|
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_outputs;
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,
|
|
|
|
|