single pserver workable version

del_some_in_makelist
typhoonzero 7 years ago
parent 2b47fb3d25
commit 40d0fff2e5

@ -69,6 +69,9 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
// TODO(typhoonzero): get from multiple trainers.
for (size_t i = 0; i < param_count; ++i) { for (size_t i = 0; i < param_count; ++i) {
// blocking get one var from client. // blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get(); const detail::TensorWithName &v = rpc_service_->Get();
@ -106,6 +109,7 @@ class RecvOp : public framework::OperatorBase {
out.second = out_var->Get<framework::LoDTensor>(); out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out); rpc_service_->Push(out);
} }
} // while(true)
} }
protected: protected:

@ -93,7 +93,7 @@ class Executor(object):
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=True) persistable=var.persistable)
def _optimize_distributed(self, optimize_ops, program, params_and_grads, def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs): **kwargs):

Loading…
Cancel
Save