From f708e4b108b8555778438ff583a78584f53b43b0 Mon Sep 17 00:00:00 2001 From: Yanjun Peng Date: Mon, 14 Dec 2020 16:50:09 +0800 Subject: [PATCH] fix xception and textrcnn --- model_zoo/official/cv/xception/train.py | 28 +++++++++---------- .../official/nlp/textrcnn/scripts/run_eval.sh | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/model_zoo/official/cv/xception/train.py b/model_zoo/official/cv/xception/train.py index 3c0e7b43f4..77595aafbf 100644 --- a/model_zoo/official/cv/xception/train.py +++ b/model_zoo/official/cv/xception/train.py @@ -98,24 +98,24 @@ if __name__ == '__main__': parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') args_opt = parser.parse_args() - # init distributed - if args_opt.is_distributed: - if os.getenv('DEVICE_ID', "not_set").isdigit(): - context.set_context(device_id=int(os.getenv('DEVICE_ID'))) - rank = get_rank() - group_size = get_group_size() - parallel_mode = ParallelMode.DATA_PARALLEL - context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) - init() - else: - rank = 0 - group_size = 1 - context.set_context(device_id=0) - if args_opt.device_target == "Ascend": #train on Ascend context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) + # init distributed + if args_opt.is_distributed: + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + init() + rank = get_rank() + group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) + else: + rank = 0 + group_size = 1 + context.set_context(device_id=0) + # define network net = xception(class_num=config.class_num) net.to_float(mstype.float16) diff --git a/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh index 19c38957c4..519f427347 100644 --- a/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh +++ b/model_zoo/official/nlp/textrcnn/scripts/run_eval.sh @@ -17,4 +17,4 @@ ulimit -u unlimited BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH -python ${BASEPATH}/../eval.py > --ckpt_path $1 ./eval.log 2>&1 & +python ${BASEPATH}/../eval.py --ckpt_path $1 > ./eval.log 2>&1 &