|
|
|
@ -27,17 +27,13 @@ class SimpleDataSet(Dataset):
|
|
|
|
|
global_config = config['Global']
|
|
|
|
|
dataset_config = config[mode]['dataset']
|
|
|
|
|
loader_config = config[mode]['loader']
|
|
|
|
|
if 'data_num_per_epoch' in loader_config.keys():
|
|
|
|
|
data_num_per_epoch = loader_config['data_num_per_epoch']
|
|
|
|
|
else:
|
|
|
|
|
data_num_per_epoch = None
|
|
|
|
|
|
|
|
|
|
self.delimiter = dataset_config.get('delimiter', '\t')
|
|
|
|
|
label_file_list = dataset_config.pop('label_file_list')
|
|
|
|
|
data_source_num = len(label_file_list)
|
|
|
|
|
ratio_list = dataset_config.get("ratio_list", [1.0])
|
|
|
|
|
if isinstance(ratio_list, (float, int)):
|
|
|
|
|
ratio_list = [float(ratio_list)] * len(data_source_num)
|
|
|
|
|
ratio_list = [float(ratio_list)] * int(data_source_num)
|
|
|
|
|
|
|
|
|
|
assert len(
|
|
|
|
|
ratio_list
|
|
|
|
@ -46,34 +42,26 @@ class SimpleDataSet(Dataset):
|
|
|
|
|
self.do_shuffle = loader_config['shuffle']
|
|
|
|
|
|
|
|
|
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
|
|
|
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
|
|
|
|
data_num_per_epoch)
|
|
|
|
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
|
|
|
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
|
|
|
|
if mode.lower() == "train":
|
|
|
|
|
self.shuffle_data_random()
|
|
|
|
|
self.ops = create_operators(dataset_config['transforms'], global_config)
|
|
|
|
|
|
|
|
|
|
def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None):
|
|
|
|
|
def _sample_dataset(self, datas, sample_ratio):
|
|
|
|
|
sample_num = round(len(datas) * sample_ratio)
|
|
|
|
|
|
|
|
|
|
if data_num_per_epoch is not None:
|
|
|
|
|
sample_num = int(data_num_per_epoch * sample_ratio)
|
|
|
|
|
|
|
|
|
|
nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas))
|
|
|
|
|
return list(datas) * nums + random.sample(datas, rem)
|
|
|
|
|
|
|
|
|
|
def get_image_info_list(self,
|
|
|
|
|
file_list,
|
|
|
|
|
ratio_list,
|
|
|
|
|
data_num_per_epoch=None):
|
|
|
|
|
def get_image_info_list(self, file_list, ratio_list):
|
|
|
|
|
if isinstance(file_list, str):
|
|
|
|
|
file_list = [file_list]
|
|
|
|
|
data_lines = []
|
|
|
|
|
for idx, file in enumerate(file_list):
|
|
|
|
|
with open(file, "rb") as f:
|
|
|
|
|
lines = f.readlines()
|
|
|
|
|
lines = self._sample_dataset(lines, ratio_list[idx],
|
|
|
|
|
data_num_per_epoch)
|
|
|
|
|
lines = self._sample_dataset(lines, ratio_list[idx])
|
|
|
|
|
data_lines.extend(lines)
|
|
|
|
|
return data_lines
|
|
|
|
|
|
|
|
|
|