|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/detail/macros.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
|
|
|
|
|
#include "paddle/fluid/operators/listen_and_serv_op.h"
|
|
|
|
@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
framework::Executor *executor, framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope,
|
|
|
|
|
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
|
|
|
|
|
const std::vector<int> &prefetch_block_id_list,
|
|
|
|
|
const int checkpoint_point_block_id) const {
|
|
|
|
|
const int checkpoint_point_block_id,
|
|
|
|
|
const std::vector<std::string> &recv_varnames) const {
|
|
|
|
|
VLOG(2) << "RunSyncLoop";
|
|
|
|
|
size_t num_blocks = program->Size();
|
|
|
|
|
auto optimize_blocks =
|
|
|
|
@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
|
|
|
|
|
|
|
|
|
|
// reset received sparse vars to avoid reuse it in the next mini-batch
|
|
|
|
|
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
|
|
|
|
|
->ResetSparseVarRecorder();
|
|
|
|
|
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
|
|
|
|
|
!rpc_service_->NeedResetAllVars());
|
|
|
|
|
|
|
|
|
|
rpc_service_->SetCond(distributed::kRequestGet);
|
|
|
|
|
rpc_service_->WaitBarrier(distributed::kRequestGet);
|
|
|
|
@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop(
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::ResetReceivedVars(
|
|
|
|
|
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
|
|
|
|
|
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
|
|
|
|
|
for (auto &varname : recv_varnames) {
|
|
|
|
|
auto var = recv_scope->FindVar(varname);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
VLOG(2) << "can not find var " << varname << " in received scope";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|
}
|
|
|
|
|
if (!only_sparse_vars) {
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
|
|
|
|
|
static_cast<float>(0));
|
|
|
|
|
} else if (var->IsType<framework::Tensor>()) {
|
|
|
|
|
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
|
|
|
|
|
static_cast<float>(0));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"received var should be in [SelectedRows, LoDTensor, Tensor]");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope) const {
|
|
|
|
@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
|
|
|
|
|
bool sync_mode = Attr<bool>("sync_mode");
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
auto inputs = Inputs("X");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!rpc_service_);
|
|
|
|
|
std::string endpoint = Attr<std::string>("endpoint");
|
|
|
|
@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
// Write to a file of server selected port for python use.
|
|
|
|
|
SavePort();
|
|
|
|
|
if (sync_mode) {
|
|
|
|
|
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
|
|
|
|
|
checkpoint_block_id);
|
|
|
|
|
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
|
|
|
|
|
prefetch_block_id_list, checkpoint_block_id, inputs);
|
|
|
|
|
} else {
|
|
|
|
|
RunAsyncLoop(&executor, program, &recv_scope);
|
|
|
|
|
}
|
|
|
|
|