|
|
|
@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de
|
|
|
|
|
import mindspore.dataset.transforms.c_transforms as C
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
|
|
|
|
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
|
|
|
|
@ -49,9 +50,9 @@ def create_train_dataset(batch_size):
|
|
|
|
|
"""create train dataset"""
|
|
|
|
|
# apply repeat operations
|
|
|
|
|
repeat_count = bert_train_cfg.epoch_size
|
|
|
|
|
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
|
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
|
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
|
|
|
|
ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
|
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
|
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
|
|
|
|
type_cast_op = C.TypeCast(mstype.int32)
|
|
|
|
|
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
|
|
|
|
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
|
|
|
|