From 17d6f1c2f9da65a9a83af34027cef75964251506 Mon Sep 17 00:00:00 2001 From: tronzhang <6517937+tronzhang@user.noreply.gitee.com> Date: Tue, 8 Dec 2020 22:17:03 +0800 Subject: [PATCH] add option for graph kernel and mixed precision --- model_zoo/official/nlp/bert/run_pretrain.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 5bb32913f5..13366f2553 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -91,6 +91,12 @@ def _get_optimizer(args_opt, network): return optimizer +def _auto_enable_graph_kernel(device_target, graph_kernel_mode): + """Judge whether is suitable to enable graph kernel.""" + return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \ + cfg.bert_network == 'base' and cfg.batch_size == 32 and cfg.optimizer == 'AdamWeightDecay' + + def run_pretrain(): """pre-train bert_clue""" parser = argparse.ArgumentParser(description='bert pre_training') @@ -121,6 +127,8 @@ def run_pretrain(): parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") + parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"], + help="Accelerate by graph kernel, default is auto.") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) @@ -145,10 +153,17 @@ def run_pretrain(): rank = 0 device_num = 1 - if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: + is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel) + + if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + + if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \ + not is_auto_enable_graph_kernel: logger.warning('Gpu only support fp32 temporarily, run with fp32.') bert_net_cfg.compute_type = mstype.float32 + if args_opt.accumulation_steps > 1: logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps))