|
|
|
@ -20,9 +20,9 @@ import shutil
|
|
|
|
|
import sys
|
|
|
|
|
import importlib
|
|
|
|
|
import paddle.dataset
|
|
|
|
|
import cPickle
|
|
|
|
|
import pickle
|
|
|
|
|
import glob
|
|
|
|
|
import cPickle as pickle
|
|
|
|
|
import pickle as pickle
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'DATA_HOME',
|
|
|
|
@ -75,13 +75,13 @@ def download(url, module_name, md5sum, save_name=None):
|
|
|
|
|
retry_limit = 3
|
|
|
|
|
while not (os.path.exists(filename) and md5file(filename) == md5sum):
|
|
|
|
|
if os.path.exists(filename):
|
|
|
|
|
print "file md5", md5file(filename), md5sum
|
|
|
|
|
print(("file md5", md5file(filename), md5sum))
|
|
|
|
|
if retry < retry_limit:
|
|
|
|
|
retry += 1
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Cannot download {0} within retry limit {1}".
|
|
|
|
|
format(url, retry_limit))
|
|
|
|
|
print "Cache file %s not found, downloading %s" % (filename, url)
|
|
|
|
|
print(("Cache file %s not found, downloading %s" % (filename, url)))
|
|
|
|
|
r = requests.get(url, stream=True)
|
|
|
|
|
total_length = r.headers.get('content-length')
|
|
|
|
|
|
|
|
|
@ -104,8 +104,9 @@ def download(url, module_name, md5sum, save_name=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_all():
|
|
|
|
|
for module_name in filter(lambda x: not x.startswith("__"),
|
|
|
|
|
dir(paddle.dataset)):
|
|
|
|
|
for module_name in [
|
|
|
|
|
x for x in dir(paddle.dataset) if not x.startswith("__")
|
|
|
|
|
]:
|
|
|
|
|
if "fetch" in dir(
|
|
|
|
|
importlib.import_module("paddle.dataset.%s" % module_name)):
|
|
|
|
|
getattr(
|
|
|
|
@ -114,8 +115,9 @@ def fetch_all():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_all_recordio(path):
|
|
|
|
|
for module_name in filter(lambda x: not x.startswith("__"),
|
|
|
|
|
dir(paddle.dataset)):
|
|
|
|
|
for module_name in [
|
|
|
|
|
x for x in dir(paddle.dataset) if not x.startswith("__")
|
|
|
|
|
]:
|
|
|
|
|
if "convert" in dir(
|
|
|
|
|
importlib.import_module("paddle.dataset.%s" % module_name)) and \
|
|
|
|
|
not module_name == "common":
|
|
|
|
@ -126,7 +128,7 @@ def fetch_all_recordio(path):
|
|
|
|
|
"convert")(ds_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
|
|
|
|
|
def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
|
|
|
|
|
"""
|
|
|
|
|
you can call the function as:
|
|
|
|
|
|
|
|
|
@ -167,7 +169,7 @@ def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
|
|
|
|
|
def cluster_files_reader(files_pattern,
|
|
|
|
|
trainer_count,
|
|
|
|
|
trainer_id,
|
|
|
|
|
loader=cPickle.load):
|
|
|
|
|
loader=pickle.load):
|
|
|
|
|
"""
|
|
|
|
|
Create a reader that yield element from the given files, select
|
|
|
|
|
a file set according trainer count and trainer_id
|
|
|
|
@ -188,7 +190,7 @@ def cluster_files_reader(files_pattern,
|
|
|
|
|
my_file_list = []
|
|
|
|
|
for idx, fn in enumerate(file_list):
|
|
|
|
|
if idx % trainer_count == trainer_id:
|
|
|
|
|
print "append file: %s" % fn
|
|
|
|
|
print(("append file: %s" % fn))
|
|
|
|
|
my_file_list.append(fn)
|
|
|
|
|
for fn in my_file_list:
|
|
|
|
|
with open(fn, "r") as f:
|
|
|
|
@ -221,7 +223,7 @@ def convert(output_path, reader, line_count, name_prefix):
|
|
|
|
|
for l in lines:
|
|
|
|
|
# FIXME(Yancey1989):
|
|
|
|
|
# dumps with protocol: pickle.HIGHEST_PROTOCOL
|
|
|
|
|
writer.write(cPickle.dumps(l))
|
|
|
|
|
writer.write(pickle.dumps(l))
|
|
|
|
|
writer.close()
|
|
|
|
|
|
|
|
|
|
lines = []
|
|
|
|
|