From 1e52582e23ed611828451acae0552e7056f8d936 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Wed, 13 May 2020 19:13:36 +0800 Subject: [PATCH] add pretrained and update launch --- example/mobilenetv2_imagenet2012/launch.py | 7 +++++-- example/mobilenetv2_imagenet2012/train.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/example/mobilenetv2_imagenet2012/launch.py b/example/mobilenetv2_imagenet2012/launch.py index 22c4af0c31..48c8159664 100644 --- a/example/mobilenetv2_imagenet2012/launch.py +++ b/example/mobilenetv2_imagenet2012/launch.py @@ -130,9 +130,11 @@ def main(): log_files = [] env = os.environ.copy() env['RANK_SIZE'] = str(args.nproc_per_node) + cur_path = os.getcwd() for rank_id in range(0, args.nproc_per_node): + os.chdir(cur_path) device_id = visible_devices[rank_id] - device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id)) + device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) env['RANK_ID'] = str(rank_id) env['DEVICE_ID'] = str(device_id) if args.nproc_per_node > 1: @@ -141,11 +143,12 @@ def main(): if os.path.exists(device_dir): shutil.rmtree(device_dir) os.mkdir(device_dir) + os.chdir(device_dir) cmd = [sys.executable, '-u'] cmd.append(args.training_script) cmd.extend(args.training_script_args) log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') - process = subprocess.Popen(cmd, stdout=log_file, env=env) + process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) processes.append(process) cmds.append(cmd) log_files.append(log_file) diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py index d22e97a290..513d22ef56 100644 --- a/example/mobilenetv2_imagenet2012/train.py +++ b/example/mobilenetv2_imagenet2012/train.py @@ -37,6 +37,7 @@ from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset.engine as de from mindspore.communication.management import init @@ -46,6 +47,7 @@ de.config.set_seed(1) parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') args_opt = parser.parse_args() device_id = int(os.getenv('DEVICE_ID')) @@ -166,6 +168,9 @@ if __name__ == '__main__': dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=epoch_size, batch_size=config.batch_size) step_size = dataset.get_dataset_size() + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr,