From a4de9ba0eb8f168c5fcdd8e168088c6b77d1bd58 Mon Sep 17 00:00:00 2001 From: yanghaitao1 Date: Wed, 23 Sep 2020 21:33:06 -0400 Subject: [PATCH] fix tinybert failes if run 1p --- model_zoo/official/nlp/tinybert/src/dataset.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/src/dataset.py b/model_zoo/official/nlp/tinybert/src/dataset.py index 5829846043..e4e065a9f1 100644 --- a/model_zoo/official/nlp/tinybert/src/dataset.py +++ b/model_zoo/official/nlp/tinybert/src/dataset.py @@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0, else: columns_list = ["input_ids", "input_mask", "segment_ids"] + shard_equal_rows = True + shuffle = (do_shuffle == "true") + if device_num == 1: + shard_equal_rows = False + shuffle = False + if data_type == DataType.MINDRECORD: ds = de.MindDataset(data_files, columns_list=columns_list, shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank) else: ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list, - shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, - shard_equal_rows=True) + shuffle=shuffle, num_shards=device_num, shard_id=rank, + shard_equal_rows=shard_equal_rows) + if device_num == 1 and shuffle is True: + ds = ds.shuffle(10000) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(operations=type_cast_op, input_columns="segment_ids")