@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
return real_node - > isa < ValueNode > ( ) ;
}
void SetInput ( const CNodePtr & control_depend , const int index , const FuncGraphPtr & graph , const CNodePtr & hccl_node ,
const std : : vector < AnfNodePtr > & memcpy_async_list ) {
MS_EXCEPTION_IF_NULL ( control_depend ) ;
MS_EXCEPTION_IF_NULL ( graph ) ;
MS_EXCEPTION_IF_NULL ( hccl_node ) ;
std : : vector < AnfNodePtr > make_tuple_inputs = { NewValueNode ( prim : : kPrimMakeTuple ) } ;
make_tuple_inputs . insert ( make_tuple_inputs . end ( ) , memcpy_async_list . begin ( ) , memcpy_async_list . end ( ) ) ;
make_tuple_inputs . emplace_back ( hccl_node ) ;
auto make_tuple = graph - > NewCNode ( make_tuple_inputs ) ;
MS_EXCEPTION_IF_NULL ( make_tuple ) ;
control_depend - > set_input ( IntToSize ( index ) , make_tuple ) ;
}
void DealControlForGetitem ( const CNodePtr & tuple_getitem , const FuncGraphPtr & graph , const CNodePtr & hccl_node ,
const std : : vector < AnfNodePtr > & memcpy_async_list ) {
MS_EXCEPTION_IF_NULL ( tuple_getitem ) ;
auto manager = graph - > manager ( ) ;
MS_EXCEPTION_IF_NULL ( manager ) ;
auto & node_users = manager - > node_users ( ) ;
auto iter = node_users . find ( tuple_getitem ) ;
if ( iter = = node_users . end ( ) ) {
MS_LOG ( EXCEPTION ) < < " node has no output in manager "
< < " trace: " < < trace : : DumpSourceLines ( hccl_node ) ;
}
for ( const auto & node_index : iter - > second ) {
AnfNodePtr output = node_index . first ;
MS_EXCEPTION_IF_NULL ( output ) ;
if ( AnfAlgo : : CheckPrimitiveType ( output , prim : : kPrimControlDepend ) ) {
SetInput ( output - > cast < CNodePtr > ( ) , node_index . second , graph , hccl_node , memcpy_async_list ) ;
}
}
}
void TransferControl ( const CNodePtr & hccl_node , const std : : vector < AnfNodePtr > & memcpy_async_list ,
const FuncGraphPtr & graph ) {
MS_EXCEPTION_IF_NULL ( hccl_node ) ;
MS_EXCEPTION_IF_NULL ( graph ) ;
auto manager = graph - > manager ( ) ;
MS_EXCEPTION_IF_NULL ( manager ) ;
auto & node_users = manager - > node_users ( ) ;
auto iter = node_users . find ( hccl_node ) ;
if ( iter = = node_users . end ( ) ) {
MS_LOG ( EXCEPTION ) < < " node has no output in manager "
< < " trace: " < < trace : : DumpSourceLines ( hccl_node ) ;
}
// find hccl_node's output which is a control depend
for ( const auto & node_index : iter - > second ) {
AnfNodePtr output = node_index . first ;
MS_EXCEPTION_IF_NULL ( output ) ;
if ( AnfAlgo : : CheckPrimitiveType ( output , prim : : kPrimControlDepend ) ) {
SetInput ( output - > cast < CNodePtr > ( ) , node_index . second , graph , hccl_node , memcpy_async_list ) ;
} else if ( AnfAlgo : : CheckPrimitiveType ( output , prim : : kPrimTupleGetItem ) ) {
DealControlForGetitem ( output - > cast < CNodePtr > ( ) , graph , hccl_node , memcpy_async_list ) ;
}
}
}
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel ( const AnfNodeIndexSet & node_users ) {
if ( node_users . size ( ) = = 1 ) {
@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp : : InsertMemcpyAsync ( const FuncGraphPtr & graph , const CNodePtr & hccl_node ) const {
MS_EXCEPTION_IF_NULL ( graph ) ;
MS_EXCEPTION_IF_NULL ( hccl_node ) ;
std : : vector < AnfNodePtr > memcpy_async_list ;
bool need_memcpy_async = false ;
std : : vector < AnfNodePtr > new_inputs = { hccl_node - > input ( 0 ) } ;
for ( size_t i = 1 ; i < hccl_node - > size ( ) ; + + i ) {
auto input = hccl_node - > input ( i ) ;
@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
if ( memcpy_async = = nullptr ) {
MS_LOG ( EXCEPTION ) < < " Create memcpy_async op failed. " ;
}
if ( AnfAlgo: : IsNodeDynamicShape ( input ) ) {
if ( input- > isa < CNode > ( ) & & AnfAlgo: : IsNodeDynamicShape ( input ) ) {
AnfAlgo : : SetNodeAttr ( kAttrIsDynamicShape , MakeValue ( true ) , memcpy_async ) ;
}
new_inputs . push_back ( memcpy_async ) ;
memcpy_async_list. push_back ( memcpy_async ) ;
need_memcpy_async = true ;
} else {
new_inputs . push_back ( input ) ;
}
}
if ( ! memcpy_async_list. empty ( ) ) {
if ( need_ memcpy_async) {
CNodePtr new_hccl_node = std : : make_shared < CNode > ( * hccl_node ) ;
new_hccl_node - > set_inputs ( new_inputs ) ;
auto manager = graph - > manager ( ) ;
@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG ( DEBUG ) < < " start replace new_hccl_node to old hccl_node " ;
( void ) manager - > Replace ( hccl_node , new_hccl_node ) ;
MS_LOG ( DEBUG ) < < " end replace " ;
// transer hccl op's control to the memcpy_async
TransferControl ( new_hccl_node , memcpy_async_list , graph ) ;
}
}