You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/details/memory_optimize_helper.cc

573 lines
17 KiB

// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <iterator>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/cpu_info.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h"
#endif // PADDLE_WITH_CUDA
namespace paddle {
namespace framework {
namespace details {
using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
"Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
// 2. topology sort order
auto nodes = graph.Nodes();
std::deque<ir::Node*> ops;
FilterVariables(nodes, [&](ir::Node* op) {
if (op->IsOp() && op->Op() != nullptr) {
ops.emplace_back(op);
}
});
std::unordered_map<ir::Node*, size_t> op_deps;
std::list<ir::Node*> ready_ops;
std::unordered_map<ir::Node*, std::unordered_set<ir::Node*>> pending_ops;
for (auto* op : ops) {
std::unordered_set<ir::Node*> preceding_op;
for (auto* in : op->inputs) {
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp());
preceding_op.emplace(in->inputs[0]);
pending_ops[in->inputs[0]].emplace(op);
}
op_deps[op] = preceding_op.size();
if (preceding_op.empty()) {
ready_ops.emplace_back(op);
}
}
// 3. generated op list based desc order and the topology order
std::vector<ir::Node*> ret;
std::list<OpDesc*> op_descs_list(op_descs.begin(), op_descs.end());
auto update_by_found_node = [&](ir::Node* found_node) {
for (auto* pending_op : pending_ops[found_node]) {
if (--op_deps[pending_op] == 0) {
ready_ops.emplace_back(pending_op);
}
}
ready_ops.remove(found_node);
ret.emplace_back(found_node);
};
while (!ready_ops.empty()) {
bool all_of_ready_op_unmatched = true;
for (auto it = op_descs_list.begin(); it != op_descs_list.end();) {
auto op_desc = *it;
ir::Node* found_node = nullptr;
for (auto* op : ready_ops) {
if (IsSameDesc(op->Op(), op_desc)) {
found_node = op;
break;
}
}
// 3.1 op desc deleted by other pass
if (found_node == nullptr) {
++it;
continue;
} else {
all_of_ready_op_unmatched = false;
it = op_descs_list.erase(it);
}
update_by_found_node(found_node);
}
// 3.2 op descs are added by other pass
// preceding op non empty means some new op descs are
// created, but not contained in return node list.
// these new op desc may depend on each other.
std::list<ir::Node*> prev_ready_ops(ready_ops);
if (all_of_ready_op_unmatched) {
for (auto op : prev_ready_ops) {
update_by_found_node(op);
}
}
}
PADDLE_ENFORCE(std::all_of(
op_deps.begin(), op_deps.end(),
[&](const std::pair<ir::Node*, size_t>& p) { return p.second == 0; }));
return ret;
}
size_t NodeSize(const VarDesc& node) {
auto shape = node.GetShape();
int size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
size_t type_size = SizeOfType(node.GetDataType());
return type_size * std::abs(size);
}
size_t NodeSize(ir::Node* n) {
VarDesc* desc = nullptr;
// some op do not have block pointer
if (n->inputs[0]->Op() != nullptr) {
desc = FindVarDescInBlock(n);
} else {
desc = n->Var();
}
return NodeSize(*desc);
}
std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss;
ss << var->Name();
ss << "[";
try {
auto shape = var->GetShape();
for (size_t i = 0; i < shape.size(); ++i) {
if (i != shape.size() - 1) {
ss << shape[i] << ",";
} else {
ss << shape[i];
}
}
ss << "]";
} catch (...) {
ss << "Var has no VarDesc !!! Name:" << var->Name();
}
return ss.str();
}
std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var));
}
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc.
VarDesc* FindVarDescInBlock(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
BlockDesc* block = n->inputs[0]->Op()->Block();
PADDLE_ENFORCE(block->HasVar(n->Name()),
string::Sprintf("Block do not has var %s", n->Name()));
return block->FindVar(n->Name());
}
struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs);
auto* rhs_desc = FindVarDescInBlock(rhs);
// match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false;
}
// match shape
auto lhs_shape = lhs_desc->GetShape();
auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
return NodeSize(lhs) <= NodeSize(rhs);
} else {
return false;
}
}
};
void OrderedSet::Insert(ir::Node* var) {
PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar());
if (mark_table_.count(var->Name()) != 0) {
mark_table_[var->Name()]->emplace_back(var);
return;
}
auto* var_desc = FindVarDescInBlock(var);
auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]);
NodeComparator functor;
Iter it = nodes_.begin();
while (it != nodes_.end()) {
auto& prev = it->front();
auto* cache_desc = FindVarDescInBlock(prev);
int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) {
if (functor(prev, var)) {
++it;
} else {
break;
}
} else if (cache_batch_size == -1 && batch_size != -1) {
++it;
} else if (cache_batch_size != -1 && batch_size == -1) {
break;
}
}
it = nodes_.insert(it, {var});
mark_table_[var->Name()] = it;
}
int OrderedSet::GetNodeIndexInPool(ir::Node* var) {
return std::distance(nodes_.begin(), mark_table_[var->Name()]);
}
ir::Node* OrderedSet::FindBestFitNode(ir::Node* var) const {
ir::Node* found_node = nullptr;
NodeComparator functor;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
auto& candidate = it->front();
if (functor(var, candidate)) {
found_node = candidate;
break;
}
}
return found_node;
}
ir::Node* OrderedSet::FindNextBestFitNode(ir::Node* var, ir::Node* prev) const {
ir::Node* found_node = nullptr;
NodeComparator functor;
auto it =
std::find_if(nodes_.begin(), nodes_.end(), [&](const NodeVector& v) {
if (v.front() == prev)
return true;
else
return false;
});
PADDLE_ENFORCE(it != nodes_.end(), "Not found previous in node list!");
for (it = std::next(it); it != nodes_.end(); ++it) {
auto& candidate = it->front();
if (functor(var, candidate)) {
found_node = candidate;
break;
}
}
return found_node;
}
bool OrderedSet::Has(ir::Node* var) const {
if (mark_table_.count(var->Name())) {
auto& node_in_samename = mark_table_.at(var->Name());
auto iter =
std::find_if(node_in_samename->begin(), node_in_samename->end(),
[&](ir::Node* n) { return n->Name() == var->Name(); });
return iter != node_in_samename->end();
}
return false;
}
void OrderedSet::Erase(const std::string& var) {
PADDLE_ENFORCE(mark_table_.count(var));
nodes_.erase(mark_table_[var]);
mark_table_.erase(var);
}
void OrderedSet::Erase(ir::Node* var) {
PADDLE_ENFORCE(var != nullptr);
Erase(var->Name());
}
std::string OrderedSet::ToString() const {
std::stringstream ss;
for (auto it = nodes_.begin(); it != nodes_.end(); ++it) {
for (auto& node : *it) {
ss << DebugString(node) << " ";
}
}
return ss.str();
}
bool NodeCanReused(ir::Node* node) {
// valid the node is a var node
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
if (node == nullptr || !node->IsVar() || node->IsCtrlVar() ||
node->Name() == kEmptyVarName)
return false;
bool flag = true;
// op output force generated in cpu, can not be reused.
for (auto* op : node->inputs) {
if (op->Op()->HasAttr("force_cpu")) {
flag &= framework::AttrReader(op->Op()->GetAttrMap())
.Get<bool>("force_cpu") == 0;
}
}
// var desc validation.
flag &= NodeCanReused(*node->Var());
return flag;
}
int MinChunkSize() {
int size{0};
#ifdef PADDLE_WITH_CUDA
size = platform::GpuMinChunkSize();
#else
size = platform::CpuMinChunkSize();
#endif // PADDLE_WITH_CUDA
return size;
}
bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
// only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::LOD_TENSOR_ARRAY)) {
return false;
}
// persistable variable is parameter
if (node.Persistable()) {
return false;
}
// shape < min_chunk_size is meaningless.
// further more, fetched loss always has size = 1
// which should not be reused.
auto shape = node.GetShape();
int size = std::abs(
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()));
if (shape.empty() || size < MinChunkSize()) {
return false;
}
return true;
}
bool OpHasSubBlock(OpDesc* desc) {
const AttributeMap& attrs = desc->GetAttrMap();
for (auto& attr : attrs) {
if (attr.second.type() == typeid(BlockDesc*) || // NOLINT
attr.second.type() == typeid(std::vector<BlockDesc*>)) // NOLINT
return true;
}
return false;
}
ControlFlowGraph::ControlFlowGraph(const ir::Graph& graph) {
ops_ = SortOpLikeDescOrder(graph);
ConnectNodes();
}
void ControlFlowGraph::BuildCFGGraph() {
// FIXME(dzh): same effect with ConnectNodes, but use the control
// link to build dependency graph, it goes wrong in transformer.
for (ir::Node* op : ops_) {
for (auto& input_var : op->inputs) {
if (!input_var->inputs.empty()) {
PADDLE_ENFORCE(
input_var->inputs.size() == 1 && input_var->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
auto* pred_op = input_var->inputs[0];
if (pred_op->Op() != nullptr) {
predecessors_[op].insert(pred_op);
successors_[pred_op].insert(op);
}
}
if (input_var->IsVar() && !input_var->IsCtrlVar()) {
uses_[op].insert(input_var->Name());
}
}
for (auto& output_var : op->outputs) {
// output var may be used by many op
for (auto* succ_op : output_var->outputs) {
if (succ_op->Op() != nullptr) {
successors_[op].insert(succ_op);
predecessors_[succ_op].insert(op);
}
}
if (output_var->IsVar() && !output_var->IsCtrlVar()) {
defs_[op].insert(output_var->Name());
}
}
}
}
void ControlFlowGraph::ConnectNodes() {
for (size_t i = 0; i < ops_.size(); ++i) {
auto& op = ops_[i];
try {
auto& next_op = ops_.at(i + 1);
successors_[op].insert(next_op);
predecessors_[next_op].insert(op);
} catch (...) {
// do nothing
}
FilterVariables(op->inputs,
[&](ir::Node* var) { uses_[op].emplace(var->Name()); });
FilterVariables(op->outputs,
[&](ir::Node* var) { defs_[op].emplace(var->Name()); });
}
}
void ControlFlowGraph::LiveVariableAnalysis() {
// NOTE(dzh): variable liveless analysis (a.k.a reversed_ops algorithm)
// compute the liveness of for each variable though reversed_ops algorithm.
// It iterates the operators from end to begin, compute the live in/live out
// variable set for each op, then the diff between in/out will be used for
// the variable reuse. For detail refer to
// http://www.cs.cornell.edu/courses/cs4120/2013fa/lectures/lec26-fa13.pdf
std::list<ir::Node*> work_list(ops_.rbegin(), ops_.rend());
while (!work_list.empty()) {
ir::Node* op = work_list.front();
work_list.pop_front();
// get the live_in calculated before. Empty if first.
auto prev_live_in = std::move(live_in_[op]);
for (auto& s : successors_[op]) {
for (auto& var : live_in_[s]) {
live_out_[op].insert(var);
}
}
for (auto& var : uses_[op]) {
live_in_[op].insert(var);
}
for (auto& var : live_out_[op]) {
live_in_[op].insert(var);
}
for (auto& var : defs_[op]) {
live_in_[op].erase(var);
}
// If the live_in is not changed, then the liveness analysis of
// predecessors is completed.
//
// Otherwise, recalculate the predecessors liveness
if (live_in_[op] != prev_live_in) {
for (auto& pre : predecessors_[op]) {
work_list.push_back(pre);
}
}
}
for (auto* op : ops_) {
unlived_vars_[op] = std::set<std::string>();
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
void ControlFlowGraph::RenameVarInCFGGraph(const std::string& old_node,
const std::string& new_node,
int begin_idx) {
std::vector<bool> need_update(ops_.size(), false);
// update graph from begin idx to the end
for (size_t i = begin_idx; i != ops_.size(); ++i) {
auto* op = ops_[i];
if (uses_[op].find(old_node) != uses_[op].end()) {
uses_[op].erase(old_node);
uses_[op].insert(new_node);
}
if (defs_[op].find(old_node) != defs_[op].end()) {
defs_[op].erase(old_node);
defs_[op].insert(new_node);
}
if (live_in_[op].find(old_node) != live_in_[op].end()) {
live_in_[op].erase(old_node);
live_in_[op].insert(new_node);
need_update[i] = true;
}
if (live_out_[op].find(old_node) != live_out_[op].end()) {
live_out_[op].erase(old_node);
live_out_[op].insert(new_node);
need_update[i] = true;
}
}
for (size_t i = begin_idx; i < ops_.size(); ++i) {
if (!need_update[i]) continue;
auto* op = ops_[i];
for (auto& var : this->LiveIn(op)) {
if (!this->LiveOut(op).count(var)) {
unlived_vars_[op].insert(var);
}
}
}
}
const std::set<std::string>& ControlFlowGraph::LiveIn(ir::Node* op) const {
auto it = live_in_.find(op);
PADDLE_ENFORCE(
it != live_in_.end(),
string::Sprintf("Expect %s in live_in, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::LiveOut(ir::Node* op) const {
auto it = live_out_.find(op);
PADDLE_ENFORCE(
it != live_out_.end(),
string::Sprintf("Expect %s in live_out, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Use(ir::Node* op) const {
auto it = uses_.find(op);
PADDLE_ENFORCE(
it != uses_.end(),
string::Sprintf("Expect %s in use, but Not Found.", op->Name()));
return it->second;
}
const std::set<std::string>& ControlFlowGraph::Unlived(ir::Node* op) const {
auto it = unlived_vars_.find(op);
PADDLE_ENFORCE(
it != unlived_vars_.end(),
string::Sprintf("Expect %s in unlived_set, but Not Found.", op->Name()));
return it->second;
return it->second;
}
const std::vector<ir::Node*>& ControlFlowGraph::Ops() const { return ops_; }
std::vector<ir::Node*>& ControlFlowGraph::Ops() { return ops_; }
ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
ir::Node* op) const {
// in ssa-graph, different version nodes have same name,
// this function get the latest version var before target op
// It may return nullptr, such as data node.
ir::Node* found_node = nullptr;
for (auto* node : ops_) {
if (node == op) break;
for (auto& output : node->outputs) {
PADDLE_ENFORCE((output != nullptr && output->IsVar()),
"Output is empty!");
if (output->Var() && output->Name() == name) {
found_node = output;
}
}
}
return found_node;
}
} // namespace details
} // namespace framework
} // namespace paddle