|
|
|
@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
auto epmap = Attr<std::vector<std::string>>("epmap");
|
|
|
|
|
int sync_send = Attr<int>("sync_mode");
|
|
|
|
|
auto trainer_id = Attr<int>("trainer_id");
|
|
|
|
|
|
|
|
|
|
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
|
|
|
|
|
auto height_sections = Attr<std::vector<int64_t>>("sections");
|
|
|
|
@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
|
|
|
|
|
/*
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
|
|
|
|
|
height_sections);
|
|
|
|
|
height_sections, trainer_id);
|
|
|
|
|
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
|
|
|
|
|
*/
|
|
|
|
|
VLOG(3) << "send " << ins[0];
|
|
|
|
@ -63,8 +64,7 @@ class SendOp : public framework::OperatorBase {
|
|
|
|
|
auto& ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
distributed::RPCClient* rpc_client =
|
|
|
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
|
|
|
|
|
Attr<int>("trainer_id"));
|
|
|
|
|
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
|
|
|
|
|
|
|
|
|
|
std::vector<distributed::VarHandlePtr> rets;
|
|
|
|
|
for (size_t i = 0; i < ins.size(); i++) {
|
|
|
|
|