【paddle.fleet】fleet_util move to paddle.fleet (#25805)

* test=develop,test=document_fix, remove the out args

* fleet_util move to paddle.fleet

Co-authored-by: WuHaobo <wuhaobo1994@gmail.com>
Co-authored-by: tangwei12 <tangwei12@baidu.com>
revert-24895-update_cub
123malin 5 years ago committed by GitHub
parent 0d4ce6ac5d
commit 2191a08317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,10 +54,9 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
paddle::framework::fs_remove(tmp); paddle::framework::fs_remove(tmp);
if (i == retry_times_) { if (i == retry_times_) {
VLOG(0) << "fs_open_write failed, retry times reaches limit"; VLOG(0) << "fs_open_write failed, retry times reaches limit";
// PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
// "fs_open_write failed, retry times reaches" "fs_open_write failed, retry times reaches %d limit.",
// " limit ", retry_times_));
// retry_times_));
} }
} else { } else {
break; break;
@ -143,9 +142,9 @@ void HdfsStore::wait(const std::vector<std::string>& keys,
break; break;
} }
} }
// PADDLE_THROW(platform::errors::ExecutionTimeout( PADDLE_THROW(paddle::platform::errors::ExecutionTimeout(
VLOG(0) << "TIMEOUT self_rank = " << self_rank_ "TIMEOUT self_rank = %d pair_rank = %d", self_rank_,
<< " pair_rank = " << last_check_rank; last_check_rank));
} }
std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_)); std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,18 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fs import *
from .http_server import KVHandler, KVHTTPServer, KVServer
__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__

File diff suppressed because it is too large Load Diff

@ -0,0 +1,195 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Http Server."""
import logging
import six
# NOTE: HTTPServer has a different name in python2 and python3
if six.PY2:
from BaseHTTPServer import HTTPServer
import SimpleHTTPServer
else:
from http.server import HTTPServer
import http.server as SimpleHTTPServer
import time
import threading
import socket
def get_logger(name, level, fmt):
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.FileHandler('http.log', mode='w')
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
_http_server_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""
kv handler class for kv http server,
it defines the way to get/set kv in server.
"""
def do_GET(self):
"""
get method for kv handler, get value according to key.
"""
log_str = "GET " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
with self.server.kv_lock:
value = self.server.kv.get(scope, {}).get(key)
if value is None:
log_str += ' , key not found: ' + key
self.send_status_code(404)
else:
log_str += ' , key found: ' + key
self.send_response(200)
self.send_header("Content-Length", str(len(value)))
self.end_headers()
self.wfile.write(value)
_http_server_logger.info(log_str)
def do_PUT(self):
"""
put method for kv handler, set value according to key.
"""
log_str = "PUT " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
content_length = int(self.headers['Content-Length'])
try:
value = self.rfile.read(content_length)
except:
print("receive error invalid request")
self.send_status_code(404)
return
with self.server.kv_lock:
if self.server.kv.get(scope) is None:
self.server.kv[scope] = {}
self.server.kv[scope][key] = value
self.send_status_code(200)
_http_server_logger.info(log_str)
def do_DELETE(self):
"""
delete method for kv handler, set value according to key.
"""
log_str = "DELETE " + self.address_string() + self.path
paths = self.path.split('/')
if len(paths) < 3:
print('len of request path must be 3: ' + self.path)
self.send_status_code(400)
return
_, scope, key = paths
with self.server.delete_kv_lock:
if self.server.delete_kv.get(scope) is None:
self.server.delete_kv[scope] = []
self.server.delete_kv[scope].append(key)
self.send_status_code(200)
_http_server_logger.info(log_str)
def log_message(self, format, *args):
"""
ignore all logging messages in kv handler.
"""
pass
def send_status_code(self, code):
"""
send status code back to client.
"""
self.send_response(code)
self.send_header("Content-Length", 0)
self.end_headers()
class KVHTTPServer(HTTPServer, object):
"""
it is a http server storing kv pairs.
"""
def __init__(self, port, handler):
"""Init."""
super(KVHTTPServer, self).__init__(('', port), handler)
self.delete_kv_lock = threading.Lock()
self.delete_kv = {}
self.kv_lock = threading.Lock()
self.kv = {}
def get_deleted_size(self, key):
"""
get deleted size in key.
"""
ret = 0
with self.delete_kv_lock:
ret = self.delete_kv.get(key, 0)
return ret
class KVServer:
"""
it is a server storing kv pairs, has a http server inside.
"""
def __init__(self, port, size={}):
"""Init."""
self.http_server = KVHTTPServer(port, KVHandler)
self.listen_thread = None
self.size = {}
def start(self):
"""
start server until user calls stop to let it quit.
"""
self.listen_thread = threading.Thread(
target=lambda: self.http_server.serve_forever())
self.listen_thread.start()
def stop(self):
"""
stop server and clear its resources.
"""
self.http_server.shutdown()
self.listen_thread.join()
self.http_server.server_close()
def shoud_stop(self):
"""
return whether the server should stop.
Returns:
ret(bool): whether the server should stop
"""
for key in self.size:
s = self.http_server.get_deleted_size(key)
if s != self.size.get(key, 0):
return False
return True

