fix launch bug and add RANK_TABLE_FILE and remove hccl context

pull/818/head
wandongdong 5 years ago
parent 09b2dcb3fb
commit 20c79c3f7c

@ -15,7 +15,6 @@
"""launch train script""" """launch train script"""
import os import os
import sys import sys
import subprocess
import json import json
from argparse import ArgumentParser from argparse import ArgumentParser
@ -125,25 +124,19 @@ def main():
sys.stdout.flush() sys.stdout.flush()
# spawn the processes # spawn the processes
current_env = os.environ.copy()
current_env["RANK_SIZE"] = str(args.nproc_per_node)
if args.nproc_per_node > 1:
current_env["MINDSPORE_HCCL_CONFIG_PATH"] = table_fn
processes = []
cmds = []
for rank_id in range(0, args.nproc_per_node): for rank_id in range(0, args.nproc_per_node):
current_env["RANK_ID"] = str(rank_id) device_id = visible_devices[rank_id]
current_env["DEVICE_ID"] = visible_devices[rank_id] device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id))
cmd = [sys.executable, "-u"] rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node,
cmd.append(args.training_script) rank_id, device_id)
cmd.extend(args.training_script_args) if args.nproc_per_node > 1:
process = subprocess.Popen(cmd, env=current_env) rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn)
processes.append(process) rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn)
cmds.append(cmd) rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir,
for process, cmd in zip(processes, cmds): script=args.training_script
process.wait() )
if process.returncode != 0: rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id)
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) os.system(rank_process)
if __name__ == "__main__": if __name__ == "__main__":

@ -23,6 +23,7 @@ from lr_generator import get_lr
from config import config from config import config
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn
from mindspore.model_zoo.mobilenet import mobilenet_v2 from mindspore.model_zoo.mobilenet import mobilenet_v2
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
@ -110,16 +111,17 @@ class Monitor(Callback):
if __name__ == '__main__': if __name__ == '__main__':
if run_distribute: if run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init() init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = mobilenet_v2(num_classes=config.num_classes) net = mobilenet_v2(num_classes=config.num_classes)
net.add_flags_recursive(fp16=True)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Dense):
cell.add_flags_recursive(fp32=True)
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
print("train args: ", args_opt, "\ncfg: ", config, print("train args: ", args_opt, "\ncfg: ", config,
@ -135,8 +137,7 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale) config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0', model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
keep_batchnorm_fp32=False)
cb = None cb = None
if rank_id == 0: if rank_id == 0:

Loading…
Cancel
Save