@ -14,7 +14,7 @@ limitations under the License. */
# include <unistd.h>
# include <string>
# include <thread>
# include <thread> // NOLINT
# include "gtest/gtest.h"
# include "paddle/fluid/framework/op_registry.h"
@ -37,11 +37,11 @@ namespace m = paddle::operators::math;
std : : unique_ptr < f : : OperatorBase > listen_and_serv_op ;
int selected_port ;
void InitTensorsInScope ( f : : Scope & scope , p : : CPUPlace & plac e) {
void InitTensorsInScope ( const p : : CPUPlace & place , f : : Scope * scop e) {
p : : CPUDeviceContext ctx ( place ) ;
for ( int i = 0 ; i < 2 ; + + i ) {
auto var_name = paddle : : string : : Sprintf ( " x%d " , i ) ;
auto var = scope . Var ( var_name ) ;
auto var = scope - > Var ( var_name ) ;
auto tensor = var - > GetMutable < f : : LoDTensor > ( ) ;
tensor - > Resize ( { 10 , 10 } ) ;
float * expect = tensor - > mutable_data < float > ( place ) ;
@ -50,20 +50,20 @@ void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
}
}
auto out_var = scope . Var ( " Out " ) ;
auto out_var = scope - > Var ( " Out " ) ;
auto out_tensor = out_var - > GetMutable < f : : LoDTensor > ( ) ;
out_tensor - > Resize ( { 10 , 10 } ) ;
out_tensor - > mutable_data < float > ( place ) ; // allocate
}
void InitSelectedRowsInScope ( f : : Scope & scope , p : : CPUPlace & plac e) {
void InitSelectedRowsInScope ( const p : : CPUPlace & place , f : : Scope * scop e) {
p : : CPUDeviceContext ctx ( place ) ;
int64_t height = 10 ;
int64_t row_numel = 10 ;
m : : SetConstant < p : : CPUDeviceContext , float > set_one ;
// init x0
std : : vector < int64_t > rows0 { 0 , 4 , 7 } ;
auto x0_var = scope . Var ( " x0 " ) ;
auto x0_var = scope - > Var ( " x0 " ) ;
auto x0 = x0_var - > GetMutable < f : : SelectedRows > ( ) ;
x0 - > set_rows ( rows0 ) ;
x0 - > set_height ( height ) ;
@ -74,7 +74,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
// init x1
std : : vector < int64_t > rows1 { 2 , 9 } ;
auto x1_var = scope . Var ( " x1 " ) ;
auto x1_var = scope - > Var ( " x1 " ) ;
auto x1 = x1_var - > GetMutable < f : : SelectedRows > ( ) ;
x1 - > set_rows ( rows1 ) ;
x1 - > set_height ( height ) ;
@ -83,7 +83,7 @@ void InitSelectedRowsInScope(f::Scope &scope, p::CPUPlace &place) {
f : : make_ddim ( { static_cast < int64_t > ( rows1 . size ( ) ) , row_numel } ) , place ) ;
set_one ( ctx , x1_value , 1.0 ) ;
auto out_var = scope . Var ( " Out " ) ;
auto out_var = scope - > Var ( " Out " ) ;
auto out = out_var - > GetMutable < f : : SelectedRows > ( ) ;
auto out_value = out - > mutable_value ( ) ;
out - > set_height ( height ) ;
@ -117,15 +117,16 @@ void StartServerNet(bool is_sparse) {
f : : Scope scope ;
p : : CPUPlace place ;
if ( is_sparse ) {
InitSelectedRowsInScope ( sco pe, plac e) ;
InitSelectedRowsInScope ( plac e, & sco pe) ;
} else {
InitTensorsInScope ( sco pe, plac e) ;
InitTensorsInScope ( plac e, & sco pe) ;
}
// sub program run in listen_and_serv_op, for simple test we use sum
f : : ProgramDesc program ;
const auto & root_block = program . Block ( 0 ) ;
auto * optimize_block = program . AppendBlock ( root_block ) ;
auto * prefetch_block = program . AppendBlock ( root_block ) ;
// X for server side tensors, RX for received tensers, must be of same shape.
AddOp ( " sum " , { { " X " , { " x0 " , " x1 " } } } , { { " Out " , { " Out " } } } , { } , optimize_block ) ;
@ -135,6 +136,7 @@ void StartServerNet(bool is_sparse) {
attrs . insert ( { " ParamList " , std : : vector < std : : string > ( { " Out " } ) } ) ;
attrs . insert ( { " GradList " , std : : vector < std : : string > ( { " x1 " } ) } ) ;
attrs . insert ( { " OptimizeBlock " , optimize_block } ) ;
attrs . insert ( { " PrefetchBlock " , prefetch_block } ) ;
listen_and_serv_op =
f : : OpRegistry : : CreateOp ( " listen_and_serv " , { { " X " , { " x1 " } } } , { } , attrs ) ;
LOG ( INFO ) < < " selected port before run " < < selected_port ;
@ -148,7 +150,7 @@ TEST(SendRecvOp, CPUDense) {
// local net
f : : Scope scope ;
p : : CPUPlace place ;
InitTensorsInScope ( sco pe, plac e) ;
InitTensorsInScope ( plac e, & sco pe) ;
// create rpc client var
scope . Var ( " RPC_CLIENT_VAR " ) ;
@ -191,7 +193,7 @@ TEST(SendRecvOp, CPUSparse) {
f : : Scope scope ;
p : : CPUPlace place ;
p : : CPUDeviceContext ctx ( place ) ;
InitSelectedRowsInScope ( sco pe, plac e) ;
InitSelectedRowsInScope ( plac e, & sco pe) ;
scope . Var ( " RPC_CLIENT_VAR " ) ;
f : : AttributeMap attrs ;
selected_port = static_cast < paddle : : operators : : ListenAndServOp * > (