|
|
|
@ -22,7 +22,6 @@ import tempfile
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from paddle.dataset.common import download, DATA_HOME
|
|
|
|
|
from paddle.distributed.fleet.base.util_factory import fleet_util
|
|
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -59,8 +58,7 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
|
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
|
|
|
|
fleet.init(role)
|
|
|
|
|
default_util = fleet.util()
|
|
|
|
|
self.assertEqual(default_util, None)
|
|
|
|
|
self.assertNotEqual(fleet.util, None)
|
|
|
|
|
|
|
|
|
|
def test_set_user_defined_util(self):
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
@ -76,17 +74,19 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
|
|
|
|
|
fleet.init(role)
|
|
|
|
|
my_util = UserDefinedUtil()
|
|
|
|
|
fleet.set_util(my_util)
|
|
|
|
|
user_id = fleet.util().get_user_id()
|
|
|
|
|
fleet.util = my_util
|
|
|
|
|
user_id = fleet.util.get_user_id()
|
|
|
|
|
self.assertEqual(user_id, 10)
|
|
|
|
|
|
|
|
|
|
def test_fs(self):
|
|
|
|
|
from paddle.distributed.fleet.utils.fs import LocalFS
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
from paddle.distributed.fleet.utils import LocalFS
|
|
|
|
|
|
|
|
|
|
fs = LocalFS()
|
|
|
|
|
dirs, files = fs.ls_dir("test_tmp")
|
|
|
|
|
dirs, files = fs.ls_dir("./")
|
|
|
|
|
self.assertFalse(fs.need_upload_download())
|
|
|
|
|
fleet_util._set_file_system(fs)
|
|
|
|
|
fleet.util._set_file_system(fs)
|
|
|
|
|
|
|
|
|
|
def download_files(self):
|
|
|
|
|
path = download(self.proto_data_url, self.module_name,
|
|
|
|
@ -98,7 +98,8 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
return unzip_folder
|
|
|
|
|
|
|
|
|
|
def test_get_file_shard(self):
|
|
|
|
|
self.assertRaises(Exception, fleet_util.get_file_shard, "files")
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
self.assertRaises(Exception, fleet.util.get_file_shard, "files")
|
|
|
|
|
try:
|
|
|
|
|
import netifaces
|
|
|
|
|
except:
|
|
|
|
@ -112,18 +113,20 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
role=role_maker.Role.WORKER,
|
|
|
|
|
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
|
|
|
|
|
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
|
|
|
|
|
fleet_util._set_role_maker(role)
|
|
|
|
|
files = fleet_util.get_file_shard(["1", "2", "3"])
|
|
|
|
|
fleet.init(role)
|
|
|
|
|
|
|
|
|
|
files = fleet.util.get_file_shard(["1", "2", "3"])
|
|
|
|
|
self.assertTrue(len(files) == 2 and "1" in files and "2" in files)
|
|
|
|
|
|
|
|
|
|
def test_program_type_trans(self):
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
data_dir = self.download_files()
|
|
|
|
|
program_dir = os.path.join(data_dir, self.pruned_dir)
|
|
|
|
|
text_program = "pruned_main_program.pbtxt"
|
|
|
|
|
binary_program = "pruned_main_program.bin"
|
|
|
|
|
text_to_binary = fleet_util._program_type_trans(program_dir,
|
|
|
|
|
text_to_binary = fleet.util._program_type_trans(program_dir,
|
|
|
|
|
text_program, True)
|
|
|
|
|
binary_to_text = fleet_util._program_type_trans(program_dir,
|
|
|
|
|
binary_to_text = fleet.util._program_type_trans(program_dir,
|
|
|
|
|
binary_program, False)
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
os.path.exists(os.path.join(program_dir, text_to_binary)))
|
|
|
|
@ -131,6 +134,7 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
os.path.exists(os.path.join(program_dir, binary_to_text)))
|
|
|
|
|
|
|
|
|
|
def test_prams_check(self):
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
data_dir = self.download_files()
|
|
|
|
|
|
|
|
|
|
class config:
|
|
|
|
@ -160,11 +164,11 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
# test saved var's shape
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match"
|
|
|
|
|
|
|
|
|
|
self.assertRaises(Exception, fleet_util._params_check)
|
|
|
|
|
self.assertRaises(Exception, fleet.util._params_check)
|
|
|
|
|
|
|
|
|
|
# test program.proto without feed_op and fetch_op
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.no_feed_fetch"
|
|
|
|
|
results = fleet_util._params_check(conf)
|
|
|
|
|
results = fleet.util._params_check(conf)
|
|
|
|
|
self.assertTrue(len(results) == 1)
|
|
|
|
|
np.testing.assert_array_almost_equal(
|
|
|
|
|
results[0], np.array(
|
|
|
|
@ -172,11 +176,11 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
# test feed_var's shape
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match"
|
|
|
|
|
self.assertRaises(Exception, fleet_util._params_check)
|
|
|
|
|
self.assertRaises(Exception, fleet.util._params_check)
|
|
|
|
|
|
|
|
|
|
# test correct case with feed_vars_filelist
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.pbtxt"
|
|
|
|
|
results = fleet_util._params_check(conf)
|
|
|
|
|
results = fleet.util._params_check(conf)
|
|
|
|
|
self.assertTrue(len(results) == 1)
|
|
|
|
|
np.testing.assert_array_almost_equal(
|
|
|
|
|
results[0], np.array(
|
|
|
|
@ -186,13 +190,14 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
conf.feed_config.feeded_vars_filelist = None
|
|
|
|
|
# test feed var with lod_level >= 2
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.feed_lod2"
|
|
|
|
|
self.assertRaises(Exception, fleet_util._params_check)
|
|
|
|
|
self.assertRaises(Exception, fleet.util._params_check)
|
|
|
|
|
|
|
|
|
|
conf.dump_program_filename = "pruned_main_program.pbtxt"
|
|
|
|
|
results = fleet_util._params_check(conf)
|
|
|
|
|
results = fleet.util._params_check(conf)
|
|
|
|
|
self.assertTrue(len(results) == 1)
|
|
|
|
|
|
|
|
|
|
def test_proto_check(self):
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
data_dir = self.download_files()
|
|
|
|
|
|
|
|
|
|
class config:
|
|
|
|
@ -210,7 +215,7 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
"pruned_main_program.save_var_shape_not_match"))
|
|
|
|
|
conf.is_text_pruned_program = True
|
|
|
|
|
conf.draw = False
|
|
|
|
|
res = fleet_util._proto_check(conf)
|
|
|
|
|
res = fleet.util._proto_check(conf)
|
|
|
|
|
self.assertFalse(res)
|
|
|
|
|
|
|
|
|
|
# test match
|
|
|
|
@ -222,10 +227,11 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
else:
|
|
|
|
|
conf.draw = True
|
|
|
|
|
conf.draw_out_name = "pruned_check"
|
|
|
|
|
res = fleet_util._proto_check(conf)
|
|
|
|
|
res = fleet.util._proto_check(conf)
|
|
|
|
|
self.assertTrue(res)
|
|
|
|
|
|
|
|
|
|
def test_visualize(self):
|
|
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
|
|
if sys.platform == 'win32' or sys.platform == 'sys.platform':
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
@ -234,10 +240,10 @@ class TestFleetUtil(unittest.TestCase):
|
|
|
|
|
data_dir,
|
|
|
|
|
os.path.join(self.train_dir, "join_main_program.pbtxt"))
|
|
|
|
|
is_text = True
|
|
|
|
|
program = fleet_util._load_program(program_path, is_text)
|
|
|
|
|
program = fleet.util._load_program(program_path, is_text)
|
|
|
|
|
output_dir = os.path.join(data_dir, self.train_dir)
|
|
|
|
|
output_filename = "draw_prog"
|
|
|
|
|
fleet_util._visualize_graphviz(program, output_dir, output_filename)
|
|
|
|
|
fleet.util._visualize_graphviz(program, output_dir, output_filename)
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
os.path.exists(
|
|
|
|
|
os.path.join(output_dir, output_filename + ".dot")))
|
|
|
|
|