fix the default value prefetch_var_name_to_block_id

wangkuiyi-patch-1
qiaolongfei 7 years ago
parent 16658f7b59
commit 2b9ff39f5f

@ -340,7 +340,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch block to run on server side.");
"prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}

@ -530,19 +530,23 @@ class DistributeTranspiler:
else:
assert len(prefetch_var_name_to_block_id) == 0
attrs = {
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
}
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"prefetch_var_name_to_block_id": prefetch_var_name_to_block_id,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
})
attrs=attrs)
pserver_program.sync_with_cpp()
return pserver_program

Loading…
Cancel
Save