|
|
|
@ -91,7 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StartServer() {
|
|
|
|
|
void StartServer(const std::string& rpc_name) {
|
|
|
|
|
framework::ProgramDesc program;
|
|
|
|
|
framework::Scope scope;
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
@ -107,14 +107,14 @@ void StartServer() {
|
|
|
|
|
std::shared_ptr<framework::ExecutorPrepareContext>>
|
|
|
|
|
prefetch_var_name_to_prepared;
|
|
|
|
|
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
|
|
|
|
|
|
|
|
|
|
g_req_handler->SetProgram(&program);
|
|
|
|
|
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
|
|
|
|
|
g_req_handler->SetDevCtx(&ctx);
|
|
|
|
|
g_req_handler->SetScope(&scope);
|
|
|
|
|
g_req_handler->SetExecutor(&exe);
|
|
|
|
|
|
|
|
|
|
g_rpc_service->RegisterRPC(distributed::kRequestPrefetch,
|
|
|
|
|
g_req_handler.get());
|
|
|
|
|
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
|
|
|
|
|
g_req_handler->SetRPCServer(g_rpc_service.get());
|
|
|
|
|
|
|
|
|
|
std::thread server_thread(
|
|
|
|
@ -129,7 +129,7 @@ TEST(PREFETCH, CPU) {
|
|
|
|
|
distributed::RPCClient* client =
|
|
|
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
|
|
|
|
|
|
|
|
|
|
std::thread server_thread(StartServer);
|
|
|
|
|
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
|
|
|
|
|
g_rpc_service->WaitServerReady();
|
|
|
|
|
|
|
|
|
|
int port = g_rpc_service->GetSelectedPort();
|
|
|
|
@ -162,3 +162,24 @@ TEST(PREFETCH, CPU) {
|
|
|
|
|
g_rpc_service.reset(nullptr);
|
|
|
|
|
g_req_handler.reset(nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(COMPLETE, CPU) {
|
|
|
|
|
g_req_handler.reset(new distributed::RequestSendHandler(true));
|
|
|
|
|
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
|
|
|
|
|
distributed::RPCClient* client =
|
|
|
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
|
|
|
|
|
PADDLE_ENFORCE(client != nullptr);
|
|
|
|
|
std::thread server_thread(StartServer, distributed::kRequestSend);
|
|
|
|
|
g_rpc_service->WaitServerReady();
|
|
|
|
|
int port = g_rpc_service->GetSelectedPort();
|
|
|
|
|
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
|
|
|
|
|
client->AsyncSendComplete(ep);
|
|
|
|
|
client->Wait();
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
|
|
|
|
|
|
|
|
|
|
g_rpc_service->ShutDown();
|
|
|
|
|
server_thread.join();
|
|
|
|
|
g_rpc_service.reset(nullptr);
|
|
|
|
|
g_req_handler.reset(nullptr);
|
|
|
|
|
}
|
|
|
|
|