From b1a1e24e0790d37319ebc6f4ca1243409dab938d Mon Sep 17 00:00:00 2001
From: zhouyaqiang <zhouyaqiang2@huawei.com>
Date: Sun, 28 Jun 2020 15:34:56 +0800
Subject: [PATCH] add resnext50

---
 model_zoo/resnext50/README.md                 | 128 ++++++++
 model_zoo/resnext50/eval.py                   | 243 +++++++++++++++
 .../resnext50/scripts/run_distribute_train.sh |  55 ++++
 model_zoo/resnext50/scripts/run_eval.sh       |  24 ++
 .../resnext50/scripts/run_standalone_train.sh |  30 ++
 model_zoo/resnext50/src/__init__.py           |   0
 model_zoo/resnext50/src/backbone/__init__.py  |  16 +
 model_zoo/resnext50/src/backbone/resnet.py    | 273 +++++++++++++++++
 model_zoo/resnext50/src/config.py             |  45 +++
 model_zoo/resnext50/src/crossentropy.py       |  41 +++
 model_zoo/resnext50/src/dataset.py            | 155 ++++++++++
 model_zoo/resnext50/src/head.py               |  42 +++
 .../resnext50/src/image_classification.py     |  85 ++++++
 model_zoo/resnext50/src/linear_warmup.py      |  21 ++
 model_zoo/resnext50/src/utils/__init__.py     |   0
 model_zoo/resnext50/src/utils/cunstom_op.py   | 108 +++++++
 model_zoo/resnext50/src/utils/logging.py      |  82 +++++
 .../resnext50/src/utils/optimizers__init__.py |  39 +++
 model_zoo/resnext50/src/utils/sampler.py      |  53 ++++
 model_zoo/resnext50/src/utils/var_init.py     | 213 +++++++++++++
 .../src/warmup_cosine_annealing_lr.py         |  40 +++
 model_zoo/resnext50/src/warmup_step_lr.py     |  56 ++++
 model_zoo/resnext50/train.py                  | 289 ++++++++++++++++++
 23 files changed, 2038 insertions(+)
 create mode 100644 model_zoo/resnext50/README.md
 create mode 100644 model_zoo/resnext50/eval.py
 create mode 100644 model_zoo/resnext50/scripts/run_distribute_train.sh
 create mode 100644 model_zoo/resnext50/scripts/run_eval.sh
 create mode 100644 model_zoo/resnext50/scripts/run_standalone_train.sh
 create mode 100644 model_zoo/resnext50/src/__init__.py
 create mode 100644 model_zoo/resnext50/src/backbone/__init__.py
 create mode 100644 model_zoo/resnext50/src/backbone/resnet.py
 create mode 100644 model_zoo/resnext50/src/config.py
 create mode 100644 model_zoo/resnext50/src/crossentropy.py
 create mode 100644 model_zoo/resnext50/src/dataset.py
 create mode 100644 model_zoo/resnext50/src/head.py
 create mode 100644 model_zoo/resnext50/src/image_classification.py
 create mode 100644 model_zoo/resnext50/src/linear_warmup.py
 create mode 100644 model_zoo/resnext50/src/utils/__init__.py
 create mode 100644 model_zoo/resnext50/src/utils/cunstom_op.py
 create mode 100644 model_zoo/resnext50/src/utils/logging.py
 create mode 100644 model_zoo/resnext50/src/utils/optimizers__init__.py
 create mode 100644 model_zoo/resnext50/src/utils/sampler.py
 create mode 100644 model_zoo/resnext50/src/utils/var_init.py
 create mode 100644 model_zoo/resnext50/src/warmup_cosine_annealing_lr.py
 create mode 100644 model_zoo/resnext50/src/warmup_step_lr.py
 create mode 100644 model_zoo/resnext50/train.py

