!6831 fix RuntimeError in TinyBert

Merge pull request !6831 from yanghaitao/yht_tiny_bert
pull/6831/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ab30741cdd

@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
else:
columns_list = ["input_ids", "input_mask", "segment_ids"]
shard_equal_rows = True
shuffle = (do_shuffle == "true")
if device_num == 1:
shard_equal_rows = False
shuffle = False
if data_type == DataType.MINDRECORD:
ds = de.MindDataset(data_files, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
else:
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
shuffle=shuffle, num_shards=device_num, shard_id=rank,
shard_equal_rows=shard_equal_rows)
if device_num == 1 and shuffle is True:
ds = ds.shuffle(10000)
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="segment_ids")

Loading…
Cancel
Save