From b079e34e732741615e38bf04867f090ba13ef0cf Mon Sep 17 00:00:00 2001 From: qujianwei Date: Tue, 11 Aug 2020 09:49:30 +0800 Subject: [PATCH] add se block for resnet50 --- model_zoo/official/cv/resnet/eval.py | 7 +- .../cv/resnet/scripts/run_distribute_train.sh | 11 +- .../official/cv/resnet/scripts/run_eval.sh | 11 +- .../cv/resnet/scripts/run_standalone_train.sh | 11 +- model_zoo/official/cv/resnet/src/config.py | 28 ++- model_zoo/official/cv/resnet/src/dataset.py | 54 ++++- .../official/cv/resnet/src/lr_generator.py | 12 ++ model_zoo/official/cv/resnet/src/resnet.py | 202 ++++++++++++++---- model_zoo/official/cv/resnet/train.py | 22 +- 9 files changed, 286 insertions(+), 72 deletions(-) diff --git a/model_zoo/official/cv/resnet/eval.py b/model_zoo/official/cv/resnet/eval.py index 7ad67289fe..f7f0b593ae 100755 --- a/model_zoo/official/cv/resnet/eval.py +++ b/model_zoo/official/cv/resnet/eval.py @@ -38,17 +38,20 @@ de.config.set_seed(1) if args_opt.net == "resnet50": from src.resnet import resnet50 as resnet - if args_opt.dataset == "cifar10": from src.config import config1 as config from src.dataset import create_dataset1 as create_dataset else: from src.config import config2 as config from src.dataset import create_dataset2 as create_dataset -else: +elif args_opt.net == "resnet101": from src.resnet import resnet101 as resnet from src.config import config3 as config from src.dataset import create_dataset3 as create_dataset +else: + from src.resnet import se_resnet50 as resnet + from src.config import config4 as config + from src.dataset import create_dataset4 as create_dataset if __name__ == '__main__': target = args_opt.device_target diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh index 58e2cb1c6c..7117805f1a 100755 --- a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh @@ -16,13 +16,13 @@ if [ $# != 4 ] && [ $# != 5 ] then - echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" + echo "Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" exit 1 fi -if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ] then - echo "error: the selected net is neither resnet50 nor resnet101" + echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50" exit 1 fi @@ -38,6 +38,11 @@ then exit 1 fi +if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ] +then + echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!" +exit 1 +fi get_real_path(){ if [ "${1:0:1}" == "/" ]; then diff --git a/model_zoo/official/cv/resnet/scripts/run_eval.sh b/model_zoo/official/cv/resnet/scripts/run_eval.sh index 496b3c1e2b..745e072336 100755 --- a/model_zoo/official/cv/resnet/scripts/run_eval.sh +++ b/model_zoo/official/cv/resnet/scripts/run_eval.sh @@ -16,13 +16,13 @@ if [ $# != 4 ] then - echo "Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]" + echo "Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]" exit 1 fi -if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ] then - echo "error: the selected net is neither resnet50 nor resnet101" + echo "error: the selected net is neither resnet50 nor resnet101 nor se-resnet50" exit 1 fi @@ -38,6 +38,11 @@ then exit 1 fi +if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ] +then + echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!" +exit 1 +fi get_real_path(){ if [ "${1:0:1}" == "/" ]; then diff --git a/model_zoo/official/cv/resnet/scripts/run_standalone_train.sh b/model_zoo/official/cv/resnet/scripts/run_standalone_train.sh index 2272dbd88b..0c11cceda4 100755 --- a/model_zoo/official/cv/resnet/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_standalone_train.sh @@ -16,13 +16,13 @@ if [ $# != 3 ] && [ $# != 4 ] then - echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" + echo "Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" exit 1 fi -if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] +if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ] then - echo "error: the selected net is neither resnet50 nor resnet101" + echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50" exit 1 fi @@ -38,6 +38,11 @@ then exit 1 fi +if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ] +then + echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!" +exit 1 +fi get_real_path(){ if [ "${1:0:1}" == "/" ]; then diff --git a/model_zoo/official/cv/resnet/src/config.py b/model_zoo/official/cv/resnet/src/config.py index 7b1759fde0..96675ef4be 100755 --- a/model_zoo/official/cv/resnet/src/config.py +++ b/model_zoo/official/cv/resnet/src/config.py @@ -50,12 +50,12 @@ config2 = ed({ "keep_checkpoint_max": 10, "save_checkpoint_path": "./", "warmup_epochs": 0, - "lr_decay_mode": "cosine", + "lr_decay_mode": "linear", "use_label_smooth": True, "label_smooth_factor": 0.1, "lr_init": 0, - "lr_max": 0.1 - + "lr_max": 0.1, + "lr_end": 0.0 }) # config for resent101, imagenet2012 @@ -77,3 +77,25 @@ config3 = ed({ "label_smooth_factor": 0.1, "lr": 0.1 }) + +# config for se-resnet50, imagenet2012 +config4 = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 28, + "pretrain_epoch_size": 1, + "save_checkpoint": True, + "save_checkpoint_epochs": 4, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 3, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0.0, + "lr_max": 0.3, + "lr_end": 0.0001 +}) diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index d4a8969ed1..9d39b5ab77 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -22,7 +22,6 @@ import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 from mindspore.communication.management import init, get_rank, get_group_size - def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): """ create a train or evaluate cifar10 dataset for resnet50 @@ -191,6 +190,59 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target= return ds +def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): + """ + create a train or eval imagenet2012 dataset for se-resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + if device_num == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True, + num_shards=device_num, shard_id=rank_id) + image_size = 224 + mean = [123.68, 116.78, 103.94] + std = [1.0, 1.0, 1.0] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(292), + C.CenterCrop(256), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", num_parallel_workers=12, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=12, operations=type_cast_op) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds def _get_rank_info(): """ diff --git a/model_zoo/official/cv/resnet/src/lr_generator.py b/model_zoo/official/cv/resnet/src/lr_generator.py index 2af8971715..3c02cde2db 100755 --- a/model_zoo/official/cv/resnet/src/lr_generator.py +++ b/model_zoo/official/cv/resnet/src/lr_generator.py @@ -62,6 +62,18 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch if lr < 0.0: lr = 0.0 lr_each_step.append(lr) + elif lr_decay_mode == 'cosine': + decay_steps = total_steps - warmup_steps + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) else: for i in range(total_steps): if i < warmup_steps: diff --git a/model_zoo/official/cv/resnet/src/resnet.py b/model_zoo/official/cv/resnet/src/resnet.py index 0e21222d21..73ef4b5cec 100755 --- a/model_zoo/official/cv/resnet/src/resnet.py +++ b/model_zoo/official/cv/resnet/src/resnet.py @@ -15,32 +15,53 @@ """ResNet.""" import numpy as np import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.common.tensor import Tensor - +from scipy.stats import truncnorm + +def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): + fan_in = in_channel * kernel_size * kernel_size + scale = 1.0 + scale /= max(1., fan_in) + stddev = (scale ** 0.5) / .87962566103423978 + mu, sigma = 0, stddev + weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) + weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) + return Tensor(weight, dtype=mstype.float32) def _weight_variable(shape, factor=0.01): init_value = np.random.randn(*shape).astype(np.float32) * factor return Tensor(init_value) -def _conv3x3(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 3, 3) - weight = _weight_variable(weight_shape) +def _conv3x3(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) + else: + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) -def _conv1x1(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 1, 1) - weight = _weight_variable(weight_shape) +def _conv1x1(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) + else: + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) -def _conv7x7(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 7, 7) - weight = _weight_variable(weight_shape) +def _conv7x7(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) + else: + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) @@ -55,9 +76,13 @@ def _bn_last(channel): gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) -def _fc(in_channel, out_channel): - weight_shape = (out_channel, in_channel) - weight = _weight_variable(weight_shape) +def _fc(in_channel, out_channel, use_se=False): + if use_se: + weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel) + weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) + else: + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) @@ -69,6 +94,8 @@ class ResidualBlock(nn.Cell): in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. Default: 1. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net. Default: False. Returns: Tensor, output tensor. @@ -81,19 +108,30 @@ class ResidualBlock(nn.Cell): def __init__(self, in_channel, out_channel, - stride=1): + stride=1, + use_se=False, se_block=False): super(ResidualBlock, self).__init__() - + self.stride = stride + self.use_se = use_se + self.se_block = se_block channel = out_channel // self.expansion - self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) self.bn1 = _bn(channel) - - self.conv2 = _conv3x3(channel, channel, stride=stride) - self.bn2 = _bn(channel) - - self.conv3 = _conv1x1(channel, out_channel, stride=1) + if self.use_se and self.stride != 1: + self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), + nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) + else: + self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) self.bn3 = _bn_last(out_channel) - + if self.se_block: + self.se_global_pool = P.ReduceMean(keep_dims=False) + self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se) + self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se) + self.se_sigmoid = nn.Sigmoid() + self.se_mul = P.Mul() self.relu = nn.ReLU() self.down_sample = False @@ -103,8 +141,17 @@ class ResidualBlock(nn.Cell): self.down_sample_layer = None if self.down_sample: - self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), - _bn(out_channel)]) + if self.use_se: + if stride == 1: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, + stride, use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), + _conv1x1(in_channel, out_channel, 1, + use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, + use_se=self.use_se), _bn(out_channel)]) self.add = P.TensorAdd() def construct(self, x): @@ -113,13 +160,23 @@ class ResidualBlock(nn.Cell): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - + if self.use_se and self.stride != 1: + out = self.e2(out) + else: + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) out = self.conv3(out) out = self.bn3(out) + if self.se_block: + out_se = out + out = self.se_global_pool(out, (2, 3)) + out = self.se_dense_0(out) + out = self.relu(out) + out = self.se_dense_1(out) + out = self.se_sigmoid(out) + out = F.reshape(out, F.shape(out) + (1, 1)) + out = self.se_mul(out, out_se) if self.down_sample: identity = self.down_sample_layer(identity) @@ -141,6 +198,8 @@ class ResNet(nn.Cell): out_channels (list): Output channel in each layer. strides (list): Stride size in each layer. num_classes (int): The number of classes that the training images are belonging to. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. Returns: Tensor, output tensor. @@ -159,43 +218,60 @@ class ResNet(nn.Cell): in_channels, out_channels, strides, - num_classes): + num_classes, + use_se=False): super(ResNet, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") - - self.conv1 = _conv7x7(3, 64, stride=2) + self.use_se = use_se + self.se_block = False + if self.use_se: + self.se_block = True + + if self.use_se: + self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) + self.bn1_0 = _bn(32) + self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) + self.bn1_1 = _bn(32) + self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) + else: + self.conv1 = _conv7x7(3, 64, stride=2) self.bn1 = _bn(64) self.relu = P.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") - self.layer1 = self._make_layer(block, layer_nums[0], in_channel=in_channels[0], out_channel=out_channels[0], - stride=strides[0]) + stride=strides[0], + use_se=self.use_se) self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1], out_channel=out_channels[1], - stride=strides[1]) + stride=strides[1], + use_se=self.use_se) self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2], out_channel=out_channels[2], - stride=strides[2]) + stride=strides[2], + use_se=self.use_se, + se_block=self.se_block) self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3], out_channel=out_channels[3], - stride=strides[3]) + stride=strides[3], + use_se=self.use_se, + se_block=self.se_block) self.mean = P.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() - self.end_point = _fc(out_channels[3], num_classes) + self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) - def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): """ Make stage network of ResNet. @@ -205,7 +281,7 @@ class ResNet(nn.Cell): in_channel (int): Input channel. out_channel (int): Output channel. stride (int): Stride size for the first convolutional layer. - + se_block(bool): use se block in SE-ResNet50 net. Default: False. Returns: SequentialCell, the output layer. @@ -214,17 +290,31 @@ class ResNet(nn.Cell): """ layers = [] - resnet_block = block(in_channel, out_channel, stride=stride) + resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) layers.append(resnet_block) - - for _ in range(1, layer_num): - resnet_block = block(out_channel, out_channel, stride=1) + if se_block: + for _ in range(1, layer_num - 1): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) layers.append(resnet_block) - + else: + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) return nn.SequentialCell(layers) def construct(self, x): - x = self.conv1(x) + if self.use_se: + x = self.conv1_0(x) + x = self.bn1_0(x) + x = self.relu(x) + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu(x) + x = self.conv1_2(x) + else: + x = self.conv1(x) x = self.bn1(x) x = self.relu(x) c1 = self.maxpool(x) @@ -261,6 +351,26 @@ def resnet50(class_num=10): [1, 2, 2, 2], class_num) +def se_resnet50(class_num=1001): + """ + Get SE-ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of SE-ResNet50 neural network. + + Examples: + >>> net = se-resnet50(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + use_se=True) def resnet101(class_num=1001): """ diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 3d65d10392..b85760ad00 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -50,17 +50,21 @@ de.config.set_seed(1) if args_opt.net == "resnet50": from src.resnet import resnet50 as resnet - if args_opt.dataset == "cifar10": from src.config import config1 as config from src.dataset import create_dataset1 as create_dataset else: from src.config import config2 as config from src.dataset import create_dataset2 as create_dataset -else: +elif args_opt.net == "resnet101": from src.resnet import resnet101 as resnet from src.config import config3 as config from src.dataset import create_dataset3 as create_dataset +else: + from src.resnet import se_resnet50 as resnet + from src.config import config4 as config + from src.dataset import create_dataset4 as create_dataset + if __name__ == '__main__': target = args_opt.device_target @@ -74,7 +78,7 @@ if __name__ == '__main__': context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - if args_opt.net == "resnet50": + if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) else: auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) @@ -112,14 +116,10 @@ if __name__ == '__main__': cell.weight.dtype) # init lr - if args_opt.net == "resnet50": - if args_opt.dataset == "cifar10": - lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, - warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, - lr_decay_mode='poly') - else: - lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, - total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') + if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": + lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode) else: lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size, config.pretrain_epoch_size * step_size)