|
|
|
@ -23,6 +23,7 @@ from lr_generator import get_lr
|
|
|
|
|
from config import config
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
from mindspore import nn
|
|
|
|
|
from mindspore.model_zoo.mobilenet import mobilenet_v2
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
from mindspore.nn.optim.momentum import Momentum
|
|
|
|
@ -110,16 +111,17 @@ class Monitor(Callback):
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if run_distribute:
|
|
|
|
|
context.set_context(enable_hccl=True)
|
|
|
|
|
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
parameter_broadcast=True, mirror_mean=True)
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
|
|
|
|
|
init()
|
|
|
|
|
else:
|
|
|
|
|
context.set_context(enable_hccl=False)
|
|
|
|
|
|
|
|
|
|
epoch_size = config.epoch_size
|
|
|
|
|
net = mobilenet_v2(num_classes=config.num_classes)
|
|
|
|
|
net.add_flags_recursive(fp16=True)
|
|
|
|
|
for _, cell in net.cells_and_names():
|
|
|
|
|
if isinstance(cell, nn.Dense):
|
|
|
|
|
cell.add_flags_recursive(fp32=True)
|
|
|
|
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
|
|
|
|
|
|
|
|
|
|
print("train args: ", args_opt, "\ncfg: ", config,
|
|
|
|
@ -135,8 +137,7 @@ if __name__ == '__main__':
|
|
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
|
|
|
|
config.weight_decay, config.loss_scale)
|
|
|
|
|
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level='O0',
|
|
|
|
|
keep_batchnorm_fp32=False)
|
|
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale)
|
|
|
|
|
|
|
|
|
|
cb = None
|
|
|
|
|
if rank_id == 0:
|
|
|
|
|