|
|
|
@ -34,6 +34,9 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/string/printf.h"
|
|
|
|
|
#include "paddle/fluid/string/split.h"
|
|
|
|
|
|
|
|
|
|
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
|
|
|
|
|
#define STEP_COUNTER "@PS_STEP_COUNTER@"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace distributed {
|
|
|
|
|
|
|
|
|
@ -377,6 +380,37 @@ void Communicator::RpcProfilerControl() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
|
|
|
|
|
Scope *send_scope) {
|
|
|
|
|
if (batches == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto &table_id = ctx.table_id;
|
|
|
|
|
size_t request_call_num = _worker_ptr->get_server_nums();
|
|
|
|
|
|
|
|
|
|
auto &var_name = STEP_COUNTER;
|
|
|
|
|
auto *out_var = send_scope->Var(var_name);
|
|
|
|
|
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
|
|
|
|
|
data[0] = static_cast<int64_t>(batches);
|
|
|
|
|
VLOG(3) << "Communicator::SendGlobalStep send: " << batches;
|
|
|
|
|
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
|
|
|
|
|
request_call_num, [this, request_call_num](void *done) {
|
|
|
|
|
int ret = 0;
|
|
|
|
|
auto *closure = (DownpourBrpcClosure *)done;
|
|
|
|
|
for (size_t i = 0; i < request_call_num; ++i) {
|
|
|
|
|
if (closure->check_response(i, PS_PUSH_GLOBAL_STEP) != 0) {
|
|
|
|
|
ret = -1;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
closure->set_promise_value(ret);
|
|
|
|
|
});
|
|
|
|
|
auto status = _worker_ptr->push_global_step(table_id, data, closure);
|
|
|
|
|
status.wait();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncCommunicator::RecvThread() {
|
|
|
|
|
if (!independent_recv_) return;
|
|
|
|
|
VLOG(3) << "Independent RecvThread Start and Wait";
|
|
|
|
@ -465,10 +499,16 @@ void AsyncCommunicator::SendByCommunicator() {
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_nums; i++) {
|
|
|
|
|
auto &var_name = varnames[i];
|
|
|
|
|
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
|
|
|
|
|
if (var_name == STEP_COUNTER) {
|
|
|
|
|
MergeVars<int64_t>(var_name, vars[i], send_scope_.get(), 1);
|
|
|
|
|
} else {
|
|
|
|
|
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx.is_sparse) {
|
|
|
|
|
if (ctx.is_tensor_table) {
|
|
|
|
|
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
|
|
|
|
|
} else if (ctx.is_sparse) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
varnames.size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
@ -599,8 +639,18 @@ bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
|
|
|
|
|
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
|
|
|
|
|
|
|
|
|
|
auto table_name = var_tables[0];
|
|
|
|
|
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end())
|
|
|
|
|
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (table_name == STEP_COUNTER) {
|
|
|
|
|
VLOG(3) << "send step_counter into queue";
|
|
|
|
|
auto tmp_var = std::make_shared<Variable>();
|
|
|
|
|
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->Resize(framework::make_ddim({1}));
|
|
|
|
|
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
|
|
|
|
|
out_d[0] = 1;
|
|
|
|
|
send_varname_to_queue_[table_name]->Push(tmp_var);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|