|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
import hashlib
|
|
|
|
|
import unittest
|
|
|
|
|
import os
|
|
|
|
|
import io
|
|
|
|
|
import numpy as np
|
|
|
|
|
import time
|
|
|
|
|
import sys
|
|
|
|
@ -23,10 +24,9 @@ from PIL import Image
|
|
|
|
|
import math
|
|
|
|
|
from paddle.dataset.common import download
|
|
|
|
|
import tarfile
|
|
|
|
|
from six.moves import StringIO
|
|
|
|
|
import argparse
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
|
|
random.seed(0)
|
|
|
|
|
np.random.seed(0)
|
|
|
|
|
|
|
|
|
|
DATA_DIM = 224
|
|
|
|
@ -34,7 +34,7 @@ SIZE_FLOAT32 = 4
|
|
|
|
|
SIZE_INT64 = 8
|
|
|
|
|
FULL_SIZE_BYTES = 30106000008
|
|
|
|
|
FULL_IMAGES = 50000
|
|
|
|
|
TARGET_HASH = '22d2e0008dca693916d9595a5ea3ded8'
|
|
|
|
|
TARGET_HASH = '0be07c2c23296b97dad83c626682c66a'
|
|
|
|
|
FOLDER_NAME = "ILSVRC2012/"
|
|
|
|
|
VALLIST_TAR_NAME = "ILSVRC2012/val_list.txt"
|
|
|
|
|
CHUNK_SIZE = 8192
|
|
|
|
@ -55,8 +55,8 @@ def crop_image(img, target_size, center):
|
|
|
|
|
width, height = img.size
|
|
|
|
|
size = target_size
|
|
|
|
|
if center == True:
|
|
|
|
|
w_start = (width - size) / 2
|
|
|
|
|
h_start = (height - size) / 2
|
|
|
|
|
w_start = (width - size) // 2
|
|
|
|
|
h_start = (height - size) // 2
|
|
|
|
|
else:
|
|
|
|
|
w_start = np.random.randint(0, width - size + 1)
|
|
|
|
|
h_start = np.random.randint(0, height - size + 1)
|
|
|
|
@ -95,11 +95,9 @@ def download_concat(cache_folder, zip_path):
|
|
|
|
|
file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1])
|
|
|
|
|
file_names.append(file_name)
|
|
|
|
|
print("Downloaded part {0}\n".format(file_name))
|
|
|
|
|
if not os.path.exists(zip_path):
|
|
|
|
|
with open(zip_path, "w+") as outfile:
|
|
|
|
|
for fname in file_names:
|
|
|
|
|
with open(fname) as infile:
|
|
|
|
|
outfile.write(infile.read())
|
|
|
|
|
with open(zip_path, "wb") as outfile:
|
|
|
|
|
for fname in file_names:
|
|
|
|
|
shutil.copyfileobj(open(fname, 'rb'), outfile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_processbar(done_percentage):
|
|
|
|
@ -114,12 +112,12 @@ def check_integrity(filename, target_hash):
|
|
|
|
|
print('\nThe binary file exists. Checking file integrity...\n')
|
|
|
|
|
md = hashlib.md5()
|
|
|
|
|
count = 0
|
|
|
|
|
onepart = FULL_SIZE_BYTES / CHUNK_SIZE / 100
|
|
|
|
|
with open(filename) as ifs:
|
|
|
|
|
onepart = FULL_SIZE_BYTES // CHUNK_SIZE // 100
|
|
|
|
|
with open(filename, 'rb') as ifs:
|
|
|
|
|
while True:
|
|
|
|
|
buf = ifs.read(CHUNK_SIZE)
|
|
|
|
|
if count % onepart == 0:
|
|
|
|
|
done = count / onepart
|
|
|
|
|
done = count // onepart
|
|
|
|
|
print_processbar(done)
|
|
|
|
|
count = count + 1
|
|
|
|
|
if not buf:
|
|
|
|
@ -142,28 +140,26 @@ def convert_Imagenet_tar2bin(tar_file, output_file):
|
|
|
|
|
for tarInfo in tar:
|
|
|
|
|
if tarInfo.isfile() and tarInfo.name != VALLIST_TAR_NAME:
|
|
|
|
|
dataset[tarInfo.name] = tar.extractfile(tarInfo).read()
|
|
|
|
|
|
|
|
|
|
with open(output_file, "w+b") as ofs:
|
|
|
|
|
ofs.seek(0)
|
|
|
|
|
num = np.array(int(FULL_IMAGES)).astype('int64')
|
|
|
|
|
ofs.write(num.tobytes())
|
|
|
|
|
|
|
|
|
|
per_percentage = FULL_IMAGES / 100
|
|
|
|
|
per_percentage = FULL_IMAGES // 100
|
|
|
|
|
|
|
|
|
|
val_info = tar.getmember(VALLIST_TAR_NAME)
|
|
|
|
|
val_list = tar.extractfile(val_info).read().decode("utf-8")
|
|
|
|
|
lines = val_list.splitlines()
|
|
|
|
|
idx = 0
|
|
|
|
|
for imagedata in dataset.values():
|
|
|
|
|
img = Image.open(StringIO(imagedata))
|
|
|
|
|
img = Image.open(io.BytesIO(imagedata))
|
|
|
|
|
img = process_image(img)
|
|
|
|
|
np_img = np.array(img)
|
|
|
|
|
ofs.write(np_img.astype('float32').tobytes())
|
|
|
|
|
if idx % per_percentage == 0:
|
|
|
|
|
print_processbar(idx / per_percentage)
|
|
|
|
|
print_processbar(idx // per_percentage)
|
|
|
|
|
idx = idx + 1
|
|
|
|
|
|
|
|
|
|
val_info = tar.getmember(VALLIST_TAR_NAME)
|
|
|
|
|
val_list = tar.extractfile(val_info).read()
|
|
|
|
|
|
|
|
|
|
lines = val_list.split('\n')
|
|
|
|
|
val_dict = {}
|
|
|
|
|
for line_idx, line in enumerate(lines):
|
|
|
|
|
if line_idx == FULL_IMAGES:
|
|
|
|
|