|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/detail/macros.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
|
|
|
|
|
#include "paddle/fluid/operators/listen_and_serv_op.h"
|
|
|
|
@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
framework::Executor *executor, framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope,
|
|
|
|
|
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
|
|
|
|
|
const std::vector<int> &prefetch_block_id_list,
|
|
|
|
|
const int checkpoint_point_block_id) const {
|
|
|
|
|
VLOG(2) << "RunSyncLoop";
|
|
|
|
@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
rpc_service_->SetCond(distributed::kRequestGet);
|
|
|
|
|
rpc_service_->WaitBarrier(distributed::kRequestGet);
|
|
|
|
|
rpc_service_->ResetBarrierCounter();
|
|
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
|
rpc_service_->Profiler().OneStep();
|
|
|
|
|
// Get from multiple trainers, we don't care about the order in which
|
|
|
|
@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
recv_scope);
|
|
|
|
|
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
|
|
|
|
|
|
|
|
|
|
// reset received sparse vars to avoid reuse it in the next mini-batch
|
|
|
|
|
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
|
|
|
|
|
->ResetSparseVarRecorder();
|
|
|
|
|
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
|
|
|
|
|
|
|
|
|
|
rpc_service_->SetCond(distributed::kRequestGet);
|
|
|
|
|
rpc_service_->WaitBarrier(distributed::kRequestGet);
|
|
|
|
@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
|
|
|
|
|
platform::DeviceContext *dev_ctx,
|
|
|
|
|
bool reset_all) const {
|
|
|
|
|
for (auto &varname : sparse_vars_) {
|
|
|
|
|
auto var = recv_scope->FindVar(varname);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
VLOG(2) << "can not find var " << varname << " in received scope";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
VLOG(3) << "reset sparse var: " << varname;
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("The type of sparse var should be SelectedRows");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (UNLIKELY(reset_all)) {
|
|
|
|
|
for (auto &varname : dense_vars_) {
|
|
|
|
|
auto var = recv_scope->FindVar(varname);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
VLOG(2) << "can not find var " << varname << " in received scope";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
|
|
|
|
|
static_cast<float>(0));
|
|
|
|
|
} else if (var->IsType<framework::Tensor>()) {
|
|
|
|
|
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
|
|
|
|
|
static_cast<float>(0));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope) const {
|
|
|
|
@ -248,6 +284,25 @@ static void FillRequestCtx(
|
|
|
|
|
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
|
|
|
|
|
const framework::Scope &scope) const {
|
|
|
|
|
for (const auto &varname : varnames) {
|
|
|
|
|
auto var = scope.FindVar(varname);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr,
|
|
|
|
|
"Received var should be initialized in the received scope.");
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
sparse_vars_.push_back(varname);
|
|
|
|
|
} else if (var->IsType<framework::LoDTensor>() ||
|
|
|
|
|
var->IsType<framework::Tensor>()) {
|
|
|
|
|
dense_vars_.push_back(varname);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"The type of received var should be in [SelectedRows, LoDTensor, "
|
|
|
|
|
"Tensor].");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const {
|
|
|
|
|
// Mark this as PS that it should decide profiling by listening from trainer.
|
|
|
|
@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
|
|
|
|
|
bool sync_mode = Attr<bool>("sync_mode");
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
auto inputs = Inputs("X");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!rpc_service_);
|
|
|
|
|
std::string endpoint = Attr<std::string>("endpoint");
|
|
|
|
@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
signal(SIGINT, SignalHandler::StopAndExit);
|
|
|
|
|
signal(SIGTERM, SignalHandler::StopAndExit);
|
|
|
|
|
|
|
|
|
|
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
|
|
|
|
|
// so that we can reset them at the end of each iteration.
|
|
|
|
|
// NOTE: only used in sync update
|
|
|
|
|
CacheVarsType(inputs, recv_scope);
|
|
|
|
|
|
|
|
|
|
// Write to a file of server selected port for python use.
|
|
|
|
|
SavePort();
|
|
|
|
|
if (sync_mode) {
|
|
|
|
|
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
|
|
|
|
|
checkpoint_block_id);
|
|
|
|
|
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
|
|
|
|
|
prefetch_block_id_list, checkpoint_block_id);
|
|
|
|
|
} else {
|
|
|
|
|
RunAsyncLoop(&executor, program, &recv_scope);
|
|
|
|
|
}
|
|
|
|
|