|
|
|
@ -207,18 +207,19 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
framework::BlockDesc *prefetch_block) const {
|
|
|
|
|
VLOG(3) << "RunAsyncLoop in";
|
|
|
|
|
// grad name to block id
|
|
|
|
|
std::unordered_map<std::string, int32_t> grad_to_id;
|
|
|
|
|
std::unordered_map<std::string, int32_t> grad_to_block_id;
|
|
|
|
|
std::unordered_map<int32_t, std::string> id_to_grad;
|
|
|
|
|
|
|
|
|
|
auto grad_to_id_str = Attr<std::vector<std::string>>("grad_to_id");
|
|
|
|
|
for (auto &grad_and_id : grad_to_id_str) {
|
|
|
|
|
auto grad_to_block_id_str =
|
|
|
|
|
Attr<std::vector<std::string>>("grad_to_block_id");
|
|
|
|
|
for (auto &grad_and_id : grad_to_block_id_str) {
|
|
|
|
|
std::vector<std::string> pieces;
|
|
|
|
|
split(grad_and_id, ':', &pieces);
|
|
|
|
|
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(pieces.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
|
|
|
|
|
int block_id = std::stoi(pieces[1]);
|
|
|
|
|
grad_to_id[pieces[0]] = block_id;
|
|
|
|
|
grad_to_block_id[pieces[0]] = block_id;
|
|
|
|
|
id_to_grad[block_id] = pieces[0];
|
|
|
|
|
}
|
|
|
|
|
size_t num_blocks = program->Size();
|
|
|
|
@ -232,9 +233,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
auto optimize_prepared = executor->Prepare(*program, block_list);
|
|
|
|
|
std::unordered_map<std::string,
|
|
|
|
|
std::shared_ptr<framework::ExecutorPrepareContext>>
|
|
|
|
|
grad_to_prepared;
|
|
|
|
|
grad_to_prepared_block;
|
|
|
|
|
for (size_t i = 0; i < block_list.size(); ++i) {
|
|
|
|
|
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
|
|
|
|
|
grad_to_prepared_block[id_to_grad[block_list[i]]] = optimize_prepared[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "RunAsyncLoop into while";
|
|
|
|
@ -253,8 +254,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(),
|
|
|
|
|
&(v.second->GetLocalScope()));
|
|
|
|
|
AsyncExecuteBlock(executor, grad_to_prepared_block[recv_var_name].get(),
|
|
|
|
|
v.second->GetMutableLocalScope());
|
|
|
|
|
// TODO(qiao): explain why
|
|
|
|
|
if (var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
@ -328,7 +329,7 @@ from send_op and send back variables to recv_op.
|
|
|
|
|
.SetDefault("127.0.0.1:6164")
|
|
|
|
|
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"grad_to_id",
|
|
|
|
|
"grad_to_block_id",
|
|
|
|
|
"['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
|
|
|
|
|
"a map from grad name to it's optimize block id")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|