diff --git a/model_zoo/resnext50/README.md b/model_zoo/resnext50/README.md
new file mode 100644
index 0000000000..c44844eecc
--- /dev/null
+++ b/model_zoo/resnext50/README.md
@@ -0,0 +1,128 @@
+# ResNext50 Example
+
+## Description
+
+This is an example of training ResNext50 with ImageNet dataset in Mindspore.
+
+## Requirements
+
+- Install [Mindspore](http://www.mindspore.cn/install/en).
+- Downlaod the dataset ImageNet2012.
+
+## Structure
+
+```shell
+.
+└─resnext50      
+  ├─README.md
+  ├─scripts      
+    ├─run_standalone_train.sh         # launch standalone training(1p)
+    ├─run_distribute_train.sh         # launch distributed training(8p)
+    └─run_eval.sh                     # launch evaluating
+  ├─src
+    ├─backbone
+      ├─_init_.py                     # initalize
+      ├─resnet.py                     # resnext50 backbone
+    ├─utils
+      ├─_init_.py                     # initalize
+      ├─cunstom_op.py                 # network operation
+      ├─logging.py                    # print log
+      ├─optimizers_init_.py           # get parameters
+      ├─sampler.py                    # distributed sampler
+      ├─var_init_.py                  # calculate gain value
+    ├─_init_.py                       # initalize
+    ├─config.py                       # parameter configuration
+    ├─crossentropy.py                 # CrossEntropy loss function
+    ├─dataset.py                      # data preprocessing
+    ├─head.py                         # commom head
+    ├─image_classification.py         # get resnet
+    ├─linear_warmup.py                # linear warmup learning rate
+    ├─warmup_cosine_annealing.py      # learning rate each step
+    ├─warmup_step_lr.py               # warmup step learning rate
+  ├─eval.py                           # eval net
+  └─train.py                          # train net
+  
+```
+
+## Parameter Configuration
+
+Parameters for both training and evaluating can be set in config.py
+
+```       
+"image_height": '224,224'                 # image size
+"num_classes": 1000,                      # dataset class number
+"per_batch_size": 128,                    # batch size of input tensor
+"lr": 0.05,                               # base learning rate
+"lr_scheduler": 'cosine_annealing',       # learning rate mode
+"lr_epochs": '30,60,90,120',              # epoch of lr changing
+"lr_gamma": 0.1,                          # decrease lr by a factor of exponential lr_scheduler
+"eta_min": 0,                             # eta_min in cosine_annealing scheduler
+"T_max": 150,                             # T-max in cosine_annealing scheduler
+"max_epoch": 150,                         # max epoch num to train the model
+"backbone": 'resnext50',                  # backbone metwork
+"warmup_epochs" : 1,                      # warmup epoch
+"weight_decay": 0.0001,                   # weight decay
+"momentum": 0.9,                          # momentum
+"is_dynamic_loss_scale": 0,               # dynamic loss scale
+"loss_scale": 1024,                       # loss scale
+"label_smooth": 1,                        # label_smooth
+"label_smooth_factor": 0.1,               # label_smooth_factor
+"ckpt_interval": 2000,                    # ckpt_interval
+"ckpt_path": 'outputs/',                  # checkpoint save location
+"is_save_on_master": 1,
+"rank": 0,                                # local rank of distributed
+"group_size": 1                           # world size of distributed
+```
+
+## Running the example
+
+### Train
+
+#### Usage
+
+```
+# distribute training example(8p)
+sh run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH DATA_PATH
+# standalone training
+sh run_standalone_train.sh DEVICE_ID DATA_PATH
+```
+
+#### Launch
+
+```bash
+# distributed training example(8p)
+sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /ImageNet/train
+# standalone training example
+sh scripts/run_standalone_train.sh 0 /ImageNet_Original/train
+```
+
+#### Result
+
+You can find checkpoint file together with result in log.
+
+### Evaluation
+
+#### Usage
+
+```
+# Evaluation
+sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH
+```
+
+#### Launch
+
+```bash
+# Evaluation with checkpoint
+sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt
+```
+
+> checkpoint can be produced in training process.
+
+#### Result
+
+Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
+ 
+```
+acc=78,16%(TOP1)
+acc=93.88%(TOP5)
+```
\ No newline at end of file
diff --git a/model_zoo/resnext50/eval.py b/model_zoo/resnext50/eval.py
new file mode 100644
index 0000000000..ff5c83843e
--- /dev/null
+++ b/model_zoo/resnext50/eval.py
@@ -0,0 +1,243 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Eval"""
+import os
+import time
+import argparse
+import datetime
+import glob
+import numpy as np
+import mindspore.nn as nn
+
+from mindspore import Tensor, context
+from mindspore.communication.management import init, get_rank, get_group_size, release
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.ops import operations as P
+from mindspore.ops import functional as F
+from mindspore.common import dtype as mstype
+
+from src.utils.logging import get_logger
+from src.image_classification import get_network
+from src.dataset import classification_dataset
+from src.config import config
+
+devid = int(os.getenv('DEVICE_ID'))
+context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
+                    device_target="Ascend", save_graphs=False, device_id=devid)
+
+
+
+class ParameterReduce(nn.Cell):
+    """ParameterReduce"""
+    def __init__(self):
+        super(ParameterReduce, self).__init__()
+        self.cast = P.Cast()
+        self.reduce = P.AllReduce()
+
+    def construct(self, x):
+        one = self.cast(F.scalar_to_array(1.0), mstype.float32)
+        out = x * one
+        ret = self.reduce(out)
+        return ret
+
+
+def parse_args(cloud_args=None):
+    """parse_args"""
+    parser = argparse.ArgumentParser('mindspore classification test')
+
+    # dataset related
+    parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir')
+    parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu')
+    # network related
+    parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt')
+    parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. '
+                        'If it is a direction, it will test all ckpt')
+
+    # logging related
+    parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
+    parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
+
+    # roma obs
+    parser.add_argument('--train_url', type=str, default="", help='train url')
+
+    args, _ = parser.parse_known_args()
+    args = merge_args(args, cloud_args)
+    args.image_size = config.image_size
+    args.num_classes = config.num_classes
+    args.backbone = config.backbone
+    args.rank = config.rank
+    args.group_size = config.group_size
+
+    args.image_size = list(map(int, args.image_size.split(',')))
+
+    return args
+
+
+def get_top5_acc(top5_arg, gt_class):
+    sub_count = 0
+    for top5, gt in zip(top5_arg, gt_class):
+        if gt in top5:
+            sub_count += 1
+    return sub_count
+
+def merge_args(args, cloud_args):
+    """merge_args"""
+    args_dict = vars(args)
+    if isinstance(cloud_args, dict):
+        for key in cloud_args.keys():
+            val = cloud_args[key]
+            if key in args_dict and val:
+                arg_type = type(args_dict[key])
+                if arg_type is not type(None):
+                    val = arg_type(val)
+                args_dict[key] = val
+    return args
+
+def test(cloud_args=None):
+    """test"""
+    args = parse_args(cloud_args)
+
+    # init distributed
+    if args.is_distributed:
+        init()
+        args.rank = get_rank()
+        args.group_size = get_group_size()
+
+    args.outputs_dir = os.path.join(args.log_path,
+                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
+
+    args.logger = get_logger(args.outputs_dir, args.rank)
+    args.logger.save_args(args)
+
+    # network
+    args.logger.important_info('start create network')
+    if os.path.isdir(args.pretrained):
+        models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
+        print(models)
+        if args.graph_ckpt:
+            f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
+        else:
+            f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
+        args.models = sorted(models, key=f)
+    else:
+        args.models = [args.pretrained,]
+
+    for model in args.models:
+        de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
+                                            per_batch_size=args.per_batch_size,
+                                            max_epoch=1, rank=args.rank, group_size=args.group_size,
+                                            mode='eval')
+        eval_dataloader = de_dataset.create_tuple_iterator()
+        network = get_network(args.backbone, args.num_classes)
+        if network is None:
+            raise NotImplementedError('not implement {}'.format(args.backbone))
+
+        param_dict = load_checkpoint(model)
+        param_dict_new = {}
+        for key, values in param_dict.items():
+            if key.startswith('moments.'):
+                continue
+            elif key.startswith('network.'):
+                param_dict_new[key[8:]] = values
+            else:
+                param_dict_new[key] = values
+
+        load_param_into_net(network, param_dict_new)
+        args.logger.info('load model {} success'.format(model))
+
+        # must add
+        network.add_flags_recursive(fp16=True)
+
+        img_tot = 0
+        top1_correct = 0
+        top5_correct = 0
+        network.set_train(False)
+        t_end = time.time()
+        it = 0
+        for data, gt_classes in eval_dataloader:
+            output = network(Tensor(data, mstype.float32))
+            output = output.asnumpy()
+
+            top1_output = np.argmax(output, (-1))
+            top5_output = np.argsort(output)[:, -5:]
+
+            t1_correct = np.equal(top1_output, gt_classes).sum()
+            top1_correct += t1_correct
+            top5_correct += get_top5_acc(top5_output, gt_classes)
+            img_tot += args.per_batch_size
+
+            if args.rank == 0 and it == 0:
+                t_end = time.time()
+                it = 1
+        if args.rank == 0:
+            time_used = time.time() - t_end
+            fps = (img_tot - args.per_batch_size) * args.group_size / time_used
+            args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
+        results = [[top1_correct], [top5_correct], [img_tot]]
+        args.logger.info('before results={}'.format(results))
+        if args.is_distributed:
+            model_md5 = model.replace('/', '')
+            tmp_dir = '/cache'
+            if not os.path.exists(tmp_dir):
+                os.mkdir(tmp_dir)
+            top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
+            top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
+            img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
+            np.save(top1_correct_npy, top1_correct)
+            np.save(top5_correct_npy, top5_correct)
+            np.save(img_tot_npy, img_tot)
+            while True:
+                rank_ok = True
+                for other_rank in range(args.group_size):
+                    top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
+                    top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
+                    img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
+                    if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
+                       not os.path.exists(img_tot_npy):
+                        rank_ok = False
+                if rank_ok:
+                    break
+
+            top1_correct_all = 0
+            top5_correct_all = 0
+            img_tot_all = 0
+            for other_rank in range(args.group_size):
+                top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
+                top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
+                img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
+                top1_correct_all += np.load(top1_correct_npy)
+                top5_correct_all += np.load(top5_correct_npy)
+                img_tot_all += np.load(img_tot_npy)
+            results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
+            results = np.array(results)
+        else:
+            results = np.array(results)
+
+        args.logger.info('after results={}'.format(results))
+        top1_correct = results[0, 0]
+        top5_correct = results[1, 0]
+        img_tot = results[2, 0]
+        acc1 = 100.0 * top1_correct / img_tot
+        acc5 = 100.0 * top5_correct / img_tot
+        args.logger.info('after allreduce eval: top1_correct={}, tot={},'
+                         'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
+        args.logger.info('after allreduce eval: top5_correct={}, tot={},'
+                         'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
+    if args.is_distributed:
+        release()
+
+
+if __name__ == "__main__":
+    test()
diff --git a/model_zoo/resnext50/scripts/run_distribute_train.sh b/model_zoo/resnext50/scripts/run_distribute_train.sh
new file mode 100644
index 0000000000..226cfe3cb6
--- /dev/null
+++ b/model_zoo/resnext50/scripts/run_distribute_train.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+DATA_DIR=$2
+export RANK_TABLE_FILE=$1
+export RANK_SIZE=8
+PATH_CHECKPOINT=""
+if [ $# == 3 ]
+then
+	PATH_CHECKPOINT=$3
+fi
+
+cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
+echo "the number of logical core" $cores
+avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
+core_gap=`expr $avg_core_per_rank \- 1`
+echo "avg_core_per_rank" $avg_core_per_rank
+echo "core_gap" $core_gap
+for((i=0;i<RANK_SIZE;i++))
+do
+    start=`expr $i \* $avg_core_per_rank`
+    export DEVICE_ID=$i
+    export RANK_ID=$i
+    export DEPLOY_MODE=0
+    export GE_USE_STATIC_MEMORY=1
+    end=`expr $start \+ $core_gap`
+    cmdopt=$start"-"$end
+
+    rm -rf LOG$i
+    mkdir ./LOG$i
+    cp  *.py ./LOG$i
+    cd ./LOG$i || exit
+    echo "start training for rank $i, device $DEVICE_ID"
+
+    env > env.log
+    taskset -c $cmdopt python ../train.py  \
+    --is_distribute=1 \
+    --device_id=$DEVICE_ID \
+    --pretrained=$PATH_CHECKPOINT \
+    --data_dir=$DATA_DIR > log.txt 2>&1 &
+    cd ../
+done
diff --git a/model_zoo/resnext50/scripts/run_eval.sh b/model_zoo/resnext50/scripts/run_eval.sh
new file mode 100644
index 0000000000..610faa874e
--- /dev/null
+++ b/model_zoo/resnext50/scripts/run_eval.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+DEVICE_ID=$1
+DATA_DIR=$2
+PATH_CHECKPOINT=$3
+
+python eval.py  \
+    --device_id=$DEVICE_ID \
+    --pretrained=$PATH_CHECKPOINT \
+    --data_dir=$DATA_DIR > log.txt 2>&1 &
diff --git a/model_zoo/resnext50/scripts/run_standalone_train.sh b/model_zoo/resnext50/scripts/run_standalone_train.sh
new file mode 100644
index 0000000000..ca5d8206f3
--- /dev/null
+++ b/model_zoo/resnext50/scripts/run_standalone_train.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+DEVICE_ID=$1
+DATA_DIR=$2
+PATH_CHECKPOINT=""
+if [ $# == 3 ]
+then
+  PATH_CHECKPOINT=$3
+fi
+
+python train.py  \
+    --is_distribute=0 \
+    --device_id=$DEVICE_ID \
+    --pretrained=$PATH_CHECKPOINT \
+    --data_dir=$DATA_DIR > log.txt 2>&1 &
+
diff --git a/model_zoo/resnext50/src/__init__.py b/model_zoo/resnext50/src/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/model_zoo/resnext50/src/backbone/__init__.py b/model_zoo/resnext50/src/backbone/__init__.py
new file mode 100644
index 0000000000..b757d07410
--- /dev/null
+++ b/model_zoo/resnext50/src/backbone/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""resnet"""
+from .resnet import *
diff --git a/model_zoo/resnext50/src/backbone/resnet.py b/model_zoo/resnext50/src/backbone/resnet.py
new file mode 100644
index 0000000000..5b69f9e1f5
--- /dev/null
+++ b/model_zoo/resnext50/src/backbone/resnet.py
@@ -0,0 +1,273 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+ResNet based ResNext
+"""
+import mindspore.nn as nn
+from mindspore.ops.operations import TensorAdd, Split, Concat
+from mindspore.ops import operations as P
+from mindspore.common.initializer import TruncatedNormal
+
+from src.utils.cunstom_op import SEBlock, GroupConv
+
+
+__all__ = ['ResNet', 'resnext50']
+
+
+def weight_variable(shape, factor=0.1):
+    return TruncatedNormal(0.02)
+
+
+def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1):
+    return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
+                     padding=padding, pad_mode="pad", group=groups)
+
+
+def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1):
+    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
+                     padding=padding, pad_mode="pad", group=groups)
+
+
+def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1):
+    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
+                     padding=padding, pad_mode="pad", group=groups)
+
+
+class _DownSample(nn.Cell):
+    """
+    Downsample for ResNext-ResNet.
+
+    Args:
+        in_channels (int): Input channels.
+        out_channels (int): Output channels.
+        stride (int): Stride size for the 1*1 convolutional layer.
+
+    Returns:
+        Tensor, output tensor.
+
+    Examples:
+        >>>DownSample(32, 64, 2)
+    """
+    def __init__(self, in_channels, out_channels, stride):
+        super(_DownSample, self).__init__()
+        self.conv = conv1x1(in_channels, out_channels, stride=stride, padding=0)
+        self.bn = nn.BatchNorm2d(out_channels)
+
+    def construct(self, x):
+        out = self.conv(x)
+        out = self.bn(out)
+        return out
+
+class BasicBlock(nn.Cell):
+    """
+    ResNet basic block definition.
+
+    Args:
+        in_channels (int): Input channels.
+        out_channels (int): Output channels.
+        stride (int): Stride size for the first convolutional layer. Default: 1.
+
+    Returns:
+        Tensor, output tensor.
+
+    Examples:
+        >>>BasicBlock(32, 256, stride=2)
+    """
+    expansion = 1
+
+    def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
+        self.bn1 = nn.BatchNorm2d(out_channels)
+        self.relu = P.ReLU()
+        self.conv2 = conv3x3(out_channels, out_channels, stride=1)
+        self.bn2 = nn.BatchNorm2d(out_channels)
+
+        self.use_se = use_se
+        if self.use_se:
+            self.se = SEBlock(out_channels)
+
+        self.down_sample_flag = False
+        if down_sample is not None:
+            self.down_sample = down_sample
+            self.down_sample_flag = True
+
+        self.add = TensorAdd()
+
+    def construct(self, x):
+        identity = x
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.use_se:
+            out = self.se(out)
+
+        if self.down_sample_flag:
+            identity = self.down_sample(x)
+
+        out = self.add(out, identity)
+        out = self.relu(out)
+        return out
+
+class Bottleneck(nn.Cell):
+    """
+    ResNet Bottleneck block definition.
+
+    Args:
+        in_channels (int): Input channels.
+        out_channels (int): Output channels.
+        stride (int): Stride size for the initial convolutional layer. Default: 1.
+
+    Returns:
+        Tensor, the ResNet unit's output.
+
+    Examples:
+        >>>Bottleneck(3, 256, stride=2)
+    """
+    expansion = 4
+
+    def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
+                 base_width=64, groups=1, use_se=False, **kwargs):
+        super(Bottleneck, self).__init__()
+
+        width = int(out_channels * (base_width / 64.0)) * groups
+        self.groups = groups
+        self.conv1 = conv1x1(in_channels, width, stride=1)
+        self.bn1 = nn.BatchNorm2d(width)
+        self.relu = P.ReLU()
+
+        self.conv3x3s = nn.CellList()
+
+        self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups)
+        self.op_split = Split(axis=1, output_num=self.groups)
+        self.op_concat = Concat(axis=1)
+
+        self.bn2 = nn.BatchNorm2d(width)
+        self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1)
+        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
+
+        self.use_se = use_se
+        if self.use_se:
+            self.se = SEBlock(out_channels * self.expansion)
+
+        self.down_sample_flag = False
+        if down_sample is not None:
+            self.down_sample = down_sample
+            self.down_sample_flag = True
+
+        self.cast = P.Cast()
+        self.add = TensorAdd()
+
+    def construct(self, x):
+        identity = x
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.use_se:
+            out = self.se(out)
+
+        if self.down_sample_flag:
+            identity = self.down_sample(x)
+
+        out = self.add(out, identity)
+        out = self.relu(out)
+        return out
+
+class ResNet(nn.Cell):
+    """
+    ResNet architecture.
+
+    Args:
+        block (cell): Block for network.
+        layers (list): Numbers of block in different layers.
+        width_per_group (int): Width of every group.
+        groups (int): Groups number.
+
+    Returns:
+        Tuple, output tensor tuple.
+
+    Examples:
+        >>>ResNet()
+    """
+    def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False):
+        super(ResNet, self).__init__()
+        self.in_channels = 64
+        self.groups = groups
+        self.base_width = width_per_group
+
+        self.conv = conv7x7(3, self.in_channels, stride=2, padding=3)
+        self.bn = nn.BatchNorm2d(self.in_channels)
+        self.relu = P.ReLU()
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
+
+        self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se)
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se)
+
+        self.out_channels = 512 * block.expansion
+        self.cast = P.Cast()
+
+    def construct(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x
+
+    def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False):
+        """_make_layer"""
+        down_sample = None
+        if stride != 1 or self.in_channels != out_channels * block.expansion:
+            down_sample = _DownSample(self.in_channels,
+                                      out_channels * block.expansion,
+                                      stride=stride)
+
+        layers = []
+        layers.append(block(self.in_channels,
+                            out_channels,
+                            stride=stride,
+                            down_sample=down_sample,
+                            base_width=self.base_width,
+                            groups=self.groups,
+                            use_se=use_se))
+        self.in_channels = out_channels * block.expansion
+        for _ in range(1, blocks_num):
+            layers.append(block(self.in_channels, out_channels,
+                                base_width=self.base_width, groups=self.groups, use_se=use_se))
+
+        return nn.SequentialCell(layers)
+
+    def get_out_channels(self):
+        return self.out_channels
+
+
+def resnext50():
+    return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32)
diff --git a/model_zoo/resnext50/src/config.py b/model_zoo/resnext50/src/config.py
new file mode 100644
index 0000000000..c1a12aa14e
--- /dev/null
+++ b/model_zoo/resnext50/src/config.py
@@ -0,0 +1,45 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""config"""
+from easydict import EasyDict as ed
+
+config = ed({
+    "image_size": '224,224',
+    "num_classes": 1000,
+
+    "lr": 0.4,
+    "lr_scheduler": 'cosine_annealing',
+    "lr_epochs": '30,60,90,120',
+    "lr_gamma": 0.1,
+    "eta_min": 0,
+    "T_max": 150,
+    "max_epoch": 150,
+    "backbone": 'resnext50',
+    "warmup_epochs": 1,
+
+    "weight_decay": 0.0001,
+    "momentum": 0.9,
+    "is_dynamic_loss_scale": 0,
+    "loss_scale": 1024,
+    "label_smooth": 1,
+    "label_smooth_factor": 0.1,
+
+    "ckpt_interval": 1250,
+    "ckpt_path": 'outputs/',
+    "is_save_on_master": 1,
+
+    "rank": 0,
+    "group_size": 1
+})
diff --git a/model_zoo/resnext50/src/crossentropy.py b/model_zoo/resnext50/src/crossentropy.py
new file mode 100644
index 0000000000..a0e509a51e
--- /dev/null
+++ b/model_zoo/resnext50/src/crossentropy.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+define loss function for network.
+"""
+from mindspore.nn.loss.loss import _Loss
+from mindspore.ops import operations as P
+from mindspore.ops import functional as F
+from mindspore import Tensor
+from mindspore.common import dtype as mstype
+import mindspore.nn as nn
+
+class CrossEntropy(_Loss):
+    """
+    the redefined loss function with SoftmaxCrossEntropyWithLogits.
+    """
+    def __init__(self, smooth_factor=0., num_classes=1000):
+        super(CrossEntropy, self).__init__()
+        self.onehot = P.OneHot()
+        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
+        self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
+        self.ce = nn.SoftmaxCrossEntropyWithLogits()
+        self.mean = P.ReduceMean(False)
+
+    def construct(self, logit, label):
+        one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+        loss = self.ce(logit, one_hot_label)
+        loss = self.mean(loss, 0)
+        return loss
diff --git a/model_zoo/resnext50/src/dataset.py b/model_zoo/resnext50/src/dataset.py
new file mode 100644
index 0000000000..9608e3c790
--- /dev/null
+++ b/model_zoo/resnext50/src/dataset.py
@@ -0,0 +1,155 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+dataset processing.
+"""
+import os
+from mindspore.common import dtype as mstype
+import mindspore.dataset as de
+import mindspore.dataset.transforms.c_transforms as C
+import mindspore.dataset.transforms.vision.c_transforms as V_C
+from PIL import Image, ImageFile
+from src.utils.sampler import DistributedSampler
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+class TxtDataset():
+    """
+    create txt dataset.
+
+    Args:
+    Returns:
+        de_dataset.
+    """
+    def __init__(self, root, txt_name):
+        super(TxtDataset, self).__init__()
+        self.imgs = []
+        self.labels = []
+        fin = open(txt_name, "r")
+        for line in fin:
+            img_name, label = line.strip().split(' ')
+            self.imgs.append(os.path.join(root, img_name))
+            self.labels.append(int(label))
+        fin.close()
+
+    def __getitem__(self, index):
+        img = Image.open(self.imgs[index]).convert('RGB')
+        return img, self.labels[index]
+
+    def __len__(self):
+        return len(self.imgs)
+
+
+def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
+                           mode='train',
+                           input_mode='folder',
+                           root='',
+                           num_parallel_workers=None,
+                           shuffle=None,
+                           sampler=None,
+                           class_indexing=None,
+                           drop_remainder=True,
+                           transform=None,
+                           target_transform=None):
+    """
+    A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
+    If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
+    are written into a textfile.
+
+    Args:
+        data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
+            Or path of the textfile that contains every image's path of the dataset.
+        image_size (str): Size of the input images.
+        per_batch_size (int): the batch size of evey step during training.
+        max_epoch (int): the number of epochs.
+        rank (int): The shard ID within num_shards (default=None).
+        group_size (int): Number of shards that the dataset should be divided
+            into (default=None).
+        mode (str): "train" or others. Default: " train".
+        input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
+        root (str): the images path for "input_mode="txt"". Default: " ".
+        num_parallel_workers (int): Number of workers to read the data. Default: None.
+        shuffle (bool): Whether or not to perform shuffle on the dataset
+            (default=None, performs shuffle).
+        sampler (Sampler): Object used to choose samples from the dataset. Default: None.
+        class_indexing (dict): A str-to-int mapping from folder name to index
+            (default=None, the folder names will be sorted
+            alphabetically and each class will be given a
+            unique index starting from 0).
+
+    Examples:
+        >>> from mindvision.common.datasets.classification import classification_dataset
+        >>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
+        >>> dataset_dir = "/path/to/imagefolder_directory"
+        >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
+        >>>                               per_batch_size=64, max_epoch=100,
+        >>>                               rank=0, group_size=4)
+        >>> # Path of the textfile that contains every image's path of the dataset.
+        >>> dataset_dir = "/path/to/dataset/images/train.txt"
+        >>> images_dir = "/path/to/dataset/images"
+        >>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
+        >>>                               per_batch_size=64, max_epoch=100,
+        >>>                               rank=0, group_size=4,
+        >>>                               input_mode="txt", root=images_dir)
+    """
+
+    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
+    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
+
+    if transform is None:
+        if mode == 'train':
+            transform_img = [
+                V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
+                V_C.RandomHorizontalFlip(prob=0.5),
+                V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
+                V_C.Normalize(mean=mean, std=std),
+                V_C.HWC2CHW()
+            ]
+        else:
+            transform_img = [
+                V_C.Decode(),
+                V_C.Resize((256, 256)),
+                V_C.CenterCrop(image_size),
+                V_C.Normalize(mean=mean, std=std),
+                V_C.HWC2CHW()
+            ]
+    else:
+        transform_img = transform
+
+    if target_transform is None:
+        transform_label = [C.TypeCast(mstype.int32)]
+    else:
+        transform_label = target_transform
+
+    if input_mode == 'folder':
+        de_dataset = de.ImageFolderDatasetV2(data_dir, num_parallel_workers=num_parallel_workers,
+                                             shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
+                                             num_shards=group_size, shard_id=rank)
+    else:
+        dataset = TxtDataset(root, data_dir)
+        sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
+        de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
+        de_dataset.set_dataset_size(len(sampler))
+
+    de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
+    de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
+
+    columns_to_project = ["image", "label"]
+    de_dataset = de_dataset.project(columns=columns_to_project)
+
+    de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
+    de_dataset = de_dataset.repeat(max_epoch)
+
+    return de_dataset
diff --git a/model_zoo/resnext50/src/head.py b/model_zoo/resnext50/src/head.py
new file mode 100644
index 0000000000..a7bd85c906
--- /dev/null
+++ b/model_zoo/resnext50/src/head.py
@@ -0,0 +1,42 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+common architecture.
+"""
+import mindspore.nn as nn
+from src.utils.cunstom_op import GlobalAvgPooling
+
+__all__ = ['CommonHead']
+
+class CommonHead(nn.Cell):
+    """
+    commom architecture definition.
+
+    Args:
+        num_classes (int): Number of classes.
+        out_channels (int): Output channels.
+
+    Returns:
+        Tensor, output tensor.
+    """
+    def __init__(self, num_classes, out_channels):
+        super(CommonHead, self).__init__()
+        self.avgpool = GlobalAvgPooling()
+        self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True)
+
+    def construct(self, x):
+        x = self.avgpool(x)
+        x = self.fc(x)
+        return x
diff --git a/model_zoo/resnext50/src/image_classification.py b/model_zoo/resnext50/src/image_classification.py
new file mode 100644
index 0000000000..d8003ad200
--- /dev/null
+++ b/model_zoo/resnext50/src/image_classification.py
@@ -0,0 +1,85 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Image classifiation.
+"""
+import math
+import mindspore.nn as nn
+from mindspore.common import initializer as init
+import src.backbone as backbones
+import src.head as heads
+from src.utils.var_init import default_recurisive_init, KaimingNormal
+
+
+class ImageClassificationNetwork(nn.Cell):
+    """
+    architecture of image classification network.
+
+    Args:
+    Returns:
+        Tensor, output tensor.
+    """
+    def __init__(self, backbone, head):
+        super(ImageClassificationNetwork, self).__init__()
+        self.backbone = backbone
+        self.head = head
+
+    def construct(self, x):
+        x = self.backbone(x)
+        x = self.head(x)
+        return x
+
+class Resnet(ImageClassificationNetwork):
+    """
+    Resnet architecture.
+    Args:
+        backbone_name (string): backbone.
+        num_classes (int): number of classes.
+    Returns:
+        Resnet.
+    """
+    def __init__(self, backbone_name, num_classes):
+        self.backbone_name = backbone_name
+        backbone = backbones.__dict__[self.backbone_name]()
+        out_channels = backbone.get_out_channels()
+        head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
+        super(Resnet, self).__init__(backbone, head)
+
+        default_recurisive_init(self)
+
+        for cell in self.cells_and_names():
+            if isinstance(cell, nn.Conv2d):
+                cell.weight.default_input = init.initializer(
+                    KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
+                    cell.weight.default_input.shape, cell.weight.default_input.dtype).to_tensor()
+            elif isinstance(cell, nn.BatchNorm2d):
+                cell.gamma.default_input = init.initializer('ones', cell.gamma.default_input.shape).to_tensor()
+                cell.beta.default_input = init.initializer('zeros', cell.beta.default_input.shape).to_tensor()
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        for cell in self.cells_and_names():
+            if isinstance(cell, backbones.resnet.Bottleneck):
+                cell.bn3.gamma.default_input = init.initializer('zeros', cell.bn3.gamma.default_input.shape).to_tensor()
+            elif isinstance(cell, backbones.resnet.BasicBlock):
+                cell.bn2.gamma.default_input = init.initializer('zeros', cell.bn2.gamma.default_input.shape).to_tensor()
+
+
+
+def get_network(backbone_name, num_classes):
+    if backbone_name in ['resnext50']:
+        return Resnet(backbone_name, num_classes)
+    return None
diff --git a/model_zoo/resnext50/src/linear_warmup.py b/model_zoo/resnext50/src/linear_warmup.py
new file mode 100644
index 0000000000..af0bac631a
--- /dev/null
+++ b/model_zoo/resnext50/src/linear_warmup.py
@@ -0,0 +1,21 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+linear warm up learning rate.
+"""
+def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
+    lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
+    lr = float(init_lr) + lr_inc * current_step
+    return lr
diff --git a/model_zoo/resnext50/src/utils/__init__.py b/model_zoo/resnext50/src/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/model_zoo/resnext50/src/utils/cunstom_op.py b/model_zoo/resnext50/src/utils/cunstom_op.py
new file mode 100644
index 0000000000..cbe89a1610
--- /dev/null
+++ b/model_zoo/resnext50/src/utils/cunstom_op.py
@@ -0,0 +1,108 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network operations
+"""
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+from mindspore.common import dtype as mstype
+
+
+class GlobalAvgPooling(nn.Cell):
+    """
+    global average pooling feature map.
+
+    Args:
+         mean (tuple): means for each channel.
+    """
+    def __init__(self):
+        super(GlobalAvgPooling, self).__init__()
+        self.mean = P.ReduceMean(True)
+        self.shape = P.Shape()
+        self.reshape = P.Reshape()
+
+    def construct(self, x):
+        x = self.mean(x, (2, 3))
+        b, c, _, _ = self.shape(x)
+        x = self.reshape(x, (b, c))
+        return x
+
+
+class SEBlock(nn.Cell):
+    """
+    squeeze and excitation block.
+
+    Args:
+        channel (int): number of feature maps.
+        reduction (int): weight.
+    """
+    def __init__(self, channel, reduction=16):
+        super(SEBlock, self).__init__()
+
+        self.avg_pool = GlobalAvgPooling()
+        self.fc1 = nn.Dense(channel, channel // reduction)
+        self.relu = P.ReLU()
+        self.fc2 = nn.Dense(channel // reduction, channel)
+        self.sigmoid = P.Sigmoid()
+        self.reshape = P.Reshape()
+        self.shape = P.Shape()
+        self.sum = P.Sum()
+        self.cast = P.Cast()
+
+    def construct(self, x):
+        b, c = self.shape(x)
+        y = self.avg_pool(x)
+
+        y = self.reshape(y, (b, c))
+        y = self.fc1(y)
+        y = self.relu(y)
+        y = self.fc2(y)
+        y = self.sigmoid(y)
+        y = self.reshape(y, (b, c, 1, 1))
+        return x * y
+
+class GroupConv(nn.Cell):
+    """
+    group convolution operation.
+
+    Args:
+        in_channels (int): Input channels of feature map.
+        out_channels (int): Output channels of feature map.
+        kernel_size (int): Size of convolution kernel.
+        stride (int): Stride size for the group convolution layer.
+
+    Returns:
+        tensor, output tensor.
+    """
+    def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
+        super(GroupConv, self).__init__()
+        assert in_channels % groups == 0 and out_channels % groups == 0
+        self.groups = groups
+        self.convs = nn.CellList()
+        self.op_split = P.Split(axis=1, output_num=self.groups)
+        self.op_concat = P.Concat(axis=1)
+        self.cast = P.Cast()
+        for _ in range(groups):
+            self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups,
+                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
+                                        padding=pad, pad_mode=pad_mode, group=1))
+
+    def construct(self, x):
+        features = self.op_split(x)
+        outputs = ()
+        for i in range(self.groups):
+            outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
+        out = self.op_concat(outputs)
+        return out
diff --git a/model_zoo/resnext50/src/utils/logging.py b/model_zoo/resnext50/src/utils/logging.py
new file mode 100644
index 0000000000..ac37bec4ec
--- /dev/null
+++ b/model_zoo/resnext50/src/utils/logging.py
@@ -0,0 +1,82 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+get logger.
+"""
+import logging
+import os
+import sys
+from datetime import datetime
+
+class LOGGER(logging.Logger):
+    """
+    set up logging file.
+
+    Args:
+        logger_name (string): logger name.
+        log_dir (string): path of logger.
+
+    Returns:
+        string, logger path
+    """
+    def __init__(self, logger_name, rank=0):
+        super(LOGGER, self).__init__(logger_name)
+        if rank % 8 == 0:
+            console = logging.StreamHandler(sys.stdout)
+            console.setLevel(logging.INFO)
+            formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
+            console.setFormatter(formatter)
+            self.addHandler(console)
+
+    def setup_logging_file(self, log_dir, rank=0):
+        """set up log file"""
+        self.rank = rank
+        if not os.path.exists(log_dir):
+            os.makedirs(log_dir, exist_ok=True)
+        log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
+        self.log_fn = os.path.join(log_dir, log_name)
+        fh = logging.FileHandler(self.log_fn)
+        fh.setLevel(logging.INFO)
+        formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
+        fh.setFormatter(formatter)
+        self.addHandler(fh)
+
+    def info(self, msg, *args, **kwargs):
+        if self.isEnabledFor(logging.INFO):
+            self._log(logging.INFO, msg, args, **kwargs)
+
+    def save_args(self, args):
+        self.info('Args:')
+        args_dict = vars(args)
+        for key in args_dict.keys():
+            self.info('--> %s: %s', key, args_dict[key])
+        self.info('')
+
+    def important_info(self, msg, *args, **kwargs):
+        if self.isEnabledFor(logging.INFO) and self.rank == 0:
+            line_width = 2
+            important_msg = '\n'
+            important_msg += ('*'*70 + '\n')*line_width
+            important_msg += ('*'*line_width + '\n')*2
+            important_msg += '*'*line_width + ' '*8 + msg + '\n'
+            important_msg += ('*'*line_width + '\n')*2
+            important_msg += ('*'*70 + '\n')*line_width
+            self.info(important_msg, *args, **kwargs)
+
+
+def get_logger(path, rank):
+    logger = LOGGER("mindversion", rank)
+    logger.setup_logging_file(path, rank)
+    return logger
diff --git a/model_zoo/resnext50/src/utils/optimizers__init__.py b/model_zoo/resnext50/src/utils/optimizers__init__.py
new file mode 100644
index 0000000000..d4683959b5
--- /dev/null
+++ b/model_zoo/resnext50/src/utils/optimizers__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+optimizer parameters.
+"""
+def get_param_groups(network):
+    """get param groups"""
+    decay_params = []
+    no_decay_params = []
+    for x in network.trainable_params():
+        parameter_name = x.name
+        if parameter_name.endswith('.bias'):
+            # all bias not using weight decay
+            # print('no decay:{}'.format(parameter_name))
+            no_decay_params.append(x)
+        elif parameter_name.endswith('.gamma'):
+            # bn weight bias not using weight decay, be carefully for now x not include BN
+            # print('no decay:{}'.format(parameter_name))
+            no_decay_params.append(x)
+        elif parameter_name.endswith('.beta'):
+            # bn weight bias not using weight decay, be carefully for now x not include BN
+            # print('no decay:{}'.format(parameter_name))
+            no_decay_params.append(x)
+        else:
+            decay_params.append(x)
+
+    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
diff --git a/model_zoo/resnext50/src/utils/sampler.py b/model_zoo/resnext50/src/utils/sampler.py
new file mode 100644
index 0000000000..5b68f8325e
--- /dev/null
+++ b/model_zoo/resnext50/src/utils/sampler.py
@@ -0,0 +1,53 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+choose samples from the dataset
+"""
+import math
+import numpy as np
+
+class DistributedSampler():
+    """
+    sampling the dataset.
+
+    Args:
+    Returns:
+        num_samples, number of samples.
+    """
+    def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
+        self.dataset = dataset
+        self.rank = rank
+        self.group_size = group_size
+        self.dataset_length = len(self.dataset)
+        self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
+        self.total_size = self.num_samples * self.group_size
+        self.shuffle = shuffle
+        self.seed = seed
+
+    def __iter__(self):
+        if self.shuffle:
+            self.seed = (self.seed + 1) & 0xffffffff
+            np.random.seed(self.seed)
+            indices = np.random.permutation(self.dataset_length).tolist()
+        else:
+            indices = list(range(len(self.dataset_length)))
+
+        indices += indices[:(self.total_size - len(indices))]
+        indices = indices[self.rank::self.group_size]
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+ 
\ No newline at end of file
diff --git a/model_zoo/resnext50/src/utils/var_init.py b/model_zoo/resnext50/src/utils/var_init.py
new file mode 100644
index 0000000000..51fc109990
--- /dev/null
+++ b/model_zoo/resnext50/src/utils/var_init.py
@@ -0,0 +1,213 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Initialize.
+"""
+import math
+from functools import reduce
+import numpy as np
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore.common import initializer as init
+
+def _calculate_gain(nonlinearity, param=None):
+    r"""
+    Return the recommended gain value for the given nonlinearity function.
+
+    The values are as follows:
+    ================= ====================================================
+    nonlinearity      gain
+    ================= ====================================================
+    Linear / Identity :math:`1`
+    Conv{1,2,3}D      :math:`1`
+    Sigmoid           :math:`1`
+    Tanh              :math:`\frac{5}{3}`
+    ReLU              :math:`\sqrt{2}`
+    Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
+    ================= ====================================================
+
+    Args:
+        nonlinearity: the non-linear function
+        param: optional parameter for the non-linear function
+
+    Examples:
+        >>> gain = calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
+    """
+    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
+    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
+        return 1
+    if nonlinearity == 'tanh':
+        return 5.0 / 3
+    if nonlinearity == 'relu':
+        return math.sqrt(2.0)
+    if nonlinearity == 'leaky_relu':
+        if param is None:
+            negative_slope = 0.01
+        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
+            negative_slope = param
+        else:
+            raise ValueError("negative_slope {} not a valid number".format(param))
+        return math.sqrt(2.0 / (1 + negative_slope ** 2))
+
+    raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
+
+def _assignment(arr, num):
+    """Assign the value of `num` to `arr`."""
+    if arr.shape == ():
+        arr = arr.reshape((1))
+        arr[:] = num
+        arr = arr.reshape(())
+    else:
+        if isinstance(num, np.ndarray):
+            arr[:] = num[:]
+        else:
+            arr[:] = num
+    return arr
+
+def _calculate_in_and_out(arr):
+    """
+    Calculate n_in and n_out.
+
+    Args:
+        arr (Array): Input array.
+
+    Returns:
+        Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
+    """
+    dim = len(arr.shape)
+    if dim < 2:
+        raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
+
+    n_in = arr.shape[1]
+    n_out = arr.shape[0]
+
+    if dim > 2:
+        counter = reduce(lambda x, y: x * y, arr.shape[2:])
+        n_in *= counter
+        n_out *= counter
+    return n_in, n_out
+
+def _select_fan(array, mode):
+    mode = mode.lower()
+    valid_modes = ['fan_in', 'fan_out']
+    if mode not in valid_modes:
+        raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
+
+    fan_in, fan_out = _calculate_in_and_out(array)
+    return fan_in if mode == 'fan_in' else fan_out
+
+class KaimingInit(init.Initializer):
+    r"""
+    Base Class. Initialize the array with He kaiming algorithm.
+
+    Args:
+        a: the negative slope of the rectifier used after this layer (only
+            used with ``'leaky_relu'``)
+        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
+            preserves the magnitude of the variance of the weights in the
+            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
+            backwards pass.
+        nonlinearity: the non-linear function, recommended to use only with
+            ``'relu'`` or ``'leaky_relu'`` (default).
+    """
+    def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
+        super(KaimingInit, self).__init__()
+        self.mode = mode
+        self.gain = _calculate_gain(nonlinearity, a)
+    def _initialize(self, arr):
+        pass
+
+
+class KaimingUniform(KaimingInit):
+    r"""
+    Initialize the array with He kaiming uniform algorithm. The resulting tensor will
+    have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
+
+    .. math::
+        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
+
+    Input:
+        arr (Array): The array to be assigned.
+
+    Returns:
+        Array, assigned array.
+
+    Examples:
+        >>> w = np.empty(3, 5)
+        >>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
+    """
+
+    def _initialize(self, arr):
+        fan = _select_fan(arr, self.mode)
+        bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
+        np.random.seed(0)
+        data = np.random.uniform(-bound, bound, arr.shape)
+
+        _assignment(arr, data)
+
+
+class KaimingNormal(KaimingInit):
+    r"""
+    Initialize the array with He kaiming normal algorithm. The resulting tensor will
+    have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
+
+    .. math::
+        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
+
+    Input:
+        arr (Array): The array to be assigned.
+
+    Returns:
+        Array, assigned array.
+
+    Examples:
+        >>> w = np.empty(3, 5)
+        >>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
+    """
+
+    def _initialize(self, arr):
+        fan = _select_fan(arr, self.mode)
+        std = self.gain / math.sqrt(fan)
+        np.random.seed(0)
+        data = np.random.normal(0, std, arr.shape)
+
+        _assignment(arr, data)
+
+
+def default_recurisive_init(custom_cell):
+    """default_recurisive_init"""
+    for _, cell in custom_cell.cells_and_names():
+        if isinstance(cell, nn.Conv2d):
+            cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
+                                                         cell.weight.default_input.shape,
+                                                         cell.weight.default_input.dtype).to_tensor()
+            if cell.bias is not None:
+                fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
+                bound = 1 / math.sqrt(fan_in)
+                np.random.seed(0)
+                cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
+                                                 cell.bias.default_input.dtype)
+        elif isinstance(cell, nn.Dense):
+            cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
+                                                         cell.weight.default_input.shape,
+                                                         cell.weight.default_input.dtype).to_tensor()
+            if cell.bias is not None:
+                fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
+                bound = 1 / math.sqrt(fan_in)
+                np.random.seed(0)
+                cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
+                                                 cell.bias.default_input.dtype)
+        elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
+            pass
diff --git a/model_zoo/resnext50/src/warmup_cosine_annealing_lr.py b/model_zoo/resnext50/src/warmup_cosine_annealing_lr.py
new file mode 100644
index 0000000000..5d9fce9af4
--- /dev/null
+++ b/model_zoo/resnext50/src/warmup_cosine_annealing_lr.py
@@ -0,0 +1,40 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+warm up cosine annealing learning rate.
+"""
+import math
+import numpy as np
+
+from .linear_warmup import linear_warmup_lr
+
+
+def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
+    """warm up cosine annealing learning rate."""
+    base_lr = lr
+    warmup_init_lr = 0
+    total_steps = int(max_epoch * steps_per_epoch)
+    warmup_steps = int(warmup_epochs * steps_per_epoch)
+
+    lr_each_step = []
+    for i in range(total_steps):
+        last_epoch = i // steps_per_epoch
+        if i < warmup_steps:
+            lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
+        else:
+            lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
+        lr_each_step.append(lr)
+
+    return np.array(lr_each_step).astype(np.float32)
diff --git a/model_zoo/resnext50/src/warmup_step_lr.py b/model_zoo/resnext50/src/warmup_step_lr.py
new file mode 100644
index 0000000000..d8e85ab610
--- /dev/null
+++ b/model_zoo/resnext50/src/warmup_step_lr.py
@@ -0,0 +1,56 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+warm up step learning rate.
+"""
+from collections import Counter
+import numpy as np
+
+from .linear_warmup import linear_warmup_lr
+
+
+def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
+    """warmup_step_lr"""
+    base_lr = lr
+    warmup_init_lr = 0
+    total_steps = int(max_epoch * steps_per_epoch)
+    warmup_steps = int(warmup_epochs * steps_per_epoch)
+    milestones = lr_epochs
+    milestones_steps = []
+    for milestone in milestones:
+        milestones_step = milestone * steps_per_epoch
+        milestones_steps.append(milestones_step)
+
+    lr_each_step = []
+    lr = base_lr
+    milestones_steps_counter = Counter(milestones_steps)
+    for i in range(total_steps):
+        if i < warmup_steps:
+            lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
+        else:
+            lr = lr * gamma**milestones_steps_counter[i]
+        lr_each_step.append(lr)
+
+    return np.array(lr_each_step).astype(np.float32)
+
+def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
+    return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
+
+def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
+    lr_epochs = []
+    for i in range(1, max_epoch):
+        if i % epoch_size == 0:
+            lr_epochs.append(i)
+    return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
diff --git a/model_zoo/resnext50/train.py b/model_zoo/resnext50/train.py
new file mode 100644
index 0000000000..29ccd9b00c
--- /dev/null
+++ b/model_zoo/resnext50/train.py
@@ -0,0 +1,289 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""train ImageNet."""
+import os
+import time
+import argparse
+import datetime
+
+import mindspore.nn as nn
+from mindspore import Tensor, context
+from mindspore import ParallelMode
+from mindspore.nn.optim import Momentum
+from mindspore.communication.management import init, get_rank, get_group_size
+from mindspore.train.callback import ModelCheckpoint
+from mindspore.train.callback import CheckpointConfig, Callback
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.train.model import Model
+from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
+
+from src.dataset import classification_dataset
+from src.crossentropy import CrossEntropy
+from src.warmup_step_lr import warmup_step_lr
+from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
+from src.utils.logging import get_logger
+from src.utils.optimizers__init__ import get_param_groups
+from src.image_classification import get_network
+from src.config import config
+
+devid = int(os.getenv('DEVICE_ID'))
+context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
+                    device_target="Ascend", save_graphs=False, device_id=devid)
+
+class BuildTrainNetwork(nn.Cell):
+    """build training network"""
+    def __init__(self, network, criterion):
+        super(BuildTrainNetwork, self).__init__()
+        self.network = network
+        self.criterion = criterion
+
+    def construct(self, input_data, label):
+        output = self.network(input_data)
+        loss = self.criterion(output, label)
+        return loss
+
+class ProgressMonitor(Callback):
+    """monitor loss and time"""
+    def __init__(self, args):
+        super(ProgressMonitor, self).__init__()
+        self.me_epoch_start_time = 0
+        self.me_epoch_start_step_num = 0
+        self.args = args
+        self.ckpt_history = []
+
+    def begin(self, run_context):
+        self.args.logger.info('start network train...')
+
+    def epoch_begin(self, run_context):
+        pass
+
+    def epoch_end(self, run_context, *me_args):
+        cb_params = run_context.original_args()
+        me_step = cb_params.cur_step_num - 1
+
+        real_epoch = me_step // self.args.steps_per_epoch
+        time_used = time.time() - self.me_epoch_start_time
+        fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used
+        self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}'
+                              'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean))
+
+        if self.args.rank_save_ckpt_flag:
+            import glob
+            ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
+            for ckpt in ckpts:
+                ckpt_fn = os.path.basename(ckpt)
+                if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
+                    continue
+                if ckpt in self.ckpt_history:
+                    continue
+                self.ckpt_history.append(ckpt)
+                self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},'
+                                      'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn))
+
+
+        self.me_epoch_start_step_num = me_step
+        self.me_epoch_start_time = time.time()
+
+    def step_begin(self, run_context):
+        pass
+
+    def step_end(self, run_context, *me_args):
+        pass
+
+    def end(self, run_context):
+        self.args.logger.info('end network train...')
+
+
+def parse_args(cloud_args=None):
+    """parameters"""
+    parser = argparse.ArgumentParser('mindspore classification training')
+
+    # dataset related
+    parser.add_argument('--data_dir', type=str, default='', help='train data dir')
+    parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu')
+    # network related
+    parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
+
+    # distributed related
+    parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
+    # roma obs
+    parser.add_argument('--train_url', type=str, default="", help='train url')
+
+    args, _ = parser.parse_known_args()
+    args = merge_args(args, cloud_args)
+    args.image_size = config.image_size
+    args.num_classes = config.num_classes
+    args.lr = config.lr
+    args.lr_scheduler = config.lr_scheduler
+    args.lr_epochs = config.lr_epochs
+    args.lr_gamma = config.lr_gamma
+    args.eta_min = config.eta_min
+    args.T_max = config.T_max
+    args.max_epoch = config.max_epoch
+    args.backbone = config.backbone
+    args.warmup_epochs = config.warmup_epochs
+    args.weight_decay = config.weight_decay
+    args.momentum = config.momentum
+    args.is_dynamic_loss_scale = config.is_dynamic_loss_scale
+    args.loss_scale = config.loss_scale
+    args.label_smooth = config.label_smooth
+    args.label_smooth_factor = config.label_smooth_factor
+    args.ckpt_interval = config.ckpt_interval
+    args.ckpt_path = config.ckpt_path
+    args.is_save_on_master = config.is_save_on_master
+    args.rank = config.rank
+    args.group_size = config.group_size
+    args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
+    args.image_size = list(map(int, args.image_size.split(',')))
+
+    return args
+
+def merge_args(args, cloud_args):
+    """dictionary"""
+    args_dict = vars(args)
+    if isinstance(cloud_args, dict):
+        for key in cloud_args.keys():
+            val = cloud_args[key]
+            if key in args_dict and val:
+                arg_type = type(args_dict[key])
+                if arg_type is not type(None):
+                    val = arg_type(val)
+                args_dict[key] = val
+    return args
+
+def train(cloud_args=None):
+    """training process"""
+    args = parse_args(cloud_args)
+
+    # init distributed
+    if args.is_distributed:
+        init()
+        args.rank = get_rank()
+        args.group_size = get_group_size()
+
+    if args.is_dynamic_loss_scale == 1:
+        args.loss_scale = 1  # for dynamic loss scale can not set loss scale in momentum opt
+
+    # select for master rank save ckpt or all rank save, compatiable for model parallel
+    args.rank_save_ckpt_flag = 0
+    if args.is_save_on_master:
+        if args.rank == 0:
+            args.rank_save_ckpt_flag = 1
+    else:
+        args.rank_save_ckpt_flag = 1
+
+    # logger
+    args.outputs_dir = os.path.join(args.ckpt_path,
+                                    datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
+    args.logger = get_logger(args.outputs_dir, args.rank)
+
+    # dataloader
+    de_dataset = classification_dataset(args.data_dir, args.image_size,
+                                        args.per_batch_size, args.max_epoch,
+                                        args.rank, args.group_size)
+    de_dataset.map_model = 4  # !!!important
+    args.steps_per_epoch = de_dataset.get_dataset_size()
+
+    args.logger.save_args(args)
+
+    # network
+    args.logger.important_info('start create network')
+    # get network and init
+    network = get_network(args.backbone, args.num_classes)
+    if network is None:
+        raise NotImplementedError('not implement {}'.format(args.backbone))
+    network.add_flags_recursive(fp16=True)
+    # loss
+    if not args.label_smooth:
+        args.label_smooth_factor = 0.0
+    criterion = CrossEntropy(smooth_factor=args.label_smooth_factor,
+                             num_classes=args.num_classes)
+
+    # load pretrain model
+    if os.path.isfile(args.pretrained):
+        param_dict = load_checkpoint(args.pretrained)
+        param_dict_new = {}
+        for key, values in param_dict.items():
+            if key.startswith('moments.'):
+                continue
+            elif key.startswith('network.'):
+                param_dict_new[key[8:]] = values
+            else:
+                param_dict_new[key] = values
+        load_param_into_net(network, param_dict_new)
+        args.logger.info('load model {} success'.format(args.pretrained))
+
+    # lr scheduler
+    if args.lr_scheduler == 'exponential':
+        lr = warmup_step_lr(args.lr,
+                            args.lr_epochs,
+                            args.steps_per_epoch,
+                            args.warmup_epochs,
+                            args.max_epoch,
+                            gamma=args.lr_gamma,
+                            )
+    elif args.lr_scheduler == 'cosine_annealing':
+        lr = warmup_cosine_annealing_lr(args.lr,
+                                        args.steps_per_epoch,
+                                        args.warmup_epochs,
+                                        args.max_epoch,
+                                        args.T_max,
+                                        args.eta_min)
+    else:
+        raise NotImplementedError(args.lr_scheduler)
+
+    # optimizer
+    opt = Momentum(params=get_param_groups(network),
+                   learning_rate=Tensor(lr),
+                   momentum=args.momentum,
+                   weight_decay=args.weight_decay,
+                   loss_scale=args.loss_scale)
+
+
+    criterion.add_flags_recursive(fp32=True)
+
+    # package training process, adjust lr + forward + backward + optimizer
+    train_net = BuildTrainNetwork(network, criterion)
+    if args.is_distributed:
+        parallel_mode = ParallelMode.DATA_PARALLEL
+    else:
+        parallel_mode = ParallelMode.STAND_ALONE
+    if args.is_dynamic_loss_scale == 1:
+        loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
+    else:
+        loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
+
+    # Model api changed since TR5_branch 2020/03/09
+    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
+                                      parameter_broadcast=True, mirror_mean=True)
+    model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager)
+
+    # checkpoint save
+    progress_cb = ProgressMonitor(args)
+    callbacks = [progress_cb,]
+    if args.rank_save_ckpt_flag:
+        ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
+        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval,
+                                       keep_checkpoint_max=ckpt_max_num)
+        ckpt_cb = ModelCheckpoint(config=ckpt_config,
+                                  directory=args.outputs_dir,
+                                  prefix='{}'.format(args.rank))
+        callbacks.append(ckpt_cb)
+
+    model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
+
+
+if __name__ == "__main__":
+    train()