add parallel strategy searching optimization in model_zoo's resnet

pull/11483/head
Xiaoda Zhang 5 years ago
parent da92f1affb
commit 2646654a30

@ -27,6 +27,7 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.parallel import set_algo_parameters
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from src.lr_generator import get_lr, warmup_cosine_annealing_lr from src.lr_generator import get_lr, warmup_cosine_annealing_lr
@ -82,6 +83,7 @@ if __name__ == '__main__':
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
else: else:

Loading…
Cancel
Save