# Copyright 2020 Huawei Technologies Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import jsbeautifier import os import urllib import urllib.request def create_data_cache_dir(): cwd = os.getcwd() target_directory = os.path.join(cwd, "data_cache") try: if not (os.path.exists(target_directory)): os.mkdir(target_directory) except OSError: print("Creation of the directory %s failed" % target_directory) return target_directory; def download_and_uncompress(files, source_url, target_directory, is_tar=False): for f in files: url = source_url + f target_file = os.path.join(target_directory, f) ##check if file already downloaded if not (os.path.exists(target_file) or os.path.exists(target_file[:-3])): urllib.request.urlretrieve(url, target_file) if is_tar: print("extracting from local tar file " + target_file) rc = os.system("tar -C " + target_directory + " -xvf " + target_file) else: print("unzipping " + target_file) rc = os.system("gunzip -f " + target_file) if rc != 0: print("Failed to uncompress ", target_file, " removing") os.system("rm " + target_file) ##exit with error so that build script will fail raise SystemError else: print("Using cached dataset at ", target_file) def download_mnist(target_directory=None): if target_directory == None: target_directory = create_data_cache_dir() ##create mnst directory target_directory = os.path.join(target_directory, "mnist") try: if not (os.path.exists(target_directory)): os.mkdir(target_directory) except OSError: print("Creation of the directory %s failed" % target_directory) MNIST_URL = "http://yann.lecun.com/exdb/mnist/" files = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] download_and_uncompress(files, MNIST_URL, target_directory, is_tar=False) return target_directory, os.path.join(target_directory, "datasetSchema.json") CIFAR_URL = "https://www.cs.toronto.edu/~kriz/" def download_cifar(target_directory, files, directory_from_tar): if target_directory == None: target_directory = create_data_cache_dir() download_and_uncompress([files], CIFAR_URL, target_directory, is_tar=True) ##if target dir was specify move data from directory created by tar ##and put data into target dir if target_directory != None: tar_dir_full_path = os.path.join(target_directory, directory_from_tar) all_files = os.path.join(tar_dir_full_path, "*") cmd = "mv " + all_files + " " + target_directory if os.path.exists(tar_dir_full_path): print("copy files back to target_directory") print("Executing: ", cmd) rc1 = os.system(cmd) rc2 = os.system("rm -r " + tar_dir_full_path) if rc1 != 0 or rc2 != 0: print("error when running command: ", cmd) download_file = os.path.join(target_directory, files) print("removing " + download_file) os.system("rm " + download_file) ##exit with error so that build script will fail raise SystemError ##change target directory to directory after tar return os.path.join(target_directory, directory_from_tar) def download_cifar10(target_directory=None): return download_cifar(target_directory, "cifar-10-binary.tar.gz", "cifar-10-batches-bin") def download_cifar100(target_directory=None): return download_cifar(target_directory, "cifar-100-binary.tar.gz", "cifar-100-binary") def download_all_for_test(cwd): download_mnist(os.path.join(cwd, "testMnistData")) ##Download all datasets to existing test directories if __name__ == "__main__": download_all_for_test(os.getcwd())