@ -21,7 +21,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import SGD from paddle.fluid.optimizer import SGD
from paddle.fluid.incubate.fleet.base.mode import Mode from paddle.fluid.incubate.fleet.base.mode import Mode
from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase from paddle.fleet.base.role_maker import RoleMakerBase
from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision
from . import mode from . import mode
@ -209,7 +209,10 @@ class Fleet(object):
self._executor = Executor(fluid.CPUPlace()) self._executor = Executor(fluid.CPUPlace())
if role_maker and not isinstance(role_maker, RoleMakerBase): if role_maker and not isinstance(role_maker, RoleMakerBase):
raise TypeError("role_maker must be an instance of RoleMakerBase") from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase as RoleMakerBaseIncubate
if role_maker and not isinstance(role_maker, RoleMakerBaseIncubate):
raise TypeError(
"role_maker must be an instance of RoleMakerBase")
self._role_maker = role_maker self._role_maker = role_maker
self._role_maker.generate_role() self._role_maker.generate_role()

@ -345,7 +345,6 @@ if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add these tests back # FIXME(typhoonzero): add these tests back
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr")
#not need #not need
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_base") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_base")

@ -28,6 +28,7 @@ import numpy as np
import ctr_dataset_reader import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from paddle.fleet.base.util_factory import fleet_util
# Fix seed for test # Fix seed for test
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
@ -181,8 +182,14 @@ class TestDistCTR2x2(FleetDistRunnerBase):
loss_val = exe.run(program=compiled_prog, loss_val = exe.run(program=compiled_prog,
fetch_list=[self.avg_cost.name]) fetch_list=[self.avg_cost.name])
loss_val = np.mean(loss_val) loss_val = np.mean(loss_val)
print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, reduce_output = fleet_util.all_reduce(
loss_val)) np.array(loss_val), mode="sum")
loss_all_trainer = fleet_util.all_gather(float(loss_val))
loss_val = float(reduce_output) / len(loss_all_trainer)
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val)
fleet_util.print_on_rank(message, 0)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
except fluid.core.EOFException: except fluid.core.EOFException:
self.reader.reset() self.reader.reset()

