|
|
|
@ -110,44 +110,57 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
|
|
|
|
|
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
|
|
|
|
|
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
|
|
|
|
|
VectorRef batch_norm_grad =
|
|
|
|
|
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
|
|
|
|
return batch_norm_grad;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
|
|
|
|
|
MS_EXCEPTION_IF_NULL(format_attr);
|
|
|
|
|
auto format = GetValue<std::string>(format_attr);
|
|
|
|
|
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
|
|
|
|
|
return nullptr;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(relu_grad);
|
|
|
|
|
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
|
|
|
|
|
if (relu_users->size() != 2) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// process pattern as Relu(TensorAdd(BN#0, BN#1))
|
|
|
|
|
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
|
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
|
|
|
|
|
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
|
|
|
|
|
VectorRef batch_norm_grad =
|
|
|
|
|
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
|
|
|
|
return batch_norm_grad;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
|
|
|
|
const EquivPtr &) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
|
|
|
|
if (!PatternCheck(graph, node)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(relu_grad);
|
|
|
|
|
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(dy);
|
|
|
|
|
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);
|
|
|
|
|