|
|
|
|
@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
|
|
|
|
|
std::string d_ele_y_n = d_ele_y->Name();
|
|
|
|
|
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
desc.SetType("fused_elemwise_activation_grad");
|
|
|
|
|
desc.SetType("fused_elemwise_add_activation_grad");
|
|
|
|
|
desc.SetInput("IntermediateOut", {});
|
|
|
|
|
desc.SetInput("X", {});
|
|
|
|
|
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
|
|
|
|
|
@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
|
|
|
|
|
desc.SetInput("Y", std::vector<std::string>({ele_y_n}));
|
|
|
|
|
desc.SetOutput("Out", std::vector<std::string>({act_out_n}));
|
|
|
|
|
desc.SetOutput("IntermediateOut", std::vector<std::string>({ele_out_n}));
|
|
|
|
|
desc.SetType("fused_elemwise_activation");
|
|
|
|
|
desc.SetType("fused_elemwise_add_activation");
|
|
|
|
|
desc.SetAttr("save_intermediate_out", true);
|
|
|
|
|
desc.SetAttr("functor_list", std::vector<std::string>(
|
|
|
|
|
{op_1->Op()->Type(), op_2->Op()->Type()}));
|
|
|
|
|
@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
|
|
|
|
|
std::unordered_set<const Node *> need_removed_nodes;
|
|
|
|
|
for (auto &cur_node : graph->Nodes()) {
|
|
|
|
|
if (cur_node->IsVar()) continue;
|
|
|
|
|
if (cur_node->Name() == "fused_elemwise_activation") {
|
|
|
|
|
if (cur_node->Name() == "fused_elemwise_add_activation") {
|
|
|
|
|
bool save_intermediate_out = BOOST_GET_CONST(
|
|
|
|
|
bool, cur_node->Op()->GetAttr("save_intermediate_out"));
|
|
|
|
|
auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut");
|
|
|
|
|
@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (cur_node->Name() == "fused_elemwise_activation_grad") {
|
|
|
|
|
} else if (cur_node->Name() == "fused_elemwise_add_activation_grad") {
|
|
|
|
|
auto intermediate_out_grad_args =
|
|
|
|
|
cur_node->Op()->Output(GradVarName("IntermediateOut"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|