From 1be7ad52bb877c1286963a3c164ab8923bc41477 Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Mon, 8 Jun 2020 19:23:36 +0800 Subject: [PATCH] fix bert scripts --- model_zoo/bert/run_pretrain.py | 2 ++ model_zoo/bert/src/dataset.py | 3 ++- tests/st/networks/models/bert/src/bert_for_pre_training.py | 3 +-- tests/st/networks/models/bert/src/config.py | 6 +++--- tests/st/networks/models/bert/src/dataset.py | 3 ++- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py index 01e7a24753..ab3d7d63ba 100644 --- a/model_zoo/bert/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -19,6 +19,7 @@ python run_pretrain.py import os import argparse +import numpy import mindspore.communication.management as D from mindspore import context from mindspore.train.model import Model @@ -142,4 +143,5 @@ def run_pretrain(): model = Model(netwithgrads) model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) if __name__ == '__main__': + numpy.random.seed(0) run_pretrain() diff --git a/model_zoo/bert/src/dataset.py b/model_zoo/bert/src/dataset.py index 1828fac454..7985ca8559 100644 --- a/model_zoo/bert/src/dataset.py +++ b/model_zoo/bert/src/dataset.py @@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, shard_equal_rows=True) ori_dataset_size = ds.get_dataset_size() + print('origin dataset size: ', ori_dataset_size) new_size = ori_dataset_size if enable_data_sink == "true": new_size = data_sink_steps * bert_net_cfg.batch_size @@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = ds.map(input_columns="input_ids", operations=type_cast_op) # apply batch operations ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) - ds = ds.repeat(new_repeat_count) + ds = ds.repeat(max(new_repeat_count, repeat_count)) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeatcount: {}".format(ds.get_repeat_count())) return ds, new_repeat_count diff --git a/tests/st/networks/models/bert/src/bert_for_pre_training.py b/tests/st/networks/models/bert/src/bert_for_pre_training.py index 600512b4a7..976f1a3c43 100644 --- a/tests/st/networks/models/bert/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/src/bert_for_pre_training.py @@ -32,7 +32,6 @@ from .bert_model import BertModel GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 -_nn_clip_by_norm = nn.ClipByNorm() clip_grad = C.MultitypeFuncGraph("clip_grad") @@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad): new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), F.cast(F.tuple_to_array((clip_value,)), dt)) else: - new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) return new_grad diff --git a/tests/st/networks/models/bert/src/config.py b/tests/st/networks/models/bert/src/config.py index d1062b78ee..812f0c2f18 100644 --- a/tests/st/networks/models/bert/src/config.py +++ b/tests/st/networks/models/bert/src/config.py @@ -56,7 +56,7 @@ if cfg.bert_network == 'base': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21136, + vocab_size=21128, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha': bert_net_cfg = BertConfig( batch_size=32, seq_length=128, - vocab_size=21136, + vocab_size=21128, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, @@ -98,7 +98,7 @@ if cfg.bert_network == 'large': bert_net_cfg = BertConfig( batch_size=16, seq_length=512, - vocab_size=30528, + vocab_size=30522, hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, diff --git a/tests/st/networks/models/bert/src/dataset.py b/tests/st/networks/models/bert/src/dataset.py index 1828fac454..7985ca8559 100644 --- a/tests/st/networks/models/bert/src/dataset.py +++ b/tests/st/networks/models/bert/src/dataset.py @@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, shard_equal_rows=True) ori_dataset_size = ds.get_dataset_size() + print('origin dataset size: ', ori_dataset_size) new_size = ori_dataset_size if enable_data_sink == "true": new_size = data_sink_steps * bert_net_cfg.batch_size @@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = ds.map(input_columns="input_ids", operations=type_cast_op) # apply batch operations ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) - ds = ds.repeat(new_repeat_count) + ds = ds.repeat(max(new_repeat_count, repeat_count)) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeatcount: {}".format(ds.get_repeat_count())) return ds, new_repeat_count