|
|
|
@ -49,12 +49,14 @@ def train():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
|
|
|
|
device_target="Ascend", device_id=args.device_id)
|
|
|
|
|
# init multicards training
|
|
|
|
|
args.rank = 0
|
|
|
|
|
args.group_size = 1
|
|
|
|
|
if device_num > 1:
|
|
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
|
|
|
|
init()
|
|
|
|
|
args.rank = get_rank()
|
|
|
|
|
args.group_size = get_group_size()
|
|
|
|
|
args.rank = get_rank()
|
|
|
|
|
args.group_size = get_group_size()
|
|
|
|
|
|
|
|
|
|
# dataset
|
|
|
|
|
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
|
|
|
|