|  |  |  | @ -51,8 +51,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( | 
			
		
	
		
			
				
					|  |  |  |  |   // together since we currently cannot overlap computation and memcpy streams.
 | 
			
		
	
		
			
				
					|  |  |  |  |   // Should revisit it if overlapping is available.
 | 
			
		
	
		
			
				
					|  |  |  |  |   std::unordered_set<OpHandleBase *> delayed_ops; | 
			
		
	
		
			
				
					|  |  |  |  |   std::unordered_set<OpHandleBase *> blocked_by_delayed_ops; | 
			
		
	
		
			
				
					|  |  |  |  |   std::unordered_set<VarHandleBase *> delayed_vars; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { | 
			
		
	
		
			
				
					|  |  |  |  |     pending_vars.insert(&var); | 
			
		
	
	
		
			
				
					|  |  |  | @ -122,24 +120,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( | 
			
		
	
		
			
				
					|  |  |  |  |     InsertPendingOp(*op); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   auto run_all_ready_ops = [&] { | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto *op : ready_ops) { | 
			
		
	
		
			
				
					|  |  |  |  |       if (op->IsMultiDeviceTransfer() && allow_op_delay_) { | 
			
		
	
		
			
				
					|  |  |  |  |         delayed_ops.insert(op); | 
			
		
	
		
			
				
					|  |  |  |  |         delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); | 
			
		
	
		
			
				
					|  |  |  |  |         ready_vars.Extend(op->outputs_); | 
			
		
	
		
			
				
					|  |  |  |  |         continue; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |   auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto *op : set) { | 
			
		
	
		
			
				
					|  |  |  |  |       running_ops_++; | 
			
		
	
		
			
				
					|  |  |  |  |       RunOp(&ready_vars, op); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     ready_ops.clear(); | 
			
		
	
		
			
				
					|  |  |  |  |     set.clear(); | 
			
		
	
		
			
				
					|  |  |  |  |   }; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   // Step 3. Execution
 | 
			
		
	
		
			
				
					|  |  |  |  |   while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |   while (!pending_vars.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |     // 1. Run All Ready ops
 | 
			
		
	
		
			
				
					|  |  |  |  |     run_all_ready_ops(); | 
			
		
	
		
			
				
					|  |  |  |  |     // Keep loop until all vars are ready.
 | 
			
		
	
		
			
				
					|  |  |  |  |     //
 | 
			
		
	
		
			
				
					|  |  |  |  |     // NOTE: DelayedOps have a lower priority. It will be scheduled after all
 | 
			
		
	
		
			
				
					|  |  |  |  |     // ready_ops have been performed.
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (ready_ops.empty() && allow_op_delay_) { | 
			
		
	
		
			
				
					|  |  |  |  |       run_all_ops(delayed_ops); | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       run_all_ops(ready_ops); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     // 2. Find ready variable
 | 
			
		
	
		
			
				
					|  |  |  |  |     bool timeout; | 
			
		
	
	
		
			
				
					|  |  |  | @ -160,29 +160,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( | 
			
		
	
		
			
				
					|  |  |  |  |         auto &deps = pending_ops[op]; | 
			
		
	
		
			
				
					|  |  |  |  |         --deps; | 
			
		
	
		
			
				
					|  |  |  |  |         if (deps == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |           if (delayed_vars.find(ready_var) != delayed_vars.end()) { | 
			
		
	
		
			
				
					|  |  |  |  |             blocked_by_delayed_ops.insert(op); | 
			
		
	
		
			
				
					|  |  |  |  |           if (op->IsMultiDeviceTransfer() && allow_op_delay_) { | 
			
		
	
		
			
				
					|  |  |  |  |             delayed_ops.insert(op); | 
			
		
	
		
			
				
					|  |  |  |  |           } else { | 
			
		
	
		
			
				
					|  |  |  |  |             ready_ops.insert(op); | 
			
		
	
		
			
				
					|  |  |  |  |           } | 
			
		
	
		
			
				
					|  |  |  |  |         } | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     // When there are no other ops to schedule, schedule buffered delayed
 | 
			
		
	
		
			
				
					|  |  |  |  |     // ops and unblock other ops.
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |       RunDelayedOps(delayed_ops); | 
			
		
	
		
			
				
					|  |  |  |  |       delayed_ops.clear(); | 
			
		
	
		
			
				
					|  |  |  |  |       for (auto *op : blocked_by_delayed_ops) { | 
			
		
	
		
			
				
					|  |  |  |  |         ready_ops.insert(op); | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       blocked_by_delayed_ops.clear(); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     // Keep loop until all vars are ready.
 | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   PADDLE_ENFORCE(ready_ops.empty()); | 
			
		
	
		
			
				
					|  |  |  |  |   PADDLE_ENFORCE(delayed_ops.empty()); | 
			
		
	
		
			
				
					|  |  |  |  |   PADDLE_ENFORCE(blocked_by_delayed_ops.empty()); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   // Wait FetchOps.
 | 
			
		
	
		
			
				
					|  |  |  |  |   if (!fetch_ops.empty()) { | 
			
		
	
	
		
			
				
					|  |  |  | 
 |