|
|
|
@ -81,7 +81,7 @@ bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) {
|
|
|
|
|
bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm_grad) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto tuple_getitem = node->cast<CNodePtr>();
|
|
|
|
@ -93,12 +93,12 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(batchnormgrad_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(batchnormgrad);
|
|
|
|
|
*batchnormgrad = batchnormgrad_anf->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(*batchnormgrad);
|
|
|
|
|
return CheckBatchNormGrad(graph, *batchnormgrad);
|
|
|
|
|
AnfNodePtr batchnorm_grad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(batchnorm_grad_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(batchnorm_grad);
|
|
|
|
|
*batchnorm_grad = batchnorm_grad_anf->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(*batchnorm_grad);
|
|
|
|
|
return CheckBatchNormGrad(graph, *batchnorm_grad);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
@ -117,11 +117,13 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
CNodePtr batchnormgrad = nullptr;
|
|
|
|
|
if (!NeedFusion(graph, node, &batchnormgrad)) {
|
|
|
|
|
CNodePtr batchnorm_grad = nullptr;
|
|
|
|
|
if (!NeedFusion(graph, node, &batchnorm_grad)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return CreateBNInferGrad(graph, batchnormgrad, node);
|
|
|
|
|
auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node);
|
|
|
|
|
TransferDepend(batchnorm_grad, graph, bn_infer_grad);
|
|
|
|
|
return bn_infer_grad;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|