|
|
@ -69,6 +69,9 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
auto param_list = Attr<std::vector<std::string>>("ParamList");
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
auto grad_list = Attr<std::vector<std::string>>("GradList");
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
size_t param_count = param_list.size();
|
|
|
|
|
|
|
|
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
|
|
|
|
|
|
|
|
while (true) {
|
|
|
|
|
|
|
|
// TODO(typhoonzero): get from multiple trainers.
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
for (size_t i = 0; i < param_count; ++i) {
|
|
|
|
// blocking get one var from client.
|
|
|
|
// blocking get one var from client.
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
|
const detail::TensorWithName &v = rpc_service_->Get();
|
|
|
@ -106,6 +109,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
out.second = out_var->Get<framework::LoDTensor>();
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
rpc_service_->Push(out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} // while(true)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|