You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/ir/fuse_bn_act_pass.cc

333 lines
13 KiB

// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_bn_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace framework {
namespace ir {
void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
// forward
std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAct(graph, act_types);
// backward
std::unordered_set<std::string> act_grad_types = {"relu_grad"};
graph = FuseBatchNormActGrad(graph, act_grad_types);
#endif
#endif
}
// act(bn(x))
ir::Graph *FuseBatchNormActPass::FuseBatchNormAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument(
"The input graph of FuseBatchNormAct should not be nullptr."));
FusePassBase::Init("bn_act", graph);
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("bn_act/x")
->AsInput()
->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
patterns::BatchNormAct bn_act_pattern(gpd.mutable_pattern(), "bn_act");
bn_act_pattern(x, act_types);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormAct fuse";
// BN inputs
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, bn_act_pattern);
// BN outputs
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);
std::string bn_x_n = subgraph.at(x)->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_variance_n = bn_variance->Name();
std::string bn_mean_n = bn_mean->Name();
std::string bn_mean_out_n = bn_mean_out->Name();
std::string bn_variance_out_n = bn_variance_out->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string bn_out_n = bn_out->Name();
std::string act_out_n = act_out->Name();
Node *fused_bn_act_node = CreateFusedBatchNormActNode(
g, act, batch_norm, bn_x_n, bn_scale_n, bn_bias_n, bn_variance_n,
bn_mean_n, bn_mean_out_n, bn_variance_out_n, bn_saved_variance_n,
bn_saved_mean_n, bn_reserve_space_n, act_out_n);
VLOG(4) << "\n\t " << bn_x_n << ", " << bn_scale_n << ", " << bn_bias_n
<< ", " << bn_variance_n << " and " << bn_mean_n << " -> "
<< batch_norm->Name() << " -> " << bn_mean_out_n << ", "
<< bn_variance_out_n << ", " << bn_saved_variance_n << ", "
<< bn_saved_mean_n << ", " << bn_reserve_space_n << " and "
<< bn_out_n << "\n"
<< "\t " << bn_out_n << " -> " << act->Name() << " -> "
<< act_out_n;
ReLinkNodes(g, bn_out, batch_norm, act, fused_bn_act_node);
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
return graph;
}
Node *FuseBatchNormActPass::CreateFusedBatchNormActNode(
Graph *g, const Node *act, const Node *bn, const std::string &bn_x_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_variance_n, const std::string &bn_mean_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n, const std::string &bn_saved_mean_n,
const std::string &bn_reserve_space_n, const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({bn_x_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("Mean", std::vector<std::string>({bn_mean_n}));
desc.SetInput("Variance", std::vector<std::string>({bn_variance_n}));
desc.SetOutput("Y", std::vector<std::string>({act_out_n}));
desc.SetOutput("MeanOut", std::vector<std::string>({bn_mean_out_n}));
desc.SetOutput("VarianceOut", std::vector<std::string>({bn_variance_out_n}));
desc.SetOutput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetOutput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetOutput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetType("fused_batch_norm_act");
desc.SetAttr("act_type", act->Name());
// Set attrs
for (auto &n : {act->Op(), bn->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_bn_act_node = g->CreateOpNode(&desc);
return fused_bn_act_node;
}
// the backward of act(bn(x))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// bn_grad: in["X", "Y@GRAD", "Scale", "Bias", "SavedMean", "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Scale@GRAD", "Bias@GRAD"]
ir::Graph *FuseBatchNormActPass::FuseBatchNormActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of FuseBatchNormActGrad should not be nullptr."));
FusePassBase::Init("bn_act_grad", graph);
GraphPatternDetector gpd;
auto *d_act_out =
gpd.mutable_pattern()
->NewNode("bn_act_grad/x")
->AsInput()
->assert_is_ops_input(act_grad_types, GradVarName("Out"));
patterns::BatchNormActGrad bn_act_grad_pattern(gpd.mutable_pattern(),
"bn_act_grad");
bn_act_grad_pattern(d_act_out, act_grad_types);
int found_bn_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormActGrad fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(batch_norm_grad, batch_norm_grad,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_x, bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_x, d_bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_scale, d_bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_bias, d_bn_bias, bn_act_grad_pattern);
std::string d_act_out_n = subgraph.at(d_act_out)->Name(); // Y@GRAD
std::string act_out_n = act_out->Name(); // Y
std::string d_itermediate_out_n = d_itermediate_out->Name();
std::string bn_x_n = bn_x->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string d_bn_x_n = d_bn_x->Name();
std::string d_bn_scale_n = d_bn_scale->Name();
std::string d_bn_bias_n = d_bn_bias->Name();
OpDesc desc;
desc.SetType("fused_batch_norm_act_grad");
desc.SetInput("X", {bn_x_n});
desc.SetInput("Y", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Y"), std::vector<std::string>({d_act_out_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetInput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetInput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_bn_x_n}));
desc.SetOutput(GradVarName("Scale"),
std::vector<std::string>({d_bn_scale_n}));
desc.SetOutput(GradVarName("Bias"),
std::vector<std::string>({d_bn_bias_n}));
std::string act = act_grad->Name();
act = act.substr(0, act.length() - 5); // remove "_grad"
desc.SetAttr("act_type", act);
for (auto &n : {act_grad->Op(), batch_norm_grad->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< bn_x_n << ", " << d_itermediate_out_n << ", " << bn_scale_n
<< ", " << bn_bias_n << ", " << bn_saved_mean_n << ", "
<< bn_saved_variance_n << " and " << bn_reserve_space_n << " -> "
<< batch_norm_grad->Name() << " -> " << d_bn_x_n << ", "
<< d_bn_scale_n << " and " << d_bn_bias_n;
ReLinkNodes(g, d_itermediate_out, act_grad, batch_norm_grad, fused_node);
found_bn_act_count++;
};
gpd(graph, handler);
AddStatis(found_bn_act_count);
return graph;
}
void FuseBatchNormActPass::ReLinkNodes(Graph *graph,
const Node *intermediate_out, Node *op_1,
Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
// intermediate_out or ctr_var
auto result_iter =
std::find_if(op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
}
for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}
for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
GraphSafeRemoveNodes(graph, nodes2delete);
}
std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.",
cur_node->Name()));
return new_list;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_bn_act_pass, paddle::framework::ir::FuseBatchNormActPass);