diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index 7faca2e3bd..e5e8c8b49f 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -57,6 +57,7 @@ large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which Functional Relative Posetional Encoding as an effective positional encoding scheme). ''' if cfg.bert_network == 'base': + cfg.batch_size = 64 bert_net_cfg = BertConfig( seq_length=128, vocab_size=21128, @@ -75,6 +76,7 @@ if cfg.bert_network == 'base': compute_type=mstype.float16 ) if cfg.bert_network == 'nezha': + cfg.batch_size = 96 bert_net_cfg = BertConfig( seq_length=128, vocab_size=21128, @@ -93,6 +95,7 @@ if cfg.bert_network == 'nezha': compute_type=mstype.float16 ) if cfg.bert_network == 'large': + cfg.batch_size = 24 bert_net_cfg = BertConfig( seq_length=512, vocab_size=30522,