|
|
@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
|
|
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
Graph *g) {
|
|
|
|
Graph *g) {
|
|
|
|
VLOG(40) << "handle FuseElewiseAddAct fuse";
|
|
|
|
VLOG(4) << "handle FuseElewiseAddAct fuse";
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
|
|
|
|
elewise_add_act_pattern);
|
|
|
|
elewise_add_act_pattern);
|
|
|
@ -77,10 +77,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
|
|
|
|
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
|
|
|
|
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
|
|
|
|
g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n);
|
|
|
|
g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n);
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(40) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> "
|
|
|
|
VLOG(4) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> "
|
|
|
|
<< ele_add->Name() << " -> " << ele_out_n << "\n"
|
|
|
|
<< ele_add->Name() << " -> " << ele_out_n << "\n"
|
|
|
|
<< "\t " << ele_out_n << " -> " << act->Name() << " -> "
|
|
|
|
<< "\t " << ele_out_n << " -> " << act->Name() << " -> "
|
|
|
|
<< act_out_n;
|
|
|
|
<< act_out_n;
|
|
|
|
|
|
|
|
|
|
|
|
ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node);
|
|
|
|
ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node);
|
|
|
|
found_elewise_add_act_count++;
|
|
|
|
found_elewise_add_act_count++;
|
|
|
@ -113,7 +113,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
|
|
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
Graph *g) {
|
|
|
|
Graph *g) {
|
|
|
|
VLOG(40) << "handle FuseElewiseAddAct fuse";
|
|
|
|
VLOG(4) << "handle FuseElewiseAddAct fuse";
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
|
|
|
@ -129,9 +129,9 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
|
|
|
|
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
|
|
|
|
Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
|
|
|
|
g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n);
|
|
|
|
g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n);
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(40) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n
|
|
|
|
VLOG(4) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n
|
|
|
|
<< "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> "
|
|
|
|
<< "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> "
|
|
|
|
<< ele_add->Name() << " -> " << elewise_add_out_n;
|
|
|
|
<< ele_add->Name() << " -> " << elewise_add_out_n;
|
|
|
|
|
|
|
|
|
|
|
|
ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node);
|
|
|
|
ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node);
|
|
|
|
found_elewise_add_act_count++;
|
|
|
|
found_elewise_add_act_count++;
|
|
|
@ -165,7 +165,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
|
|
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
|
|
|
|
Graph *g) {
|
|
|
|
Graph *g) {
|
|
|
|
VLOG(40) << "handle FuseElewiseAddActGrad1 fuse";
|
|
|
|
VLOG(4) << "handle FuseElewiseAddActGrad1 fuse";
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern);
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
|
|
|
@ -208,10 +208,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
|
|
|
|
|
|
|
|
|
|
|
|
auto fused_node = g->CreateOpNode(&desc);
|
|
|
|
auto fused_node = g->CreateOpNode(&desc);
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(40) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
|
|
|
|
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
|
|
|
|
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
|
|
|
|
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
|
|
|
|
<< d_itermediate_out_n << " and " << act_out_n << " -> "
|
|
|
|
<< d_itermediate_out_n << " and " << act_out_n << " -> "
|
|
|
|
<< ele_add_grad->Name() << " -> " << d_itermediate_out_n;
|
|
|
|
<< ele_add_grad->Name() << " -> " << d_itermediate_out_n;
|
|
|
|
|
|
|
|
|
|
|
|
ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node);
|
|
|
|
ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node);
|
|
|
|
found_elewise_add_act_count++;
|
|
|
|
found_elewise_add_act_count++;
|
|
|
|