|
|
@ -67,11 +67,13 @@ def create_network(name, *args, **kwargs):
|
|
|
|
bert_net_cfg_base.batch_size = kwargs["batch_size"]
|
|
|
|
bert_net_cfg_base.batch_size = kwargs["batch_size"]
|
|
|
|
if "seq_length" in kwargs:
|
|
|
|
if "seq_length" in kwargs:
|
|
|
|
bert_net_cfg_base.seq_length = kwargs["seq_length"]
|
|
|
|
bert_net_cfg_base.seq_length = kwargs["seq_length"]
|
|
|
|
return BertModel(bert_net_cfg_base, *args)
|
|
|
|
is_training = kwargs.get("is_training", default=False)
|
|
|
|
|
|
|
|
return BertModel(bert_net_cfg_base, is_training, *args)
|
|
|
|
if name == 'bert_nezha':
|
|
|
|
if name == 'bert_nezha':
|
|
|
|
if "batch_size" in kwargs:
|
|
|
|
if "batch_size" in kwargs:
|
|
|
|
bert_net_cfg_nezha.batch_size = kwargs["batch_size"]
|
|
|
|
bert_net_cfg_nezha.batch_size = kwargs["batch_size"]
|
|
|
|
if "seq_length" in kwargs:
|
|
|
|
if "seq_length" in kwargs:
|
|
|
|
bert_net_cfg_nezha.seq_length = kwargs["seq_length"]
|
|
|
|
bert_net_cfg_nezha.seq_length = kwargs["seq_length"]
|
|
|
|
return BertModel(bert_net_cfg_nezha, *args)
|
|
|
|
is_training = kwargs.get("is_training", default=False)
|
|
|
|
|
|
|
|
return BertModel(bert_net_cfg_nezha, is_training, *args)
|
|
|
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
|
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
|
|