Use module name and raw data filename as the local filename

avx_docs
Yi Wang 8 years ago
parent 37e2b92089
commit 91115ab6de

@ -1,7 +1,7 @@
import requests
import hashlib
import os
import shutil
import urllib2
__all__ = ['DATA_HOME', 'download', 'md5file']
@ -11,31 +11,6 @@ if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
def download(url, package_name, md5):
filename = os.path.split(url)[-1]
assert DATA_HOME is not None
filepath = os.path.join(DATA_HOME, md5)
if not os.path.exists(filepath):
os.makedirs(filepath)
__full_file__ = os.path.join(filepath, filename)
def __file_ok__():
if not os.path.exists(__full_file__):
return False
md5_hash = hashlib.md5()
with open(__full_file__, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest() == md5
while not __file_ok__():
response = urllib2.urlopen(url)
with open(__full_file__, mode='wb') as of:
shutil.copyfileobj(fsrc=response, fdst=of)
return __full_file__
def md5file(fname):
hash_md5 = hashlib.md5()
f = open(fname, "rb")
@ -43,3 +18,18 @@ def md5file(fname):
hash_md5.update(chunk)
f.close()
return hash_md5.hexdigest()
def download(url, module_name, md5sum):
dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1])
if not (os.path.exists(filename) and md5file(filename) == md5sum):
# If file doesn't exist or MD5 doesn't match, then download.
r = requests.get(url, stream=True)
with open(filename, 'w') as f:
shutil.copyfileobj(r.raw, f)
return filename

@ -5,12 +5,18 @@ import tempfile
class TestCommon(unittest.TestCase):
def test_md5file(self):
_, temp_path =tempfile.mkstemp()
f = open(temp_path, 'w')
f.write("Hello\n")
f.close()
with open(temp_path, 'w') as f:
f.write("Hello\n")
self.assertEqual(
'09f7e02f1290be211da707a266f153b3',
paddle.v2.dataset.common.md5file(temp_path))
def test_download(self):
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
self.assertEqual(
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
paddle.v2.dataset.common.download(
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save