From 6fbb124b1d050e8218ec4b284b60721fb7556e60 Mon Sep 17 00:00:00 2001 From: zhouyaqiang Date: Thu, 3 Dec 2020 20:21:37 +0800 Subject: [PATCH] fix hccl init of resnext50 --- model_zoo/official/cv/resnext50/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py index 27e782da9c..7a7aaa6157 100644 --- a/model_zoo/official/cv/resnext50/train.py +++ b/model_zoo/official/cv/resnext50/train.py @@ -147,6 +147,8 @@ def parse_args(cloud_args=None): args.lr_epochs = list(map(int, args.lr_epochs.split(','))) args.image_size = list(map(int, args.image_size.split(','))) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) # init distributed if args.is_distributed: init() @@ -190,8 +192,6 @@ def merge_args(args, cloud_args): def train(cloud_args=None): """training process""" args = parse_args(cloud_args) - context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target=args.platform, save_graphs=False) if os.getenv('DEVICE_ID', "not_set").isdigit(): context.set_context(device_id=int(os.getenv('DEVICE_ID')))