You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
348 lines
13 KiB
348 lines
13 KiB
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import math
|
|
import cPickle as pickle
|
|
import random
|
|
import collections
|
|
|
|
def save_file(data, filename):
|
|
"""
|
|
Save data into pickle format.
|
|
data: the data to save.
|
|
filename: the output filename.
|
|
"""
|
|
pickle.dump(data, open(filename, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
def save_list(l, outfile):
|
|
"""
|
|
Save a list of string into a text file. There is one line for each string.
|
|
l: the list of string to save
|
|
outfile: the output file
|
|
"""
|
|
open(outfile, "w").write("\n".join(l))
|
|
|
|
|
|
def exclude_pattern(f):
|
|
"""
|
|
Return whether f is in the exlucde pattern.
|
|
Exclude the files that starts with . or ends with ~.
|
|
"""
|
|
return f.startswith(".") or f.endswith("~")
|
|
|
|
def list_dirs(path):
|
|
"""
|
|
Return a list of directories in path. Exclude all the directories that
|
|
start with '.'.
|
|
path: the base directory to search over.
|
|
"""
|
|
return [os.path.join(path, d) for d in next(os.walk(path))[1] if not exclude_pattern(d)]
|
|
|
|
def list_images(path, exts = set(["jpg", "png", "bmp", "jpeg"])):
|
|
"""
|
|
Return a list of images in path.
|
|
path: the base directory to search over.
|
|
exts: the extensions of the images to find.
|
|
"""
|
|
return [os.path.join(path, d) for d in os.listdir(path) \
|
|
if os.path.isfile(os.path.join(path, d)) and not exclude_pattern(d)\
|
|
and os.path.splitext(d)[-1][1:] in exts]
|
|
|
|
def list_files(path):
|
|
"""
|
|
Return a list of files in path.
|
|
path: the base directory to search over.
|
|
exts: the extensions of the images to find.
|
|
"""
|
|
return [os.path.join(path, d) for d in os.listdir(path) \
|
|
if os.path.isfile(os.path.join(path, d)) and not exclude_pattern(d)]
|
|
|
|
def get_label_set_from_dir(path):
|
|
"""
|
|
Return a dictionary of the labels and label ids from a path.
|
|
Assume each direcotry in the path corresponds to a unique label.
|
|
The keys of the dictionary is the label name.
|
|
The values of the dictionary is the label id.
|
|
"""
|
|
dirs = list_dirs(path)
|
|
return dict([(os.path.basename(d), i) for i, d in enumerate(sorted(dirs))])
|
|
|
|
|
|
class Label:
|
|
"""
|
|
A class of label data.
|
|
"""
|
|
def __init__(self, label, name):
|
|
"""
|
|
label: the id of the label.
|
|
name: the name of the label.
|
|
"""
|
|
self.label = label
|
|
self.name = name
|
|
|
|
def convert_to_paddle_format(self):
|
|
"""
|
|
convert the image into the paddle batch format.
|
|
"""
|
|
return int(self.label)
|
|
|
|
def __hash__(self):
|
|
return hash((self.label))
|
|
|
|
class Dataset:
|
|
"""
|
|
A class to represent a dataset. A dataset contains a set of items.
|
|
Each item contains multiple slots of data.
|
|
For example: in image classification dataset, each item contains two slot,
|
|
The first slot is an image, and the second slot is a label.
|
|
"""
|
|
def __init__(self, data, keys):
|
|
"""
|
|
data: a list of data.
|
|
Each data is a tuple containing multiple slots of data.
|
|
Each slot is an object with convert_to_paddle_format function.
|
|
keys: contains a list of keys for all the slots.
|
|
"""
|
|
self.data = data
|
|
self.keys = keys
|
|
|
|
def check_valid(self):
|
|
for d in self.data:
|
|
assert(len(d) == len(self.keys))
|
|
|
|
def permute(self, key_id, num_per_batch):
|
|
"""
|
|
Permuate data for batching. It supports two types now:
|
|
1. if key_id == None, the batching process is completely random.
|
|
2. if key_id is not None. The batching process Permuate the data so that the key specified by key_id are
|
|
uniformly distributed in batches. See the comments of permute_by_key for details.
|
|
"""
|
|
if key_id is None:
|
|
self.uniform_permute()
|
|
else:
|
|
self.permute_by_key(key_id, num_per_batch)
|
|
|
|
def uniform_permute(self):
|
|
"""
|
|
Permuate the data randomly.
|
|
"""
|
|
random.shuffle(self.data)
|
|
|
|
def permute_by_key(self, key_id, num_per_batch):
|
|
"""
|
|
Permuate the data so that the key specified by key_id are
|
|
uniformly distributed in batches.
|
|
For example: if we have three labels, and the number of data
|
|
for each label are 100, 200, and 300, respectively. The number of batches is 4.
|
|
Then, the number of data for these labels is 25, 50, and 75.
|
|
"""
|
|
# Store the indices of the data that has the key value
|
|
# specified by key_id.
|
|
keyvalue_indices = collections.defaultdict(list)
|
|
for idx in range(len(self.data)):
|
|
keyvalue_indices[self.data[idx][key_id].label].append(idx)
|
|
for k in keyvalue_indices:
|
|
random.shuffle(keyvalue_indices[k])
|
|
|
|
num_data_per_key_batch = \
|
|
math.ceil(num_per_batch / float(len(keyvalue_indices.keys())))
|
|
|
|
if num_data_per_key_batch < 2:
|
|
raise Exception("The number of data in a batch is too small")
|
|
|
|
permuted_data = []
|
|
keyvalue_readpointer = collections.defaultdict(int)
|
|
while len(permuted_data) < len(self.data):
|
|
for k in keyvalue_indices:
|
|
begin_idx = keyvalue_readpointer[k]
|
|
end_idx = int(min(begin_idx + num_data_per_key_batch,
|
|
len(keyvalue_indices[k])))
|
|
print "begin_idx, end_idx"
|
|
print begin_idx, end_idx
|
|
for idx in range(begin_idx, end_idx):
|
|
permuted_data.append(self.data[keyvalue_indices[k][idx]])
|
|
keyvalue_readpointer[k] = end_idx
|
|
self.data = permuted_data
|
|
|
|
|
|
|
|
class DataBatcher:
|
|
"""
|
|
A class that is used to create batches for both training and testing
|
|
datasets.
|
|
"""
|
|
def __init__(self, train_data, test_data, label_set):
|
|
"""
|
|
train_data, test_data: Each one is a dataset object repesenting
|
|
training and testing data, respectively.
|
|
label_set: a dictionary storing the mapping from label name to label id.
|
|
"""
|
|
self.train_data = train_data
|
|
self.test_data = test_data
|
|
self.label_set = label_set
|
|
self.num_per_batch = 5000
|
|
assert(self.train_data.keys == self.test_data.keys)
|
|
|
|
def create_batches_and_list(self, output_path, train_list_name,
|
|
test_list_name, label_set_name):
|
|
"""
|
|
Create batches for both training and testing objects.
|
|
It also create train.list and test.list to indicate the list
|
|
of the batch files for training and testing data, respectively.
|
|
"""
|
|
train_list = self.create_batches(self.train_data, output_path,
|
|
"train_", self.num_per_batch)
|
|
test_list = self.create_batches(self.test_data, output_path, "test_",
|
|
self.num_per_batch)
|
|
save_list(train_list, os.path.join(output_path, train_list_name))
|
|
save_list(test_list, os.path.join(output_path, test_list_name))
|
|
save_file(self.label_set, os.path.join(output_path, label_set_name))
|
|
|
|
def create_batches(self, data, output_path,
|
|
prefix = "", num_data_per_batch=5000):
|
|
"""
|
|
Create batches for a Dataset object.
|
|
data: the Dataset object to process.
|
|
output_path: the output path of the batches.
|
|
prefix: the prefix of each batch.
|
|
num_data_per_batch: number of data in each batch.
|
|
"""
|
|
num_batches = int(math.ceil(len(data.data) / float(num_data_per_batch)))
|
|
batch_names = []
|
|
data.check_valid()
|
|
num_slots = len(data.keys)
|
|
for i in range(num_batches):
|
|
batch_name = os.path.join(output_path, prefix + "batch_%03d" % i)
|
|
out_data = dict([(k, []) for k in data.keys])
|
|
begin_idx = i * num_data_per_batch
|
|
end_idx = min((i + 1) * num_data_per_batch, len(data.data))
|
|
for j in range(begin_idx, end_idx):
|
|
for slot_id in range(num_slots):
|
|
out_data[data.keys[slot_id]].\
|
|
append(data.data[j][slot_id].convert_to_paddle_format())
|
|
save_file(out_data, batch_name)
|
|
batch_names.append(batch_name)
|
|
return batch_names
|
|
|
|
|
|
class DatasetCreater(object):
|
|
"""
|
|
A virtual class for creating datasets.
|
|
The derived clasas needs to implemnt the following methods:
|
|
- create_dataset()
|
|
- create_meta_file()
|
|
"""
|
|
def __init__(self, data_path):
|
|
"""
|
|
data_path: the path to store the training data and batches.
|
|
train_dir_name: relative training data directory.
|
|
test_dir_name: relative testing data directory.
|
|
batch_dir_name: relative batch directory.
|
|
num_per_batch: the number of data in a batch.
|
|
meta_filename: the filename of the meta file.
|
|
train_list_name: training batch list name.
|
|
test_list_name: testing batch list name.
|
|
label_set: label set name.
|
|
overwrite: whether to overwrite the files if the batches are already in
|
|
the given path.
|
|
"""
|
|
self.data_path = data_path
|
|
self.train_dir_name = 'train'
|
|
self.test_dir_name = 'test'
|
|
self.batch_dir_name = 'batches'
|
|
self.num_per_batch = 50000
|
|
self.meta_filename = "batches.meta"
|
|
self.train_list_name = "train.list"
|
|
self.test_list_name = "test.list"
|
|
self.label_set_name = "labels.pkl"
|
|
self.output_path = os.path.join(self.data_path, self.batch_dir_name)
|
|
self.overwrite = False
|
|
self.permutate_key = "labels"
|
|
self.from_list = False
|
|
|
|
def create_meta_file(self, data):
|
|
"""
|
|
Create a meta file from training data.
|
|
data: training data given in a Dataset format.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def create_dataset(self, path):
|
|
"""
|
|
Create a data set object from a path.
|
|
It will use directory structure or a file list to determine dataset if
|
|
self.from_list is True. Otherwise, it will uses a file list to
|
|
determine the datset.
|
|
path: the path of the dataset.
|
|
return a tuple of Dataset object, and a mapping from lable set
|
|
to label id.
|
|
"""
|
|
if self.from_list:
|
|
return self.create_dataset_from_list(path)
|
|
else:
|
|
return self.create_dataset_from_dir(path)
|
|
|
|
def create_dataset_from_list(self, path):
|
|
"""
|
|
Create a data set object from a path.
|
|
It will uses a file list to determine the datset.
|
|
path: the path of the dataset.
|
|
return a tuple of Dataset object, and a mapping from lable set
|
|
to label id
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def create_dataset_from_dir(self, path):
|
|
"""
|
|
Create a data set object from a path.
|
|
It will use directory structure or a file list to determine dataset if
|
|
self.from_list is True.
|
|
path: the path of the dataset.
|
|
return a tuple of Dataset object, and a mapping from lable set
|
|
to label id
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def create_batches(self):
|
|
"""
|
|
create batches and meta file.
|
|
"""
|
|
train_path = os.path.join(self.data_path, self.train_dir_name)
|
|
test_path = os.path.join(self.data_path, self.test_dir_name)
|
|
out_path = os.path.join(self.data_path, self.batch_dir_name)
|
|
if not os.path.exists(out_path):
|
|
os.makedirs(out_path)
|
|
if (self.overwrite or
|
|
not os.path.exists(os.path.join(out_path, self.train_list_name))):
|
|
train_data, train_label_set = \
|
|
self.create_dataset(train_path)
|
|
test_data, test_label_set = \
|
|
self.create_dataset(test_path)
|
|
|
|
train_data.permute(self.keys.index(self.permutate_key),
|
|
self.num_per_batch)
|
|
|
|
assert(train_label_set == test_label_set)
|
|
data_batcher = DataBatcher(train_data, test_data,
|
|
train_label_set)
|
|
data_batcher.num_per_batch = self.num_per_batch
|
|
data_batcher.create_batches_and_list(self.output_path,
|
|
self.train_list_name,
|
|
self.test_list_name,
|
|
self.label_set_name)
|
|
self.num_classes = len(train_label_set.keys())
|
|
self.create_meta_file(train_data)
|
|
return out_path
|