|
|
|
@ -93,6 +93,10 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
"server program should have at least 2 blocks");
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
std::vector<int> block_list;
|
|
|
|
|
for (int blkid = 1; blkid < num_blocks; ++blkid)
|
|
|
|
|
block_list.push_back(blkid);
|
|
|
|
|
auto prepared = executor.Prepare(*program, block_list);
|
|
|
|
|
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
bool exit_flag = false;
|
|
|
|
@ -143,11 +147,12 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
std::vector<std::future<void>> fs;
|
|
|
|
|
// block0 contains only listen_and_serv op, start run from block1.
|
|
|
|
|
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
|
|
|
|
|
fs.push_back(
|
|
|
|
|
framework::Async([&executor, &program, &recv_scope, blkid]() {
|
|
|
|
|
fs.push_back(framework::Async(
|
|
|
|
|
[&executor, &program, &recv_scope, &prepared, blkid]() {
|
|
|
|
|
int run_block = blkid; // thread local
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(*program, &recv_scope, run_block, false, false);
|
|
|
|
|
executor.RunPreparedContext(prepared[run_block].get(),
|
|
|
|
|
&recv_scope, false, false);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
@ -157,7 +162,9 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
// Run global block at final step, or block1 if there are only 2 blocks
|
|
|
|
|
if (num_blocks >= 2) {
|
|
|
|
|
try {
|
|
|
|
|
executor.Run(*program, &recv_scope, num_blocks - 1, false, false);
|
|
|
|
|
// executor.Run(program, &recv_scope, num_blocks - 1, false, false);
|
|
|
|
|
executor.RunPreparedContext(prepared[num_blocks - 1].get(),
|
|
|
|
|
&recv_scope, false, false);
|
|
|
|
|
} catch (std::exception &e) {
|
|
|
|
|
LOG(ERROR) << "run sub program error " << e.what();
|
|
|
|
|
}
|
|
|
|
@ -172,14 +179,11 @@ class ListenAndServOp : public framework::OperatorBase {
|
|
|
|
|
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
|
|
|
|
|
}
|
|
|
|
|
rpc_service_->SetCond(1);
|
|
|
|
|
// FIXME(typhoonzero): use another condition to sync wait clients get.
|
|
|
|
|
// NOTE: does not consider barrier request retry in here, we may use
|
|
|
|
|
// global barrier id to resolve this.
|
|
|
|
|
rpc_service_->WaitClientGet(fan_in);
|
|
|
|
|
sparse_vars.clear();
|
|
|
|
|
} // while(true)
|
|
|
|
|
|
|
|
|
|
// for (int i = 0; i < num_blocks; ++i) {
|
|
|
|
|
// delete blk_ctx_list[i];
|
|
|
|
|
// }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|