|
|
|
@ -23,6 +23,7 @@ import pandas as pd
|
|
|
|
|
import mindspore.dataset.engine as de
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataType(Enum):
|
|
|
|
|
"""
|
|
|
|
|
Enumerate supported dataset format.
|
|
|
|
@ -169,8 +170,41 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _padding_func(batch_size, manual_shape, target_column, field_size=39):
|
|
|
|
|
"""
|
|
|
|
|
get padding_func
|
|
|
|
|
"""
|
|
|
|
|
if manual_shape:
|
|
|
|
|
generate_concat_offset = [item[0]+item[1] for item in manual_shape]
|
|
|
|
|
part_size = int(target_column / len(generate_concat_offset))
|
|
|
|
|
filled_value = []
|
|
|
|
|
for i in range(field_size, target_column):
|
|
|
|
|
filled_value.append(generate_concat_offset[i//part_size]-1)
|
|
|
|
|
print("Filed Value:", filled_value)
|
|
|
|
|
|
|
|
|
|
def padding_func(x, y, z):
|
|
|
|
|
x = np.array(x).flatten().reshape(batch_size, field_size)
|
|
|
|
|
y = np.array(y).flatten().reshape(batch_size, field_size)
|
|
|
|
|
z = np.array(z).flatten().reshape(batch_size, 1)
|
|
|
|
|
|
|
|
|
|
x_id = np.ones((batch_size, target_column - field_size),
|
|
|
|
|
dtype=np.int32) * filled_value
|
|
|
|
|
x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1)
|
|
|
|
|
mask = np.concatenate(
|
|
|
|
|
[y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1)
|
|
|
|
|
return (x_id, mask, z)
|
|
|
|
|
else:
|
|
|
|
|
def padding_func(x, y, z):
|
|
|
|
|
x = np.array(x).flatten().reshape(batch_size, field_size)
|
|
|
|
|
y = np.array(y).flatten().reshape(batch_size, field_size)
|
|
|
|
|
z = np.array(z).flatten().reshape(batch_size, 1)
|
|
|
|
|
return (x, y, z)
|
|
|
|
|
return padding_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
|
|
|
|
line_per_sample=1000, rank_size=None, rank_id=None):
|
|
|
|
|
line_per_sample=1000, rank_size=None, rank_id=None,
|
|
|
|
|
manual_shape=None, target_column=40):
|
|
|
|
|
"""
|
|
|
|
|
get_tf_dataset
|
|
|
|
|
"""
|
|
|
|
@ -189,21 +223,22 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
|
|
|
|
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8,
|
|
|
|
|
num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True)
|
|
|
|
|
else:
|
|
|
|
|
ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8)
|
|
|
|
|
ds = de.TFRecordDataset(dataset_files=dataset_files,
|
|
|
|
|
shuffle=shuffle, schema=schema, num_parallel_workers=8)
|
|
|
|
|
ds = ds.batch(int(batch_size / line_per_sample),
|
|
|
|
|
drop_remainder=True)
|
|
|
|
|
ds = ds.map(operations=(lambda x, y, z: (
|
|
|
|
|
np.array(x).flatten().reshape(batch_size, 39),
|
|
|
|
|
np.array(y).flatten().reshape(batch_size, 39),
|
|
|
|
|
np.array(z).flatten().reshape(batch_size, 1))),
|
|
|
|
|
|
|
|
|
|
ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column),
|
|
|
|
|
input_columns=['feat_ids', 'feat_vals', 'label'],
|
|
|
|
|
columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8)
|
|
|
|
|
#if train_mode:
|
|
|
|
|
# if train_mode:
|
|
|
|
|
ds = ds.repeat(epochs)
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000,
|
|
|
|
|
line_per_sample=1000, rank_size=None, rank_id=None):
|
|
|
|
|
line_per_sample=1000, rank_size=None, rank_id=None,
|
|
|
|
|
manual_shape=None, target_column=40):
|
|
|
|
|
"""
|
|
|
|
|
Get dataset with mindrecord format.
|
|
|
|
|
|
|
|
|
@ -233,9 +268,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
|
|
|
|
|
columns_list=['feat_ids', 'feat_vals', 'label'],
|
|
|
|
|
shuffle=shuffle, num_parallel_workers=8)
|
|
|
|
|
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
|
|
|
|
ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39),
|
|
|
|
|
np.array(y).flatten().reshape(batch_size, 39),
|
|
|
|
|
np.array(z).flatten().reshape(batch_size, 1))),
|
|
|
|
|
ds = ds.map(_padding_func(batch_size, manual_shape, target_column),
|
|
|
|
|
input_columns=['feat_ids', 'feat_vals', 'label'],
|
|
|
|
|
columns_order=['feat_ids', 'feat_vals', 'label'],
|
|
|
|
|
num_parallel_workers=8)
|
|
|
|
@ -243,18 +276,84 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None):
|
|
|
|
|
"""
|
|
|
|
|
get_vocab_size
|
|
|
|
|
"""
|
|
|
|
|
# Only 39
|
|
|
|
|
inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15,
|
|
|
|
|
10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610,
|
|
|
|
|
5, 21762, 14, 15, 15030, 61, 12220]
|
|
|
|
|
|
|
|
|
|
new_vocabs = inidival_vocabs + [1] * \
|
|
|
|
|
(target_column_number - len(inidival_vocabs))
|
|
|
|
|
part_size = int(target_column_number / worker_size)
|
|
|
|
|
|
|
|
|
|
# According to the workers, we merge some fields into the same part
|
|
|
|
|
new_vocab_size = []
|
|
|
|
|
for i in range(0, target_column_number, part_size):
|
|
|
|
|
new_vocab_size.append(sum(new_vocabs[i: i + part_size]))
|
|
|
|
|
|
|
|
|
|
index_offsets = [0]
|
|
|
|
|
|
|
|
|
|
# The gold feature numbers ared used to caculate the offset
|
|
|
|
|
features = [item for item in new_vocab_size]
|
|
|
|
|
|
|
|
|
|
# According to the per_vocab_size, maxize the vocab size
|
|
|
|
|
if per_vocab_size is not None:
|
|
|
|
|
new_vocab_size = [per_vocab_size] * worker_size
|
|
|
|
|
else:
|
|
|
|
|
# Expands the vocabulary of each field by the multiplier
|
|
|
|
|
if multiply is True:
|
|
|
|
|
cur_sum = sum(new_vocab_size)
|
|
|
|
|
k = total_vocab_size/cur_sum
|
|
|
|
|
new_vocab_size = [
|
|
|
|
|
math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size]
|
|
|
|
|
new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size]
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
if total_vocab_size > sum(new_vocab_size):
|
|
|
|
|
new_vocab_size[-1] = total_vocab_size - \
|
|
|
|
|
sum(new_vocab_size[:-1])
|
|
|
|
|
new_vocab_size = [item for item in new_vocab_size]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Please providede the correct vocab size, now is {}".format(total_vocab_size))
|
|
|
|
|
|
|
|
|
|
for i in range(worker_size-1):
|
|
|
|
|
off = index_offsets[i] + features[i]
|
|
|
|
|
index_offsets.append(off)
|
|
|
|
|
|
|
|
|
|
print("the offset: ", index_offsets)
|
|
|
|
|
manual_shape = tuple(
|
|
|
|
|
((new_vocab_size[i], index_offsets[i]) for i in range(worker_size)))
|
|
|
|
|
vocab_total = sum(new_vocab_size)
|
|
|
|
|
return manual_shape, vocab_total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_manual_shape(config, worker_size):
|
|
|
|
|
target_column = (config.field_size // worker_size + 1) * worker_size
|
|
|
|
|
config.field_size = target_column
|
|
|
|
|
manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size,
|
|
|
|
|
per_vocab_size=None, multiply=False)
|
|
|
|
|
config.manual_shape = manual_shape
|
|
|
|
|
config.vocab_size = int(vocab_total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000,
|
|
|
|
|
data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None):
|
|
|
|
|
data_type=DataType.TFRECORD, line_per_sample=1000,
|
|
|
|
|
rank_size=None, rank_id=None, manual_shape=None, target_column=40):
|
|
|
|
|
"""
|
|
|
|
|
create_dataset
|
|
|
|
|
"""
|
|
|
|
|
if data_type == DataType.TFRECORD:
|
|
|
|
|
return _get_tf_dataset(data_dir, train_mode, epochs, batch_size,
|
|
|
|
|
line_per_sample, rank_size=rank_size, rank_id=rank_id)
|
|
|
|
|
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
|
|
|
|
manual_shape=manual_shape, target_column=target_column)
|
|
|
|
|
if data_type == DataType.MINDRECORD:
|
|
|
|
|
return _get_mindrecord_dataset(data_dir, train_mode, epochs,
|
|
|
|
|
batch_size, line_per_sample,
|
|
|
|
|
rank_size, rank_id)
|
|
|
|
|
return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size,
|
|
|
|
|
line_per_sample, rank_size=rank_size, rank_id=rank_id,
|
|
|
|
|
manual_shape=manual_shape, target_column=target_column)
|
|
|
|
|
|
|
|
|
|
if rank_size > 1:
|
|
|
|
|
raise RuntimeError("please use tfrecord dataset.")
|
|
|
|
|