|
|
|
@ -59,7 +59,7 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
|
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
|
|
|
|
fleet.init(role)
|
|
|
|
|
default_util = fleet.util
|
|
|
|
|
default_util = fleet.util()
|
|
|
|
|
self.assertEqual(default_util, None)
|
|
|
|
|
|
|
|
|
|
def test_set_user_defined_util(self):
|
|
|
|
@ -76,8 +76,8 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
|
|
|
|
fleet.init(role)
|
|
|
|
|
my_util = UserDefinedUtil()
|
|
|
|
|
fleet.util = my_util
|
|
|
|
|
user_id = fleet.util.get_user_id()
|
|
|
|
|
fleet.set_util(my_util)
|
|
|
|
|
user_id = fleet.util().get_user_id()
|
|
|
|
|
self.assertEqual(user_id, 10)
|
|
|
|
|
|
|
|
|
|
def test_fs(self):
|
|
|
|
@ -88,97 +88,6 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
self.assertFalse(fs.need_upload_download())
|
|
|
|
|
fleet_util._set_file_system(fs)
|
|
|
|
|
|
|
|
|
|
def test_barrier(self):
|
|
|
|
|
try:
|
|
|
|
|
import netifaces
|
|
|
|
|
except:
|
|
|
|
|
print("warning: no netifaces, skip test_barrier")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
gloo = fluid.core.Gloo()
|
|
|
|
|
gloo.set_rank(0)
|
|
|
|
|
gloo.set_size(1)
|
|
|
|
|
gloo.set_prefix("123")
|
|
|
|
|
gloo.set_iface("lo")
|
|
|
|
|
gloo.set_hdfs_store("./tmp_test_fleet_barrier", "", "")
|
|
|
|
|
gloo.init()
|
|
|
|
|
|
|
|
|
|
role = role_maker.UserDefinedRoleMaker(
|
|
|
|
|
is_collective=False,
|
|
|
|
|
init_gloo=False,
|
|
|
|
|
current_id=0,
|
|
|
|
|
role=role_maker.Role.SERVER,
|
|
|
|
|
worker_endpoints=["127.0.0.1:6003"],
|
|
|
|
|
server_endpoints=["127.0.0.1:6001"])
|
|
|
|
|
role._node_type_comm = gloo
|
|
|
|
|
role._role_is_generated = True
|
|
|
|
|
fleet_util._set_role_maker(role)
|
|
|
|
|
|
|
|
|
|
fleet_util.barrier("worker")
|
|
|
|
|
|
|
|
|
|
def test_all_reduce(self):
|
|
|
|
|
try:
|
|
|
|
|
import netifaces
|
|
|
|
|
except:
|
|
|
|
|
print("warning: no netifaces, skip test_all_reduce")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
gloo = fluid.core.Gloo()
|
|
|
|
|
gloo.set_rank(0)
|
|
|
|
|
gloo.set_size(1)
|
|
|
|
|
gloo.set_prefix("123")
|
|
|
|
|
gloo.set_iface("lo")
|
|
|
|
|
gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "")
|
|
|
|
|
gloo.init()
|
|
|
|
|
|
|
|
|
|
role = role_maker.UserDefinedRoleMaker(
|
|
|
|
|
is_collective=False,
|
|
|
|
|
init_gloo=False,
|
|
|
|
|
current_id=0,
|
|
|
|
|
role=role_maker.Role.WORKER,
|
|
|
|
|
worker_endpoints=["127.0.0.1:6003"],
|
|
|
|
|
server_endpoints=["127.0.0.1:6001"])
|
|
|
|
|
role._node_type_comm = gloo
|
|
|
|
|
role._role_is_generated = True
|
|
|
|
|
fleet_util._set_role_maker(role)
|
|
|
|
|
|
|
|
|
|
output = fleet_util.all_reduce(1, "sum", comm_world="server")
|
|
|
|
|
print(output)
|
|
|
|
|
|
|
|
|
|
# self.assertEqual(output, 1)
|
|
|
|
|
|
|
|
|
|
def test_all_gather(self):
|
|
|
|
|
try:
|
|
|
|
|
import netifaces
|
|
|
|
|
except:
|
|
|
|
|
print("warning: no netifaces, skip test_all_gather")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
gloo = fluid.core.Gloo()
|
|
|
|
|
gloo.set_rank(0)
|
|
|
|
|
gloo.set_size(1)
|
|
|
|
|
gloo.set_prefix("123")
|
|
|
|
|
gloo.set_iface("lo")
|
|
|
|
|
gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "")
|
|
|
|
|
gloo.init()
|
|
|
|
|
|
|
|
|
|
role = role_maker.UserDefinedRoleMaker(
|
|
|
|
|
is_collective=False,
|
|
|
|
|
init_gloo=False,
|
|
|
|
|
current_id=0,
|
|
|
|
|
role=role_maker.Role.SERVER,
|
|
|
|
|
worker_endpoints=["127.0.0.1:6003"],
|
|
|
|
|
server_endpoints=["127.0.0.1:6001"])
|
|
|
|
|
role._node_type_comm = gloo
|
|
|
|
|
role._all_comm = gloo
|
|
|
|
|
role._role_is_generated = True
|
|
|
|
|
fleet_util._set_role_maker(role)
|
|
|
|
|
|
|
|
|
|
output = fleet_util.all_gather(1, comm_world="all")
|
|
|
|
|
print(output)
|
|
|
|
|
# self.assertTrue(len(output) == 1 and output[0] == 1)
|
|
|
|
|
self.assertRaises(Exception, fleet_util.all_gather, 1, "test")
|
|
|
|
|
|
|
|
|
|
def download_files(self):
|
|
|
|
|
path = download(self.proto_data_url, self.module_name,
|
|
|
|
|
self.proto_data_md5)
|
|
|
|
|