@ -17,6 +17,8 @@
# include <deque>
# include <deque>
# include <iterator>
# include <iterator>
# include <memory>
# include <memory>
# include <queue>
# include <sstream>
# include <stack>
# include <stack>
# include <string>
# include <string>
# include <unordered_map>
# include <unordered_map>
@ -148,12 +150,14 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_ . Build ( graph . get ( ) ) ;
view_ . Build ( graph . get ( ) ) ;
InitSSAGraphNodes ( ) ;
InitSSAGraphNodes ( ) ;
auto cnt = 0 ;
for ( auto * op : view_ . AllOps ( ) ) {
for ( auto * op : view_ . AllOps ( ) ) {
VLOG ( 4 ) < < " Handle op " < < cnt + + < < " : " < < op - > Name ( ) ;
if ( FLAGS_enable_inplace_whitelist & & ! whitelist_ . count ( op - > Name ( ) ) )
if ( FLAGS_enable_inplace_whitelist & & ! whitelist_ . count ( op - > Name ( ) ) )
continue ;
continue ;
TryInplaceOpInputOutput ( op , graph . get ( ) ) ;
TryInplaceOpInputOutput ( op , graph . get ( ) ) ;
}
}
graph - > ResolveHazard ( var_nodes_ ) ;
// graph->ResolveHazard(var_nodes_);
return graph ;
return graph ;
}
}
@ -264,13 +268,10 @@ void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
void InplacePass : : TryInplaceOpInputOutput ( ir : : Node * op ,
void InplacePass : : TryInplaceOpInputOutput ( ir : : Node * op ,
ir : : Graph * graph ) const {
ir : : Graph * graph ) const {
VLOG ( 4 ) < < " Try to inplace op " < < op - > Name ( ) ;
VLOG ( 4 ) < < " Try to inplace op " < < op - > Name ( ) ;
// FIXME(liuwei1031): Graph is not aware of the existence of BlockDescs and
// PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
// ProgramDescs.
// "op_desc is nullptr");
// The operations related to BlockDesc or ProgramDesc should perform on Graph
// or Node directly!
PADDLE_ENFORCE ( op - > Op ( ) ! = nullptr & & op - > Op ( ) - > Block ( ) ! = nullptr ,
" op_desc is nullptr " ) ;
// some pre-requirments need to meet if the op want to inplaced.
// some pre-requirments need to meet if the op want to inplaced.
PADDLE_ENFORCE ( op - > Op ( ) ! = nullptr , " op_desc is nullptr " ) ;
auto * op_desc = op - > Op ( ) ;
auto * op_desc = op - > Op ( ) ;
auto & infer_inplace =
auto & infer_inplace =
@ -281,21 +282,58 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
PADDLE_ENFORCE ( static_cast < bool > ( infer_inplace ) ,
PADDLE_ENFORCE ( static_cast < bool > ( infer_inplace ) ,
" %s's infer_inplace has not been registered " , op_desc - > Type ( ) ) ;
" %s's infer_inplace has not been registered " , op_desc - > Type ( ) ) ;
auto * block = op_desc - > Block ( ) ;
auto in_to_outs = infer_inplace ( * op_desc ) ;
auto in_to_outs = infer_inplace ( * op_desc , block ) ;
auto & all_ops = view_ . AllOps ( ) ;
auto & all_ops = view_ . AllOps ( ) ;
auto cursor = std : : find ( all_ops . begin ( ) , all_ops . end ( ) , op ) ;
auto cursor = std : : find ( all_ops . begin ( ) , all_ops . end ( ) , op ) ;
size_t idx = std : : distance ( all_ops . begin ( ) , cursor ) ;
size_t idx = std : : distance ( all_ops . begin ( ) , cursor ) ;
for ( auto & pair : in_to_outs ) {
for ( auto & pair : in_to_outs ) {
auto & in_var_name = pair . first ;
auto & in_para_name = pair . first ;
auto & out_var_name = pair . second ;
auto & out_para_name = pair . second ;
auto input_vars = op - > Op ( ) - > Input ( in_para_name ) ;
if ( ! input_vars . size ( ) ) {
VLOG ( 4 ) < < " Parameter " < < in_para_name < < " is empty skip "
< < in_para_name < < " => " < < out_para_name < < " pair " ;
continue ;
}
auto output_vars = op - > Op ( ) - > Output ( out_para_name ) ;
if ( ! output_vars . size ( ) ) {
VLOG ( 4 ) < < " Parameter " < < out_para_name < < " is empty skip "
< < in_para_name < < " => " < < out_para_name < < " pair " ;
continue ;
}
auto in_var_name = input_vars . at ( 0 ) ;
auto out_var_name = output_vars . at ( 0 ) ;
auto * in_node = view_ . GetNodeByName ( in_var_name , op - > inputs ) ;
auto * in_node = view_ . GetNodeByName ( in_var_name , op - > inputs ) ;
auto * out_node = view_ . GetNodeByName ( out_var_name , op - > outputs ) ;
auto * out_node = view_ . GetNodeByName ( out_var_name , op - > outputs ) ;
VLOG ( 4 ) < < " Try to inplace " < < in_var_name < < " with " < < out_var_name ;
bool can_replace = true ;
if ( in_var_name = = out_var_name ) {
can_replace = false ;
VLOG ( 4 ) < < " SKIP: Input variable " < < in_var_name < < " & Output variable "
< < out_var_name < < " are the same " ;
} else if ( ! NodeCanReused ( in_node ) ) {
can_replace = false ;
VLOG ( 4 ) < < " SKIP: Input varialbe " < < in_var_name < < " cannot be reused " ;
} else if ( ! NodeCanReused ( out_node ) ) {
can_replace = false ;
VLOG ( 4 ) < < " SKIP: Output variable " < < out_var_name
< < " cannot be reused " ;
} else if ( details : : NodeSize ( * in_node - > Var ( ) ) ! =
details : : NodeSize ( * out_node - > Var ( ) ) ) {
can_replace = false ;
VLOG ( 4 ) < < " SKIP: Input and Output varialbe size not match " ;
}
if ( ! can_replace ) continue ;
// 2. there is no external pending op on the input node
// 2. there is no external pending op on the input node
if ( view_ . PendingOpsOnVar ( in_node ) . size ( ) > 1 ) {
// if (view_.PendingOpsOnVar(in_node).size() > 1) {
if ( in_node - > outputs . size ( ) > 1 & & ! view_ . CheckDeps ( in_node , op ) ) {
VLOG ( 4 ) < < string : : Sprintf (
VLOG ( 4 ) < < string : : Sprintf (
" Skiped pair %s => %s. %s input has external dependency. "
" Skiped pair %s => %s. %s input has external dependency. "
" inplace such pair will overwrite the memory. " ,
" inplace such pair will overwrite the memory. " ,
@ -342,6 +380,97 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
}
}
}
}
void GraphView : : TopoSort ( ir : : Graph * graph ) {
//
ops_ . clear ( ) ;
auto deps_num = [ ] ( ir : : Node * op ) {
auto cnt = 0 ;
for ( auto & var : op - > inputs )
if ( var - > inputs . size ( ) > 0 ) + + cnt ;
return cnt ;
} ;
std : : queue < std : : pair < ir : : Node * , uint32_t > > ready_ops ;
int level = 0 ;
auto nodes = graph - > Nodes ( ) ;
std : : unordered_map < ir : : Node * , uint32_t > deps_map ;
for ( auto & node : nodes ) {
if ( node - > IsOp ( ) & & node - > Op ( ) ! = nullptr ) {
deps_map [ node ] = deps_num ( node ) ;
if ( 0 = = deps_map [ node ] ) {
ready_ops . push ( { node , level } ) ;
}
}
}
while ( ! ready_ops . empty ( ) ) {
auto item = ready_ops . front ( ) ;
ready_ops . pop ( ) ;
ops_ . emplace_back ( item . first ) ;
// record level when pop from queue
op_level_ [ item . first ] = item . second ;
for ( auto node : item . first - > outputs ) {
for ( auto op : node - > outputs ) {
- - deps_map [ op ] ;
if ( deps_map [ op ] = = 0 ) ready_ops . push ( { op , item . second + 1 } ) ;
}
}
}
bool all_ops_checked = true ;
for ( auto & node : nodes ) {
if ( node - > IsOp ( ) & & node - > Op ( ) ! = nullptr & & deps_map [ node ] > 0 ) {
all_ops_checked = false ;
break ;
}
}
PADDLE_ENFORCE ( all_ops_checked , " All ops deps should be 0 after analysis " ) ;
}
// return true if current op node depeneds on all other op that use the same
// variable node
bool GraphView : : CheckDeps ( ir : : Node * var , ir : : Node * current_op ) const {
// get op list that rely on the same variable
auto op_list = var - > outputs ;
for ( auto & op : op_list ) {
if ( op = = current_op ) continue ;
VLOG ( 4 ) < < " GraphView::CheckDeps : " < < op - > Name ( ) < < " & "
< < current_op - > Name ( ) ;
if ( ! CheckOpDeps ( op , current_op ) ) return false ;
VLOG ( 4 ) < < " " ;
}
return true ;
}
// check if op2 depends on op1's output
bool GraphView : : CheckOpDeps ( ir : : Node * op1 , ir : : Node * op2 ) const {
auto print_op = [ & ] ( ir : : Node * op , const char * name ) {
std : : ostringstream os ;
os < < " " < < name < < " : " < < op - > Name ( ) < < " " ;
os < < " Input args : " ;
for ( auto & arg : op - > inputs ) os < < arg - > Name ( ) < < " " ;
os < < " Output args : " ;
for ( auto & arg : op - > outputs ) os < < arg - > Name ( ) < < " " ;
os < < " Level : " < < op_level_ . at ( op ) ;
VLOG ( 4 ) < < os . str ( ) ;
} ;
print_op ( op1 , " OP1 " ) ;
print_op ( op2 , " OP2 " ) ;
if ( op1 = = op2 ) return true ;
if ( op_level_ . at ( op1 ) > = op_level_ . at ( op2 ) ) return false ;
for ( auto & var : op2 - > inputs )
if ( var - > inputs . size ( ) > 0 & & CheckOpDeps ( op1 , var - > inputs [ 0 ] ) ) return true ;
return false ;
}
ir : : Node * GraphView : : GetNodeByName ( const std : : string & name ,
ir : : Node * GraphView : : GetNodeByName ( const std : : string & name ,
const std : : vector < ir : : Node * > & nodes ) const {
const std : : vector < ir : : Node * > & nodes ) const {
// nodes should be op->inputs/outputs
// nodes should be op->inputs/outputs
@ -387,22 +516,7 @@ void GraphView::Build(ir::Graph* g) {
// Because we insert some new created node. Which may have data race between
// Because we insert some new created node. Which may have data race between
// nodes.
// nodes.
// resolve data harzards depends on the var nodes in right order.
// resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder ( * g ) ;
TopoSort ( g ) ;
// 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std : : unordered_set < std : : string > all_vars ;
for ( auto & node : g - > Nodes ( ) ) {
if ( node - > IsVar ( ) ) continue ;
for ( auto & out : node - > outputs ) {
if ( out - > IsCtrlVar ( ) | | out - > Var ( ) = = nullptr ) continue ;
if ( all_vars . count ( out - > Name ( ) ) ) {
dup_nodes_ . emplace ( out - > Name ( ) ) ;
} else {
all_vars . emplace ( out - > Name ( ) ) ;
}
}
}
// 2. track the nodes which used by parameter server.
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// these node can not be inplaced, otherwise trainer