diff --git a/example/Bert_NEZHA_cnwiki/train.py b/example/Bert_NEZHA_cnwiki/train.py index 86e033fc9f..2610542a9a 100644 --- a/example/Bert_NEZHA_cnwiki/train.py +++ b/example/Bert_NEZHA_cnwiki/train.py @@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C from mindspore import context from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from mindspore.train.model import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell @@ -49,9 +50,9 @@ def create_train_dataset(batch_size): """create train dataset""" # apply repeat operations repeat_count = bert_train_cfg.epoch_size - ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, - columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", - "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) + ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)