|
|
|
@ -45,20 +45,6 @@ static void split(const std::string &str, char sep,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void AsyncExecuteBlock(framework::Executor *executor,
|
|
|
|
|
framework::ExecutorPrepareContext *prepared,
|
|
|
|
|
framework::Scope *scope) {
|
|
|
|
|
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
|
|
|
|
|
try {
|
|
|
|
|
executor->RunPreparedContext(prepared, scope, false, false);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
// TODO(qiao) maybe we can remove this
|
|
|
|
|
future.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ParallelExecuteBlocks(
|
|
|
|
|
const std::vector<size_t> ¶llel_blkids, framework::Executor *executor,
|
|
|
|
|
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
|
|
|
|
@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
|
|
|
|
|
} // while(true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void AsyncUpdateThread(
|
|
|
|
|
const std::string &var_name, const bool &exit_flag,
|
|
|
|
|
const std::shared_ptr<detail::ReceivedQueue> &queue,
|
|
|
|
|
framework::Executor *executor,
|
|
|
|
|
framework::ExecutorPrepareContext *prepared) {
|
|
|
|
|
VLOG(3) << "update thread for " << var_name << " started";
|
|
|
|
|
while (!exit_flag) {
|
|
|
|
|
const detail::ReceivedMessage v = queue->Pop();
|
|
|
|
|
auto recv_var_name = v.first;
|
|
|
|
|
auto var = v.second->GetVar();
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
auto fs = framework::Async([var_name, &executor, &v, prepared] {
|
|
|
|
|
try {
|
|
|
|
|
executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(),
|
|
|
|
|
false, false);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
fs.wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope,
|
|
|
|
|
framework::BlockDesc *prefetch_block) const {
|
|
|
|
|
framework::ProgramDesc *program) const {
|
|
|
|
|
VLOG(3) << "RunAsyncLoop in";
|
|
|
|
|
// grad name to block id
|
|
|
|
|
std::unordered_map<std::string, int32_t> grad_to_block_id;
|
|
|
|
|
std::unordered_map<int32_t, std::string> id_to_grad;
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
|
|
|
|
|
grad_to_queue;
|
|
|
|
|
|
|
|
|
|
auto grad_to_block_id_str =
|
|
|
|
|
Attr<std::vector<std::string>>("grad_to_block_id");
|
|
|
|
@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
|
|
|
|
|
int block_id = std::stoi(pieces[1]);
|
|
|
|
|
grad_to_block_id[pieces[0]] = block_id;
|
|
|
|
|
grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
|
|
|
|
|
id_to_grad[block_id] = pieces[0];
|
|
|
|
|
}
|
|
|
|
|
size_t num_blocks = program->Size();
|
|
|
|
@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "RunAsyncLoop into while";
|
|
|
|
|
bool exit_flag = false;
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "start async optimize threads";
|
|
|
|
|
std::vector<std::future<void>> fs;
|
|
|
|
|
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
|
|
|
|
|
std::string grad_name = iter->first;
|
|
|
|
|
VLOG(3) << "create async update thread for " << grad_name;
|
|
|
|
|
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
|
|
|
|
|
&grad_to_queue, &grad_to_prepared_ctx]() {
|
|
|
|
|
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
|
|
|
|
|
executor, grad_to_prepared_ctx[grad_name].get());
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "RunAsyncLoop into while";
|
|
|
|
|
while (!exit_flag) {
|
|
|
|
|
const detail::ReceivedMessage v = rpc_service_->Get();
|
|
|
|
|
auto recv_var_name = v.first;
|
|
|
|
@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "received grad: " << recv_var_name;
|
|
|
|
|
auto var = v.second->GetVar();
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(),
|
|
|
|
|
v.second->GetMutableLocalScope());
|
|
|
|
|
grad_to_queue[recv_var_name]->Push(v);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (exit_flag) {
|
|
|
|
@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
if (sync_mode) {
|
|
|
|
|
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
|
|
|
|
|
} else {
|
|
|
|
|
RunAsyncLoop(&executor, program, &recv_scope, prefetch_block);
|
|
|
|
|
RunAsyncLoop(&executor, program);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|