@ -39,7 +39,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
new platform : : RecordEvent ( " ThreadedSSAGraphExecutorPrepare " , nullptr ) ) ;
new platform : : RecordEvent ( " ThreadedSSAGraphExecutorPrepare " , nullptr ) ) ;
std : : unordered_map < OpHandleBase * , size_t > pending_ops ;
std : : unordered_map < OpHandleBase * , size_t > pending_ops ;
std : : unordered_set < VarHandleBase * > pending_vars ;
std : : unordered_set < VarHandleBase * > pending_vars ;
BlockingQueue < VarHandleBase * > ready_vars ;
auto ready_vars = std : : make_shared < BlockingQueue < VarHandleBase * > > ( ) ;
std : : unordered_set < OpHandleBase * > ready_ops ;
std : : unordered_set < OpHandleBase * > ready_ops ;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// streams from multiple GPUs, it's faster to buffer them and schedule
@ -51,12 +51,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for ( auto & var_map : graph_ - > Get < details : : GraphVars > ( details : : kGraphVars ) ) {
for ( auto & var_map : graph_ - > Get < details : : GraphVars > ( details : : kGraphVars ) ) {
for ( auto & name_pair : var_map ) {
for ( auto & name_pair : var_map ) {
for ( auto & version_pair : name_pair . second ) {
for ( auto & version_pair : name_pair . second ) {
InsertPendingVar ( & pending_vars , & ready_vars , version_pair . get ( ) ) ;
InsertPendingVar ( & pending_vars , ready_vars . get ( ) , version_pair . get ( ) ) ;
}
}
}
}
}
}
for ( auto & var : graph_ - > Get < details : : GraphDepVars > ( details : : kGraphDepVars ) ) {
for ( auto & var : graph_ - > Get < details : : GraphDepVars > ( details : : kGraphDepVars ) ) {
InsertPendingVar ( & pending_vars , & ready_vars , var . get ( ) ) ;
InsertPendingVar ( & pending_vars , ready_vars . get ( ) , var . get ( ) ) ;
}
}
for ( auto & op : graph_ - > Get < details : : GraphOps > ( details : : kGraphOps ) ) {
for ( auto & op : graph_ - > Get < details : : GraphOps > ( details : : kGraphOps ) ) {
@ -73,12 +73,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
FeedFetchList fetch_data ( fetch_tensors . size ( ) ) ;
FeedFetchList fetch_data ( fetch_tensors . size ( ) ) ;
InsertFetchOps ( fetch_tensors , & fetch_ops , & fetch_dependencies , & pending_ops ,
InsertFetchOps ( fetch_tensors , & fetch_ops , & fetch_dependencies , & pending_ops ,
& pending_vars , & ready_vars , & fetch_data ) ;
& pending_vars , ready_vars . get ( ) , & fetch_data ) ;
auto run_all_ops = [ & ] ( std : : unordered_set < OpHandleBase * > & set ) {
auto run_all_ops = [ & ] ( std : : unordered_set < OpHandleBase * > & set ) {
for ( auto * op : set ) {
for ( auto * op : set ) {
running_ops_ + + ;
running_ops_ + + ;
RunOp ( & ready_vars , op ) ;
RunOp ( ready_vars , op ) ;
}
}
set . clear ( ) ;
set . clear ( ) ;
} ;
} ;
@ -87,7 +87,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
run_op_futures_ . clear ( ) ;
run_op_futures_ . clear ( ) ;
exception_holder_ . Clear ( ) ;
exception_holder_ . Clear ( ) ;
event . reset ( nullptr ) ;
event . reset ( nullptr ) ;
// Step 3. Execution
// Step 3. Execution
while ( ! pending_vars . empty ( ) ) {
while ( ! pending_vars . empty ( ) ) {
// 1. Run All Ready ops
// 1. Run All Ready ops
@ -103,7 +102,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable
// 2. Find ready variable
bool timeout ;
bool timeout ;
auto cur_ready_vars = ready_vars . PopAll ( 1 , & timeout ) ;
auto cur_ready_vars = ready_vars - > PopAll ( 1 , & timeout ) ;
if ( timeout ) {
if ( timeout ) {
if ( exception_holder_ . IsCaught ( ) ) {
if ( exception_holder_ . IsCaught ( ) ) {
@ -133,7 +132,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
}
}
PADDLE_ENFORCE ( ready_ops . empty ( ) ) ;
PADDLE_ENFORCE ( ready_ops . empty ( ) ) ;
// Wait FetchOps.
// Wait FetchOps.
ClearFetchOp ( graph_ . get ( ) , & fetch_ops ) ;
ClearFetchOp ( graph_ . get ( ) , & fetch_ops ) ;
@ -206,7 +204,8 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
}
}
void ThreadedSSAGraphExecutor : : RunOp (
void ThreadedSSAGraphExecutor : : RunOp (
BlockingQueue < VarHandleBase * > * ready_var_q , details : : OpHandleBase * op ) {
const std : : shared_ptr < BlockingQueue < VarHandleBase * > > & ready_var_q ,
details : : OpHandleBase * op ) {
auto op_run = [ ready_var_q , op , this ] {
auto op_run = [ ready_var_q , op , this ] {
try {
try {
if ( VLOG_IS_ON ( 10 ) ) {
if ( VLOG_IS_ON ( 10 ) ) {