|
|
|
@ -19,6 +19,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/framework/lod_tensor.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
#include <sys/time.h>
|
|
|
|
|
#include <future>
|
|
|
|
|
#include "paddle/operators/detail/grpc_client.h"
|
|
|
|
|
|
|
|
|
@ -42,28 +43,35 @@ class SendOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto& ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
auto client_var_name = Output("RPCClient");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
|
|
|
|
|
"Can not find variable '%s' in the scope.",
|
|
|
|
|
client_var_name);
|
|
|
|
|
auto* client_var = scope.FindVar(client_var_name);
|
|
|
|
|
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < ins.size(); i++) {
|
|
|
|
|
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
|
|
|
|
|
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
|
|
|
|
|
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(client_.Wait());
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
|
|
|
|
|
for (auto& ep : endpoints) {
|
|
|
|
|
VLOG(3) << "batch barrier, ep: " << ep;
|
|
|
|
|
client_.AsyncSendBatchBarrier(ep);
|
|
|
|
|
rpc_client->AsyncSendBatchBarrier(ep);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(client_.Wait());
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < outs.size(); i++) {
|
|
|
|
|
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
|
|
|
|
|
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
|
|
|
|
|
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(client_.Wait());
|
|
|
|
|
PADDLE_ENFORCE(rpc_client->Wait());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mutable detail::RPCClient client_;
|
|
|
|
|
// mutable detail::RPCClient client_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -73,6 +81,9 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
|
|
|
|
|
AddOutput("Out", "(Tensor) Output tensor to be received from server")
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddOutput("RPCClient",
|
|
|
|
|
"(RPCClient) The RPC client object which is"
|
|
|
|
|
"initialized at most once.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Send operator
|
|
|
|
|
|
|
|
|
|