@ -30,8 +30,15 @@ constexpr int kMaxRePassTimes = 10000;
constexpr size_t kMaxOneInNodes = 1000 ;
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later
constexpr int kMaxRecursiveDepth = 20 ;
struct DuringPassNodeSets {
std : : unordered_set < Node * > nodes_seen ;
std : : unordered_set < NodePtr > nodes_deleted ;
std : : unordered_set < NodePtr > nodes_re_pass ;
std : : unordered_set < NodePtr > nodes_re_pass_immediately ;
std : : unordered_set < NodePtr > nodes_last ;
} ;
void GetAllNodesNoInputEdge ( const ComputeGraphPtr & graph , std : : queue < NodePtr > & input_edge_nodes ,
void GetAllNodesNoInputEdge ( const ComputeGraphPtr & graph , std : : deq ue< NodePtr > & input_edge_nodes ,
std : : unordered_set < Node * > & nodes_seen , std : : unordered_set < NodePtr > & nodes_last ) {
nodes_last . clear ( ) ;
for ( auto & node : graph - > GetDirectNode ( ) ) {
@ -40,7 +47,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue<NodePtr> &i
}
size_t in_nums = node - > GetInNodes ( ) . size ( ) ;
if ( in_nums = = 0 ) {
input_edge_nodes . push ( node ) ;
input_edge_nodes . push _back ( node ) ;
nodes_seen . insert ( node . get ( ) ) ;
} else if ( in_nums > kMaxOneInNodes ) {
nodes_last . insert ( node ) ;
@ -48,7 +55,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue<NodePtr> &i
}
}
void AddNextIterNodes ( const Node : : Vistor < NodePtr > & nodes , std : : que ue< NodePtr > & nodes_to_pass ,
void AddNextIterNodes ( const Node : : Vistor < NodePtr > & nodes , std : : deq ue< NodePtr > & nodes_to_pass ,
std : : unordered_set < Node * > & nodes_seen , std : : unordered_set < NodePtr > & nodes_last ) {
for ( auto & node : nodes ) {
if ( node = = nullptr ) {
@ -60,13 +67,30 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::queue<NodePtr> &n
bool all_in_nodes_seen = node - > IsAllInNodesSeen ( nodes_seen ) ;
if ( all_in_nodes_seen & & nodes_seen . insert ( node . get ( ) ) . second ) {
nodes_to_pass . push ( node ) ;
nodes_to_pass . push _back ( node ) ;
}
}
}
Status RunPasses ( NodePtr & node , const NamesToPass & names_to_passes , std : : unordered_set < NodePtr > & nodes_re_pass ,
std : : unordered_set < NodePtr > & nodes_deleted , std : : unordered_set < Node * > & nodes_seen ) {
void PushToRePassIfSeen ( NodePtr & node , const std : : pair < std : : string , BaseNodePass * > & name_to_pass ,
std : : unordered_set < Node * > & nodes_seen , std : : unordered_set < NodePtr > & nodes_to_re_pass ,
std : : unordered_set < NodePtr > & nodes_re_pass ) {
for ( const auto & node_to_re_pass : nodes_to_re_pass ) {
if ( node_to_re_pass = = nullptr ) {
GELOGW ( " Found null re-pass node when executing %s on node %s type %s " , name_to_pass . first . c_str ( ) ,
node - > GetName ( ) . c_str ( ) , node - > GetType ( ) . c_str ( ) ) ;
continue ;
}
if ( nodes_seen . count ( node_to_re_pass . get ( ) ) > 0 | | node_to_re_pass - > IsAllInNodesSeen ( nodes_seen ) ) {
GELOGD ( " The node %s will be re-pass. " , node_to_re_pass - > GetName ( ) . c_str ( ) ) ;
nodes_re_pass . insert ( node_to_re_pass ) ;
} else {
GELOGD ( " The node %s are not all seen, don't set repass this time " , node_to_re_pass - > GetName ( ) . c_str ( ) ) ;
}
}
}
Status RunPasses ( NodePtr & node , const NamesToPass & names_to_passes , DuringPassNodeSets & during_pass_node_set ) {
if ( node = = nullptr ) {
GELOGE ( FAILED , " parameter is null. " ) ;
return FAILED ;
@ -90,22 +114,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder
}
auto nodes_to_re_pass = name_to_pass . second - > GetNodesNeedRePass ( ) ;
for ( const auto & node_to_re_pass : nodes_to_re_pass ) {
if ( node_to_re_pass = = nullptr ) {
GELOGW ( " Found null re-pass node when executing %s on node %s type %s " , name_to_pass . first . c_str ( ) ,
node - > GetName ( ) . c_str ( ) , node - > GetType ( ) . c_str ( ) ) ;
continue ;
}
if ( nodes_seen . count ( node_to_re_pass . get ( ) ) > 0 | | node_to_re_pass - > IsAllInNodesSeen ( nodes_seen ) ) {
GELOGD ( " The node %s will be re-pass later " , node_to_re_pass - > GetName ( ) . c_str ( ) ) ;
nodes_re_pass . insert ( node_to_re_pass ) ;
} else {
GELOGD ( " The node %s are not all seen, don't set repass this time " , node_to_re_pass - > GetName ( ) . c_str ( ) ) ;
}
}
PushToRePassIfSeen ( node , name_to_pass , during_pass_node_set . nodes_seen , nodes_to_re_pass ,
during_pass_node_set . nodes_re_pass ) ;
auto nodes_to_re_pass_immediately = name_to_pass . second - > GetNodesNeedRePassImmediately ( ) ;
PushToRePassIfSeen ( node , name_to_pass , during_pass_node_set . nodes_seen , nodes_to_re_pass_immediately ,
during_pass_node_set . nodes_re_pass_immediately ) ;
auto nodes_deleted_by_pass = name_to_pass . second - > GetNodesDeleted ( ) ;
nodes_deleted. insert ( nodes_deleted_by_pass . begin ( ) , nodes_deleted_by_pass . end ( ) ) ;
during_pass_node_set . nodes_deleted . insert ( nodes_deleted_by_pass . begin ( ) , nodes_deleted_by_pass . end ( ) ) ;
if ( nodes_deleted_by_pass . count ( node ) > 0 ) {
GELOGD ( " The node %s was deleted by pass %s, stop the remain passes " , node - > GetName ( ) . c_str ( ) ,
name_to_pass . first . c_str ( ) ) ;
@ -181,36 +198,33 @@ Status GEPass::Run(const NamesToPass &names_to_passes) {
Status GEPass : : RunPassesOneGraph ( const NamesToPass & names_to_passes ) {
GELOGD ( " Begin to run pass on graph, passes count %zu " , names_to_passes . size ( ) ) ;
std : : queue < NodePtr > nodes ;
std : : unordered_set < Node * > nodes_seen ;
std : : unordered_set < NodePtr > nodes_deleted ;
std : : unordered_set < NodePtr > nodes_re_pass ;
std : : unordered_set < NodePtr > nodes_last ;
GetAllNodesNoInputEdge ( graph_ , nodes , nodes_seen , nodes_last ) ;
std : : deque < NodePtr > nodes ;
DuringPassNodeSets during_pass_node_set ;
GetAllNodesNoInputEdge ( graph_ , nodes , during_pass_node_set . nodes_seen , during_pass_node_set . nodes_last ) ;
GELOGD ( " Start points count %zu " , nodes . size ( ) ) ;
int re_pass_times = 0 ;
do {
for ( auto & node : nodes_re_pass) {
nodes . push ( node ) ;
nodes_seen. insert ( node . get ( ) ) ;
for ( auto & node : during_pass_node_set. nodes_re_pass) {
nodes . push _back ( node ) ;
during_pass_node_set. nodes_seen. insert ( node . get ( ) ) ;
}
nodes_re_pass. clear ( ) ;
during_pass_node_set. nodes_re_pass. clear ( ) ;
while ( ! nodes . empty ( ) ) {
NodePtr node = nodes . front ( ) ;
nodes . pop ( ) ;
nodes . pop _front ( ) ;
( void ) nodes_re_pass. erase ( node ) ;
( void ) during_pass_node_set. nodes_re_pass. erase ( node ) ;
GE_IF_BOOL_EXEC ( node = = nullptr , GELOGW ( " node is null " ) ; continue ) ;
if ( nodes_deleted. count ( node ) > 0 ) {
if ( during_pass_node_set. nodes_deleted. count ( node ) > 0 ) {
GELOGD ( " The node %s was deleted before, skip it. " , node - > GetName ( ) . c_str ( ) ) ;
continue ;
}
AddNextIterNodes ( node - > GetOutNodes ( ) , nodes , nodes_seen, nodes_last ) ;
AddNextIterNodes ( node - > GetOutNodes ( ) , nodes , during_pass_node_set. nodes_seen, during_pass_node_set . nodes_last ) ;
auto ret = RunPasses ( node , names_to_passes , nodes_re_pass, nodes_deleted , nodes_seen ) ;
auto ret = RunPasses ( node , names_to_passes , during_pass_node_set ) ;
if ( ret ! = SUCCESS ) {
GELOGE ( ret , " Failed to process passes on node %s type %s, error code: %u " ,
node - > GetName ( ) . c_str ( ) , node - > GetType ( ) . c_str ( ) , ret ) ;
@ -227,7 +241,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
if ( has_sub_graph ) {
GELOGD ( " There are subgraphs on node %s, run passes for for the second time " , node - > GetName ( ) . c_str ( ) ) ;
SetFlagOption ( kOptimizeAfterSubGraph , names_to_passes ) ;
ret = RunPasses ( node , names_to_passes , nodes_re_pass, nodes_deleted , nodes_seen ) ;
ret = RunPasses ( node , names_to_passes , during_pass_node_set ) ;
if ( ret ! = SUCCESS ) {
GELOGE ( ret , " Failed to process passes on node %s type %s, error code: %u " ,
node - > GetName ( ) . c_str ( ) , node - > GetType ( ) . c_str ( ) , ret ) ;
@ -239,16 +253,21 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) {
// should be called each time at the begin of the iteration
ClearOption ( names_to_passes ) ;
}
for ( const auto & node : during_pass_node_set . nodes_re_pass_immediately ) {
GELOGD ( " The node %s will be re-pass immediately. " , node - > GetName ( ) . c_str ( ) ) ;
nodes . push_front ( node ) ;
}
during_pass_node_set . nodes_re_pass_immediately . clear ( ) ;
}
for ( auto & node : nodes_last ) {
bool all_in_nodes_seen = node - > IsAllInNodesSeen ( nodes_seen ) ;
if ( all_in_nodes_seen & & nodes_seen . insert ( node . get ( ) ) . second ) {
nodes . push ( node ) ;
for ( auto & node : during_pass_node_set. nodes_last) {
bool all_in_nodes_seen = node - > IsAllInNodesSeen ( during_pass_node_set. nodes_seen) ;
if ( all_in_nodes_seen & & during_pass_node_set. nodes_seen. insert ( node . get ( ) ) . second ) {
nodes . push _back ( node ) ;
}
}
nodes_last. clear ( ) ;
} while ( ( ! nodes_re_pass. empty ( ) | | ! nodes . empty ( ) ) & & + + re_pass_times < kMaxRePassTimes ) ;
during_pass_node_set. nodes_last. clear ( ) ;
} while ( ( ! during_pass_node_set. nodes_re_pass. empty ( ) | | ! nodes . empty ( ) ) & & + + re_pass_times < kMaxRePassTimes ) ;
if ( re_pass_times = = kMaxRePassTimes ) {
GELOGW ( " re_pass_times should not come to %d " , kMaxRePassTimes ) ;