You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
491 lines
16 KiB
491 lines
16 KiB
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <map>
|
|
#include <queue>
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
#include "paddle/fluid/framework/ir/graph_helper.h"
|
|
#include "paddle/fluid/framework/ir/pass.h"
|
|
#include "paddle/fluid/framework/op_info.h"
|
|
|
|
// NOTE(dzhwinter): inplace means one op output variable reuse the input space.
|
|
// By our design, one operator only can read its input(const Variable),
|
|
// write its output(non-const Variable). If one operator is inplaced, means
|
|
// user have chance to write the space before reading happens.
|
|
// Especially when some optimize code writing style is applied.
|
|
//
|
|
//
|
|
// /* wrong case in operator */
|
|
// /*In this case, a larger allocation is allocated, input content is lost*/
|
|
// const Tensor* in = ctx.Input<Tensor>("In")
|
|
// Tensor* out = ctx.Output<Tensor>("Out");
|
|
// auto* out_ptr = out->mutable_data<T>(ctx.GetPlace());
|
|
// out_ptr[0] = 0; // input contect is overwrited.
|
|
|
|
// NOTE(dzhwinter):
|
|
// Only for backward compacity and stable. if enable_inplace_whitelist is turn
|
|
// on.
|
|
// only the ops in whitelist will be use inplace strategy.
|
|
// if not, all the op will be inplaced if it registered with InplaceClass
|
|
DEFINE_bool(
|
|
enable_inplace_whitelist, false,
|
|
"If this option turns on, only these op in whitelist can be inplaced."
|
|
"If it turns off, all of the running op can be candidate of inplaced op."
|
|
"Such as scale, elementwise_add"
|
|
"By default, it's turned off");
|
|
|
|
DECLARE_string(memory_optimize_debug);
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
namespace details {
|
|
|
|
// clang-format off
|
|
const std::string kInplacedOpWhiteList[] = { // NOLINT
|
|
"sigmoid",
|
|
"exp",
|
|
"relu",
|
|
"tanh",
|
|
"sqrt",
|
|
"ceil",
|
|
"floor",
|
|
"reciprocal",
|
|
"relu6",
|
|
"soft_relu",
|
|
"hard_sigmoid",
|
|
"batch_norm",
|
|
"batch_norm_grad",
|
|
"sum",
|
|
"sum_grad",
|
|
"scale",
|
|
"reshape",
|
|
"elementwise_add",
|
|
"elementwise_add_grad",
|
|
};
|
|
|
|
// FIXME(zjl): Shapes of in-out of some ops are exactly the same,
|
|
// but the static size during compiling time would be wrong.
|
|
// Use a flag to indicate such ops. Please fix me when found a better way.
|
|
static const std::unordered_set<std::string> kSameShapeOpWhiteSet{ // NOLINT
|
|
"reshape2", "reshape2_grad"
|
|
};
|
|
// clang-format on
|
|
|
|
class InplacePass : public ir::Pass {
|
|
public:
|
|
InplacePass();
|
|
|
|
protected:
|
|
void ApplyImpl(ir::Graph *graph) const override;
|
|
|
|
private:
|
|
// Collect vars that cannot be reused
|
|
// e.g.: subblock ops in/out, distributed ops in/out, op_role_var
|
|
void CollectSkipVars(ir::Graph *graph,
|
|
const std::vector<ir::Node *> &ops) const;
|
|
|
|
// Check whether var_name should be skipped
|
|
bool IsSkipVar(const std::string &var_name) const;
|
|
|
|
// Rename out with name of in, and guarantee that the graph is
|
|
// still a SSA graph
|
|
void RenameInOut(ir::Node *op, ir::Node *in, ir::Node *out) const;
|
|
|
|
// Check whether var is the last version one in SSA graph
|
|
bool IsLastVersionVar(ir::Node *var) const;
|
|
|
|
// Check whether all `ops` is the preceding ops of `op`
|
|
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
|
|
|
|
// Find nodes whose name are equal to the given name
|
|
static std::unordered_set<ir::Node *> FindNodesByName(
|
|
const std::string &name, const std::vector<ir::Node *> &nodes);
|
|
|
|
// Get all versions vars named var_name
|
|
std::vector<ir::Node *> *AllVersionVars(const std::string &var_name) const;
|
|
|
|
private:
|
|
// SSA graph. var_name -> each version of vars
|
|
mutable std::map<std::string, std::vector<ir::Node *>> ssa_map_;
|
|
|
|
// Skip vars, including subblock ops in/out, distributed ops in/out,
|
|
// op_role_var
|
|
mutable std::unordered_set<std::string> skip_vars_;
|
|
|
|
// Op whitelist which should not peform inplace
|
|
// Only enabled when FLAGS_enable_inplace_whitelist is true.
|
|
mutable std::unordered_set<std::string> whitelist_ops_;
|
|
};
|
|
|
|
InplacePass::InplacePass() {
|
|
if (FLAGS_enable_inplace_whitelist) {
|
|
for (auto &s : kInplacedOpWhiteList) {
|
|
whitelist_ops_.emplace(s);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<ir::Node *> *InplacePass::AllVersionVars(
|
|
const std::string &var_name) const {
|
|
auto iter = ssa_map_.find(var_name);
|
|
PADDLE_ENFORCE(iter != ssa_map_.end(), "cannot find var %s in ssa graph",
|
|
var_name);
|
|
PADDLE_ENFORCE(!iter->second.empty(), "var %s is empty in ssa graph",
|
|
var_name);
|
|
return &(iter->second);
|
|
}
|
|
|
|
bool InplacePass::IsSkipVar(const std::string &var_name) const {
|
|
return skip_vars_.count(var_name) > 0;
|
|
}
|
|
|
|
bool InplacePass::IsLastVersionVar(ir::Node *var) const {
|
|
return AllVersionVars(var->Name())->back() == var;
|
|
}
|
|
|
|
bool InplacePass::CheckOpDeps(ir::Node *op,
|
|
const std::vector<ir::Node *> &ops) const {
|
|
std::unordered_set<ir::Node *> other_ops(ops.begin(), ops.end());
|
|
other_ops.erase(op);
|
|
if (other_ops.empty()) return true;
|
|
|
|
// Traverse all preceding ops of op
|
|
std::queue<ir::Node *> queue;
|
|
std::unordered_set<ir::Node *> visited_ops;
|
|
queue.push(op);
|
|
visited_ops.insert(op);
|
|
|
|
// Visit all preceding ops of `op`, and erase it from other_ops if it is
|
|
// inside other_ops. Return true only if other_ops is empty(), which means
|
|
// that all `ops` are preceding ops of `op`.
|
|
while (!queue.empty()) {
|
|
auto *cur_op = queue.front();
|
|
queue.pop();
|
|
|
|
for (auto *in_var : cur_op->inputs) {
|
|
for (auto *in_op : in_var->inputs) {
|
|
if (visited_ops.count(in_op) != 0) {
|
|
continue;
|
|
}
|
|
|
|
visited_ops.insert(in_op);
|
|
queue.push(in_op);
|
|
other_ops.erase(in_op);
|
|
if (other_ops.empty()) return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void InplacePass::CollectSkipVars(ir::Graph *graph,
|
|
const std::vector<ir::Node *> &ops) const {
|
|
// 1. Collect op role vars
|
|
PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars),
|
|
"Graph should have attr %s", details::kMemOptSkipVars);
|
|
auto &mem_opt_whitelist = graph->Get<MemOptSkipVars>(kMemOptSkipVars);
|
|
for (const auto &var : mem_opt_whitelist) {
|
|
skip_vars_.emplace(var);
|
|
}
|
|
|
|
// 2. track the nodes which used by parameter server.
|
|
// these node can not be inplaced, otherwise trainer
|
|
// pserver can not find each other's name.
|
|
// Also check the ops which has sub-block
|
|
auto update_skip_set = [&](ir::Node *node) {
|
|
for (auto &in : node->inputs) {
|
|
if (in->IsVar() && in->Var() != nullptr) {
|
|
skip_vars_.emplace(in->Name());
|
|
}
|
|
}
|
|
for (auto &out : node->outputs) {
|
|
if (out->IsVar() && out->Var() != nullptr) {
|
|
skip_vars_.emplace(out->Name());
|
|
}
|
|
}
|
|
};
|
|
|
|
for (auto *node : ops) {
|
|
if (!node->IsOp()) continue;
|
|
// avoid optimizing the variable used in sub-blocks
|
|
if (OpHasSubBlock(node->Op())) {
|
|
update_skip_set(node);
|
|
continue;
|
|
}
|
|
|
|
auto node_name = node->Name();
|
|
if (node_name == "send" || node_name == "recv" || node_name == "prefetch") {
|
|
update_skip_set(node);
|
|
}
|
|
}
|
|
}
|
|
|
|
void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var,
|
|
ir::Node *out_var) const {
|
|
auto out_var_name = out_var->Name();
|
|
auto in_var_name = in_var->Name();
|
|
|
|
auto &all_out_nodes = *AllVersionVars(out_var_name);
|
|
auto &all_in_nodes = *AllVersionVars(in_var_name);
|
|
|
|
auto iter = std::find(all_out_nodes.begin(), all_out_nodes.end(), out_var);
|
|
PADDLE_ENFORCE(iter != all_out_nodes.end(), "Cannot find out var %s",
|
|
out_var_name);
|
|
|
|
// The following codes are designed to guarantee that ssa_map_ is still
|
|
// an ssa graph after inplace is performed.
|
|
// Step 1: Rename the following versions of out_var as the name of in_var
|
|
// Step 2: Remove the following versions of out_var and append them to in_var
|
|
// Be careful that the inputs of input op of out_var should not be renamed,
|
|
// but outputs should be renamed.
|
|
auto original_iter = iter;
|
|
while (iter != all_out_nodes.end()) {
|
|
auto *node = *iter;
|
|
/* Step 1 */
|
|
node->RenameVar(in_var_name);
|
|
if (iter != original_iter) {
|
|
for (auto *in : node->inputs) {
|
|
if (in->IsOp() && in->Op()) {
|
|
in->Op()->RenameOutput(out_var_name, in_var_name);
|
|
in->Op()->RenameInput(out_var_name, in_var_name);
|
|
in->Op()->Flush();
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto *out : node->outputs) {
|
|
if (out->IsOp() && out->Op()) {
|
|
out->Op()->RenameOutput(out_var_name, in_var_name);
|
|
out->Op()->RenameInput(out_var_name, in_var_name);
|
|
out->Op()->Flush();
|
|
}
|
|
}
|
|
|
|
/* Step 2 */
|
|
all_in_nodes.emplace_back(node);
|
|
++iter;
|
|
}
|
|
|
|
/* Step 2 */
|
|
all_out_nodes.erase(original_iter, all_out_nodes.end());
|
|
|
|
if (all_out_nodes.empty()) {
|
|
ssa_map_.erase(out_var_name);
|
|
}
|
|
op->Op()->RenameOutput(out_var_name, in_var_name);
|
|
op->Op()->Flush();
|
|
}
|
|
|
|
std::unordered_set<ir::Node *> InplacePass::FindNodesByName(
|
|
const std::string &name, const std::vector<ir::Node *> &nodes) {
|
|
std::unordered_set<ir::Node *> ret;
|
|
for (auto *node : nodes) {
|
|
if (node->Name() == name) {
|
|
ret.insert(node);
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
void InplacePass::ApplyImpl(ir::Graph *graph) const {
|
|
// Step 1: topo sort ops, collect skip vars
|
|
auto ops = ir::TopologySortOperations(*graph);
|
|
CollectSkipVars(graph, ops);
|
|
|
|
// Step 2: build ssa var map
|
|
for (auto *op_node : ops) {
|
|
for (auto *in : op_node->inputs) {
|
|
PADDLE_ENFORCE(in->IsVar());
|
|
// Only create a new var node when var first occurs in input of op.
|
|
if (ssa_map_.count(in->Name()) == 0) {
|
|
ssa_map_[in->Name()].emplace_back(in);
|
|
}
|
|
}
|
|
|
|
// Always create a new var node for each output of op.
|
|
for (auto *out : op_node->outputs) {
|
|
PADDLE_ENFORCE(out->IsVar());
|
|
ssa_map_[out->Name()].emplace_back(out);
|
|
}
|
|
}
|
|
|
|
// Step 3: traverse ops and try inplace if possible
|
|
bool use_cuda = Get<bool>(kUseCuda);
|
|
VLOG(4) << "Inplace pass is applied when use_cuda = "
|
|
<< (use_cuda ? "true" : "false");
|
|
|
|
for (auto *op_node : ops) {
|
|
PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr");
|
|
|
|
auto *op_desc = op_node->Op();
|
|
auto op_type = op_desc->Type();
|
|
|
|
// Skip op inside whitelist
|
|
if (whitelist_ops_.count(op_type) > 0) {
|
|
continue;
|
|
}
|
|
|
|
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
|
|
|
|
if (!infer_inplace) {
|
|
continue;
|
|
}
|
|
|
|
auto in_to_outs = infer_inplace(*op_desc, use_cuda);
|
|
for (auto &pair : in_to_outs) {
|
|
auto &in_param = pair.first;
|
|
auto &out_param = pair.second;
|
|
|
|
auto &in_args = op_desc->Input(in_param);
|
|
auto &out_args = op_desc->Output(out_param);
|
|
|
|
if (in_args.empty()) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param
|
|
<< ") is empty in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (out_args.empty()) {
|
|
VLOG(4) << "Cannot inplace because Output(" << out_param
|
|
<< ") is empty in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
auto &in_arg = in_args[0];
|
|
auto &out_arg = out_args[0];
|
|
|
|
if (IsSkipVar(in_arg)) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is skipped in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (IsSkipVar(out_arg)) {
|
|
VLOG(4) << "Cannot inplace because Output(" << out_param
|
|
<< ")=" << out_arg << " is skipped in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (in_arg == out_arg) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is the same with Output(" << out_param << ")=" << out_arg
|
|
<< " in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
auto in_nodes = FindNodesByName(in_arg, op_node->inputs);
|
|
PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s",
|
|
in_param, in_arg, op_type);
|
|
|
|
if (in_nodes.size() > 1) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " occurs in other inputs of " << op_type;
|
|
continue;
|
|
}
|
|
|
|
auto *in_node = *in_nodes.begin();
|
|
|
|
if (!NodeCanReused(in_node)) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is not reusable in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (!IsLastVersionVar(in_node)) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is not the last version in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
// If in_node is used as inputs of many ops, check whether all of that ops
|
|
// depends on op_node. If not, in_node cannot be inplaced.
|
|
if (in_node->outputs.size() > 1 &&
|
|
!CheckOpDeps(op_node, in_node->outputs)) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is not lastly used in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
auto out_nodes = FindNodesByName(out_arg, op_node->outputs);
|
|
PADDLE_ENFORCE(!out_nodes.empty(),
|
|
"Output(%s)=%s cannot be found in op %s", out_param,
|
|
out_arg, op_type);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
out_nodes.size(), 1,
|
|
"Wrong graph: Output(%s)=%s occurs in other outputs of op %s",
|
|
out_param, out_arg, op_type);
|
|
|
|
if (!FindNodesByName(in_arg, op_node->outputs).empty()) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " occurs in output of op " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (!FindNodesByName(out_arg, op_node->inputs).empty()) {
|
|
VLOG(4) << "Cannot inplace because Output(" << in_param
|
|
<< ")=" << out_arg << " occurs in input of op " << op_type;
|
|
continue;
|
|
}
|
|
|
|
auto *out_node = *out_nodes.begin();
|
|
|
|
if (!NodeCanReused(out_node)) {
|
|
VLOG(4) << "Cannot inplace because Output(" << out_param
|
|
<< ")=" << out_arg << " is not reusable in " << op_type;
|
|
continue;
|
|
}
|
|
|
|
if (in_node->Var()->GetType() != out_node->Var()->GetType()) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is not the same type with "
|
|
<< "Output(" << out_param << ")=" << out_arg << " in "
|
|
<< op_type;
|
|
continue;
|
|
}
|
|
|
|
if (details::NodeSize(*in_node->Var()) !=
|
|
details::NodeSize(*out_node->Var()) &&
|
|
kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) {
|
|
VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg
|
|
<< " is not the same size with "
|
|
<< "Output(" << out_param << ")=" << out_arg << " in "
|
|
<< op_type;
|
|
continue;
|
|
}
|
|
|
|
// Debug Interface. Which would be skipped by the pass.
|
|
if (out_arg == FLAGS_memory_optimize_debug) {
|
|
VLOG(4) << "Skiped var by force. FLAGS_memory_optimize_debug="
|
|
<< out_node->Name();
|
|
continue;
|
|
}
|
|
|
|
VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name()
|
|
<< " in " << op_type;
|
|
RenameInOut(op_node, in_node, out_node);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace details
|
|
} // namespace framework
|
|
} // namespace paddle
|
|
|
|
REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass)
|
|
.RequirePassAttr(paddle::framework::details::kUseCuda);
|