|
|
|
@ -48,13 +48,15 @@ static void split(const std::string &str, char sep,
|
|
|
|
|
static void AsyncExecuteBlock(framework::Executor *executor,
|
|
|
|
|
framework::ExecutorPrepareContext *prepared,
|
|
|
|
|
framework::Scope *scope) {
|
|
|
|
|
framework::Async([&executor, &prepared, &scope]() {
|
|
|
|
|
std::future<void> future = framework::Async([&executor, &prepared, &scope]() {
|
|
|
|
|
try {
|
|
|
|
|
executor->RunPreparedContext(prepared, scope, false, false);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
// TODO(qiao) maybe we can remove this
|
|
|
|
|
future.wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ParallelExecuteBlocks(
|
|
|
|
@ -203,6 +205,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
framework::ProgramDesc *program,
|
|
|
|
|
framework::Scope *recv_scope,
|
|
|
|
|
framework::BlockDesc *prefetch_block) const {
|
|
|
|
|
VLOG(3) << "RunAsyncLoop in";
|
|
|
|
|
// grad name to block id
|
|
|
|
|
std::unordered_map<std::string, int32_t> grad_to_id;
|
|
|
|
|
std::unordered_map<int32_t, std::string> id_to_grad;
|
|
|
|
@ -210,7 +213,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
|
|
|
|
|
for (auto &grad_and_id : grad_to_id_str) {
|
|
|
|
|
std::vector<std::string> pieces;
|
|
|
|
|
split(grad_and_id, ' ', &pieces);
|
|
|
|
|
split(grad_and_id, ':', &pieces);
|
|
|
|
|
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(pieces.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
|
|
|
|
|
int block_id = std::stoi(pieces[1]);
|
|
|
|
@ -223,14 +227,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
|
|
|
|
|
std::vector<int> block_list;
|
|
|
|
|
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
|
|
|
|
|
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
|
|
|
|
|
block_list.push_back(blkid);
|
|
|
|
|
}
|
|
|
|
|
block_list.push_back(blkid);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(),
|
|
|
|
|
"grad num should be equal to optimize block num");
|
|
|
|
|
auto optimize_prepared = executor->Prepare(*program, block_list);
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string,
|
|
|
|
|
std::shared_ptr<framework::ExecutorPrepareContext>>
|
|
|
|
|
grad_to_prepared;
|
|
|
|
@ -238,6 +237,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "RunAsyncLoop into while";
|
|
|
|
|
bool exit_flag = false;
|
|
|
|
|
while (!exit_flag) {
|
|
|
|
|
const detail::ReceivedMessage v = rpc_service_->Get();
|
|
|
|
@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
|
|
|
|
|
recv_scope);
|
|
|
|
|
&(v.second->GetLocalScope()));
|
|
|
|
|
// TODO(qiao): explain why
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|