@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std : : vector < Scope * > local_scopes_ ;
std : : vector < Scope * > param_scopes_ ;
Scope g_scope_ ;
std : : unique_ptr < OpHandleBase > op_handle_ ;
std : : vector < std : : unique_ptr < VarHandleBase > > vars_ ;
OpHandleBase * op_handle_ ;
std : : vector < VarHandleBase * > vars_ ;
std : : vector < std : : unique_ptr < ir : : Node > > nodes_ ;
std : : vector < p : : Place > place_list_ ;
bool use_gpu_ ;
# ifdef PADDLE_WITH_CUDA
@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
}
void InitBroadcastOp ( size_t input_scope_idx ) {
nodes_ . clear ( ) ;
for ( size_t j = 0 ; j < place_list_ . size ( ) ; + + j ) {
local_scopes_ . push_back ( & ( g_scope_ . NewScope ( ) ) ) ;
Scope & local_scope = local_scopes_ . back ( ) - > NewScope ( ) ;
@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
}
param_scopes_ [ input_scope_idx ] - > Var ( " input " ) ;
std: : unique_ptr < ir : : Node > n =
ir : : CreateNodeForTest ( " node0 " , ir : : Node : : Type : : kOperation ) ;
nodes_. emplace_back (
ir : : CreateNodeForTest ( " node0 " , ir : : Node : : Type : : kOperation ) ) ;
if ( use_gpu_ ) {
# ifdef PADDLE_WITH_CUDA
op_handle_ . reset ( new BroadcastOpHandle ( n . get ( ) , local_scopes_ ,
place_list_ , nccl_ctxs_ . get ( ) ) ) ;
op_handle_ = new BroadcastOpHandle ( n odes_. back ( ) . get ( ) , local_scopes_ ,
place_list_ , nccl_ctxs_ . get ( ) ) ;
# else
PADDLE_THROW ( " CUDA is not support. " ) ;
# endif
} else {
# ifdef PADDLE_WITH_CUDA
op_handle_ . reset ( new BroadcastOpHandle ( n . get ( ) , local_scopes_ ,
place_list_ , nccl_ctxs_ . get ( ) ) ) ;
op_handle_ = new BroadcastOpHandle ( n odes_. back ( ) . get ( ) , local_scopes_ ,
place_list_ , nccl_ctxs_ . get ( ) ) ;
# else
op_handle_ . reset (
new BroadcastOpHandle ( n . get ( ) , local_scopes_ , place_list_ ) ) ;
op_handle_ = new BroadcastOpHandle ( nodes_ . back ( ) . get ( ) , local_scopes_ ,
place_list_ ) ;
# endif
}
std: : unique_ptr < ir : : Node > v =
ir : : CreateNodeForTest ( " node1 " , ir : : Node : : Type : : kVariable ) ;
auto * in_var_handle = new VarHandle ( v . get ( ) , 1 , input_scope_idx , " input " ,
place_list_ [ input_scope_idx ] ) ;
nodes_. emplace_back (
ir : : CreateNodeForTest ( " node1 " , ir : : Node : : Type : : kVariable ) ) ;
auto * in_var_handle = new VarHandle ( nodes_. back ( ) . get ( ) , 1 , input_scope_idx ,
" input " , place_list_ [ input_scope_idx ] ) ;
vars_ . emplace_back ( in_var_handle ) ;
op_handle_ - > AddInput ( in_var_handle ) ;
// add dummy var
std: : unique_ptr < ir : : Node > v2 =
ir : : CreateNodeForTest ( " node2 " , ir : : Node : : Type : : kVariable ) ;
vars_ . emplace_back ( new DummyVarHandle ( v2 . get ( ) ) ) ;
nodes_. emplace_back (
ir : : CreateNodeForTest ( " node2 " , ir : : Node : : Type : : kVariable ) ) ;
vars_ . emplace_back ( new DummyVarHandle ( nodes_. back ( ) . get ( ) ) ) ;
DummyVarHandle * dummy_var_handle =
static_cast < DummyVarHandle * > ( vars_ . back ( ) .get ( ) );
static_cast < DummyVarHandle * > ( vars_ . back ( ) );
dummy_var_handle - > ClearGeneratedOp ( ) ;
op_handle_ - > AddInput ( dummy_var_handle ) ;
@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if ( ! use_gpu_ ) {
op_handle_ - > SetDeviceContext ( place_list_ [ j ] , ctxs_ [ j ] . get ( ) ) ;
}
std: : unique_ptr < ir : : Node > v3 =
ir : : CreateNodeForTest ( " node3 " , ir : : Node : : Type : : kVariable ) ;
nodes_. emplace_back (
ir : : CreateNodeForTest ( " node3 " , ir : : Node : : Type : : kVariable ) ) ;
VarHandle * out_var_handle =
new VarHandle ( v3 . get ( ) , 2 , j , " out " , place_list_ [ j ] ) ;
new VarHandle ( nodes_. back ( ) . get ( ) , 2 , j , " out " , place_list_ [ j ] ) ;
vars_ . emplace_back ( out_var_handle ) ;
op_handle_ - > AddOutput ( out_var_handle ) ;
}
// add dummy var
std: : unique_ptr < ir : : Node > v4 =
ir : : CreateNodeForTest ( " node4 " , ir : : Node : : Type : : kVariable ) ;
vars_ . emplace_back ( new DummyVarHandle ( v4 . get ( ) ) ) ;
nodes_. emplace_back (
ir : : CreateNodeForTest ( " node4 " , ir : : Node : : Type : : kVariable ) ) ;
vars_ . emplace_back ( new DummyVarHandle ( nodes_. back ( ) . get ( ) ) ) ;
DummyVarHandle * out_dummy_var_handle =
static_cast < DummyVarHandle * > ( vars_ . back ( ) .get ( ) );
static_cast < DummyVarHandle * > ( vars_ . back ( ) );
out_dummy_var_handle - > ClearGeneratedOp ( ) ;
op_handle_ - > AddOutput ( out_dummy_var_handle ) ;
}