|
|
|
@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
|
|
|
|
|
return type_size * std::abs(size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t NodeSize(ir::Node* n) {
|
|
|
|
|
VarDesc* desc = nullptr;
|
|
|
|
|
// some op do not have block pointer
|
|
|
|
|
if (n->inputs[0]->Op() != nullptr) {
|
|
|
|
|
desc = FindVarDescInBlock(n);
|
|
|
|
|
} else {
|
|
|
|
|
desc = n->Var();
|
|
|
|
|
}
|
|
|
|
|
return NodeSize(*desc);
|
|
|
|
|
}
|
|
|
|
|
size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
|
|
|
|
|
|
|
|
|
|
std::string DebugStringImpl(VarDesc* var) {
|
|
|
|
|
std::stringstream ss;
|
|
|
|
@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string DebugString(ir::Node* var) {
|
|
|
|
|
return DebugStringImpl(FindVarDescInBlock(var));
|
|
|
|
|
return DebugStringImpl(GetVarDesc(var));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE(dzh): based ir node, if a large node has been reused
|
|
|
|
|
// by a small size node, then next time it appear in pool, it will
|
|
|
|
|
// have the small size. Find the original node shap from blockdesc.
|
|
|
|
|
VarDesc* FindVarDescInBlock(ir::Node* n) {
|
|
|
|
|
VarDesc* GetVarDesc(ir::Node* n) {
|
|
|
|
|
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
|
|
|
|
|
BlockDesc* block = n->inputs[0]->Op()->Block();
|
|
|
|
|
PADDLE_ENFORCE(block->HasVar(n->Name()),
|
|
|
|
|
string::Sprintf("Block do not has var %s", n->Name()));
|
|
|
|
|
return block->FindVar(n->Name());
|
|
|
|
|
return n->Var();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct NodeComparator {
|
|
|
|
|
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
|
|
|
|
|
auto* lhs_desc = FindVarDescInBlock(lhs);
|
|
|
|
|
auto* rhs_desc = FindVarDescInBlock(rhs);
|
|
|
|
|
if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
|
|
|
|
|
auto* lhs_desc = GetVarDesc(lhs);
|
|
|
|
|
auto* rhs_desc = GetVarDesc(rhs);
|
|
|
|
|
// match data type
|
|
|
|
|
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
|
|
|
|
|
return false;
|
|
|
|
@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* var_desc = FindVarDescInBlock(var);
|
|
|
|
|
auto* var_desc = var->Var();
|
|
|
|
|
auto var_shape = var_desc->GetShape();
|
|
|
|
|
int batch_size = static_cast<int>(var_shape[0]);
|
|
|
|
|
|
|
|
|
@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) {
|
|
|
|
|
Iter it = nodes_.begin();
|
|
|
|
|
while (it != nodes_.end()) {
|
|
|
|
|
auto& prev = it->front();
|
|
|
|
|
auto* cache_desc = FindVarDescInBlock(prev);
|
|
|
|
|
auto* cache_desc = GetVarDesc(prev);
|
|
|
|
|
int cache_batch_size = cache_desc->GetShape()[0];
|
|
|
|
|
if ((cache_batch_size == -1 && batch_size == -1) ||
|
|
|
|
|
(cache_batch_size != -1 && batch_size != -1)) {
|
|
|
|
@ -336,10 +325,16 @@ int MinChunkSize() {
|
|
|
|
|
bool NodeCanReused(const VarDesc& node) {
|
|
|
|
|
auto type = node.GetType();
|
|
|
|
|
// only these types holds bulk of gpu memory
|
|
|
|
|
if (!(type == proto::VarType::LOD_TENSOR ||
|
|
|
|
|
type == proto::VarType::LOD_TENSOR_ARRAY)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
|
|
|
|
|
// LOD_TENSOR_ARRAY re-use logic,
|
|
|
|
|
// disable them in version 1.4
|
|
|
|
|
// if (!(type == proto::VarType::LOD_TENSOR ||
|
|
|
|
|
// type == proto::VarType::SELECTED_ROWS ||
|
|
|
|
|
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
|
|
|
|
|
// return false;
|
|
|
|
|
// }
|
|
|
|
|
if (type != proto::VarType::LOD_TENSOR) return false;
|
|
|
|
|
|
|
|
|
|
// persistable variable is parameter
|
|
|
|
|
if (node.Persistable()) {
|
|
|
|
|
return false;
|
|
|
|
|