|
|
|
@ -15,13 +15,13 @@ limitations under the License. */
|
|
|
|
#include <limits>
|
|
|
|
#include <limits>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/detail/grpc_server.h"
|
|
|
|
#include "paddle/fluid/operators/distributed/grpc_server.h"
|
|
|
|
|
|
|
|
|
|
|
|
using ::grpc::ServerAsyncResponseWriter;
|
|
|
|
using ::grpc::ServerAsyncResponseWriter;
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
namespace detail {
|
|
|
|
namespace distributed {
|
|
|
|
enum CallStatus { PROCESS = 0, FINISH };
|
|
|
|
enum CallStatus { PROCESS = 0, FINISH };
|
|
|
|
|
|
|
|
|
|
|
|
// reference:
|
|
|
|
// reference:
|
|
|
|
@ -74,7 +74,7 @@ class RequestSend final : public RequestBase {
|
|
|
|
request_.reset(new VariableResponse(request_handler->scope(),
|
|
|
|
request_.reset(new VariableResponse(request_handler->scope(),
|
|
|
|
request_handler->dev_ctx(),
|
|
|
|
request_handler->dev_ctx(),
|
|
|
|
!request_handler->sync_mode()));
|
|
|
|
!request_handler->sync_mode()));
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
|
|
|
|
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
@ -106,7 +106,7 @@ class RequestGet final : public RequestBase {
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
::grpc::ServerCompletionQueue* cq,
|
|
|
|
RequestHandler* request_handler, int req_id)
|
|
|
|
RequestHandler* request_handler, int req_id)
|
|
|
|
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
|
|
|
|
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
|
|
|
|
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
|
|
|
|
auto method_id = static_cast<int>(distributed::GrpcMethod::kGetVariable);
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
method_id, &ctx_, &request_, &responder_, cq_, cq_,
|
|
|
|
method_id, &ctx_, &request_, &responder_, cq_, cq_,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
@ -150,7 +150,8 @@ class RequestPrefetch final : public RequestBase {
|
|
|
|
local_scope_(nullptr) {
|
|
|
|
local_scope_(nullptr) {
|
|
|
|
request_.reset(new VariableResponse(request_handler->scope(),
|
|
|
|
request_.reset(new VariableResponse(request_handler->scope(),
|
|
|
|
request_handler->dev_ctx(), true));
|
|
|
|
request_handler->dev_ctx(), true));
|
|
|
|
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
|
|
|
|
int method_id =
|
|
|
|
|
|
|
|
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
service_->RequestAsyncUnary(
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
|
|
|
|
@ -354,6 +355,6 @@ void AsyncGRPCServer::HandleRequest(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
} // namespace distributed
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|