|
|
|
@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size
|
|
|
|
|
|
|
|
|
|
from src.models import Monitor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def switch_precision(net, data_type, config):
|
|
|
|
|
if config.platform == "Ascend":
|
|
|
|
|
net.to_float(data_type)
|
|
|
|
@ -29,17 +30,18 @@ def switch_precision(net, data_type, config):
|
|
|
|
|
if isinstance(cell, nn.Dense):
|
|
|
|
|
cell.to_float(mstype.float32)
|
|
|
|
|
|
|
|
|
|
def context_device_init(config):
|
|
|
|
|
|
|
|
|
|
def context_device_init(config):
|
|
|
|
|
if config.platform == "CPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
|
|
|
|
|
|
|
|
|
|
elif config.platform == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
|
|
|
|
|
init("nccl")
|
|
|
|
|
context.set_auto_parallel_context(device_num=get_group_size(),
|
|
|
|
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|
if config.run_distribute:
|
|
|
|
|
init("nccl")
|
|
|
|
|
context.set_auto_parallel_context(device_num=get_group_size(),
|
|
|
|
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
gradients_mean=True)
|
|
|
|
|
|
|
|
|
|
elif config.platform == "Ascend":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
|
|
|
|
@ -53,6 +55,7 @@ def context_device_init(config):
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Only support CPU, GPU and Ascend.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_context(config):
|
|
|
|
|
if config.platform == "CPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform,
|
|
|
|
@ -64,6 +67,7 @@ def set_context(config):
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
|
device_target=config.platform, save_graphs=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def config_ckpoint(config, lr, step_size):
|
|
|
|
|
cb = None
|
|
|
|
|
if config.platform in ("CPU", "GPU") or config.rank_id == 0:
|
|
|
|
@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size):
|
|
|
|
|
ckpt_save_dir = config.save_checkpoint_path
|
|
|
|
|
|
|
|
|
|
if config.platform == "GPU":
|
|
|
|
|
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
|
|
|
|
|
if config.run_distribute:
|
|
|
|
|
ckpt_save_dir += "ckpt_" + str(get_rank()) + "/"
|
|
|
|
|
else:
|
|
|
|
|
ckpt_save_dir += "ckpt_" + "/"
|
|
|
|
|
|
|
|
|
|
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
|
|
|
|
cb += [ckpt_cb]
|
|
|
|
|