|
|
@ -92,7 +92,7 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
devid, rank_id, rank_size = 0, 0, 1
|
|
|
|
rank_id, rank_size = 0, 1
|
|
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
|
|
|
|
|
|
@ -101,10 +101,7 @@ def main():
|
|
|
|
init("nccl")
|
|
|
|
init("nccl")
|
|
|
|
context.set_context(device_target='GPU')
|
|
|
|
context.set_context(device_target='GPU')
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
init()
|
|
|
|
raise ValueError("Only supported GPU training.")
|
|
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
|
|
|
|
|
|
context.set_context(
|
|
|
|
|
|
|
|
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
|
|
|
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
rank_id = get_rank()
|
|
|
|
rank_id = get_rank()
|
|
|
|
rank_size = get_group_size()
|
|
|
|
rank_size = get_group_size()
|
|
|
@ -113,6 +110,8 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if args.GPU:
|
|
|
|
if args.GPU:
|
|
|
|
context.set_context(device_target='GPU')
|
|
|
|
context.set_context(device_target='GPU')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError("Only supported GPU training.")
|
|
|
|
|
|
|
|
|
|
|
|
net = efficientnet_b0(num_classes=cfg.num_classes,
|
|
|
|
net = efficientnet_b0(num_classes=cfg.num_classes,
|
|
|
|
drop_rate=cfg.drop,
|
|
|
|
drop_rate=cfg.drop,
|
|
|
|