@ -21,6 +21,9 @@ import os
import sys import sys
import subprocess import subprocess
import six
import shutil
import numpy as np
import argparse import argparse
from contextlib import closing from contextlib import closing
import socket import socket
@ -29,7 +32,8 @@ import tempfile
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fleet.base.role_maker as role_maker
from paddle.fleet.base.util_factory import fleet_util
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
@ -48,18 +52,26 @@ class FleetDistRunnerBase(object):
""" """
def build_role(self, args): def build_role(self, args):
if args.role.upper() == "PSERVER": if args.role.upper() == "PSERVER":
role = role_maker.UserDefinedRoleMaker( role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=True,
path=args.gloo_path,
current_id=args.current_id, current_id=args.current_id,
role=role_maker.Role.SERVER, role=role_maker.Role.SERVER,
worker_num=args.trainers, worker_endpoints=args.trainer_endpoints.split(","),
server_endpoints=args.endpoints.split(",")) server_endpoints=args.endpoints.split(","))
else: else:
role = role_maker.UserDefinedRoleMaker( role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=True,
path=args.gloo_path,
current_id=args.current_id, current_id=args.current_id,
role=role_maker.Role.WORKER, role=role_maker.Role.WORKER,
worker_num=args.trainers, worker_endpoints=args.trainer_endpoints.split(","),
server_endpoints=args.endpoints.split(",")) server_endpoints=args.endpoints.split(","))
self.role = role
return role return role
def build_strategy(self, args): def build_strategy(self, args):
@ -114,26 +126,13 @@ class FleetDistRunnerBase(object):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
def run_pserver(self, args): def run_pserver(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
fleet.init_server() fleet.init_server()
fleet.run_server() fleet.run_server()
def run_dataset_trainer(self, args): def run_dataset_trainer(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
out = self.do_dataset_training(fleet) out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args): def run_pyreader_trainer(self, args):
fleet.init(self.build_role(args))
strategy = self.build_strategy(args)
avg_cost = self.net(args)
self.build_optimizer(avg_cost, strategy)
out = self.do_pyreader_training(fleet) out = self.do_pyreader_training(fleet)
def net(self, args, batch_size=4, lr=0.01): def net(self, args, batch_size=4, lr=0.01):
@ -173,10 +172,14 @@ class TestFleetBase(unittest.TestCase):
print("set begin_port:", DIST_UT_PORT) print("set begin_port:", DIST_UT_PORT)
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT, DIST_UT_PORT + 1) DIST_UT_PORT, DIST_UT_PORT + 1)
DIST_UT_PORT += 2 self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT + 2, DIST_UT_PORT + 3)
DIST_UT_PORT += 4
else: else:
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
self._geo_sgd_need_push_nums = 5 self._geo_sgd_need_push_nums = 5
@ -236,18 +239,22 @@ class TestFleetBase(unittest.TestCase):
def _run_cluster(self, model, envs): def _run_cluster(self, model, envs):
env = {'GRAD_CLIP': str(self._grad_clip_mode)} env = {'GRAD_CLIP': str(self._grad_clip_mode)}
python_path = self._python_interp python_path = self._python_interp
gloo_path = tempfile.mkdtemp()
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
python_path += " -m coverage run --branch -p" python_path += " -m coverage run --branch -p"
env.update(envs) env.update(envs)
tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format(
python_path, model, self._ps_endpoints, self._trainers, self._mode, python_path, model, self._ps_endpoints, self._tr_endpoints,
self._geo_sgd_need_push_nums, self._reader) self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format(
python_path, model, self._ps_endpoints, self._trainers, self._mode, python_path, model, self._ps_endpoints, self._tr_endpoints,
self._geo_sgd_need_push_nums, self._reader) self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path)
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
@ -284,6 +291,7 @@ class TestFleetBase(unittest.TestCase):
ps0.terminate() ps0.terminate()
ps1.terminate() ps1.terminate()
shutil.rmtree(gloo_path)
return 0, 0 return 0, 0
def check_with_place(self, def check_with_place(self,
@ -313,6 +321,9 @@ def runtime_main(test_class):
parser.add_argument( parser.add_argument(
'--role', type=str, required=True, choices=['pserver', 'trainer']) '--role', type=str, required=True, choices=['pserver', 'trainer'])
parser.add_argument('--endpoints', type=str, required=False, default="") parser.add_argument('--endpoints', type=str, required=False, default="")
parser.add_argument(
'--trainer_endpoints', type=str, required=False, default="")
parser.add_argument('--gloo_path', type=str, required=False, default="")
parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--mode', type=str, required=False, default='geo') parser.add_argument('--mode', type=str, required=False, default='geo')
@ -322,6 +333,13 @@ def runtime_main(test_class):
args = parser.parse_args() args = parser.parse_args()
model = test_class() model = test_class()
role = model.build_role(args)
fleet.init(role)
strategy = model.build_strategy(args)
avg_cost = model.net(args)
model.build_optimizer(avg_cost, strategy)
fleet_util._set_strategy(strategy)
fleet_util._set_role_maker(role)
if args.role == "pserver": if args.role == "pserver":
model.run_pserver(args) model.run_pserver(args)
else: else:

@ -22,7 +22,7 @@ from test_dist_fleet_base import TestFleetBase
class TestDistMnistSync2x2(TestFleetBase): class TestDistMnistSync2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "sync" self._mode = "async"
self._reader = "pyreader" self._reader = "pyreader"
def check_with_place(self, def check_with_place(self,

@ -40,10 +40,9 @@ class TestCloudRoleMaker(unittest.TestCase):
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
from paddle.fluid.incubate.fleet.base.role_maker import \ from paddle.fluid.incubate.fleet.base.role_maker import \
GeneralRoleMaker GeneralRoleMaker
from paddle.fluid.incubate.fleet.utils.http_server import KVHandler from paddle.fleet.utils import KVHandler
from paddle.fluid.incubate.fleet.utils.http_server import KVServer from paddle.fleet.utils import KVServer
from paddle.fluid.incubate.fleet.utils.http_server import \ from paddle.fleet.utils import KVHTTPServer
KVHTTPServer
except: except:
print("warning: no fleet, skip test_pslib_4") print("warning: no fleet, skip test_pslib_4")
return return

@ -0,0 +1,171 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test cloud role maker."""
from __future__ import print_function
import os
import unittest
import paddle.fleet.base.role_maker as role_maker
class TestRoleMakerBase(unittest.TestCase):
"""
Test cases for RoleMakerBase
"""
def test_rolemaker_base(self):
role = role_maker.RoleMakerBase()
self.assertRaises(Exception, role.is_worker)
self.assertRaises(Exception, role.is_server)
self.assertRaises(Exception, role.is_first_worker)
self.assertRaises(Exception, role.worker_num)
self.assertRaises(Exception, role.server_num)
self.assertRaises(Exception, role.worker_index)
self.assertRaises(Exception, role.server_index)
self.assertRaises(Exception, role.role_id)
trainer_endpoints = role.get_trainer_endpoints()
self.assertTrue(len(trainer_endpoints) == 0)
pserver_endpoints = role.get_pserver_endpoints()
self.assertTrue(len(pserver_endpoints) == 0)
print(role.to_string())
self.assertTrue(role._all_gather(role._node_type_comm, 1) is None)
self.assertTrue(role._all_reduce(role._node_type_comm, 1) is None)
role._barrier(role._node_type_comm)
class TestCloudRoleMaker(unittest.TestCase):
"""
Test cases for PaddleCloudRoleMaker.
"""
def setUp(self):
"""Set up, set envs."""
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.2:36001"
os.environ["POD_IP"] = "127.0.0.1"
def test_tr_rolemaker(self):
"""Test tr rolenamer."""
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ID"] = "0"
try:
import netifaces
except:
print("warning: no netifaces, skip test_tr_rolemaker")
return
ro = role_maker.PaddleCloudRoleMaker(
is_collective=False, init_gloo=False)
self.assertTrue(ro.is_worker())
self.assertFalse(ro.is_server())
self.assertEqual(ro.worker_num(), 2)
self.assertTrue(ro.is_first_worker())
worker_endpoints = ro.get_trainer_endpoints()
self.assertEqual(worker_endpoints[0], '127.0.0.1:36001')
self.assertEqual(ro.role_id(), 0)
def test_tr_rolemaker_collective(self):
ro = role_maker.PaddleCloudRoleMaker(is_collective=True)
self.assertEqual(ro.worker_num(), 2)
def test_ps_rolemaker(self):
"""Test ps rolemaker."""
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
try:
import netifaces
except:
print("warning: no netifaces, skip test_ps_rolemaker")
return
ro = role_maker.PaddleCloudRoleMaker(
is_collective=False, init_gloo=False)
self.assertEqual(ro.server_index(), 0)
self.assertFalse(ro.is_worker())
self.assertTrue(ro.is_server())
self.assertEqual(ro.server_num(), 2)
pserver_endpoints = ro.get_pserver_endpoints()
self.assertEqual(pserver_endpoints[0], '127.0.0.1:36001')
self.assertTrue(ro._all_gather(ro._all_comm, 1) is None)
self.assertTrue(ro._all_reduce(ro._all_comm, 1) is None)
def test_traing_role(self):
"""Test training role."""
os.environ["TRAINING_ROLE"] = "TEST"
try:
import netifaces
except:
print("warning: no netifaces, skip test_training_role")
return
ro = role_maker.PaddleCloudRoleMaker(is_collective=False)
self.assertRaises(ValueError, ro.generate_role)
class TestUserDefinedRoleMaker(unittest.TestCase):
"""
Test cases for UserDefinedRoleMaker.
"""
def setUp(self):
pass
def test_ps_rolemaker(self):
try:
import netifaces
except:
print("warning: no netifaces, skip test_ps_rolemaker")
return
ro = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
server_endpoints="127.0.0.1:36001,127.0.0.1:36001",
role=role_maker.Role.SERVER,
current_id=0,
worker_num=2)
self.assertEqual(ro.server_num(), 2)
ro.generate_role()
self.assertTrue(ro.is_server())
self.assertEqual(ro.role_id(), 0)
def test_tr_rolemaker(self):
try:
import netifaces
except:
print("warning: no netifaces, skip test_tr_rolemaker")
return
ro = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
server_endpoints="127.0.0.1:36001,127.0.0.1:36001",
role=role_maker.Role.WORKER,
current_id=0,
worker_num=2)
self.assertIn("127.0.0.1:36001", ro.get_pserver_endpoints())
self.assertTrue(ro.is_worker())
self.assertEqual(ro.role_id(), 0)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

@ -20,9 +20,7 @@ import os
import sys import sys
import inspect import inspect
from paddle.fluid.incubate.fleet.utils.fs import LocalFS, FS from paddle.fleet.utils import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError
class FSTest(unittest.TestCase): class FSTest(unittest.TestCase):

@ -19,9 +19,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet, T
import os import os
import sys import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS from paddle.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]

@ -21,3 +21,4 @@ prettytable
objgraph objgraph
astor astor
pathlib pathlib
netifaces

@ -152,6 +152,7 @@ packages=['paddle',
'paddle.fleet.dataset', 'paddle.fleet.dataset',
'paddle.fleet.metrics', 'paddle.fleet.metrics',
'paddle.fleet.proto', 'paddle.fleet.proto',
'paddle.fleet.utils',
'paddle.framework', 'paddle.framework',
'paddle.fluid', 'paddle.fluid',
'paddle.fluid.dygraph', 'paddle.fluid.dygraph',

Loading…
Cancel
Save