fix resnet50 distribute bug

pull/1943/head
zhaoting 5 years ago
parent b16a552d41
commit e03fbc0b98

@ -15,6 +15,7 @@
"""train_imagenet.""" """train_imagenet."""
import os import os
import argparse import argparse
import numpy as np
from dataset import create_dataset from dataset import create_dataset
from lr_generator import get_lr from lr_generator import get_lr
from config import config from config import config
@ -45,6 +46,7 @@ if __name__ == '__main__':
target = args_opt.device_target target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
np.random.seed(1)
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend": if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))

@ -15,6 +15,7 @@
"""train_imagenet.""" """train_imagenet."""
import os import os
import argparse import argparse
import numpy as np
from dataset import create_dataset from dataset import create_dataset
from lr_generator import get_lr from lr_generator import get_lr
from config import config from config import config
@ -48,6 +49,7 @@ if __name__ == '__main__':
target = args_opt.device_target target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path ckpt_save_dir = config.save_checkpoint_path
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
np.random.seed(1)
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
if target == "Ascend": if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))

Loading…
Cancel
Save