|
|
|
@ -26,7 +26,6 @@ from src.config import cfg
|
|
|
|
|
from src.dataset import create_bert_dataset
|
|
|
|
|
from src.lr_generator import get_bert_lr, get_bert_damping
|
|
|
|
|
from src.model_thor import Model
|
|
|
|
|
from src.thor_for_bert_arg import THOR
|
|
|
|
|
from src.utils import LossCallBack, BertLearningRate
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.communication.management as D
|
|
|
|
@ -66,10 +65,15 @@ def run_pretrain():
|
|
|
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
|
|
|
|
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
from src.thor_for_bert_arg import THOR
|
|
|
|
|
else:
|
|
|
|
|
from src.thor_for_bert import THOR
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
|
|
|
|
device_id=args_opt.device_id, save_graphs=False)
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
context.set_context(variable_memory_max_size="30GB")
|
|
|
|
|
context.set_context(max_call_depth=3000)
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
if args_opt.device_target == 'Ascend':
|
|
|
|
|