|
|
@ -147,6 +147,8 @@ def parse_args(cloud_args=None):
|
|
|
|
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
|
|
|
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
|
|
|
args.image_size = list(map(int, args.image_size.split(',')))
|
|
|
|
args.image_size = list(map(int, args.image_size.split(',')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
|
|
|
|
|
|
device_target=args.platform, save_graphs=False)
|
|
|
|
# init distributed
|
|
|
|
# init distributed
|
|
|
|
if args.is_distributed:
|
|
|
|
if args.is_distributed:
|
|
|
|
init()
|
|
|
|
init()
|
|
|
@ -190,8 +192,6 @@ def merge_args(args, cloud_args):
|
|
|
|
def train(cloud_args=None):
|
|
|
|
def train(cloud_args=None):
|
|
|
|
"""training process"""
|
|
|
|
"""training process"""
|
|
|
|
args = parse_args(cloud_args)
|
|
|
|
args = parse_args(cloud_args)
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
|
|
|
|
|
|
device_target=args.platform, save_graphs=False)
|
|
|
|
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
|
|
|
|
|
|
|