|
|
|
@ -29,8 +29,6 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/operators/detail/simple_block_queue.h"
|
|
|
|
|
#include "paddle/string/printf.h"
|
|
|
|
|
|
|
|
|
|
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -95,7 +93,6 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
@ -103,38 +100,50 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
bool exit_flag = false;
|
|
|
|
|
size_t barrier_size = param_count * fan_in;
|
|
|
|
|
while (!exit_flag) {
|
|
|
|
|
// Get from multiple trainers, we don't care about the order in which
|
|
|
|
|
// the gradients arrives, just add suffix 0~n and merge the gradient.
|
|
|
|
|
rpc_service_->SetCond(0);
|
|
|
|
|
for (size_t i = 0; i < barrier_size; ++i) {
|
|
|
|
|
size_t recv_var_cnt = 0;
|
|
|
|
|
int batch_barrier = 0;
|
|
|
|
|
while (batch_barrier != fan_in) {
|
|
|
|
|
const detail::MessageWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
|
|
|
|
|
LOG(INFO) << "received terminate message and exit";
|
|
|
|
|
exit_flag = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
} else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
|
|
|
|
|
VLOG(3) << "recv batch barrier message";
|
|
|
|
|
batch_barrier++;
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
if (fan_in > 1) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
// receive a variable
|
|
|
|
|
recv_var_cnt++;
|
|
|
|
|
auto it =
|
|
|
|
|
std::find(grad_list.begin(), grad_list.end(), grad_var_name);
|
|
|
|
|
std::string param_var_name;
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
|
|
|
|
|
if (fan_in > 1) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
detail::DeserializeFromMessage(v.second, dev_ctx, var);
|
|
|
|
|
}
|
|
|
|
|
detail::DeserializeFromMessage(v.second, dev_ctx, var);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
|
|
|
|
|
// TODO(Yancey1989): merge SelectedRows variables here
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
rpc_service_->WaitClientGet(barrier_size);
|
|
|
|
|
rpc_service_->WaitClientGet(recv_var_cnt);
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|