Retry when failed to bind address. (#20642)

revert-20712-fix_depthwise_conv
gongweibao 5 years ago committed by GitHub
parent 3e831b6083
commit f3f52fc1e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <unistd.h>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
@ -22,6 +23,7 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
DECLARE_bool(rpc_disable_reuse_port); DECLARE_bool(rpc_disable_reuse_port);
DECLARE_int32(rpc_retry_bind_port);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -452,25 +454,42 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
}; };
void AsyncGRPCServer::StartServer() { void AsyncGRPCServer::StartServer() {
::grpc::ServerBuilder builder; for (int i = 0; i < FLAGS_rpc_retry_bind_port; i++) {
builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(), ::grpc::ServerBuilder builder;
&selected_port_); std::unique_ptr<GrpcService::AsyncService> service(
new GrpcService::AsyncService());
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); &selected_port_);
if (FLAGS_rpc_disable_reuse_port) {
builder.SetOption( builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
} if (FLAGS_rpc_disable_reuse_port) {
builder.RegisterService(&service_); builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
}
builder.RegisterService(service.get());
for (auto t : rpc_call_map_) {
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
}
server_ = builder.BuildAndStart();
if (selected_port_ != 0) {
LOG(INFO) << "Server listening on " << bind_address_
<< " successful, selected port: " << selected_port_;
service_.reset(service.release());
break;
}
LOG(WARNING) << "Server listening on " << bind_address_
<< " failed, selected port: " << selected_port_
<< ", retry after 3 seconds!";
for (auto t : rpc_call_map_) { sleep(3);
rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
} }
server_ = builder.BuildAndStart(); PADDLE_ENFORCE_NE(selected_port_, 0, "can't bind to address:%s",
LOG(INFO) << "Server listening on " << bind_address_ bind_address_);
<< " selected port: " << selected_port_;
std::function<void(const std::string&, int)> f = std::function<void(const std::string&, int)> f =
std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this, std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
@ -547,24 +566,24 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
RequestBase* b = nullptr; RequestBase* b = nullptr;
if (rpc_name == kRequestSend) { if (rpc_name == kRequestSend) {
b = new RequestSend(&service_, cq.get(), handler, req_id); b = new RequestSend(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) { } else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id); b = new RequestGet(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) { } else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id); b = new RequestGetNoBarrier(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) { } else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id, b = new RequestGetMonomerVariable(service_.get(), cq.get(), handler, req_id,
this); this);
} else if (rpc_name == kRequestGetMonomerBarrier) { } else if (rpc_name == kRequestGetMonomerBarrier) {
b = new RequestGetMonomerBarrier(&service_, cq.get(), handler, req_id, b = new RequestGetMonomerBarrier(service_.get(), cq.get(), handler, req_id,
this); this);
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) { } else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) { } else if (rpc_name == kRequestNotify) {
b = new RequestNotify(&service_, cq.get(), handler, req_id); b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }

@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
@ -67,7 +68,7 @@ class AsyncGRPCServer final : public RPCServer {
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
GrpcService::AsyncService service_; std::unique_ptr<GrpcService::AsyncService> service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
// condition of the sub program // condition of the sub program

@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not."); DEFINE_bool(rpc_disable_reuse_port, false, "Disable SO_REUSEPORT or not.");
DEFINE_int32(rpc_retry_bind_port, 3,
"Retry to bind the address if address is already used.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -192,6 +192,7 @@ def __bootstrap__():
read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port') read_env_flags.append('rpc_disable_reuse_port')
read_env_flags.append('rpc_retry_bind_port')
read_env_flags.append('worker_update_interval_secs') read_env_flags.append('worker_update_interval_secs')

@ -846,6 +846,7 @@ class TestDistBase(unittest.TestCase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast "FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"FLAGS_rpc_retry_bind_port": "50",
"FLAGS_cudnn_deterministic": "1", "FLAGS_cudnn_deterministic": "1",
"http_proxy": "", "http_proxy": "",
"NCCL_P2P_DISABLE": "1", "NCCL_P2P_DISABLE": "1",

@ -105,7 +105,7 @@ def gen_complete_file_flag(flag_file):
class TestListenAndServOp(unittest.TestCase): class TestListenAndServOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 200
self.ip = "127.0.0.1" self.ip = "127.0.0.1"
self.port = "0" self.port = "0"
self.trainers = 1 self.trainers = 1

Loading…
Cancel
Save