From c6d261b27787bf605d473b2d41466950ee5b3c69 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Mon, 30 Mar 2020 16:21:57 +0800 Subject: [PATCH] add bert script to master --- example/Bert_NEZHA/config.py | 55 ---------------- example/Bert_NEZHA_cnwiki/config.py | 57 ++++++++++++++++ .../main.py => Bert_NEZHA_cnwiki/train.py} | 66 +++++++------------ 3 files changed, 82 insertions(+), 96 deletions(-) delete mode 100644 example/Bert_NEZHA/config.py create mode 100644 example/Bert_NEZHA_cnwiki/config.py rename example/{Bert_NEZHA/main.py => Bert_NEZHA_cnwiki/train.py} (57%) diff --git a/example/Bert_NEZHA/config.py b/example/Bert_NEZHA/config.py deleted file mode 100644 index 2f3b22fe50..0000000000 --- a/example/Bert_NEZHA/config.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -network config setting, will be used in main.py -""" - -from easydict import EasyDict as edict -import mindspore.common.dtype as mstype -from mindspore.model_zoo.Bert_NEZHA import BertConfig -bert_cfg = edict({ - 'epoch_size': 10, - 'num_warmup_steps': 0, - 'start_learning_rate': 1e-4, - 'end_learning_rate': 1, - 'decay_steps': 1000, - 'power': 10.0, - 'save_checkpoint_steps': 2000, - 'keep_checkpoint_max': 10, - 'checkpoint_prefix': "checkpoint_bert", - 'DATA_DIR' = "/your/path/examples.tfrecord" - 'SCHEMA_DIR' = "/your/path/datasetSchema.json" - 'bert_config': BertConfig( - batch_size=16, - seq_length=128, - vocab_size=21136, - hidden_size=1024, - num_hidden_layers=24, - num_attention_heads=16, - intermediate_size=4096, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=True, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, - ) -}) diff --git a/example/Bert_NEZHA_cnwiki/config.py b/example/Bert_NEZHA_cnwiki/config.py new file mode 100644 index 0000000000..a704d9a264 --- /dev/null +++ b/example/Bert_NEZHA_cnwiki/config.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +network config setting, will be used in train.py +""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from mindspore.model_zoo.Bert_NEZHA import BertConfig +bert_train_cfg = edict({ + 'epoch_size': 10, + 'num_warmup_steps': 0, + 'start_learning_rate': 1e-4, + 'end_learning_rate': 0.0, + 'decay_steps': 1000, + 'power': 10.0, + 'save_checkpoint_steps': 2000, + 'keep_checkpoint_max': 10, + 'checkpoint_prefix': "checkpoint_bert", + # please add your own dataset path + 'DATA_DIR': "/your/path/examples.tfrecord", + # please add your own dataset schema path + 'SCHEMA_DIR': "/your/path/datasetSchema.json" +}) +bert_net_cfg = BertConfig( + batch_size=16, + seq_length=128, + vocab_size=21136, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) diff --git a/example/Bert_NEZHA/main.py b/example/Bert_NEZHA_cnwiki/train.py similarity index 57% rename from example/Bert_NEZHA/main.py rename to example/Bert_NEZHA_cnwiki/train.py index a5500f25a9..87f425e21c 100644 --- a/example/Bert_NEZHA/main.py +++ b/example/Bert_NEZHA_cnwiki/train.py @@ -14,7 +14,8 @@ # ============================================================================ """ -NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language model currently based on BERT developed by Huawei. +NEZHA (NEural contextualiZed representation for CHinese lAnguage understanding) is the Chinese pretrained language +model currently based on BERT developed by Huawei. 1. Prepare data Following the data preparation as in BERT, run command as below to get dataset for training: python ./create_pretraining_data.py \ @@ -28,35 +29,29 @@ Following the data preparation as in BERT, run command as below to get dataset f --random_seed=12345 \ --dupe_factor=5 2. Pretrain -First, prepare the distributed training environment, then adjust configurations in config.py, finally run main.py. +First, prepare the distributed training environment, then adjust configurations in config.py, finally run train.py. """ import os -import pytest import numpy as np -from numpy import allclose -from config import bert_cfg as cfg -import mindspore.common.dtype as mstype +from config import bert_train_cfg, bert_net_cfg import mindspore.dataset.engine.datasets as de import mindspore._c_dataengine as deMap from mindspore import context from mindspore.common.tensor import Tensor from mindspore.train.model import Model -from mindspore.train.callback import Callback -from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell from mindspore.nn.optim import Lamb -from mindspore import log as logger _current_dir = os.path.dirname(os.path.realpath(__file__)) -DATA_DIR = [cfg.DATA_DIR] -SCHEMA_DIR = cfg.SCHEMA_DIR -def me_de_train_dataset(batch_size): - """test me de train dataset""" +def create_train_dataset(batch_size): + """create train dataset""" # apply repeat operations - repeat_count = cfg.epoch_size - ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", - "next_sentence_labels", "masked_lm_positions", - "masked_lm_ids", "masked_lm_weights"]) + repeat_count = bert_train_cfg.epoch_size + ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, + columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", + "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) type_cast_op = deMap.TypeCastOp("int32") ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) @@ -69,43 +64,32 @@ def me_de_train_dataset(batch_size): ds = ds.repeat(repeat_count) return ds - def weight_variable(shape): """weight variable""" np.random.seed(1) ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) return Tensor(ones) - -class ModelCallback(Callback): - def __init__(self): - super(ModelCallback, self).__init__() - self.loss_list = [] - - def step_end(self, run_context): - cb_params = run_context.original_args() - self.loss_list.append(cb_params.net_outputs.asnumpy()[0]) - logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) - -def test_bert_tdt(): - """test bert tdt""" +def train_bert(): + """train bert""" context.set_context(mode=context.GRAPH_MODE) context.set_context(device_target="Ascend") context.set_context(enable_task_sink=True) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) - parallel_callback = ModelCallback() - ds = me_de_train_dataset(cfg.bert_config.batch_size) - config = cfg.bert_config - netwithloss = BertNetworkWithLoss(config, True) - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=cfg.decay_steps, start_learning_rate=cfg.start_learning_rate, - end_learning_rate=cfg.end_learning_rate, power=cfg.power, warmup_steps=cfg.num_warmup_steps, decay_filter=lambda x: False) + ds = create_train_dataset(bert_net_cfg.batch_size) + netwithloss = BertNetworkWithLoss(bert_net_cfg, True) + optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps, + start_learning_rate=bert_train_cfg.start_learning_rate, + end_learning_rate=bert_train_cfg.end_learning_rate, power=bert_train_cfg.power, + warmup_steps=bert_train_cfg.num_warmup_steps, decay_filter=lambda x: False) netwithgrads = BertTrainOneStepCell(netwithloss, optimizer=optimizer) netwithgrads.set_train(True) model = Model(netwithgrads) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix=cfg.checkpoint_prefix, config=config_ck) - model.train(ds.get_repeat_count(), ds, callbacks=[parallel_callback, ckpoint_cb], dataset_sink_mode=False) + config_ck = CheckpointConfig(save_checkpoint_steps=bert_train_cfg.save_checkpoint_steps, + keep_checkpoint_max=bert_train_cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=bert_train_cfg.checkpoint_prefix, config=config_ck) + model.train(ds.get_repeat_count(), ds, callbacks=[LossMonitor(), ckpoint_cb], dataset_sink_mode=False) if __name__ == '__main__': - test_bert_tdt() + train_bert()