From 752035795f8943748db352982f723c52fdec3cd6 Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Wed, 14 Oct 2020 22:10:16 +0800 Subject: [PATCH] update trantransformer scripts --- model_zoo/official/nlp/transformer/src/dataset.py | 10 +++++++--- model_zoo/official/nlp/transformer/train.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/nlp/transformer/src/dataset.py b/model_zoo/official/nlp/transformer/src/dataset.py index ac6ca9479f..75f9e42f67 100644 --- a/model_zoo/official/nlp/transformer/src/dataset.py +++ b/model_zoo/official/nlp/transformer/src/dataset.py @@ -17,10 +17,10 @@ import mindspore.common.dtype as mstype import mindspore.dataset as de import mindspore.dataset.transforms.c_transforms as deC -from .config import transformer_net_cfg +from .config import transformer_net_cfg, transformer_net_cfg_gpu de.config.set_seed(1) def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None, - bucket_boundaries=None): + bucket_boundaries=None, device_target="Ascend"): """create dataset""" def batch_per_bucket(bucket_len, dataset_path): dataset_path = dataset_path + "_" + str(bucket_len) + "_00" @@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask") # apply batch operations - ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) + if device_target == "Ascend": + ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True) + else: + ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True) + ds = ds.repeat(epoch_count) return ds diff --git a/model_zoo/official/nlp/transformer/train.py b/model_zoo/official/nlp/transformer/train.py index 4a59f2db1a..5ac2fa316e 100644 --- a/model_zoo/official/nlp/transformer/train.py +++ b/model_zoo/official/nlp/transformer/train.py @@ -146,7 +146,8 @@ def run_transformer_train(): dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num, rank_id=rank_id, do_shuffle=args.do_shuffle, dataset_path=args.data_path, - bucket_boundaries=args.bucket_boundaries) + bucket_boundaries=args.bucket_boundaries, + device_target=args.device_target) if args.device_target == "Ascend": netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True) else: