From 87612bfcf20812e5f32653a561339154348fb5c3 Mon Sep 17 00:00:00 2001 From: wangmin Date: Thu, 13 Aug 2020 21:11:42 +0800 Subject: [PATCH] THOR optimizer for GPU --- model_zoo/official/cv/resnet_thor/README.md | 70 +- model_zoo/official/cv/resnet_thor/eval.py | 71 +- .../scripts/run_distribute_train.sh | 2 +- .../scripts/run_distribute_train_gpu.sh | 46 ++ .../cv/resnet_thor/scripts/run_eval.sh | 13 +- .../cv/resnet_thor/scripts/run_eval_gpu.sh | 65 ++ .../official/cv/resnet_thor/src/config.py | 41 +- .../cv/resnet_thor/src/crossentropy.py | 3 - .../src/{dataset_imagenet.py => dataset.py} | 68 +- .../cv/resnet_thor/src/dataset_helper.py | 167 +++-- .../cv/resnet_thor/src/grad_reducer_thor.py | 6 +- .../official/cv/resnet_thor/src/model_thor.py | 662 +++--------------- .../official/cv/resnet_thor/src/resnet50.py | 262 ------- .../cv/resnet_thor/src/resnet_thor.py | 118 +++- model_zoo/official/cv/resnet_thor/src/thor.py | 138 +++- .../official/cv/resnet_thor/src/thor_layer.py | 329 ++++++++- model_zoo/official/cv/resnet_thor/train.py | 150 ++-- 17 files changed, 1106 insertions(+), 1105 deletions(-) mode change 100644 => 100755 model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh create mode 100755 model_zoo/official/cv/resnet_thor/scripts/run_distribute_train_gpu.sh mode change 100644 => 100755 model_zoo/official/cv/resnet_thor/scripts/run_eval.sh create mode 100755 model_zoo/official/cv/resnet_thor/scripts/run_eval_gpu.sh rename model_zoo/official/cv/resnet_thor/src/{dataset_imagenet.py => dataset.py} (57%) delete mode 100644 model_zoo/official/cv/resnet_thor/src/resnet50.py diff --git a/model_zoo/official/cv/resnet_thor/README.md b/model_zoo/official/cv/resnet_thor/README.md index cecd934575..820d3c73f7 100644 --- a/model_zoo/official/cv/resnet_thor/README.md +++ b/model_zoo/official/cv/resnet_thor/README.md @@ -24,22 +24,24 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon . ├── resnet_thor ├── README.md - ├── src + ├──scripts + ├── run_distribute_train.sh # launch distributed training for Ascend + └── run_eval.sh # launch infering for Ascend + ├── run_distribute_train_gpu.sh # launch distributed training for GPU + └── run_eval_gpu.sh # launch infering for GPU + ├──src ├── crossentropy.py # CrossEntropy loss function ├── config.py # parameter configuration - ├── resnet50.py # resnet50 backbone ├── dataset_helper.py # dataset help for minddata dataset ├── grad_reducer_thor.py # grad reducer for thor - ├── model_thor.py # model + ├── model_thor.py # model for train ├── resnet_thor.py # resnet50_thor backone - ├── thor.py # thor + ├── thor.py # thor optimizer ├── thor_layer.py # thor layer - └── dataset_imagenet.py # data preprocessing - ├── scripts - ├── run_distribute_train.sh # launch distributed training(8 pcs) - └── run_eval.sh # launch infering + └── dataset.py # data preprocessing ├── eval.py # infer script └── train.py # train script + ``` @@ -48,26 +50,30 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon Parameters for both training and inference can be set in config.py. ``` -"class_num": 1000, # dataset class number +"class_num": 1001, # dataset class number "batch_size": 32, # batch size of input tensor "loss_scale": 128, # loss scale "momentum": 0.9, # momentum of THOR optimizer "weight_decay": 5e-4, # weight decay "epoch_size": 45, # only valid for taining, which is always 1 for inference -"buffer_size": 1000, # number of queue size in data preprocessing -"image_height": 224, # image height -"image_width": 224, # image width "save_checkpoint": True, # whether save checkpoint or not -"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch -"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the checkpoint will be saved every epoch +"keep_checkpoint_max": 15, # only keep the last keep_checkpoint_max checkpoint "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path "label_smooth": True, # label smooth "label_smooth_factor": 0.1, # label smooth factor +"lr_init": 0.045, # learning rate init value +"lr_decay": 6, # learning rate decay rate value +"lr_end_epoch": 70, # learning rate end epoch value +"damping_init": 0.03, # damping init value for Fisher information matrix +"damping_decay": 0.87, # damping decay rate "frequency": 834, # the step interval to update second-order information matrix ``` ## Running the example +### 1 Running on Ascend 910 + ### Train #### Usage @@ -82,10 +88,10 @@ Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM] ```bash # distributed training example(8 pcs) -sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc +sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc 8 ``` -> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training_ascend.html). #### Result @@ -126,3 +132,35 @@ Inference result will be stored in the example path, whose folder name is "infer ``` result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt ``` + +### 2 Running on GPU + +### Train +``` +# distributed training example +sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM] +``` +#### Result +``` +# distribute training result(8 pcs) +epoch: 1 step: 5004, loss is 4.3069 +epoch: 2 step: 5004, loss is 3.5695 +epoch: 3 step: 5004, loss is 3.5893 +epoch: 4 step: 5004, loss is 3.1987 +epoch: 5 step: 5004, loss is 3.3526 +...... +epoch: 40 step: 5004, loss is 1.9482 +epoch: 41 step: 5004, loss is 1.8950 +epoch: 42 step: 5004, loss is 1.9023 +...... +``` + +### Infer +``` +# infer example +sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` +#### Result +``` +result: {'acc': 0.760143245838668} ckpt_0/resnet-40_5004.ckpt +``` diff --git a/model_zoo/official/cv/resnet_thor/eval.py b/model_zoo/official/cv/resnet_thor/eval.py index fb911787ba..6d857d06db 100644 --- a/model_zoo/official/cv/resnet_thor/eval.py +++ b/model_zoo/official/cv/resnet_thor/eval.py @@ -12,51 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -eval. -""" +"""train resnet.""" import os +import random import argparse - +import numpy as np from mindspore import context +from mindspore import dataset as de from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from src.dataset_imagenet import create_dataset -from src.config import config from src.crossentropy import CrossEntropy -from src.resnet50 import resnet50 +from src.config import config +from src.dataset import create_dataset +from src.resnet_thor import resnet50 as resnet parser = argparse.ArgumentParser(description='Image classification') -parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') -parser.add_argument('--device_num', type=int, default=1, help='Device num.') -parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.') -parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') args_opt = parser.parse_args() -device_id = int(os.getenv('DEVICE_ID')) - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) -context.set_context(device_id=device_id) +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) if __name__ == '__main__': + target = args_opt.device_target - net = resnet50(class_num=config.class_num) - if not config.label_smooth: + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + if target != "GPU": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, + target=target) + + # define net + net = resnet(class_num=config.class_num) + net.add_flags_recursive(thor=False) + + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + keys = list(param_dict.keys()) + for key in keys: + if "damping" in key: + param_dict.pop(key) + load_param_into_net(net, param_dict) + net.set_train(False) + + # define loss, model + if not config.use_label_smooth: config.label_smooth_factor = 0.0 loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - if args_opt.do_eval: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - - if args_opt.checkpoint_path: - param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(net, param_dict) - net.set_train(False) + # define model + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) - model = Model(net, loss_fn=loss, metrics={'acc'}) - res = model.eval(dataset) - print("result:", res, "ckpt=", args_opt.checkpoint_path) + # eval model + res = model.eval(dataset) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh old mode 100644 new mode 100755 index 63d192bfa1..c68a4fa159 --- a/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train.sh @@ -52,6 +52,6 @@ do echo "start training for rank $RANK_ID, device $DEVICE_ID" env > env.log - python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & + python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & cd .. done diff --git a/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train_gpu.sh new file mode 100755 index 0000000000..0e424e725a --- /dev/null +++ b/model_zoo/official/cv/resnet_thor/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) + +ulimit -u unlimited +export DEVICE_NUM=$2 +export RANK_SIZE=$2 + +rm -rf ./train_parallel +mkdir ./train_parallel +cp ../*.py ./train_parallel +cp *.sh ./train_parallel +cp -r ../src ./train_parallel +cd ./train_parallel || exit + +mpirun -n $RANK_SIZE \ +python train.py --run_distribute=True \ +--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & \ No newline at end of file diff --git a/model_zoo/official/cv/resnet_thor/scripts/run_eval.sh b/model_zoo/official/cv/resnet_thor/scripts/run_eval.sh old mode 100644 new mode 100755 index eafba5fbea..e6a0174a2a --- a/model_zoo/official/cv/resnet_thor/scripts/run_eval.sh +++ b/model_zoo/official/cv/resnet_thor/scripts/run_eval.sh @@ -20,6 +20,7 @@ then exit 1 fi + get_real_path(){ if [ "${1:0:1}" == "/" ]; then echo "$1" @@ -44,9 +45,6 @@ then exit 1 fi -BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -cd $BASE_PATH/../ || exit - ulimit -u unlimited export DEVICE_NUM=1 export DEVICE_ID=0 @@ -58,10 +56,11 @@ then rm -rf ./eval fi mkdir ./eval -cp *.py ./eval -cp -r ./src ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval cd ./eval || exit env > env.log -echo "start infering for device $DEVICE_ID" -python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & +echo "start evaluation for device $DEVICE_ID" +python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & cd .. diff --git a/model_zoo/official/cv/resnet_thor/scripts/run_eval_gpu.sh b/model_zoo/official/cv/resnet_thor/scripts/run_eval_gpu.sh new file mode 100755 index 0000000000..c9be5dc866 --- /dev/null +++ b/model_zoo/official/cv/resnet_thor/scripts/run_eval_gpu.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start evaluation for device $DEVICE_ID" +python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log & +cd .. diff --git a/model_zoo/official/cv/resnet_thor/src/config.py b/model_zoo/official/cv/resnet_thor/src/config.py index cd0a81d5e6..00de8f985f 100644 --- a/model_zoo/official/cv/resnet_thor/src/config.py +++ b/model_zoo/official/cv/resnet_thor/src/config.py @@ -17,21 +17,46 @@ network config setting, will be used in train.py and eval.py """ from easydict import EasyDict as ed +# config for resnet50, imagenet2012, Ascend 910 config = ed({ - "class_num": 1000, + "class_num": 1001, "batch_size": 32, "loss_scale": 128, "momentum": 0.9, "weight_decay": 5e-4, "epoch_size": 45, - "buffer_size": 1000, - "image_height": 224, - "image_width": 224, "save_checkpoint": True, - "save_checkpoint_steps": 5004, - "keep_checkpoint_max": 20, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 15, "save_checkpoint_path": "./", - "label_smooth": 1, + "use_label_smooth": True, "label_smooth_factor": 0.1, - "frequency": 834 + "lr_init": 0.045, + "lr_decay": 6, + "lr_end_epoch": 70, + "damping_init": 0.03, + "damping_decay": 0.87, + "frequency": 834, +}) + +# config for resnet50, imagenet2012, GPU +config_gpu = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 128, + "momentum": 0.9, + "weight_decay": 5e-4, + "epoch_size": 45, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 15, + "save_checkpoint_path": "./", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0.04, + "lr_decay": 5, + "lr_end_epoch": 58, + "damping_init": 0.02, + "damping_decay": 0.87, + "frequency": 834, }) diff --git a/model_zoo/official/cv/resnet_thor/src/crossentropy.py b/model_zoo/official/cv/resnet_thor/src/crossentropy.py index e8681ff497..cdb27894f9 100644 --- a/model_zoo/official/cv/resnet_thor/src/crossentropy.py +++ b/model_zoo/official/cv/resnet_thor/src/crossentropy.py @@ -28,13 +28,10 @@ class CrossEntropy(_Loss): self.onehot = P.OneHot() self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) - # self.cast = P.Cast() self.ce = nn.SoftmaxCrossEntropyWithLogits() self.mean = P.ReduceMean(False) def construct(self, logit, label): - # one_hot_label = self.onehot(self.cast(label, mstype.int32), - # F.shape(logit)[1], self.on_value, self.off_value)、 one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) loss = self.ce(logit, one_hot_label) loss = self.mean(loss, 0) diff --git a/model_zoo/official/cv/resnet_thor/src/dataset_imagenet.py b/model_zoo/official/cv/resnet_thor/src/dataset.py similarity index 57% rename from model_zoo/official/cv/resnet_thor/src/dataset_imagenet.py rename to model_zoo/official/cv/resnet_thor/src/dataset.py index 296b675136..fc6dc1bac9 100644 --- a/model_zoo/official/cv/resnet_thor/src/dataset_imagenet.py +++ b/model_zoo/official/cv/resnet_thor/src/dataset.py @@ -16,30 +16,36 @@ create train or eval dataset. """ import os - import mindspore.common.dtype as mstype import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 -import mindspore.dataset.transforms.vision.c_transforms as V_C +from mindspore.communication.management import init, get_rank, get_group_size -def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): """ - create a train or eval dataset + create a train or eval imagenet2012 dataset for resnet50 + Args: dataset_path(string): the path of dataset. do_train(bool): whether dataset is used for train or eval. repeat_num(int): the repeat times of dataset. Default: 1 batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend Returns: dataset """ - device_num = int(os.getenv("RANK_SIZE")) - rank_id = int(os.getenv("RANK_ID")) + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + init("nccl") + rank_id = get_rank() + device_num = get_group_size() if device_num == 1: - ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) else: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=device_num, shard_id=rank_id) @@ -47,29 +53,28 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): image_size = 224 mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations if do_train: - transform_img = [ - V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), - V_C.RandomHorizontalFlip(prob=0.5), - V_C.Normalize(mean=mean, std=std), - V_C.HWC2CHW() + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() ] else: - transform_img = [ - V_C.Decode(), - V_C.Resize((256, 256)), - V_C.CenterCrop(image_size), - V_C.Normalize(mean=mean, std=std), - V_C.HWC2CHW() + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() ] - # type_cast_op = C2.TypeCast(mstype.float16) - type_cast_op = C2.TypeCast(mstype.int32) - ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8) - ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) + type_cast_op = C2.TypeCast(mstype.int32) - # apply shuffle operations - # ds = ds.shuffle(buffer_size=config.buffer_size) + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) @@ -78,3 +83,18 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): ds = ds.repeat(repeat_num) return ds + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id diff --git a/model_zoo/official/cv/resnet_thor/src/dataset_helper.py b/model_zoo/official/cv/resnet_thor/src/dataset_helper.py index 304bc99cfd..79e3ef678d 100644 --- a/model_zoo/official/cv/resnet_thor/src/dataset_helper.py +++ b/model_zoo/official/cv/resnet_thor/src/dataset_helper.py @@ -13,34 +13,47 @@ # limitations under the License. # ============================================================================ """Dataset help for minddata dataset""" -from mindspore._checkparam import check_bool -from mindspore.parallel._utils import _get_device_num, _get_parallel_mode -from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ - _to_full_shapes -from mindspore.train.parallel_utils import ParallelMode +import math +import os +from mindspore._checkparam import check_bool, check_int +from mindspore import context +from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes +from mindspore.nn.wrap import GetNextSingleOp +from mindspore.parallel._utils import _get_device_num, _need_to_full -def _send_data(dataset): + +def _send_data(dataset, epoch_num): """Engine dataset to write data to tdt queue.""" if not hasattr(dataset, '__has_sent__'): exec_dataset = dataset.__TRANSFER_DATASET__ - exec_dataset.send() + exec_dataset.send(epoch_num) dataset.__has_sent__ = True +def _send_data_no_flag(dataset, epoch_num): + """Engine dataset to write data to tdt queue directly.""" + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send(epoch_num) + + class DatasetHelper: """ - Help function to use the Minddata dataset. + Help function to use the MindData dataset. - According to different context, change the iter of dataset, to use the same for loop in different context. + According to different contexts, change the iterations of dataset and use the same iteration for loop in different + contexts. Note: - The iter of DatasetHelper will give one epoch data. + The iteration of DatasetHelper will provide one epoch data. Args: - dataset (DataSet): The dataset. - dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. - Default: True. + dataset (DataSet): The training dataset iterator. + dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True. + sink_size (int): Control the amount of data in each sink. + If sink_size=-1, sink the complete dataset for each epoch. + If sink_size>0, sink sink_size data for each epoch. Default: -1. + epoch_num (int): Control the number of epoch data to send. Default: 1. Examples: >>> dataset_helper = DatasetHelper(dataset) @@ -48,81 +61,116 @@ class DatasetHelper: >>> outputs = network(*inputs) """ - def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): + def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1): check_bool(dataset_sink_mode) - self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) + check_int(sink_size) + if sink_size < -1 or sink_size == 0: + raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size)) + + if dataset_sink_mode: + if context.get_context("device_target") == "Ascend": + iterclass = _DatasetIterMSLoopSink + self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order) + elif context.get_context("device_target") == "GPU": + iterclass = _DatasetIterMS + self.iter = iterclass(dataset, sink_size, epoch_num) + elif context.get_context("device_target") == "CPU": + raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") def __iter__(self): return self.iter.__iter__() # A temp solution for loop sink. Delete later def types_shapes(self): - """Get the types and shapes from dataset on current config.""" + """Get the types and shapes from dataset on the current configuration.""" return self.iter.types_shapes() - def loop_size(self): - """Get loop_size for every iteration.""" - return self.iter.loop_size + def sink_size(self): + """Get sink_size for each iteration.""" + return self.iter.get_sink_size() + + def stop_send(self): + """Free up resources about data sink.""" + self.iter.stop_send() class _DatasetIter: - """Base iter for dataset help""" - - def __init__(self, dataset): - self.loop_size = 1 - if not hasattr(dataset, '__ME_INITED__'): - if not hasattr(dataset, '__loop_size__'): - self.loop_size = dataset.get_dataset_size() - else: - self.loop_size = dataset.__loop_size__ - dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + """Base iter for dataset helper""" + def __init__(self, dataset, sink_size, epoch_num): + self.dataset = dataset + self.sink_size = sink_size + self.sink_count = 1 + + if not hasattr(dataset, '__TRANSFER_DATASET__'): + if hasattr(dataset, '__loop_size__'): + self.sink_size = dataset.__loop_size__ + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): - _send_data(dataset) + _send_data(dataset, epoch_num) else: - _send_data(dataset) + _send_data_no_flag(dataset, epoch_num) - self.ind = 0 - self.dataset = dataset - dataset_types, dataset_shapes = _get_types_and_shapes(dataset) - self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes + self.stop_send = dataset.__TRANSFER_DATASET__.stop_send + self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): - self.ind = 0 + self.index = 0 return self def __next__(self): - if self.ind >= self.loop_count: + if self.index >= self.sink_count: raise StopIteration() - self.ind += 1 + self.index += 1 return self.op() def types_shapes(self): return self.dataset_types, self.dataset_shapes - def get_loop_count(self, dataset): - loop_count = 1 + def get_sink_count(self, dataset): + sink_count = 1 if hasattr(dataset, '__loop_size__'): loop_size = dataset.__loop_size__ - if dataset.get_dataset_size() % loop_size != 0: + if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' - f'loop_size {loop_size} are not matched.') - loop_count = int(dataset.get_dataset_size() / loop_size) - return loop_count + f'sink_size {loop_size} are not matched.') + sink_count = math.ceil(dataset.get_dataset_size() / loop_size) + return sink_count + + def get_sink_size(self): + """get sink_size to device""" + sink_size = 1 + if hasattr(self.dataset, '__loop_size__'): + sink_size = self.dataset.__loop_size__ + else: + if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend": + if self.sink_size > 0: + sink_size = self.sink_size + else: + sink_size = self.dataset.get_dataset_size() + return sink_size class _DatasetIterMSLoopSink(_DatasetIter): """Iter for context (device_target=Ascend)""" - - def __init__(self, dataset, iter_first_order): - super(_DatasetIterMSLoopSink, self).__init__(dataset) - loop_size = dataset.__loop_size__ + iter_first_order - self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 - # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to - # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number - # times the batch dimension of tensors for run. Now only support LoopSink. - if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + def __init__(self, dataset, sink_size, epoch_num, iter_first_order): + super().__init__(dataset, sink_size, epoch_num) + sink_count = 1 + if hasattr(dataset, '__loop_size__'): + loop_size = dataset.__loop_size__ + iter_first_order + if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0: + raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' + f'sink_size {loop_size} are not matched.') + sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2 + self.sink_count = sink_count + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + self.sink_count = 1 + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, + # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for + # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. + if _need_to_full(): device_num = _get_device_num() self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) @@ -130,3 +178,16 @@ class _DatasetIterMSLoopSink(_DatasetIter): return tuple() self.op = op + + +class _DatasetIterMS(_DatasetIter): + """Iter for MS(enable_loop_sink=False).""" + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + if sink_size > 0: + self.sink_count = sink_size + else: + self.sink_count = dataset.get_dataset_size() + + queue_name = dataset.__ME_INITED__ + self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) diff --git a/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py b/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py index ad8d8dd8e4..b8bbbf29b7 100644 --- a/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py @@ -174,10 +174,6 @@ class DistributedGradReducerThor(Cell): datatypes = self.hyper_map(F.partial(_get_datatype), grads) grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) - if self.mean: - new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) - else: - new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) - + new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) return new_grad diff --git a/model_zoo/official/cv/resnet_thor/src/model_thor.py b/model_zoo/official/cv/resnet_thor/src/model_thor.py index 13b8c49996..4ba6a32e9f 100644 --- a/model_zoo/official/cv/resnet_thor/src/model_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/model_thor.py @@ -14,27 +14,19 @@ # ============================================================================ """Model.""" -import numpy as np +import math +from mindspore.train.callback import RunContext from mindspore import context -from mindspore import log as logger from mindspore import nn -from mindspore._c_expression import init_exec_dataset -from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool -from mindspore.common import dtype as mstype -from mindspore.common.dtype import pytype_to_dtype -from mindspore.common.tensor import Tensor -from mindspore.nn.metrics import Loss -from mindspore.nn.metrics import get_metrics -from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell -from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ - _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check -from mindspore.train import amp -from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.parallel_utils import ParallelMode +from mindspore.train._utils import _to_full_tensor +from mindspore.train.model import Model +from mindspore.parallel._utils import _need_to_full +from mindspore.common.dtype import pytype_to_dtype +from mindspore._c_expression import init_exec_dataset from src.dataset_helper import DatasetHelper - def _convert_type(types): """ Convert from numpy type to tensor type. @@ -76,194 +68,52 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): need_run=False) -class Model: +class Model_Thor(Model): """ High-Level API for Training or Testing. `Model` groups layers into an object with training and inference features. Args: - network (Cell): The training or testing network. + network (Cell): A training or testing network. loss_fn (Cell): Objective function, if loss_fn is None, the network should contain the logic of loss and grads calculation, and the logic of parallel if needed. Default: None. optimizer (Cell): Optimizer for updating the weights. Default: None. - metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during + metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during training and testing. eg: {'accuracy', 'recall'}. Default: None. eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as `eval_network`. Default: None. - eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of + eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the `eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three - elements, representing the positions of loss value, predict value and label, the loss - value would be passed to `Loss` metric, predict value and label would be passed to other - metric. Default: None. + elements, including the positions of loss value, predicted value and label. The loss + value would be passed to the `Loss` metric, the predicted value and label would be passed + to other metric. Default: None. amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed - precision training. Supports [O0, O2]. Default: "O0". + precision training. Supports [O0, O2, O3]. Default: "O0". - O0: Do not change. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. + - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. - loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else - scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. + O2 is recommended on GPU, O3 is recommended on Ascend. + + loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise, + scale the loss by LossScaleManager. It is a key argument. e.g. Use `loss_scale_manager=None` to set the value. - keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. - - Examples: - >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') - >>> self.bn = nn.BatchNorm2d(64) - >>> self.relu = nn.ReLU() - >>> self.flatten = nn.Flatten() - >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 - >>> - >>> def construct(self, x): - >>> x = self.conv(x) - >>> x = self.bn(x) - >>> x = self.relu(x) - >>> x = self.flatten(x) - >>> out = self.fc(x) - >>> return out - >>> - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) - >>> dataset = get_dataset() - >>> model.train(2, dataset) + keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before + will be overwritten. Default: True. """ def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, - eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): - self._network = network - self._loss_fn = loss_fn - self._optimizer = optimizer - self._loss_scale_manager = None - self._loss_scale_manager_set = False - self._keep_bn_fp32 = True - self._check_kwargs(kwargs) - self._amp_level = amp_level - self._process_amp_args(kwargs) - self._parallel_mode = _get_parallel_mode() - self._device_number = _get_device_num() - self._global_rank = _get_global_rank() - self._parameter_broadcast = _get_parameter_broadcast() + eval_indexes=None, amp_level="O0", frequency=834, **kwargs): + super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network, + eval_indexes, amp_level, **kwargs) self._frequency = frequency - self._stop_epoch = stop_epoch - self._train_network = self._build_train_network() - self._build_eval_network(metrics, eval_network, eval_indexes) - self._build_predict_network() - - def _process_amp_args(self, kwargs): - if self._amp_level == "O0": - self._keep_bn_fp32 = False - if 'keep_batchnorm_fp32' in kwargs: - self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] - if 'loss_scale_manager' in kwargs: - self._loss_scale_manager = kwargs['loss_scale_manager'] - self._loss_scale_manager_set = True - - def _check_kwargs(self, kwargs): - for arg in kwargs: - if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: - raise ValueError(f"Unsupport arg '{arg}'") - - def _build_train_network(self): - """Build train network""" - network = self._network - if self._optimizer: - if self._loss_scale_manager_set: - network = amp.build_train_network(network, - self._optimizer, - self._loss_fn, - level=self._amp_level, - loss_scale_manager=self._loss_scale_manager, - keep_batchnorm_fp32=self._keep_bn_fp32) - else: - network = amp.build_train_network(network, - self._optimizer, - self._loss_fn, - level=self._amp_level, - keep_batchnorm_fp32=self._keep_bn_fp32) - elif self._loss_fn: - network = nn.WithLossCell(network, self._loss_fn) - # If need to check if loss_fn is not None, but optimizer is None - - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - network.set_auto_parallel() - return network - - def _build_eval_network(self, metrics, eval_network, eval_indexes): - """Build the network for evaluation.""" - self._metric_fns = get_metrics(metrics) - if not self._metric_fns: - return - - if eval_network is not None: - if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3): - raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \ - must be three. But got {}".format(eval_indexes)) - - self._eval_network = eval_network - self._eval_indexes = eval_indexes - else: - if self._loss_fn is None: - raise ValueError("loss_fn can not be None.") - self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") - self._eval_indexes = [0, 1, 2] - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self._eval_network.set_auto_parallel() - - def _build_predict_network(self): - """Build the network for prediction.""" - self._predict_network = self._network - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self._predict_network = _VirtualDatasetCell(self._network) - self._predict_network.set_auto_parallel() - - def _clear_metrics(self): - """Clear metrics local values.""" - for metric in self._metric_fns.values(): - metric.clear() - - def _update_metrics(self, outputs): - """Update metrics local values.""" - if not isinstance(outputs, tuple): - raise ValueError("The `outputs` is not tuple.") - - if self._eval_indexes is not None and len(outputs) < 3: - raise ValueError("The length of `outputs` must be greater than or equal to 3, \ - but got {}".format(len(outputs))) - - for metric in self._metric_fns.values(): - if self._eval_indexes is None: - metric.update(*outputs) - else: - if isinstance(metric, Loss): - metric.update(outputs[self._eval_indexes[0]]) - else: - metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]]) - - def _get_metrics(self): - """Get metrics local values.""" - metrics = dict() - for key, value in self._metric_fns.items(): - metrics[key] = value.eval() - return metrics - - def _get_scaling_sens(self): - """get the scaling sens""" - scaling_sens = 1 - if self._loss_scale_manager is not None: - scaling_sens = self._loss_scale_manager.get_loss_scale() - if self._parallel_mode == ParallelMode.DATA_PARALLEL: - scaling_sens /= self._device_number - return scaling_sens - - def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order): + def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, + epoch_num=1, iter_first_order=1): """Initializes dataset.""" need_wrap = False if dataset_sink_mode: @@ -275,7 +125,7 @@ class Model: if not is_train: dataset.__loop_size__ = 1 - dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) # remove later to deal with loop sink if need_wrap: @@ -283,133 +133,31 @@ class Model: network.set_train(is_train) network.phase = phase - return dataset_helper, network - - def init(self, train_dataset=None, valid_dataset=None): - """ - Initializes compute graphs and data graphs with sink mode. - - Note: - Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. - - Args: - train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be - initialized. Default: None. - valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will - be initialized, and `metrics` in `Model` can not be None. Default: None. - - Examples: - >>> train_dataset = get_train_dataset() - >>> valid_dataset = get_valid_dataset() - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) - >>> model.init(train_dataset, valid_dataset) - >>> model.train(2, train_dataset) - >>> model.eval(valid_dataset) - """ - if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": - raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') - - if not train_dataset and not valid_dataset: - raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') - - _device_number_check(self._parallel_mode, self._device_number) - - if train_dataset: - _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) - self._train_network.set_train() - self._train_network.phase = 'train' - - if self._parameter_broadcast: - self._train_network.set_broadcast_flag() - - train_dataset_helper, train_network = self._exec_preprocess(self._train_network, - is_train=True, - phase='train', - dataset=train_dataset, - dataset_sink_mode=True) - self._train_network = train_network - for inputs in train_dataset_helper: - self._train_network.compile(*inputs) - break - - if valid_dataset: - if not self._metric_fns: - raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') - - self._eval_network.set_train(False) - self._eval_network.phase = 'eval' - valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, - is_train=False, - phase='eval', - dataset=valid_dataset, - dataset_sink_mode=True) - self._eval_network = eval_network - for inputs in valid_dataset_helper: - self._eval_network.compile(*inputs) - break - - def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): - """ - Training. - - Args: - epoch (int): Total number of iterations on the data. - train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be - returned and passed to the network. Otherwise, a tuple (data, label) will - be returned, and the data and label are passed to the network and loss - function respectively. - callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. - dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. - Configure pynative mode, the training process will be performed with - dataset not sink. - """ - epoch = check_int_positive(epoch) - self._train_network.set_train() + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + network.set_auto_parallel() - if self._parameter_broadcast: - self._train_network.set_broadcast_flag() + return dataset_helper, network - # build callback list - cb_params = _InternalCallbackParam() - cb_params.train_network = self._train_network - cb_params.epoch_num = epoch - cb_params.batch_num = train_dataset.get_dataset_size() - cb_params.mode = "train" - cb_params.loss_fn = self._loss_fn - cb_params.optimizer = self._optimizer - cb_params.parallel_mode = self._parallel_mode - cb_params.device_number = self._device_number - cb_params.train_dataset = train_dataset - cb_params.list_callback = callbacks - - with _CallbackManager(callbacks) as list_callback: - if not dataset_sink_mode: - self._train_process(epoch, train_dataset, list_callback, cb_params) - elif context.get_context("mode") == context.PYNATIVE_MODE: - logger.warning("The pynative mode cannot support dataset sink mode currently." - "So the training process will be performed with dataset not sink.") - self._train_process(epoch, train_dataset, list_callback, cb_params) - else: - self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params) - - def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None): + def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): """ Training process. The data would be passed to network through dataset channel. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be + loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss + be returned. The data and label would be passed to the network and loss function respectively. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. + sink_size (int): Control the amount of data in each sink. Default: -1. """ + if sink_size == -1: + epoch_num = epoch + else: + epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) + iter_first_order = self._frequency - 1 iter_second_order = 1 train_dataset.__loop_size__ = iter_second_order @@ -418,308 +166,82 @@ class Model: phase='train', dataset=train_dataset, dataset_sink_mode=True, + sink_size=sink_size, + epoch_num=epoch_num, iter_first_order=iter_first_order) + self._train_network = train_network cb_params.train_network = self._train_network cb_params.cur_step_num = 0 - loop_size = dataset_helper.loop_size() run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False - has_do_dataset_init = False switch_branch_one = True + index_first_order = 0 + train_network_init_flag = True + has_do_dataset_init = False + for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) - # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: + if _need_to_full() and context.get_context("device_target") == "GPU": + inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) - if switch_branch_one: - cb_params.cur_step_num += loop_size - self._train_network.add_flags_recursive(thor=True) - self._train_network.phase = 'train0' + if context.get_context("device_target") == "GPU": + if switch_branch_one: + cb_params.cur_step_num += 1 + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=True) + self._train_network.phase = 'train0' + switch_branch_one = not switch_branch_one + outputs = self._train_network(*inputs) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) + else: + cb_params.cur_step_num += 1 + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=False) + train_network_init_flag = False + self._train_network.phase = 'train1' + outputs = self._train_network(*inputs) + cb_params.net_outputs = outputs + index_first_order += 1 + if index_first_order == iter_first_order: + index_first_order = 0 + switch_branch_one = not switch_branch_one + list_callback.step_end(run_context) else: - cb_params.cur_step_num += iter_first_order - self._train_network.add_flags_recursive(thor=False) - self._train_network.phase = 'train1' - if not has_do_dataset_init: - _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') - has_do_dataset_init = True - switch_branch_one = not switch_branch_one - outputs = self._train_network(*inputs) - cb_params.net_outputs = outputs - list_callback.step_end(run_context) + if switch_branch_one: + cb_params.cur_step_num += 1 + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=True) + self._train_network.phase = 'train0' + else: + cb_params.cur_step_num += iter_first_order + if train_network_init_flag: + self._train_network.add_flags_recursive(thor=False) + train_network_init_flag = False + self._train_network.phase = 'train1' + if not has_do_dataset_init: + _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') + has_do_dataset_init = True + switch_branch_one = not switch_branch_one + outputs = self._train_network(*inputs) + cb_params.net_outputs = outputs + list_callback.step_end(run_context) list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break + dataset_helper.stop_send() list_callback.end(run_context) - def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): - """ - Training process. The data would be passed to network directly. - - Args: - epoch (int): Total number of iterations on the data. - train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be - returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss - function respectively. - list_callback (Callback): Executor of callback list. Default: None. - cb_params (_InternalCallbackParam): Callback parameters. Default: None. - """ - dataset_helper, _ = self._exec_preprocess(self._train_network, - is_train=True, - phase='train', - dataset=train_dataset, - dataset_sink_mode=False) - cb_params.cur_step_num = 0 - run_context = RunContext(cb_params) - list_callback.begin(run_context) - # used to stop training for early stop, such as stopAtTIme or stopATStep - should_stop = False - - for i in range(epoch): - cb_params.cur_epoch_num = i + 1 - - list_callback.epoch_begin(run_context) - - for next_element in dataset_helper: - len_element = len(next_element) - if self._loss_fn and len_element != 2: - raise ValueError("when loss_fn is not None, train_dataset should" - "return two elements, but got {}".format(len_element)) - cb_params.cur_step_num += 1 - list_callback.step_begin(run_context) - - overflow = False - if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): - scaling_sens = self._get_scaling_sens() - next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) - - outputs = self._train_network(*next_element) - cb_params.net_outputs = outputs - if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): - _, overflow, _ = outputs - overflow = np.all(overflow.asnumpy()) - self._loss_scale_manager.update_loss_scale(overflow) - - list_callback.step_end(run_context) - should_stop = should_stop or run_context.get_stop_requested() - if should_stop: - break - - train_dataset.reset() - - list_callback.epoch_end(run_context) - should_stop = should_stop or run_context.get_stop_requested() - if should_stop: - break - - list_callback.end(run_context) - - def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): - """ - Training API where the iteration is controlled by python front-end. - - When setting pynative mode, the training process will be performed with dataset not sink. - - Note: - CPU is not supported when dataset_sink_mode is true. - If dataset_sink_mode is True, epoch of training should be equal to the count of repeat - operation in dataset processing. Otherwise, errors could occur since the amount of data - is not the amount training requires. - If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features - of data will be transferred one by one. The limitation of data transmission per time is 256M. - - Args: - epoch (int): Total number of iterations on the data. - train_dataset (Dataset): A training dataset iterator. If there is no - loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be - returned and passed to the network. Otherwise, a tuple (data, label) should - be returned, and the data and label are passed to the network and loss - function respectively. - callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. - dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. - Configure pynative mode, the training process will be performed with - dataset not sink. - - - Examples: - >>> dataset = get_dataset() - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> loss_scale_manager = FixedLossScaleManager() - >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) - >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) - >>> model.train(2, dataset) - """ - repeat_count = train_dataset.get_repeat_count() - if epoch != repeat_count and dataset_sink_mode is True: - logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}") - check_bool(dataset_sink_mode) - _device_number_check(self._parallel_mode, self._device_number) - _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) - - self._train(epoch, - train_dataset, - callbacks=callbacks, - dataset_sink_mode=dataset_sink_mode) - - def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): - """ - Evaluation. The data would be passed to network through dataset channel. - - Args: - valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (Callback): Executor of callback list. Default: None. - cb_params (_InternalCallbackParam): Callback parameters. Default: None. - - Returns: - Dict, returns the loss value & metrics values for the model in test mode. - """ - run_context = RunContext(cb_params) - - dataset_helper, eval_network = self._exec_preprocess(self._eval_network, - is_train=False, - phase='eval', - dataset=valid_dataset, - dataset_sink_mode=True) - self._eval_network = eval_network - cb_params.eval_network = self._eval_network - list_callback.begin(run_context) - - for inputs in dataset_helper: - cb_params.cur_step_num += 1 - list_callback.step_begin(run_context) - - outputs = self._eval_network(*inputs) - - cb_params.net_outputs = outputs - list_callback.step_end(run_context) - self._update_metrics(outputs) - - metrics = self._get_metrics() - cb_params.metrics = metrics - list_callback.end(run_context) - - return metrics - - def _eval_process(self, valid_dataset, list_callback=None, cb_params=None): - """ - Evaluation. The data would be passed to network directly. - - Args: - valid_dataset (Dataset): Dataset to evaluate the model. - list_callback (Callback): Executor of callback list. Default: None. - cb_params (_InternalCallbackParam): Callback parameters. Default: None. - - Returns: - Dict, returns the loss value & metrics values for the model in test mode. - """ - run_context = RunContext(cb_params) - list_callback.begin(run_context) - - dataset_helper, _ = self._exec_preprocess(self._eval_network, - is_train=False, - phase='eval', - dataset=valid_dataset, - dataset_sink_mode=False) - for next_element in dataset_helper: - cb_params.cur_step_num += 1 - list_callback.step_begin(run_context) - outputs = self._eval_network(*next_element) - cb_params.net_outputs = outputs - list_callback.step_end(run_context) - self._update_metrics(outputs) - - metrics = self._get_metrics() - cb_params.metrics = metrics - list_callback.end(run_context) - return metrics - - def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): - """ - Evaluation API where the iteration is controlled by python front-end. - - Configure to pynative mode, the evaluation will be performed with dataset non-sink mode. - - Note: - CPU is not supported when dataset_sink_mode is true. - If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features - of data will be transferred one by one. The limitation of data transmission per time is 256M. - - Args: - valid_dataset (Dataset): Dataset to evaluate the model. - callbacks (list): List of callback object. Callbacks which should be excuted - while training. Default: None. - dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. - - Returns: - Dict, returns the loss value & metrics values for the model in test mode. - - Examples: - >>> dataset = get_dataset() - >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) - >>> model.eval(dataset) - """ - check_bool(dataset_sink_mode) - _device_number_check(self._parallel_mode, self._device_number) - if not self._metric_fns: - raise ValueError("metric fn can not be None or empty.") - - cb_params = _InternalCallbackParam() - cb_params.eval_network = self._eval_network - cb_params.valid_dataset = valid_dataset - cb_params.batch_num = valid_dataset.get_dataset_size() - cb_params.mode = "eval" - cb_params.cur_step_num = 0 - - self._eval_network.set_train(mode=False) - self._eval_network.phase = 'eval' - - self._clear_metrics() - - with _CallbackManager(callbacks) as list_callback: - if dataset_sink_mode: - return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) - return self._eval_process(valid_dataset, list_callback, cb_params) - - def predict(self, *predict_data): - """ - Generates output predictions for the input samples. - - Data could be single tensor, or list of tensor, tuple of tensor. - - Note: - Batch data should be put together in one tensor. - - Args: - predict_data (Tensor): Tensor of predict data. can be array, list or tuple. - - Returns: - Tensor, array(s) of predictions. - - Examples: - >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) - >>> model = Model(Net()) - >>> model.predict(input_data) - """ - self._predict_network.set_train(False) - check_input_data(*predict_data, data_class=Tensor) - result = self._predict_network(*predict_data) - - check_output_data(result) - return result - -__all__ = ["Model"] +__all__ = ["Model_Thor"] diff --git a/model_zoo/official/cv/resnet_thor/src/resnet50.py b/model_zoo/official/cv/resnet_thor/src/resnet50.py deleted file mode 100644 index 00beac2fdf..0000000000 --- a/model_zoo/official/cv/resnet_thor/src/resnet50.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""ResNet.""" -import numpy as np -import mindspore.nn as nn -from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor - - -def _weight_variable(shape, factor=0.01): - init_value = np.random.randn(*shape).astype(np.float32) * factor - return Tensor(init_value) - - -def _conv3x3(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 3, 3) - weight = _weight_variable(weight_shape) - return nn.Conv2d(in_channel, out_channel, - kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) - - -def _conv1x1(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 1, 1) - weight = _weight_variable(weight_shape) - return nn.Conv2d(in_channel, out_channel, - kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) - - -def _conv7x7(in_channel, out_channel, stride=1): - weight_shape = (out_channel, in_channel, 7, 7) - weight = _weight_variable(weight_shape) - return nn.Conv2d(in_channel, out_channel, - kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) - - -def _bn(channel): - return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, - gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) - - -def _bn_last(channel): - return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, - gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) - - -def _fc(in_channel, out_channel): - weight_shape = (out_channel, in_channel) - weight = _weight_variable(weight_shape) - return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) - - -class ResidualBlock(nn.Cell): - """ - ResNet V1 residual block definition. - - Args: - in_channel (int): Input channel. - out_channel (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock(3, 256, stride=2) - """ - expansion = 4 - - def __init__(self, - in_channel, - out_channel, - stride=1): - super(ResidualBlock, self).__init__() - - channel = out_channel // self.expansion - self.conv1 = _conv1x1(in_channel, channel, stride=1) - self.bn1 = _bn(channel) - - self.conv2 = _conv3x3(channel, channel, stride=stride) - self.bn2 = _bn(channel) - - self.conv3 = _conv1x1(channel, out_channel, stride=1) - self.bn3 = _bn_last(out_channel) - - self.relu = nn.ReLU() - - self.down_sample = False - - if stride != 1 or in_channel != out_channel: - self.down_sample = True - self.down_sample_layer = None - - if self.down_sample: - self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), - _bn(out_channel)]) - self.add = P.TensorAdd() - - def construct(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.down_sample: - identity = self.down_sample_layer(identity) - - out = self.add(out, identity) - out = self.relu(out) - - return out - - -class ResNet(nn.Cell): - """ - ResNet architecture. - - Args: - block (Cell): Block for network. - layer_nums (list): Numbers of block in different layers. - in_channels (list): Input channel in each layer. - out_channels (list): Output channel in each layer. - strides (list): Stride size in each layer. - num_classes (int): The number of classes that the training images are belonging to. - Returns: - Tensor, output tensor. - - Examples: - >>> ResNet(ResidualBlock, - >>> [3, 4, 6, 3], - >>> [64, 256, 512, 1024], - >>> [256, 512, 1024, 2048], - >>> [1, 2, 2, 2], - >>> 10) - """ - - def __init__(self, - block, - layer_nums, - in_channels, - out_channels, - strides, - num_classes): - super(ResNet, self).__init__() - - if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: - raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") - - self.conv1 = _conv7x7(3, 64, stride=2) - self.bn1 = _bn(64) - self.relu = P.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") - - self.layer1 = self._make_layer(block, - layer_nums[0], - in_channel=in_channels[0], - out_channel=out_channels[0], - stride=strides[0]) - self.layer2 = self._make_layer(block, - layer_nums[1], - in_channel=in_channels[1], - out_channel=out_channels[1], - stride=strides[1]) - self.layer3 = self._make_layer(block, - layer_nums[2], - in_channel=in_channels[2], - out_channel=out_channels[2], - stride=strides[2]) - self.layer4 = self._make_layer(block, - layer_nums[3], - in_channel=in_channels[3], - out_channel=out_channels[3], - stride=strides[3]) - - self.mean = P.ReduceMean(keep_dims=True) - self.flatten = nn.Flatten() - self.end_point = _fc(out_channels[3], num_classes) - - def _make_layer(self, block, layer_num, in_channel, out_channel, stride): - """ - Make stage network of ResNet. - - Args: - block (Cell): Resnet block. - layer_num (int): Layer number. - in_channel (int): Input channel. - out_channel (int): Output channel. - stride (int): Stride size for the first convolutional layer. - - Returns: - SequentialCell, the output layer. - - Examples: - >>> _make_layer(ResidualBlock, 3, 128, 256, 2) - """ - layers = [] - - resnet_block = block(in_channel, out_channel, stride=stride) - layers.append(resnet_block) - - for _ in range(1, layer_num): - resnet_block = block(out_channel, out_channel, stride=1) - layers.append(resnet_block) - - return nn.SequentialCell(layers) - - def construct(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - c1 = self.maxpool(x) - - c2 = self.layer1(c1) - c3 = self.layer2(c2) - c4 = self.layer3(c3) - c5 = self.layer4(c4) - - out = self.mean(c5, (2, 3)) - out = self.flatten(out) - out = self.end_point(out) - - return out - - -def resnet50(class_num=10): - """ - Get ResNet50 neural network. - - Args: - class_num (int): Class number. - - Returns: - Cell, cell instance of ResNet50 neural network. - - Examples: - >>> net = resnet50(10) - """ - return ResNet(ResidualBlock, - [3, 4, 6, 3], - [64, 256, 512, 1024], - [256, 512, 1024, 2048], - [1, 2, 2, 2], - class_num) diff --git a/model_zoo/official/cv/resnet_thor/src/resnet_thor.py b/model_zoo/official/cv/resnet_thor/src/resnet_thor.py index 50e694854f..937e78939b 100644 --- a/model_zoo/official/cv/resnet_thor/src/resnet_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/resnet_thor.py @@ -18,8 +18,9 @@ import numpy as np import mindspore.nn as nn from mindspore.common.tensor import Tensor from mindspore.ops import operations as P +from mindspore import context -from src.thor_layer import Conv2d_Thor, Dense_Thor +from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU def calculate_gain(nonlinearity, param=None): @@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): return np.random.normal(0, std, size=inputs_shape).astype(np.float32) -def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'): fan = _calculate_correct_fan(inputs_shape, mode) gain = calculate_gain(nonlinearity, a) std = gain / math.sqrt(fan) @@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu') return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) -def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32): weight_shape = (out_channel, in_channel, 3, 3) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) - return Conv2d_Thor(in_channel, out_channel, - kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, - damping=damping, loss_scale=loss_scale, frequency=frequency) + if context.get_context('device_target') == "Ascend": + layer = Conv2d_Thor(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + else: + layer = Conv2d_Thor_GPU(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + return layer -def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): +def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32): weight_shape = (out_channel, in_channel, 1, 1) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) - return Conv2d_Thor(in_channel, out_channel, - kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, - damping=damping, loss_scale=loss_scale, frequency=frequency) + if context.get_context('device_target') == "Ascend": + layer = Conv2d_Thor(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + else: + layer = Conv2d_Thor_GPU(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + return layer -def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): +def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32): weight_shape = (out_channel, in_channel, 7, 7) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) - return Conv2d_Thor(in_channel, out_channel, - kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, - damping=damping, loss_scale=loss_scale, frequency=frequency) + if context.get_context('device_target') == "Ascend": + layer = Conv2d_Thor(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + else: + layer = Conv2d_Thor_GPU(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, + damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size) + return layer def _bn(channel): @@ -120,14 +144,21 @@ def _bn(channel): def _bn_last(channel): return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, - gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) -def _fc(in_channel, out_channel, damping, loss_scale, frequency): +def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32): weight_shape = (out_channel, in_channel) weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) - return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, - bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) + if context.get_context('device_target') == "Ascend": + layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, + bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency, + batch_size=batch_size) + else: + layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight, + bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency, + batch_size=batch_size) + return layer class ResidualBlock(nn.Cell): @@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell): stride=1, damping=0.03, loss_scale=1, - frequency=278): + frequency=278, + batch_size=32): super(ResidualBlock, self).__init__() channel = out_channel // self.expansion self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, batch_size=batch_size) self.bn1 = _bn(channel) self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, batch_size=batch_size) self.bn2 = _bn(channel) self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, batch_size=batch_size) self.bn3 = _bn_last(out_channel) self.relu = nn.ReLU() @@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell): if self.down_sample: self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, damping=damping, loss_scale=loss_scale, - frequency=frequency), + frequency=frequency, + batch_size=batch_size), _bn(out_channel)]) self.add = P.TensorAdd() @@ -239,16 +272,19 @@ class ResNet(nn.Cell): num_classes, damping, loss_scale, - frequency): + frequency, + batch_size): super(ResNet, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") - self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency) + self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, + frequency=frequency, batch_size=batch_size) self.bn1 = _bn(64) self.relu = P.ReLU() - self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) + # self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.layer1 = self._make_layer(block, layer_nums[0], @@ -257,7 +293,8 @@ class ResNet(nn.Cell): stride=strides[0], damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, + batch_size=batch_size) self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1], @@ -265,14 +302,16 @@ class ResNet(nn.Cell): stride=strides[1], damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, + batch_size=batch_size) self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2], out_channel=out_channels[2], stride=strides[2], damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, + batch_size=batch_size) self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3], @@ -280,14 +319,16 @@ class ResNet(nn.Cell): stride=strides[3], damping=damping, loss_scale=loss_scale, - frequency=frequency) + frequency=frequency, + batch_size=batch_size) self.mean = P.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() - self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency) + self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, + frequency=frequency, batch_size=batch_size) def _make_layer(self, block, layer_num, in_channel, out_channel, stride, - damping, loss_scale, frequency): + damping, loss_scale, frequency, batch_size): """ Make stage network of ResNet. @@ -307,12 +348,14 @@ class ResNet(nn.Cell): layers = [] resnet_block = block(in_channel, out_channel, stride=stride, - damping=damping, loss_scale=loss_scale, frequency=frequency) + damping=damping, loss_scale=loss_scale, frequency=frequency, + batch_size=batch_size) layers.append(resnet_block) for _ in range(1, layer_num): resnet_block = block(out_channel, out_channel, stride=1, - damping=damping, loss_scale=loss_scale, frequency=frequency) + damping=damping, loss_scale=loss_scale, frequency=frequency, + batch_size=batch_size) layers.append(resnet_block) return nn.SequentialCell(layers) @@ -321,7 +364,7 @@ class ResNet(nn.Cell): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) - c1, _ = self.maxpool(x) + c1 = self.maxpool(x) c2 = self.layer1(c1) c3 = self.layer2(c2) @@ -335,7 +378,7 @@ class ResNet(nn.Cell): return out -def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): +def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32): """ Get ResNet50 neural network. @@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): class_num, damping, loss_scale, - frequency) + frequency, + batch_size) diff --git a/model_zoo/official/cv/resnet_thor/src/thor.py b/model_zoo/official/cv/resnet_thor/src/thor.py index cc82e6e369..7fc0d9b74e 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor.py +++ b/model_zoo/official/cv/resnet_thor/src/thor.py @@ -12,27 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""momentum""" -import mindspore.common.dtype as mstype +"""THOR""" +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.ops import _selected_ops from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -from mindspore.common.parameter import ParameterTuple +from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from mindspore._checkparam import check_bool +from mindspore._checkparam import Validator as validator from mindspore.nn.optim.optimizer import Optimizer -from mindspore.ops import functional as F, composite as C, operations as P from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from src.grad_reducer_thor import DistributedGradReducerThor -momentum_opt = C.MultitypeFuncGraph("momentum_opt") - - -@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): - """Apply momentum optimizer to the weight parameter using Tensor.""" - success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) - return success - +_momentum_opt = C.MultitypeFuncGraph("momentum_opt") op_add = P.AddN() apply_decay = C.MultitypeFuncGraph("apply_decay") @@ -46,6 +39,119 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): return gradient +@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): + """Apply momentum optimizer to the weight parameter using Tensor.""" + success = True + success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + return success + + +class THOR_GPU(Optimizer): + """ + THOR + """ + def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, + weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []): + super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale) + validator.check_value_type("momentum", momentum, [float], self.cls_name) + if isinstance(momentum, float) and momentum < 0.0: + raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") + self.params = self.parameters + self.use_nesterov = check_bool(use_nesterov) + self.moments = self.params.clone(prefix="moments", init='zeros') + self.hyper_map = C.HyperMap() + self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov) + + self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, + 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, + 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, + 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, + 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 196, 1.0 / 196, 1.0 / 196, + 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, + 1.0] + self.feature_map_new = [x ** 0.5 for x in self.feature_map] + self.transpose = P.Transpose() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.matmul = P.MatMul() + self.matrix_A = ParameterTuple(matrix_A) + self.matrix_G = ParameterTuple(matrix_G) + self.A_inv_max = ParameterTuple(A_inv_max) + self.G_inv_max = ParameterTuple(G_inv_max) + self.assign = P.Assign() + self.mul = P.Mul() + + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree) + self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree) + self.weight_decay = weight_decay + self.decay_flags = tuple(decay_filter(x) for x in self.parameters) + self.update_gradient = P.UpdateThorGradient(split_dim=128) + + def construct(self, gradients): + params = self.params + moments = self.moments + gradients = self.scale_grad(gradients) + new_grads = () + if self.thor: + matrix_A_allreduce = () + matrix_G_allreduce = () + for i in range(54): + g = gradients[i * 3] + matrix_A = self.matrix_A[i] + matrix_G = self.matrix_G[i] + matrix_A = F.depend(matrix_A, g) + matrix_G = F.depend(matrix_G, g) + matrix_A = self.mul(matrix_A, self.feature_map_new[i]) + matrix_G = self.mul(matrix_G, self.feature_map_new[i]) + matrix_A_allreduce = matrix_A_allreduce + (matrix_A,) + matrix_G_allreduce = matrix_G_allreduce + (matrix_G,) + matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce) + matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce) + for i in range(54): + g = gradients[i * 3] + g_shape = self.shape(g) + g = self.reshape(g, (g_shape[0], -1)) + matrix_A = matrix_A_allreduce[i] + matrix_G = matrix_G_allreduce[i] + g = self.update_gradient(matrix_G, g, matrix_A) + fake_A = self.assign(self.matrix_A[i], matrix_A) + fake_G = self.assign(self.matrix_G[i], matrix_G) + g = F.depend(g, fake_A) + g = F.depend(g, fake_G) + if i == 53: + new_grads = new_grads + (g,) + else: + g = self.reshape(g, g_shape) + new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) + else: + for i in range(54): + g = gradients[i * 3] + g_shape = self.shape(g) + g = self.reshape(g, (g_shape[0], -1)) + matrix_A = self.matrix_A[i] + matrix_G = self.matrix_G[i] + matrix_A = F.depend(matrix_A, g) + matrix_G = F.depend(matrix_G, g) + g = self.update_gradient(matrix_G, g, matrix_A) + if i == 53: + new_grads = new_grads + (g,) + else: + g = self.reshape(g, g_shape) + new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) + gradients = new_grads + if self.weight_decay > 0: + gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, + params, gradients) + lr = self.get_lr() + success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) + return success + class THOR(Optimizer): """THOR""" def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, @@ -195,5 +301,5 @@ class THOR(Optimizer): params, gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() - success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) + success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success diff --git a/model_zoo/official/cv/resnet_thor/src/thor_layer.py b/model_zoo/official/cv/resnet_thor/src/thor_layer.py index d84cbf7a93..f836258a80 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor_layer.py +++ b/model_zoo/official/cv/resnet_thor/src/thor_layer.py @@ -15,7 +15,6 @@ """thor_layer""" import numpy as np -import mindspore as ms import mindspore.common.dtype as mstype from mindspore._checkparam import check_bool, twice, check_int_positive from mindspore._extends import cell_attr_register @@ -25,8 +24,10 @@ from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation from mindspore.ops import operations as P + C0 = 16 + def caculate_device_shape(matrix_dim, channel, is_A): ll = (0) if is_A: @@ -37,6 +38,31 @@ def caculate_device_shape(matrix_dim, channel, is_A): ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim) return ll + +def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim): + split_dimA = split_dim + split_dimG = split_dim + if matrix_A_dim % split_dim == 0: + batch_w = matrix_A_dim // split_dim + else: + if matrix_A_dim < split_dim: + batch_w = 1 + split_dimA = matrix_A_dim + else: + batch_w = matrix_A_dim // split_dim + 1 + + if matrix_G_dim % split_dim == 0: + batch_h = matrix_G_dim // split_dim + else: + if matrix_G_dim < split_dim: + batch_h = 1 + split_dimG = matrix_G_dim + else: + batch_h = matrix_G_dim // split_dim + 1 + matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA) + matrix_G_shape = (batch_h, split_dimG, split_dimG) + return matrix_A_shape, matrix_G_shape + class _Conv(Cell): r"""Applies a N-D convolution over an input signal composed of several input planes. @@ -97,6 +123,286 @@ class _Conv(Cell): raise NotImplementedError +class Conv2d_Thor_GPU(_Conv): + """Conv2d_Thor""" + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + data_format='NCHW', + has_bias=False, + weight_init='normal', + damping=0.03, + loss_scale=1, + frequency=278, + batch_size=32, + bias_init='zeros'): + self.thor = True + self.hw = kernel_size * kernel_size + kernel_size = twice(kernel_size) + super(Conv2d_Thor_GPU, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + data_format, + has_bias, + weight_init, + bias_init, + ) + self.conv2d = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group + ) + + self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1] + self.matrix_G_dim = self.out_channels + + split_dim = 128 + matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.matrix_A_dim, self.matrix_G_dim, split_dim) + self.matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32), + name='matrix_A_inv', requires_grad=False) + self.matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32), + name='matrix_A_inv', requires_grad=False) + self.broadcast_to = P.BroadcastTo(matrix_A_shape) + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same") + self.matmul = P.MatMul(transpose_b=True) + self.shape = P.Shape() + self.reshape = P.Reshape() + self.mul = P.Mul() + self.getG = P.InsertGradientOf(self.save_gradient) + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.batch_size = Tensor(batch_size, mstype.float16) + self.transpose = P.Transpose() + self.cast = P.Cast() + self.gather = P.GatherV2() + self.freq = Tensor(frequency, mstype.int32) + self.axis = 0 + self.sqrt = P.Sqrt() + self.reduce_mean = P.ReduceMean(keep_dims=False) + self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False) + self.dampingA = Tensor(np.identity(self.matrix_A_dim), mstype.float32) + self.dampingG = Tensor(np.identity(self.matrix_G_dim), mstype.float32) + self.cholesky = P.Cholesky(split_dim=split_dim) + self.vector_matmul = P.BatchMatMul(transpose_a=True) + + def save_gradient(self, dout): + """save_gradient""" + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, self.batch_size) + dout = self.reduce_mean(dout, 0) + dout_shape = self.shape(dout) + dout = self.reshape(dout, (dout_shape[0], -1)) + dout_shape = self.shape(dout) + normalizer = dout_shape[1] + dout = self.cast(dout, mstype.float32) + matrix_G = self.matmul(dout, dout) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.freq + damping = self.mul(damping_step, 1.0 / normalizer) + damping = self.sqrt(damping) + matrix_G = matrix_G + damping * self.dampingG + matrix_G = self.cholesky(matrix_G) + matrix_G = self.vector_matmul(matrix_G, matrix_G) + self.matrix_G_inv = matrix_G + return out + + def construct(self, x): + if self.thor: + matrix_A = self.img2col(x) + matrix_A_shape = self.shape(matrix_A) + matrix_A = self.reshape(matrix_A, (matrix_A_shape[0]*matrix_A_shape[1]*matrix_A_shape[2], + matrix_A_shape[3], -1)) + matrix_A = self.reduce_mean(matrix_A, 1) + matrix_A_shape = self.shape(matrix_A) + normalizer = matrix_A_shape[1] + matrix_A = self.cast(matrix_A, mstype.float32) + matrix_A = self.matmul(matrix_A, matrix_A) + matrix_A = self.mul(matrix_A, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.mul(damping_step, 1.0 / normalizer) + damping = self.sqrt(damping) + matrix_A = matrix_A + damping * self.dampingA + matrix_A = self.cholesky(matrix_A) + matrix_A = self.vector_matmul(matrix_A, matrix_A) + matrix_A = self.broadcast_to(matrix_A) + self.matrix_A_inv = matrix_A + out = self.conv2d(x, self.weight) + out = self.getG(out) + else: + out = self.conv2d(x, self.weight) + + return out + + def extra_repr(self): + """extra_repr""" + s = 'input_channels={}, output_channels={}, kernel_size={},' \ + 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ + 'group={}, data_format={}, has_bias={},' \ + 'weight_init={}, bias_init={}'.format( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.pad_mode, + self.padding, + self.dilation, + self.group, + self.data_format, + self.has_bias, + self.weight, + self.bias) + + if self.has_bias: + s += ', bias={}'.format(self.bias) + return s + + +class Dense_Thor_GPU(Cell): + """Dense_Thor""" + @cell_attr_register(attrs=['has_bias', 'activation']) + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + damping=0.03, + loss_scale=1, + frequency=278, + batch_size=32, + has_bias=True, + activation=None): + super(Dense_Thor_GPU, self).__init__() + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = check_bool(has_bias) + self.thor = True + if isinstance(weight_init, Tensor): + if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: + raise ValueError("weight_init shape error") + + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: + raise ValueError("bias_init shape error") + + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + split_dim = 128 + matrix_A_shape, matrix_G_shape = caculate_matmul_shape(self.in_channels, self.out_channels, split_dim) + self.matrix_A_inv = Parameter(Tensor(np.zeros(matrix_A_shape).astype(np.float32)), + name='matrix_A_inv', requires_grad=False) + self.matrix_G_inv = Parameter(Tensor(np.zeros(matrix_G_shape).astype(np.float32)), + name="matrix_G_inv", requires_grad=False) + self.broadcast_to = P.BroadcastTo(matrix_A_shape) + self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) + self.shape = P.Shape() + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.mul = P.Mul() + self.cube_matmul = P.MatMul(transpose_a=True) + self.loss_scale = Tensor(1 / loss_scale, mstype.float16) + self.batch_size = Tensor(batch_size, mstype.float16) + self.getG = P.InsertGradientOf(self.save_gradient) + self.damping = Parameter(Tensor(damping), name="damping_value", requires_grad=False) + self.dampingA = Tensor(np.identity(in_channels), mstype.float32) + self.dampingG = Tensor(np.identity(out_channels), mstype.float32) + self.cast = P.Cast() + self.gather = P.GatherV2() + self.freq = Tensor(frequency, mstype.int32) + self.axis = 0 + self.add = P.TensorAdd() + self.sqrt = P.Sqrt() + self.cholesky = P.Cholesky(split_dim=split_dim) + self.vector_matmul = P.BatchMatMul(transpose_a=True) + + def save_gradient(self, dout): + """save_gradient""" + out = dout + dout = self.mul(dout, self.loss_scale) + dout = self.mul(dout, self.batch_size) + dout_shape = self.shape(dout) + normalizer = dout_shape[0] + dout = self.cast(dout, mstype.float32) + matrix_G = self.cube_matmul(dout, dout) + matrix_G = self.mul(matrix_G, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, 0) + damping_step = self.cast(damping_step, mstype.float32) + self.cov_step = self.cov_step + self.freq + damping = self.sqrt(damping_step) + matrix_G = matrix_G + damping * self.dampingG + matrix_G = self.cholesky(matrix_G) + matrix_G = self.vector_matmul(matrix_G, matrix_G) + self.matrix_G_inv = matrix_G + return out + + def construct(self, x): + """construct""" + if self.thor: + inputs = self.cast(x, mstype.float32) + inputs = self.cube_matmul(inputs, inputs) + inputs_shape = self.shape(inputs) + normalizer = inputs_shape[0] + matrix_A = self.mul(inputs, 1.0 / normalizer) + damping_step = self.gather(self.damping, self.cov_step, self.axis) + damping_step = self.cast(damping_step, mstype.float32) + damping = self.sqrt(damping_step) + matrix_A = matrix_A + damping * self.dampingA + matrix_A = self.cholesky(matrix_A) + matrix_A = self.vector_matmul(matrix_A, matrix_A) + matrix_A = self.broadcast_to(matrix_A) + self.matrix_A_inv = matrix_A + output = self.matmul(x, self.weight) + output = self.getG(output) + else: + output = self.matmul(x, self.weight) + + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.activation_flag: + return self.activation(output) + return output + + def extend_repr(self): + """extend_repr""" + str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \ + .format(self.in_channels, self.out_channels, self.weight, self.has_bias) + if self.has_bias: + str_info = str_info + ', bias={}'.format(self.bias) + + if self.activation_flag: + str_info = str_info + ', activation={}'.format(self.activation) + + return str_info + + class Conv2d_Thor(_Conv): """Conv2d_Thor""" def __init__(self, @@ -114,6 +420,7 @@ class Conv2d_Thor(_Conv): damping=0.03, loss_scale=1, frequency=278, + batch_size=32, bias_init='zeros'): self.thor = True ksizes = (1, kernel_size, kernel_size, 1) @@ -143,7 +450,7 @@ class Conv2d_Thor(_Conv): dilation=self.dilation, group=self.group ) - + self.batch_size = batch_size self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) self.cube_matmul = P.CusMatMulCube(transpose_a=True) self.matrix_combine = P.CusMatrixCombine() @@ -228,7 +535,7 @@ class Conv2d_Thor(_Conv): normalizer = dout_shape[0] matrix_G = self.cube_matmul(dout, dout) - normalizer = self.cast(normalizer, ms.float32) + normalizer = self.cast(normalizer, mstype.float32) matrix_G = self.mul(matrix_G, 1.0 / normalizer) damping_step = self.gather(self.damping, self.cov_step, 0) self.cov_step = self.cov_step + self.freq @@ -261,7 +568,7 @@ class Conv2d_Thor(_Conv): matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0)) matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels)) matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim)) - normalizer = self.cast(normalizer, ms.float32) + normalizer = self.cast(normalizer, mstype.float32) matrix_A = self.mul(matrix_A, 1.0 / normalizer) if self.padA_flag: matrix_A = self.padA(matrix_A) @@ -330,6 +637,7 @@ class Dense_Thor(Cell): damping=0.03, loss_scale=1, frequency=278, + batch_size=32, has_bias=True, activation=None): super(Dense_Thor, self).__init__() @@ -337,6 +645,7 @@ class Dense_Thor(Cell): self.out_channels = check_int_positive(out_channels) self.has_bias = check_bool(has_bias) self.thor = True + self.batch_size = batch_size if isinstance(weight_init, Tensor): if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: @@ -376,8 +685,8 @@ class Dense_Thor(Cell): self.damping = Tensor(damping) self.loss_scale = Tensor(1 / loss_scale, mstype.float16) self.vector_matmul = P.CusBatchMatMul() - self.pad = P.Pad(((0, 24), (0, 24))) - self.pad1 = P.Pad(((0, 8), (0, 8))) + self.pad = P.Pad(((0, 23), (0, 23))) + self.pad1 = P.Pad(((0, 7), (0, 7))) self.slice = P.Slice() self.gather = P.GatherV2() self.assignadd = P.AssignAdd() @@ -385,7 +694,7 @@ class Dense_Thor(Cell): self.axis = 0 self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False) self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False) - self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000]) + self.fused_abs_max1 = P.CusFusedAbsMax1([1001, 1001]) self.fused_abs_max2 = P.CusFusedAbsMax1() self.log = P.Log() self.exp = P.Exp() @@ -402,7 +711,7 @@ class Dense_Thor(Cell): dout = self.mul(dout, 32.0) normalizer = 32 matrix_G = self.cube_matmul(dout, dout) - normalizer = self.cast(normalizer, ms.float32) + normalizer = self.cast(normalizer, mstype.float32) matrix_G = self.mul(matrix_G, 1.0 / normalizer) matrix_G = self.pad(matrix_G) damping_step = self.gather(self.damping, self.cov_step, 0) @@ -417,7 +726,7 @@ class Dense_Thor(Cell): matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max) self.G_inv_max = matrix_G_inv_max matrix_G_inv = self.matrix_combine(matrix_G_inv) - matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000)) + matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001)) matrix_G_inv = self.pad1(matrix_G_inv) matrix_G_inv_shape = self.shape(matrix_G_inv) matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16)) @@ -431,7 +740,7 @@ class Dense_Thor(Cell): if self.thor: inputs = self.cube_matmul(x, x) normalizer = 32 - normalizer = self.cast(normalizer, ms.float32) + normalizer = self.cast(normalizer, mstype.float32) matrix_A = self.mul(inputs, 1.0 / normalizer) damping_step = self.gather(self.damping, self.cov_step, self.axis) diff --git a/model_zoo/official/cv/resnet_thor/train.py b/model_zoo/official/cv/resnet_thor/train.py index b6a84fe136..d7c667dffa 100644 --- a/model_zoo/official/cv/resnet_thor/train.py +++ b/model_zoo/official/cv/resnet_thor/train.py @@ -12,44 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""train_imagenet.""" -import argparse +"""train resnet.""" import os import random - +import argparse import numpy as np -from mindspore import Tensor from mindspore import context -from mindspore.communication.management import init +from mindspore import Tensor +from mindspore import dataset as de from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.model import ParallelMode -from src.model_thor import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.communication.management import init, get_rank, get_group_size + +from src.model_thor import Model_Thor as Model from src.resnet_thor import resnet50 -from src.thor import THOR -from src.config import config +from src.dataset import create_dataset from src.crossentropy import CrossEntropy -from src.dataset_imagenet import create_dataset - -random.seed(1) -np.random.seed(1) parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') -parser.add_argument('--device_num', type=int, default=1, help='Device num.') -parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') -parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') - +parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') +parser.add_argument('--device_num', type=int, default=1, help='Device num') args_opt = parser.parse_args() -device_id = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) +if args_opt.device_target == "Ascend": + from src.thor import THOR + from src.config import config +else: + from src.thor import THOR_GPU as THOR + from src.config import config_gpu as config + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) -def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): +def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100): """get_model_lr""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs @@ -57,9 +59,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): epoch = (i + 1) / steps_per_epoch base = (1.0 - float(epoch) / total_epochs) ** decay lr_local = lr_init * base - if epoch >= 39: + if epoch >= decay_epochs: lr_local = lr_local * 0.5 - if epoch >= 40: + if epoch >= decay_epochs + 1: lr_local = lr_local * 0.5 lr_each_step.append(lr_local) current_step = global_step @@ -76,7 +78,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps epoch = (step + 1) / steps_per_epoch damping_here = damping_init * (decay_rate ** (epoch / 10)) damping_each_step.append(damping_here) - current_step = global_step damping_each_step = np.array(damping_each_step).astype(np.float32) damping_now = damping_each_step[current_step:] @@ -84,49 +85,70 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps if __name__ == '__main__': - if not args_opt.do_eval and args_opt.run_distribute: - context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - mirror_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") - auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") - - init() - - epoch_size = config.epoch_size - damping = get_model_damping(0, 0.03, 0.87, 50, 5004) + target = args_opt.device_target + ckpt_save_dir = config.save_checkpoint_path + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + + if args_opt.run_distribute: + # Ascend target + if target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id, enable_auto_mixed_precision=True) + context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") + auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") + init() + # GPU target + else: + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, + batch_size=config.batch_size, target=target) + + # define net + step_size = dataset.get_dataset_size() + damping = get_model_damping(0, config.damping_init, config.damping_decay, 90, step_size) + lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39) net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, - frequency=config.frequency) + frequency=config.frequency, batch_size=config.batch_size) - if not config.label_smooth: + # define loss, model + if not config.use_label_smooth: config.label_smooth_factor = 0.0 loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - if args_opt.do_train: - dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, - batch_size=config.batch_size) - step_size = dataset.get_dataset_size() - - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) - opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, - filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), - filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), - filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), - filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), - config.weight_decay, config.loss_scale) - + opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum, + filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), + filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), + filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), + filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), + config.weight_decay, config.loss_scale) + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + if target == "Ascend": model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) - - time_cb = TimeMonitor(data_size=step_size) - loss_cb = LossMonitor() - cb = [time_cb, loss_cb] - if config.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) - cb += [ckpt_cb] - - model.train(epoch_size, dataset, callbacks=cb) + else: + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + + # train model + model.train(config.epoch_size, dataset, callbacks=cb)