add timeout and http store in communication (#23436)
* add timeout and http store in communication, add revert and confirm in fleet * test=developrevert-24314-dev/fix_err_msg
parent
1fc6cc502a
commit
1034ca316f
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Http Server."""
|
||||
|
||||
import logging
|
||||
import BaseHTTPServer
|
||||
import SimpleHTTPServer
|
||||
import time
|
||||
import threading
|
||||
import socket
|
||||
|
||||
|
||||
def get_logger(name, level, fmt):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
handler = logging.FileHandler('http.log', mode='w')
|
||||
formatter = logging.Formatter(fmt=fmt)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
_http_server_logger = get_logger(
|
||||
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
|
||||
"""
|
||||
kv handler class for kv http server,
|
||||
it defines the way to get/set kv in server.
|
||||
"""
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
get method for kv handler, get value according to key.
|
||||
"""
|
||||
log_str = "GET " + self.address_string() + self.path
|
||||
paths = self.path.split('/')
|
||||
if len(paths) < 3:
|
||||
print('len of request path must be 3: ' + self.path)
|
||||
self.send_status_code(400)
|
||||
return
|
||||
_, scope, key = paths
|
||||
with self.server.kv_lock:
|
||||
value = self.server.kv.get(scope, {}).get(key)
|
||||
if value is None:
|
||||
log_str += ' , key not found: ' + key
|
||||
self.send_status_code(404)
|
||||
else:
|
||||
log_str += ' , key found: ' + key
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Length", str(len(value)))
|
||||
self.end_headers()
|
||||
self.wfile.write(value)
|
||||
_http_server_logger.info(log_str)
|
||||
|
||||
def do_PUT(self):
|
||||
"""
|
||||
put method for kv handler, set value according to key.
|
||||
"""
|
||||
log_str = "PUT " + self.address_string() + self.path
|
||||
paths = self.path.split('/')
|
||||
if len(paths) < 3:
|
||||
print('len of request path must be 3: ' + self.path)
|
||||
self.send_status_code(400)
|
||||
return
|
||||
_, scope, key = paths
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
try:
|
||||
value = self.rfile.read(content_length)
|
||||
except:
|
||||
print("receive error invalid request")
|
||||
self.send_status_code(404)
|
||||
return
|
||||
with self.server.kv_lock:
|
||||
if self.server.kv.get(scope) is None:
|
||||
self.server.kv[scope] = {}
|
||||
self.server.kv[scope][key] = value
|
||||
self.send_status_code(200)
|
||||
_http_server_logger.info(log_str)
|
||||
|
||||
def do_DELETE(self):
|
||||
"""
|
||||
delete method for kv handler, set value according to key.
|
||||
"""
|
||||
log_str = "DELETE " + self.address_string() + self.path
|
||||
paths = self.path.split('/')
|
||||
if len(paths) < 3:
|
||||
print('len of request path must be 3: ' + self.path)
|
||||
self.send_status_code(400)
|
||||
return
|
||||
_, scope, key = paths
|
||||
with self.server.delete_kv_lock:
|
||||
if self.server.delete_kv.get(scope) is None:
|
||||
self.server.delete_kv[scope] = []
|
||||
self.server.delete_kv[scope].append(key)
|
||||
self.send_status_code(200)
|
||||
_http_server_logger.info(log_str)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""
|
||||
ignore all logging messages in kv handler.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_status_code(self, code):
|
||||
"""
|
||||
send status code back to client.
|
||||
"""
|
||||
self.send_response(code)
|
||||
self.send_header("Content-Length", 0)
|
||||
self.end_headers()
|
||||
|
||||
|
||||
class KVHTTPServer(BaseHTTPServer.HTTPServer, object):
|
||||
"""
|
||||
it is a http server storing kv pairs.
|
||||
"""
|
||||
|
||||
def __init__(self, port, handler):
|
||||
"""Init."""
|
||||
super(KVHTTPServer, self).__init__(('', port), handler)
|
||||
self.delete_kv_lock = threading.Lock()
|
||||
self.delete_kv = {}
|
||||
self.kv_lock = threading.Lock()
|
||||
self.kv = {}
|
||||
|
||||
def get_deleted_size(self, key):
|
||||
"""
|
||||
get deleted size in key.
|
||||
"""
|
||||
ret = 0
|
||||
with self.delete_kv_lock:
|
||||
ret = self.delete_kv.get(key, 0)
|
||||
return ret
|
||||
|
||||
|
||||
class KVServer:
|
||||
"""
|
||||
it is a server storing kv pairs, has a http server inside.
|
||||
"""
|
||||
|
||||
def __init__(self, port, size={}):
|
||||
"""Init."""
|
||||
self.http_server = KVHTTPServer(port, KVHandler)
|
||||
self.listen_thread = None
|
||||
self.size = {}
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
start server until user calls stop to let it quit.
|
||||
"""
|
||||
self.listen_thread = threading.Thread(
|
||||
target=lambda: self.http_server.serve_forever())
|
||||
self.listen_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop server and clear its resources.
|
||||
"""
|
||||
self.http_server.shutdown()
|
||||
self.listen_thread.join()
|
||||
self.http_server.server_close()
|
||||
|
||||
def shoud_stop(self):
|
||||
"""
|
||||
return whether the server should stop.
|
||||
|
||||
Returns:
|
||||
ret(bool): whether the server should stop
|
||||
"""
|
||||
for key in self.size:
|
||||
s = self.http_server.get_deleted_size(key)
|
||||
if s != self.size.get(key, 0):
|
||||
return False
|
||||
return True
|
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test cloud role maker."""
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
|
||||
|
||||
class TestCloudRoleMaker(unittest.TestCase):
|
||||
"""
|
||||
Test cases for PaddleCloudRoleMaker.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up, set envs."""
|
||||
os.environ["PADDLE_TRAINERS_NUM"] = "2"
|
||||
os.environ[
|
||||
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
|
||||
|
||||
def test_pslib_1(self):
|
||||
"""Test cases for pslib."""
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
|
||||
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
|
||||
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
|
||||
try:
|
||||
import netifaces
|
||||
except:
|
||||
print("warning: no netifaces, skip test_pslib_1")
|
||||
return
|
||||
os.environ["POD_IP"] = "127.0.0.1"
|
||||
os.environ["PADDLE_PORT"] = "36001"
|
||||
os.environ["TRAINING_ROLE"] = "TRAINER"
|
||||
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
|
||||
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36002"
|
||||
os.environ["PADDLE_TRAINER_ID"] = "0"
|
||||
role_maker = GeneralRoleMaker(
|
||||
init_timeout_seconds=100,
|
||||
run_timeout_seconds=100,
|
||||
http_ip_port="127.0.0.1:36003")
|
||||
role_maker.generate_role()
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
fleet.init(role_maker)
|
||||
train_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
scope = fluid.Scope()
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
show = fluid.layers.data(name="show", shape=[-1, 1], \
|
||||
dtype="float32", lod_level=1, append_batch_size=False)
|
||||
fc = fluid.layers.fc(input=show, size=1, act=None)
|
||||
label = fluid.layers.data(name="click", shape=[-1, 1], \
|
||||
dtype="int64", lod_level=1, append_batch_size=False)
|
||||
label_cast = fluid.layers.cast(label, dtype='float32')
|
||||
cost = fluid.layers.log_loss(fc, label_cast)
|
||||
try:
|
||||
adam = fluid.optimizer.Adam(learning_rate=0.000005)
|
||||
adam = fleet.distributed_optimizer(adam)
|
||||
adam.minimize([cost], [scope])
|
||||
fleet.run_server()
|
||||
http_server_d = {}
|
||||
http_server_d["running"] = False
|
||||
size_d = {}
|
||||
role_maker._GeneralRoleMaker__start_kv_server(http_server_d, size_d)
|
||||
except:
|
||||
print("do not support pslib test, skip")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,196 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test cloud role maker."""
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import unittest
|
||||
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||
|
||||
|
||||
class TestCloudRoleMaker(unittest.TestCase):
|
||||
"""
|
||||
Test cases for PaddleCloudRoleMaker.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up, set envs."""
|
||||
os.environ["PADDLE_TRAINERS_NUM"] = "2"
|
||||
os.environ[
|
||||
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
|
||||
|
||||
def test_pslib_1(self):
|
||||
"""Test cases for pslib."""
|
||||
import sys
|
||||
import threading
|
||||
import paddle.fluid as fluid
|
||||
try:
|
||||
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
|
||||
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
|
||||
from paddle.fluid.incubate.fleet.base.role_maker import \
|
||||
GeneralRoleMaker
|
||||
from paddle.fluid.incubate.fleet.utils.http_server import KVHandler
|
||||
from paddle.fluid.incubate.fleet.utils.http_server import KVServer
|
||||
from paddle.fluid.incubate.fleet.utils.http_server import \
|
||||
KVHTTPServer
|
||||
except:
|
||||
print("warning: no fleet, skip test_pslib_4")
|
||||
return
|
||||
|
||||
try:
|
||||
import netifaces
|
||||
except:
|
||||
print("warning: no netifaces, skip test_pslib_4")
|
||||
return
|
||||
|
||||
class FakeStream():
|
||||
"""
|
||||
it is a fake stream only for test.
|
||||
"""
|
||||
|
||||
def write(self, a):
|
||||
"""
|
||||
write a to stream, do nothing
|
||||
|
||||
Args:
|
||||
a(str): the string to write
|
||||
"""
|
||||
pass
|
||||
|
||||
def read(self, b):
|
||||
"""
|
||||
read data of len b from stream, do nothing
|
||||
|
||||
Args:
|
||||
b(str): the len to read
|
||||
|
||||
Returns:
|
||||
c(str): the result
|
||||
"""
|
||||
if b == 0:
|
||||
raise ValueError("this is only for test")
|
||||
return "fake"
|
||||
|
||||
import os
|
||||
|
||||
try:
|
||||
|
||||
class TmpKVHander(KVHandler):
|
||||
"""
|
||||
it is a fake handler only for this test case.
|
||||
"""
|
||||
|
||||
def __init__(self, server):
|
||||
"""Init."""
|
||||
self.path = "a/b/c"
|
||||
self.server = server
|
||||
self.wfile = FakeStream()
|
||||
self.rfile = FakeStream()
|
||||
self.headers = {}
|
||||
self.headers['Content-Length'] = 0
|
||||
|
||||
def address_string(self):
|
||||
"""
|
||||
fake address string, it will do nothing.
|
||||
"""
|
||||
return "123"
|
||||
|
||||
def send_response(self, code):
|
||||
"""
|
||||
fake send response, it will do nothing.
|
||||
|
||||
Args:
|
||||
code(int): error code
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_header(self, a, b):
|
||||
"""
|
||||
fake send header, it will do nothing.
|
||||
|
||||
Args:
|
||||
a(str): some header
|
||||
b(str): some header
|
||||
"""
|
||||
pass
|
||||
|
||||
def end_headers(self):
|
||||
"""
|
||||
fake end header, it will do nothing.
|
||||
"""
|
||||
pass
|
||||
except:
|
||||
print("warning: no KVHandler, skip test_pslib_4")
|
||||
return
|
||||
|
||||
import sys
|
||||
|
||||
try:
|
||||
|
||||
class TmpServer(KVHTTPServer):
|
||||
"""
|
||||
it is a fake server only for this test case.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
self.delete_kv_lock = threading.Lock()
|
||||
self.delete_kv = {}
|
||||
self.kv_lock = threading.Lock()
|
||||
self.kv = {}
|
||||
except:
|
||||
print("warning: no KVHTTPServer, skip test_pslib_4")
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
class TmpS(KVServer):
|
||||
"""
|
||||
it is a fake server only for this test case.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Init."""
|
||||
self.http_server = TmpServer()
|
||||
self.listen_thread = None
|
||||
self.size = {}
|
||||
self.size["a"] = 999
|
||||
except:
|
||||
print("warning: no KVServer, skip test_pslib_4")
|
||||
return
|
||||
|
||||
s = TmpServer()
|
||||
h = TmpKVHander(s)
|
||||
h.do_GET()
|
||||
h.path = "a/b"
|
||||
h.do_GET()
|
||||
h.do_PUT()
|
||||
h.do_DELETE()
|
||||
h.path = "a/b/c"
|
||||
s.kv["b"] = {}
|
||||
s.kv["b"]["c"] = "456"
|
||||
h.do_GET()
|
||||
h.path = "a/d/e"
|
||||
h.do_PUT()
|
||||
h.headers['Content-Length'] = 1
|
||||
h.do_PUT()
|
||||
h.do_DELETE()
|
||||
h.log_message("666")
|
||||
s.get_deleted_size("haha")
|
||||
s1 = TmpS()
|
||||
s1.shoud_stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue