@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) {
return type_size * std : : abs ( size ) ;
}
size_t NodeSize ( ir : : Node * n ) {
VarDesc * desc = nullptr ;
// some op do not have block pointer
if ( n - > inputs [ 0 ] - > Op ( ) ! = nullptr ) {
desc = FindVarDescInBlock ( n ) ;
} else {
desc = n - > Var ( ) ;
}
return NodeSize ( * desc ) ;
}
size_t NodeSize ( ir : : Node * n ) { return NodeSize ( * ( n - > Var ( ) ) ) ; }
std : : string DebugStringImpl ( VarDesc * var ) {
std : : stringstream ss ;
@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) {
}
std : : string DebugString ( ir : : Node * var ) {
return DebugStringImpl ( FindVarDescInBlock ( var ) ) ;
return DebugStringImpl ( GetVarDesc ( var ) ) ;
}
// NOTE(dzh): based ir node, if a large node has been reused
// by a small size node, then next time it appear in pool, it will
// have the small size. Find the original node shap from blockdesc.
VarDesc * FindVarDescInBlock ( ir : : Node * n ) {
VarDesc * GetVarDesc ( ir : : Node * n ) {
PADDLE_ENFORCE ( n - > IsVar ( ) & & ! n - > IsCtrlVar ( ) & & n - > inputs . size ( ) = = 1 ) ;
BlockDesc * block = n - > inputs [ 0 ] - > Op ( ) - > Block ( ) ;
PADDLE_ENFORCE ( block - > HasVar ( n - > Name ( ) ) ,
string : : Sprintf ( " Block do not has var %s " , n - > Name ( ) ) ) ;
return block - > FindVar ( n - > Name ( ) ) ;
return n - > Var ( ) ;
}
struct NodeComparator {
bool operator ( ) ( ir : : Node * lhs , ir : : Node * rhs ) const {
auto * lhs_desc = FindVarDescInBlock ( lhs ) ;
auto * rhs_desc = FindVarDescInBlock ( rhs ) ;
if ( lhs - > Var ( ) - > GetType ( ) ! = rhs - > Var ( ) - > GetType ( ) ) return false ;
auto * lhs_desc = GetVarDesc ( lhs ) ;
auto * rhs_desc = GetVarDesc ( rhs ) ;
// match data type
if ( lhs_desc - > GetDataType ( ) ! = rhs_desc - > GetDataType ( ) ) {
return false ;
@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) {
return ;
}
auto * var_desc = FindVarDescInBlock( var ) ;
auto * var_desc = var- > Var ( ) ;
auto var_shape = var_desc - > GetShape ( ) ;
int batch_size = static_cast < int > ( var_shape [ 0 ] ) ;
@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) {
Iter it = nodes_ . begin ( ) ;
while ( it ! = nodes_ . end ( ) ) {
auto & prev = it - > front ( ) ;
auto * cache_desc = FindVarDescInBlock ( prev ) ;
auto * cache_desc = GetVarDesc ( prev ) ;
int cache_batch_size = cache_desc - > GetShape ( ) [ 0 ] ;
if ( ( cache_batch_size = = - 1 & & batch_size = = - 1 ) | |
( cache_batch_size ! = - 1 & & batch_size ! = - 1 ) ) {
@ -336,10 +325,16 @@ int MinChunkSize() {
bool NodeCanReused ( const VarDesc & node ) {
auto type = node . GetType ( ) ;
// only these types holds bulk of gpu memory
if ( ! ( type = = proto : : VarType : : LOD_TENSOR | |
type = = proto : : VarType : : LOD_TENSOR_ARRAY ) ) {
return false ;
}
// FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and
// LOD_TENSOR_ARRAY re-use logic,
// disable them in version 1.4
// if (!(type == proto::VarType::LOD_TENSOR ||
// type == proto::VarType::SELECTED_ROWS ||
// type == proto::VarType::LOD_TENSOR_ARRAY)) {
// return false;
// }
if ( type ! = proto : : VarType : : LOD_TENSOR ) return false ;
// persistable variable is parameter
if ( node . Persistable ( ) ) {
return false ;