|
|
|
@ -17,10 +17,10 @@
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.dataset as de
|
|
|
|
|
import mindspore.dataset.transforms.c_transforms as deC
|
|
|
|
|
from .config import transformer_net_cfg
|
|
|
|
|
from .config import transformer_net_cfg, transformer_net_cfg_gpu
|
|
|
|
|
de.config.set_seed(1)
|
|
|
|
|
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None,
|
|
|
|
|
bucket_boundaries=None):
|
|
|
|
|
bucket_boundaries=None, device_target="Ascend"):
|
|
|
|
|
"""create dataset"""
|
|
|
|
|
def batch_per_bucket(bucket_len, dataset_path):
|
|
|
|
|
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
|
|
|
|
@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
|
|
|
|
|
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
|
|
|
|
|
|
|
|
|
# apply batch operations
|
|
|
|
|
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
|
|
|
|
else:
|
|
|
|
|
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
|
|
ds = ds.repeat(epoch_count)
|
|
|
|
|
return ds
|
|
|
|
|
|
|
|
|
|