From 132dae4849ce49cbfe3ef7e900dcb323e1b34d09 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Tue, 22 Sep 2020 18:52:58 +0800 Subject: [PATCH] modify transformer hub file --- model_zoo/official/nlp/transformer/mindspore_hub_conf.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/model_zoo/official/nlp/transformer/mindspore_hub_conf.py b/model_zoo/official/nlp/transformer/mindspore_hub_conf.py index a37f0261f3..25bd220d3e 100644 --- a/model_zoo/official/nlp/transformer/mindspore_hub_conf.py +++ b/model_zoo/official/nlp/transformer/mindspore_hub_conf.py @@ -41,8 +41,16 @@ def create_network(name, *args, **kwargs): Create transformer network for large. ''' if name == 'transformer_large': + if "batch_size" in kwargs: + transformer_net_cfg_large.batch_size = kwargs["batch_size"] if "seq_length" in kwargs: transformer_net_cfg_large.seq_length = kwargs["seq_length"] + if "vocab_size" in kwargs: + transformer_net_cfg_large.vocab_size = kwargs["vocab_size"] is_training = kwargs.get("is_training", False) + if not is_training: + transformer_net_cfg_large.batch_size = 1 + transformer_net_cfg_large.hidden_dropout_prob = 0. + transformer_net_cfg_large.attention_probs_dropout_prob = 0. return TransformerModel(transformer_net_cfg_large, is_training, *args) raise NotImplementedError(f"{name} is not implemented in the repo")