|
|
|
@ -46,6 +46,8 @@ def train():
|
|
|
|
|
args = parse_args()
|
|
|
|
|
cfg = FCN8s_VOC2012_cfg
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
|
|
|
|
device_target="Ascend", device_id=args.device_id)
|
|
|
|
|
# init multicards training
|
|
|
|
|
if device_num > 1:
|
|
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
|
|
@ -54,9 +56,6 @@ def train():
|
|
|
|
|
args.rank = get_rank()
|
|
|
|
|
args.group_size = get_group_size()
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
|
|
|
|
device_target="Ascend", device_id=args.device_id)
|
|
|
|
|
|
|
|
|
|
# dataset
|
|
|
|
|
dataset = data_generator.SegDataset(image_mean=cfg.image_mean,
|
|
|
|
|
image_std=cfg.image_std,
|
|
|
|
|