|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
#include <unistd.h>
|
|
|
|
|
#include <chrono> // NOLINT
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
@ -26,6 +27,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/distributed.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/rpc_client.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/rpc_server.h"
|
|
|
|
@ -35,6 +37,7 @@ namespace platform = paddle::platform;
|
|
|
|
|
namespace distributed = paddle::operators::distributed;
|
|
|
|
|
|
|
|
|
|
USE_NO_KERNEL_OP(lookup_sparse_table_read);
|
|
|
|
|
USE_NO_KERNEL_OP(checkpoint_notify);
|
|
|
|
|
USE_OP(scale);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<distributed::RPCServer> g_rpc_service;
|
|
|
|
@ -122,7 +125,7 @@ void StartServer(const std::string& rpc_name) {
|
|
|
|
|
|
|
|
|
|
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
|
|
|
|
|
|
|
|
|
|
distributed::HeartBeatMonitor::Init(2, true, "w@grad");
|
|
|
|
|
// distributed::HeartBeatMonitor::Init(1, true, "w@grad");
|
|
|
|
|
|
|
|
|
|
g_req_handler->SetRPCServer(g_rpc_service.get());
|
|
|
|
|
|
|
|
|
@ -232,3 +235,110 @@ TEST(SENDANDRECV, CPU) {
|
|
|
|
|
g_rpc_service.reset(nullptr);
|
|
|
|
|
g_req_handler.reset(nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StartCheckpointServer(const std::string& rpc_name) {
|
|
|
|
|
framework::ProgramDesc program;
|
|
|
|
|
framework::Scope scope;
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
|
framework::Executor exe(place);
|
|
|
|
|
platform::CPUDeviceContext ctx(place);
|
|
|
|
|
|
|
|
|
|
std::vector<distributed::SparseMeta> metas;
|
|
|
|
|
|
|
|
|
|
auto meta = distributed::SparseMeta();
|
|
|
|
|
meta.name = "embedding.block0";
|
|
|
|
|
meta.value_names = {"Param"};
|
|
|
|
|
meta.value_dims = {64};
|
|
|
|
|
meta.mode = distributed::Mode::training;
|
|
|
|
|
meta.grad_name = "embedding@Grad";
|
|
|
|
|
meta.cached_varnames = {"kSparseIds"};
|
|
|
|
|
meta.initializer_attrs = {"fill_constant&1.0"};
|
|
|
|
|
meta.entry = "none";
|
|
|
|
|
|
|
|
|
|
metas.push_back(meta);
|
|
|
|
|
distributed::LargeScaleKV::Init(metas);
|
|
|
|
|
|
|
|
|
|
auto* ins = distributed::LargeScaleKV::GetInstance();
|
|
|
|
|
ins->Get("embedding.block0")->Init({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string,
|
|
|
|
|
std::shared_ptr<framework::ExecutorPrepareContext>>
|
|
|
|
|
prefetch_var_name_to_prepared;
|
|
|
|
|
|
|
|
|
|
g_req_handler->SetProgram(&program);
|
|
|
|
|
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
|
|
|
|
|
g_req_handler->SetDevCtx(&ctx);
|
|
|
|
|
g_req_handler->SetScope(&scope);
|
|
|
|
|
g_req_handler->SetExecutor(&exe);
|
|
|
|
|
|
|
|
|
|
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
|
|
|
|
|
|
|
|
|
|
g_req_handler->SetRPCServer(g_rpc_service.get());
|
|
|
|
|
|
|
|
|
|
std::thread server_thread(
|
|
|
|
|
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
|
|
|
|
|
|
|
|
|
|
server_thread.join();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(LARGE_SCALE_CHECKPOINT, CPU) {
|
|
|
|
|
setenv("http_proxy", "", 1);
|
|
|
|
|
setenv("https_proxy", "", 1);
|
|
|
|
|
|
|
|
|
|
paddle::framework::Scope scope;
|
|
|
|
|
paddle::platform::CPUPlace place;
|
|
|
|
|
|
|
|
|
|
g_req_handler.reset(new distributed::RequestCheckpointHandler(
|
|
|
|
|
distributed::DistributedMode::kAsync));
|
|
|
|
|
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
|
|
|
|
|
|
|
|
|
|
distributed::RPCClient* client =
|
|
|
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(client, nullptr,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Client Start Fail, Check Your Code & Env"));
|
|
|
|
|
|
|
|
|
|
std::thread server_thread(StartCheckpointServer,
|
|
|
|
|
distributed::kRequestCheckpoint);
|
|
|
|
|
g_rpc_service->WaitServerReady();
|
|
|
|
|
|
|
|
|
|
int port = g_rpc_service->GetSelectedPort();
|
|
|
|
|
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
|
|
|
|
|
|
|
|
|
|
auto save_path =
|
|
|
|
|
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/base",
|
|
|
|
|
"embedding", "embedding.block0");
|
|
|
|
|
int mode = 0;
|
|
|
|
|
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
|
|
|
|
|
client->Wait();
|
|
|
|
|
|
|
|
|
|
save_path =
|
|
|
|
|
paddle::string::Sprintf("%s/%s/%s", "/tmp/large_scale_table/delta",
|
|
|
|
|
"embedding", "embedding.block0");
|
|
|
|
|
mode = 1;
|
|
|
|
|
client->AsyncCheckpointNotify(ep, save_path, "embedding.block0", mode);
|
|
|
|
|
client->Wait();
|
|
|
|
|
|
|
|
|
|
paddle::framework::AttributeMap attrs;
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> eps = {ep};
|
|
|
|
|
attrs["endpoints"] = eps;
|
|
|
|
|
attrs["dirname"] = std::string("/tmp/large_scale_table/delta1");
|
|
|
|
|
attrs["varname"] = std::string("embedding");
|
|
|
|
|
attrs["mode"] = 2;
|
|
|
|
|
std::vector<std::string> slices = {"embedding.block0"};
|
|
|
|
|
attrs["slice_varnames"] = slices;
|
|
|
|
|
std::vector<std::string> remotes = {"embedding.block0"};
|
|
|
|
|
attrs["remote_varnames"] = remotes;
|
|
|
|
|
|
|
|
|
|
auto ops =
|
|
|
|
|
framework::OpRegistry::CreateOp("checkpoint_notify", {}, {}, attrs, true);
|
|
|
|
|
ops->Run(scope, place);
|
|
|
|
|
|
|
|
|
|
g_rpc_service->ShutDown();
|
|
|
|
|
server_thread.join();
|
|
|
|
|
LOG(INFO) << "begin reset";
|
|
|
|
|
g_rpc_service.reset(nullptr);
|
|
|
|
|
g_req_handler.reset(nullptr);
|
|
|
|
|
}
|
|
|
|
|