diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index c5c7de678e..3af9a83f6a 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -140,6 +140,21 @@ def _map_path(url, root_dir): return osp.join(root_dir, fpath) +def _get_unique_endpoints(trainer_endpoints): + # Sorting is to avoid different environmental variables for each card + trainer_endpoints.sort() + ips = set() + unique_endpoints = set() + for endpoint in trainer_endpoints: + ip = endpoint.split(":")[0] + if ip in ips: + continue + ips.add(ip) + unique_endpoints.add(endpoint) + logger.info("unique_endpoints {}".format(unique_endpoints)) + return unique_endpoints + + def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): """ Download from given url to root_dir. if file or directory specified by url is exists under @@ -161,17 +176,20 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): assert is_url(url), "downloading from {} not a url".format(url) # parse path after download to decompress under root_dir fullpath = _map_path(url, root_dir) - + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download + # data, and the same ip will only download data once. + unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): logger.info("Found {}".format(fullpath)) else: - if ParallelEnv().local_rank == 0: + if ParallelEnv().current_endpoint in unique_endpoints: fullpath = _download(url, root_dir, md5sum) else: while not os.path.exists(fullpath): time.sleep(1) - if ParallelEnv().local_rank == 0: + if ParallelEnv().current_endpoint in unique_endpoints: if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath): fullpath = _decompress(fullpath)