|
|
@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
|
|
|
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
|
|
|
shard_equal_rows=True)
|
|
|
|
shard_equal_rows=True)
|
|
|
|
ori_dataset_size = ds.get_dataset_size()
|
|
|
|
ori_dataset_size = ds.get_dataset_size()
|
|
|
|
|
|
|
|
print('origin dataset size: ', ori_dataset_size)
|
|
|
|
new_size = ori_dataset_size
|
|
|
|
new_size = ori_dataset_size
|
|
|
|
if enable_data_sink == "true":
|
|
|
|
if enable_data_sink == "true":
|
|
|
|
new_size = data_sink_steps * bert_net_cfg.batch_size
|
|
|
|
new_size = data_sink_steps * bert_net_cfg.batch_size
|
|
|
@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
|
|
|
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
|
|
|
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
|
|
|
# apply batch operations
|
|
|
|
# apply batch operations
|
|
|
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
|
|
|
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
|
|
|
ds = ds.repeat(new_repeat_count)
|
|
|
|
ds = ds.repeat(max(new_repeat_count, repeat_count))
|
|
|
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
|
|
|
logger.info("data size: {}".format(ds.get_dataset_size()))
|
|
|
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
|
|
|
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
|
|
|
|
return ds, new_repeat_count
|
|
|
|
return ds, new_repeat_count
|
|
|
|