|
|
|
@ -27,12 +27,17 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/operators/detail/grpc_server.h"
|
|
|
|
|
#include "paddle/operators/detail/sendrecvop_utils.h"
|
|
|
|
|
#include "paddle/operators/detail/simple_block_queue.h"
|
|
|
|
|
#include "paddle/string/printf.h"
|
|
|
|
|
|
|
|
|
|
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
constexpr int kCondStart = 0;
|
|
|
|
|
constexpr int kCondRunning = 1;
|
|
|
|
|
constexpr int kCondDone = 2;
|
|
|
|
|
|
|
|
|
|
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
|
|
|
|
|
service->RunSyncUpdate();
|
|
|
|
|
VLOG(4) << "RunServer thread end";
|
|
|
|
@ -77,42 +82,41 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
if (grads_counter_.find(varname) == grads_counter_.end()) {
|
|
|
|
|
grads_counter_[varname] = 0;
|
|
|
|
|
}
|
|
|
|
|
char ret[256];
|
|
|
|
|
snprintf(ret, sizeof(ret), "%s.trainer_%d", varname.c_str(),
|
|
|
|
|
grads_counter_[varname]++);
|
|
|
|
|
return std::string(ret);
|
|
|
|
|
return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
// FIXME(typhoonzero): no new scopes for every run.
|
|
|
|
|
framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
|
framework::Scope &recv_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
// FIXME(Yancey1989): initialize rpc server with laze mode.
|
|
|
|
|
rpc_service_->SetScope(&recv_scope);
|
|
|
|
|
rpc_service_->SetDevCtx(&dev_ctx);
|
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
|
auto trainer_count = Attr<int>("Trainers");
|
|
|
|
|
auto fan_in = Attr<int>("Fanin");
|
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
|
|
|
|
|
rpc_service_->Reset();
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::proto::ProgramDesc program_desc;
|
|
|
|
|
program_desc.ParseFromString(program_str);
|
|
|
|
|
framework::ProgramDesc program(program_desc);
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
bool exit_flag = false;
|
|
|
|
|
VLOG(4) << "param_count:" << param_count
|
|
|
|
|
<< " trainer_count:" << trainer_count;
|
|
|
|
|
int64_t barrier_size = param_count * fan_in;
|
|
|
|
|
while (!exit_flag) {
|
|
|
|
|
// TODO(gognwb): simply this loop.
|
|
|
|
|
// Get from multiple trainers, we don't care about order in which
|
|
|
|
|
// the gradient arrives, just add suffix 0~n then average the gradient.
|
|
|
|
|
for (size_t i = 0; i < param_count * trainer_count; ++i) {
|
|
|
|
|
// blocking get one var from client.
|
|
|
|
|
// Get from multiple trainers, we don't care about the order in which
|
|
|
|
|
// the gradients arrives, just add suffix 0~n and merge the gradient.
|
|
|
|
|
rpc_service_->SetCond(0);
|
|
|
|
|
for (size_t i = 0; i < barrier_size; ++i) {
|
|
|
|
|
const detail::MessageWithName &v = rpc_service_->Get();
|
|
|
|
|
auto grad_var_name = v.first;
|
|
|
|
|
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
|
|
|
|
|
VLOG(4) << "received LISTEN_TERMINATE_MESSAGE and RunOp.Run() exit";
|
|
|
|
|
LOG(INFO) << "received terminate message and exit";
|
|
|
|
|
exit_flag = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -121,49 +125,31 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "grad have no paired param found!\"" << grad_var_name
|
|
|
|
|
<< "\"";
|
|
|
|
|
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
|
|
|
|
|
auto *merged_grad = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (merged_grad == nullptr) {
|
|
|
|
|
auto *ptr = recv_scope.Var(grad_var_name);
|
|
|
|
|
CreateTensorFromMessageType(ptr, v.second.type());
|
|
|
|
|
VLOG(3) << "Create Variable " << grad_var_name
|
|
|
|
|
<< " on recv scope, which pointer is " << ptr << " type is "
|
|
|
|
|
<< v.second.type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (trainer_count > 1) {
|
|
|
|
|
if (fan_in > 1) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "can not find server side var: " << grad_var_name;
|
|
|
|
|
PADDLE_THROW("can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
detail::DeserializeFromMessage(v.second, dev_ctx, var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (exit_flag) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rpc_service_->Reset();
|
|
|
|
|
|
|
|
|
|
std::string program_str = Attr<std::string>("OptimizeProgram");
|
|
|
|
|
framework::proto::ProgramDesc program_desc;
|
|
|
|
|
program_desc.ParseFromString(program_str);
|
|
|
|
|
framework::ProgramDesc program(program_desc);
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
// Run sub graph to get optimized tensor
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(program, &recv_scope, 0, /*global_block*/
|
|
|
|
|
false /*create_local_scope*/, false /*create_vars*/);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rpc_service_->Done();
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
rpc_service_->WaitClientGet(barrier_size);
|
|
|
|
|
grads_counter_.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
@ -199,7 +185,7 @@ This operator will recv tensor from send_op
|
|
|
|
|
"GradList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<int>("Trainers", "type int",
|
|
|
|
|
AddAttr<int>("Fanin", "type int",
|
|
|
|
|
"Number of trainers in the current cluster job")
|
|
|
|
|
.SetDefault(1);
|
|
|
|
|
}
|
|
|
|
|