From 39b446524d8a3825d7aecb90b4e3751408b3c64c Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Tue, 14 Apr 2020 10:18:31 +0800 Subject: [PATCH] fix bugs in bert example script --- example/Bert_NEZHA_cnwiki/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/example/Bert_NEZHA_cnwiki/train.py b/example/Bert_NEZHA_cnwiki/train.py index 86e033fc9f..2610542a9a 100644 --- a/example/Bert_NEZHA_cnwiki/train.py +++ b/example/Bert_NEZHA_cnwiki/train.py @@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C from mindspore import context from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from mindspore.train.model import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell @@ -49,9 +50,9 @@ def create_train_dataset(batch_size): """create train dataset""" # apply repeat operations repeat_count = bert_train_cfg.epoch_size - ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, - columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", - "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) + ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)