|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <unistd.h>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
@ -22,6 +23,7 @@ limitations under the License. */
|
|
|
|
|
using ::grpc::ServerAsyncResponseWriter;
|
|
|
|
|
|
|
|
|
|
DECLARE_bool(rpc_disable_reuse_port);
|
|
|
|
|
DECLARE_int32(rpc_retry_bind_port);
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void AsyncGRPCServer::StartServer() {
|
|
|
|
|
::grpc::ServerBuilder builder;
|
|
|
|
|
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
|
|
|
|
|
&selected_port_);
|
|
|
|
|
|
|
|
|
|
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
|
|
|
|
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
|
|
|
|
if (FLAGS_rpc_disable_reuse_port) {
|
|
|
|
|
builder.SetOption(
|
|
|
|
|
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
|
|
|
|
|
}
|
|
|
|
|
builder.RegisterService(&service_);
|
|
|
|
|
for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) {
|
|
|
|
|
::grpc::ServerBuilder builder;
|
|
|
|
|
std::unique_ptr<GrpcService::AsyncService> service(
|
|
|
|
|
new GrpcService::AsyncService());
|
|
|
|
|
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
|
|
|
|
|
&selected_port_);
|
|
|
|
|
|
|
|
|
|
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
|
|
|
|
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
|
|
|
|
if (FLAGS_rpc_disable_reuse_port) {
|
|
|
|
|
builder.SetOption(
|
|
|
|
|
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
|
|
|
|
|
}
|
|
|
|
|
builder.RegisterService(service.get());
|
|
|
|
|
|
|
|
|
|
for (auto t : rpc_call_map_) {
|
|
|
|
|
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server_ = builder.BuildAndStart();
|
|
|
|
|
if (selected_port_ != 0) {
|
|
|
|
|
LOG(INFO) << "Server listening on " << bind_address_
|
|
|
|
|
<< " successful, selected port: " << selected_port_;
|
|
|
|
|
service_.reset(service.release());
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(WARNING) << "Server listening on " << bind_address_
|
|
|
|
|
<< " failed, selected port: " << selected_port_
|
|
|
|
|
<< ", retry after 3 seconds!";
|
|
|
|
|
|
|
|
|
|
for (auto t : rpc_call_map_) {
|
|
|
|
|
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
|
|
|
|
|
sleep(3);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
server_ = builder.BuildAndStart();
|
|
|
|
|
LOG(INFO) << "Server listening on " << bind_address_
|
|
|
|
|
<< " selected port: " << selected_port_;
|
|
|
|
|
PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s",
|
|
|
|
|
bind_address_);
|
|
|
|
|
|
|
|
|
|
std::function<void(const std::string&, int)> f =
|
|
|
|
|
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
|
|
|
|
@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
|
|
|
|
|
|
|
|
|
|
RequestBase* b = nullptr;
|
|
|
|
|
if (rpc_name == kRequestSend) {
|
|
|
|
|
b = new RequestSend(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestSend(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestGet) {
|
|
|
|
|
b = new RequestGet(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestGet(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
|
|
|
|
|
} else if (rpc_name == kRequestGetNoBarrier) {
|
|
|
|
|
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestGetMonomerVariable) {
|
|
|
|
|
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
|
|
|
|
|
b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id,
|
|
|
|
|
this);
|
|
|
|
|
} else if (rpc_name == kRequestGetMonomerBarrier) {
|
|
|
|
|
b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id,
|
|
|
|
|
b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id,
|
|
|
|
|
this);
|
|
|
|
|
} else if (rpc_name == kRequestPrefetch) {
|
|
|
|
|
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestCheckpoint) {
|
|
|
|
|
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
} else if (rpc_name == kRequestNotify) {
|
|
|
|
|
b = new RequestNotify(&service_, cq.get(), handler, req_id);
|
|
|
|
|
b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(false, "not supported rpc");
|
|
|
|
|
}
|
|
|
|
|