@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
return type_size * std::abs(size);
size_t NodeSize(ir::Node* n) {
VarDesc* desc = nullptr;
// some op do not have block pointer
if (n->inputs[0]->Op() != nullptr) {
desc = FindVarDescInBlock(n);
} else {
desc = n->Var();
return NodeSize(*desc);
size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); }
std::string DebugStringImpl(VarDesc* var) {
std::stringstream ss;
@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
std::string DebugString(ir::Node* var) {
return DebugStringImpl(FindVarDescInBlock(var));
return DebugStringImpl(GetVarDesc(var));
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc.
VarDesc* FindVarDescInBlock(ir::Node* n) {
VarDesc* GetVarDesc(ir::Node* n) {
PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1);
BlockDesc* block = n->inputs[0]->Op()->Block();
string::Sprintf("Block do not has var %s", n->Name()));
return block->FindVar(n->Name());
return n->Var();
struct NodeComparator {
bool operator()(ir::Node* lhs, ir::Node* rhs) const {
auto* lhs_desc = FindVarDescInBlock(lhs);
auto* rhs_desc = FindVarDescInBlock(rhs);
if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false;
auto* lhs_desc = GetVarDesc(lhs);
auto* rhs_desc = GetVarDesc(rhs);
// match data type
if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) {
return false;
@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) {
auto* var_desc = FindVarDescInBlock(var);
auto* var_desc = var->Var();
auto var_shape = var_desc->GetShape();
int batch_size = static_cast<int>(var_shape[0]);
@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) {
Iter it = nodes_.begin();
while (it != nodes_.end()) {
auto& prev = it->front();
auto* cache_desc = FindVarDescInBlock(prev);
auto* cache_desc = GetVarDesc(prev);
int cache_batch_size = cache_desc->GetShape()[0];
if ((cache_batch_size == -1 && batch_size == -1) ||
(cache_batch_size != -1 && batch_size != -1)) {
@ -336,10 +325,16 @@ int MinChunkSize() {
bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
// only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::LOD_TENSOR_ARRAY)) {
return false;
// FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
// LOD_TENSOR_ARRAY re-use logic,
// disable them in version 1.4
// if (!(type == proto::VarType::LOD_TENSOR ||
// type == proto::VarType::SELECTED_ROWS ||
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
// return false;
// }
if (type != proto::VarType::LOD_TENSOR) return false;
// persistable variable is parameter
if (node.Persistable()) {
return false;