THOR optimizer for GPU

pull/4414/head
wangmin 5 years ago
parent ab84b2f18a
commit 87612bfcf2

@ -24,22 +24,24 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
. .
├── resnet_thor ├── resnet_thor
├── README.md ├── README.md
├── src ├──scripts
├── run_distribute_train.sh # launch distributed training for Ascend
└── run_eval.sh # launch infering for Ascend
├── run_distribute_train_gpu.sh # launch distributed training for GPU
└── run_eval_gpu.sh # launch infering for GPU
├──src
├── crossentropy.py # CrossEntropy loss function ├── crossentropy.py # CrossEntropy loss function
├── config.py # parameter configuration ├── config.py # parameter configuration
├── resnet50.py # resnet50 backbone
├── dataset_helper.py # dataset help for minddata dataset ├── dataset_helper.py # dataset help for minddata dataset
├── grad_reducer_thor.py # grad reducer for thor ├── grad_reducer_thor.py # grad reducer for thor
├── model_thor.py # model ├── model_thor.py # model for train
├── resnet_thor.py # resnet50_thor backone ├── resnet_thor.py # resnet50_thor backone
├── thor.py # thor ├── thor.py # thor optimizer
├── thor_layer.py # thor layer ├── thor_layer.py # thor layer
└── dataset_imagenet.py # data preprocessing └── dataset.py # data preprocessing
├── scripts
├── run_distribute_train.sh # launch distributed training(8 pcs)
└── run_eval.sh # launch infering
├── eval.py # infer script ├── eval.py # infer script
└── train.py # train script └── train.py # train script
``` ```
@ -48,26 +50,30 @@ This is an example of training ResNet-50 V1.5 with ImageNet2012 dataset by secon
Parameters for both training and inference can be set in config.py. Parameters for both training and inference can be set in config.py.
``` ```
"class_num": 1000, # dataset class number "class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor "batch_size": 32, # batch size of input tensor
"loss_scale": 128, # loss scale "loss_scale": 128, # loss scale
"momentum": 0.9, # momentum of THOR optimizer "momentum": 0.9, # momentum of THOR optimizer
"weight_decay": 5e-4, # weight decay "weight_decay": 5e-4, # weight decay
"epoch_size": 45, # only valid for taining, which is always 1 for inference "epoch_size": 45, # only valid for taining, which is always 1 for inference
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not "save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 5004, # the step interval between two checkpoints. By default, the checkpoint will be saved every epoch "save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the checkpoint will be saved every epoch
"keep_checkpoint_max": 20, # only keep the last keep_checkpoint_max checkpoint "keep_checkpoint_max": 15, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"label_smooth": True, # label smooth "label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor "label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0.045, # learning rate init value
"lr_decay": 6, # learning rate decay rate value
"lr_end_epoch": 70, # learning rate end epoch value
"damping_init": 0.03, # damping init value for Fisher information matrix
"damping_decay": 0.87, # damping decay rate
"frequency": 834, # the step interval to update second-order information matrix "frequency": 834, # the step interval to update second-order information matrix
``` ```
## Running the example ## Running the example
### 1 Running on Ascend 910
### Train ### Train
#### Usage #### Usage
@ -82,10 +88,10 @@ Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [DEVICE_NUM]
```bash ```bash
# distributed training example(8 pcs) # distributed training example(8 pcs)
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc 8
``` ```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training_ascend.html).
#### Result #### Result
@ -126,3 +132,35 @@ Inference result will be stored in the example path, whose folder name is "infer
``` ```
result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt result: {'acc': 0.759503041} ckpt=train_parallel0/resnet-42_5004.ckpt
``` ```
### 2 Running on GPU
### Train
```
# distributed training example
sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]
```
#### Result
```
# distribute training result(8 pcs)
epoch: 1 step: 5004, loss is 4.3069
epoch: 2 step: 5004, loss is 3.5695
epoch: 3 step: 5004, loss is 3.5893
epoch: 4 step: 5004, loss is 3.1987
epoch: 5 step: 5004, loss is 3.3526
......
epoch: 40 step: 5004, loss is 1.9482
epoch: 41 step: 5004, loss is 1.8950
epoch: 42 step: 5004, loss is 1.9023
......
```
### Infer
```
# infer example
sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Result
```
result: {'acc': 0.760143245838668} ckpt_0/resnet-40_5004.ckpt
```

@ -12,51 +12,64 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """train resnet."""
eval.
"""
import os import os
import random
import argparse import argparse
import numpy as np
from mindspore import context from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset_imagenet import create_dataset
from src.config import config
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from src.resnet50 import resnet50 from src.config import config
from src.dataset import create_dataset
from src.resnet_thor import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args() args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) random.seed(1)
np.random.seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) de.config.set_seed(1)
context.set_context(device_id=device_id)
if __name__ == '__main__': if __name__ == '__main__':
target = args_opt.device_target
net = resnet50(class_num=config.class_num) # init context
if not config.label_smooth: context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if target != "GPU":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
target=target)
# define net
net = resnet(class_num=config.class_num)
net.add_flags_recursive(thor=False)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
keys = list(param_dict.keys())
for key in keys:
if "damping" in key:
param_dict.pop(key)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_eval: # define model
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'acc'}) # eval model
res = model.eval(dataset) res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path) print("result:", res, "ckpt=", args_opt.checkpoint_path)

@ -52,6 +52,6 @@ do
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
cd .. cd ..
done done

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [DEVICE_NUM]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
ulimit -u unlimited
export DEVICE_NUM=$2
export RANK_SIZE=$2
rm -rf ./train_parallel
mkdir ./train_parallel
cp ../*.py ./train_parallel
cp *.sh ./train_parallel
cp -r ../src ./train_parallel
cd ./train_parallel || exit
mpirun -n $RANK_SIZE \
python train.py --run_distribute=True \
--device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log &

@ -20,6 +20,7 @@ then
exit 1 exit 1
fi fi
get_real_path(){ get_real_path(){
if [ "${1:0:1}" == "/" ]; then if [ "${1:0:1}" == "/" ]; then
echo "$1" echo "$1"
@ -44,9 +45,6 @@ then
exit 1 exit 1
fi fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=0 export DEVICE_ID=0
@ -58,10 +56,11 @@ then
rm -rf ./eval rm -rf ./eval
fi fi
mkdir ./eval mkdir ./eval
cp *.py ./eval cp ../*.py ./eval
cp -r ./src ./eval cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit cd ./eval || exit
env > env.log env > env.log
echo "start infering for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd .. cd ..

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
cd ..

@ -17,21 +17,46 @@ network config setting, will be used in train.py and eval.py
""" """
from easydict import EasyDict as ed from easydict import EasyDict as ed
# config for resnet50, imagenet2012, Ascend 910
config = ed({ config = ed({
"class_num": 1000, "class_num": 1001,
"batch_size": 32, "batch_size": 32,
"loss_scale": 128, "loss_scale": 128,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 5e-4, "weight_decay": 5e-4,
"epoch_size": 45, "epoch_size": 45,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_steps": 5004, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20, "keep_checkpoint_max": 15,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"label_smooth": 1, "use_label_smooth": True,
"label_smooth_factor": 0.1, "label_smooth_factor": 0.1,
"frequency": 834 "lr_init": 0.045,
"lr_decay": 6,
"lr_end_epoch": 70,
"damping_init": 0.03,
"damping_decay": 0.87,
"frequency": 834,
})
# config for resnet50, imagenet2012, GPU
config_gpu = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.04,
"lr_decay": 5,
"lr_end_epoch": 58,
"damping_init": 0.02,
"damping_decay": 0.87,
"frequency": 834,
}) })

@ -28,13 +28,10 @@ class CrossEntropy(_Loss):
self.onehot = P.OneHot() self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
# self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits() self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False) self.mean = P.ReduceMean(False)
def construct(self, logit, label): def construct(self, logit, label):
# one_hot_label = self.onehot(self.cast(label, mstype.int32),
# F.shape(logit)[1], self.on_value, self.off_value)、
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label) loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0) loss = self.mean(loss, 0)

@ -16,30 +16,36 @@
create train or eval dataset. create train or eval dataset.
""" """
import os import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2 import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.c_transforms as V_C from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
""" """
create a train or eval dataset create a train or eval imagenet2012 dataset for resnet50
Args: Args:
dataset_path(string): the path of dataset. dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval. do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1 repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32 batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns: Returns:
dataset dataset
""" """
device_num = int(os.getenv("RANK_SIZE")) if target == "Ascend":
rank_id = int(os.getenv("RANK_ID")) device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1: if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False) ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id) num_shards=device_num, shard_id=rank_id)
@ -47,29 +53,28 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
image_size = 224 image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255] std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train: if do_train:
transform_img = [ trans = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5), C.RandomHorizontalFlip(prob=0.5),
V_C.Normalize(mean=mean, std=std), C.Normalize(mean=mean, std=std),
V_C.HWC2CHW() C.HWC2CHW()
] ]
else: else:
transform_img = [ trans = [
V_C.Decode(), C.Decode(),
V_C.Resize((256, 256)), C.Resize(256),
V_C.CenterCrop(image_size), C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std), C.Normalize(mean=mean, std=std),
V_C.HWC2CHW() C.HWC2CHW()
] ]
# type_cast_op = C2.TypeCast(mstype.float16)
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8) type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
# ds = ds.shuffle(buffer_size=config.buffer_size) ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
@ -78,3 +83,18 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)
return ds return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

@ -13,34 +13,47 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dataset help for minddata dataset""" """Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool import math
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode import os
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
from mindspore._checkparam import check_bool, check_int
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full
def _send_data(dataset):
def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue.""" """Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'): if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__ exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send() exec_dataset.send(epoch_num)
dataset.__has_sent__ = True dataset.__has_sent__ = True
def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send(epoch_num)
class DatasetHelper: class DatasetHelper:
""" """
Help function to use the Minddata dataset. Help function to use the MindData dataset.
According to different context, change the iter of dataset, to use the same for loop in different context. According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts.
Note: Note:
The iter of DatasetHelper will give one epoch data. The iteration of DatasetHelper will provide one epoch data.
Args: Args:
dataset (DataSet): The dataset. dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
Default: True. sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.
Examples: Examples:
>>> dataset_helper = DatasetHelper(dataset) >>> dataset_helper = DatasetHelper(dataset)
@ -48,81 +61,116 @@ class DatasetHelper:
>>> outputs = network(*inputs) >>> outputs = network(*inputs)
""" """
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, sink_size, epoch_num)
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
def __iter__(self): def __iter__(self):
return self.iter.__iter__() return self.iter.__iter__()
# A temp solution for loop sink. Delete later # A temp solution for loop sink. Delete later
def types_shapes(self): def types_shapes(self):
"""Get the types and shapes from dataset on current config.""" """Get the types and shapes from dataset on the current configuration."""
return self.iter.types_shapes() return self.iter.types_shapes()
def loop_size(self): def sink_size(self):
"""Get loop_size for every iteration.""" """Get sink_size for each iteration."""
return self.iter.loop_size return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter: class _DatasetIter:
"""Base iter for dataset help""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
def __init__(self, dataset): self.dataset = dataset
self.loop_size = 1 self.sink_size = sink_size
if not hasattr(dataset, '__ME_INITED__'): self.sink_count = 1
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size() if not hasattr(dataset, '__TRANSFER_DATASET__'):
else: if hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset) _send_data(dataset, epoch_num)
else: else:
_send_data(dataset) _send_data_no_flag(dataset, epoch_num)
self.ind = 0 self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset = dataset self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
def __iter__(self): def __iter__(self):
self.ind = 0 self.index = 0
return self return self
def __next__(self): def __next__(self):
if self.ind >= self.loop_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
self.ind += 1 self.index += 1
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset): def get_sink_count(self, dataset):
loop_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
if dataset.get_dataset_size() % loop_size != 0: if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.') f'sink_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size() / loop_size) sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return loop_count return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)""" """Iter for context (device_target=Ascend)"""
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
def __init__(self, dataset, iter_first_order): super().__init__(dataset, sink_size, epoch_num)
super(_DatasetIterMSLoopSink, self).__init__(dataset) sink_count = 1
loop_size = dataset.__loop_size__ + iter_first_order if hasattr(dataset, '__loop_size__'):
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 loop_size = dataset.__loop_size__ + iter_first_order
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
# times the batch dimension of tensors for run. Now only support LoopSink. f'sink_size {loop_size} are not matched.')
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): sink_count = math.ceil(dataset.get_dataset_size() / loop_size) * 2
self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num() device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
@ -130,3 +178,16 @@ class _DatasetIterMSLoopSink(_DatasetIter):
return tuple() return tuple()
self.op = op self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

@ -174,10 +174,6 @@ class DistributedGradReducerThor(Cell):
datatypes = self.hyper_map(F.partial(_get_datatype), grads) datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean: new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad return new_grad

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -18,8 +18,9 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context
from src.thor_layer import Conv2d_Thor, Dense_Thor from src.thor_layer import Conv2d_Thor, Dense_Thor, Conv2d_Thor_GPU, Dense_Thor_GPU
def calculate_gain(nonlinearity, param=None): def calculate_gain(nonlinearity, param=None):
@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return np.random.normal(0, std, size=inputs_shape).astype(np.float32) return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode) fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a) gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu')
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 3, 3) weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, layer = Conv2d_Thor(in_channel, out_channel,
damping=damping, loss_scale=loss_scale, frequency=frequency) kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 1, 1) weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight, layer = Conv2d_Thor(in_channel, out_channel,
damping=damping, loss_scale=loss_scale, frequency=frequency) kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
weight_shape = (out_channel, in_channel, 7, 7) weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel, if context.get_context('device_target') == "Ascend":
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight, layer = Conv2d_Thor(in_channel, out_channel,
damping=damping, loss_scale=loss_scale, frequency=frequency) kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
else:
layer = Conv2d_Thor_GPU(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency, batch_size=batch_size)
return layer
def _bn(channel): def _bn(channel):
@ -120,14 +144,21 @@ def _bn(channel):
def _bn_last(channel): def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, damping, loss_scale, frequency): def _fc(in_channel, out_channel, damping, loss_scale, frequency, batch_size=32):
weight_shape = (out_channel, in_channel) weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, if context.get_context('device_target') == "Ascend":
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) layer = Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
else:
layer = Dense_Thor_GPU(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
return layer
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):
@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell):
stride=1, stride=1,
damping=0.03, damping=0.03,
loss_scale=1, loss_scale=1,
frequency=278): frequency=278,
batch_size=32):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale, self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(channel) self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale, self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn2 = _bn(channel) self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale, self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency) frequency=frequency, batch_size=batch_size)
self.bn3 = _bn_last(out_channel) self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU() self.relu = nn.ReLU()
@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell):
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale, damping=damping, loss_scale=loss_scale,
frequency=frequency), frequency=frequency,
batch_size=batch_size),
_bn(out_channel)]) _bn(out_channel)])
self.add = P.TensorAdd() self.add = P.TensorAdd()
@ -239,16 +272,19 @@ class ResNet(nn.Cell):
num_classes, num_classes,
damping, damping,
loss_scale, loss_scale,
frequency): frequency,
batch_size):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency) self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
self.bn1 = _bn(64) self.bn1 = _bn(64)
self.relu = P.ReLU() self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2) # self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block, self.layer1 = self._make_layer(block,
layer_nums[0], layer_nums[0],
@ -257,7 +293,8 @@ class ResNet(nn.Cell):
stride=strides[0], stride=strides[0],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer2 = self._make_layer(block, self.layer2 = self._make_layer(block,
layer_nums[1], layer_nums[1],
in_channel=in_channels[1], in_channel=in_channels[1],
@ -265,14 +302,16 @@ class ResNet(nn.Cell):
stride=strides[1], stride=strides[1],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer3 = self._make_layer(block, self.layer3 = self._make_layer(block,
layer_nums[2], layer_nums[2],
in_channel=in_channels[2], in_channel=in_channels[2],
out_channel=out_channels[2], out_channel=out_channels[2],
stride=strides[2], damping=damping, stride=strides[2], damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.layer4 = self._make_layer(block, self.layer4 = self._make_layer(block,
layer_nums[3], layer_nums[3],
in_channel=in_channels[3], in_channel=in_channels[3],
@ -280,14 +319,16 @@ class ResNet(nn.Cell):
stride=strides[3], stride=strides[3],
damping=damping, damping=damping,
loss_scale=loss_scale, loss_scale=loss_scale,
frequency=frequency) frequency=frequency,
batch_size=batch_size)
self.mean = P.ReduceMean(keep_dims=True) self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency) self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale,
frequency=frequency, batch_size=batch_size)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency): damping, loss_scale, frequency, batch_size):
""" """
Make stage network of ResNet. Make stage network of ResNet.
@ -307,12 +348,14 @@ class ResNet(nn.Cell):
layers = [] layers = []
resnet_block = block(in_channel, out_channel, stride=stride, resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block) layers.append(resnet_block)
for _ in range(1, layer_num): for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency) damping=damping, loss_scale=loss_scale, frequency=frequency,
batch_size=batch_size)
layers.append(resnet_block) layers.append(resnet_block)
return nn.SequentialCell(layers) return nn.SequentialCell(layers)
@ -321,7 +364,7 @@ class ResNet(nn.Cell):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
c1, _ = self.maxpool(x) c1 = self.maxpool(x)
c2 = self.layer1(c1) c2 = self.layer1(c1)
c3 = self.layer2(c2) c3 = self.layer2(c2)
@ -335,7 +378,7 @@ class ResNet(nn.Cell):
return out return out
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278, batch_size=32):
""" """
Get ResNet50 neural network. Get ResNet50 neural network.
@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
class_num, class_num,
damping, damping,
loss_scale, loss_scale,
frequency) frequency,
batch_size)

@ -12,27 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""momentum""" """THOR"""
import mindspore.common.dtype as mstype from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
op_add = P.AddN() op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay") apply_decay = C.MultitypeFuncGraph("apply_decay")
@ -46,6 +39,119 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
return gradient return gradient
@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
class THOR_GPU(Optimizer):
"""
THOR
"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max,
weight_decay=0.0, loss_scale=1.0, use_nesterov=False, decay_filter=lambda x: x.name not in []):
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
self.feature_map_new = [x ** 0.5 for x in self.feature_map]
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.assign = P.Assign()
self.mul = P.Mul()
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=128)
def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
new_grads = ()
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_A = self.mul(matrix_A, self.feature_map_new[i])
matrix_G = self.mul(matrix_G, self.feature_map_new[i])
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_allreduce = self.grad_reducer_thorA(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_thorG(matrix_G_allreduce)
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = matrix_A_allreduce[i]
matrix_G = matrix_G_allreduce[i]
g = self.update_gradient(matrix_G, g, matrix_A)
fake_A = self.assign(self.matrix_A[i], matrix_A)
fake_G = self.assign(self.matrix_G[i], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
else:
for i in range(54):
g = gradients[i * 3]
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
g = self.update_gradient(matrix_G, g, matrix_A)
if i == 53:
new_grads = new_grads + (g,)
else:
g = self.reshape(g, g_shape)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success
class THOR(Optimizer): class THOR(Optimizer):
"""THOR""" """THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0, def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
@ -195,5 +301,5 @@ class THOR(Optimizer):
params, gradients) params, gradients)
gradients = self.scale_grad(gradients) gradients = self.scale_grad(gradients)
lr = self.get_lr() lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success return success

File diff suppressed because it is too large Load Diff

@ -12,44 +12,46 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train_imagenet.""" """train resnet."""
import argparse
import os import os
import random import random
import argparse
import numpy as np import numpy as np
from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.communication.management import init from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from src.model_thor import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size
from src.model_thor import Model_Thor as Model
from src.resnet_thor import resnet50 from src.resnet_thor import resnet50
from src.thor import THOR from src.dataset import create_dataset
from src.config import config
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from src.dataset_imagenet import create_dataset
random.seed(1)
np.random.seed(1)
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_num', type=int, default=1, help='Device num')
args_opt = parser.parse_args() args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) if args_opt.device_target == "Ascend":
from src.thor import THOR
from src.config import config
else:
from src.thor import THOR_GPU as THOR
from src.config import config_gpu as config
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr""" """get_model_lr"""
lr_each_step = [] lr_each_step = []
total_steps = steps_per_epoch * total_epochs total_steps = steps_per_epoch * total_epochs
@ -57,9 +59,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
epoch = (i + 1) / steps_per_epoch epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base lr_local = lr_init * base
if epoch >= 39: if epoch >= decay_epochs:
lr_local = lr_local * 0.5 lr_local = lr_local * 0.5
if epoch >= 40: if epoch >= decay_epochs + 1:
lr_local = lr_local * 0.5 lr_local = lr_local * 0.5
lr_each_step.append(lr_local) lr_each_step.append(lr_local)
current_step = global_step current_step = global_step
@ -76,7 +78,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
epoch = (step + 1) / steps_per_epoch epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10)) damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here) damping_each_step.append(damping_here)
current_step = global_step current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32) damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:] damping_now = damping_each_step[current_step:]
@ -84,49 +85,70 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: target = args_opt.device_target
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, ckpt_save_dir = config.save_checkpoint_path
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") # init context
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") if args_opt.run_distribute:
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") # Ascend target
if target == "Ascend":
init() device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
epoch_size = config.epoch_size context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
damping = get_model_damping(0, 0.03, 0.87, 50, 5004) mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
init()
# GPU target
else:
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
# define net
step_size = dataset.get_dataset_size()
damping = get_model_damping(0, config.damping_init, config.damping_decay, 90, step_size)
lr = get_model_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency) frequency=config.frequency, batch_size=config.batch_size)
if not config.label_smooth: # define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train: opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), config.momentum,
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
batch_size=config.batch_size) filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
step_size = dataset.get_dataset_size() filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) config.weight_decay, config.loss_scale)
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004)) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, if target == "Ascend":
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency) keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency)
else:
time_cb = TimeMonitor(data_size=step_size) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
loss_cb = LossMonitor() amp_level="O2", keep_batchnorm_fp32=True, frequency=config.frequency)
cb = [time_cb, loss_cb]
if config.save_checkpoint: # define callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, time_cb = TimeMonitor(data_size=step_size)
keep_checkpoint_max=config.keep_checkpoint_max) loss_cb = LossMonitor()
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) cb = [time_cb, loss_cb]
cb += [ckpt_cb] if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
model.train(epoch_size, dataset, callbacks=cb) keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb)

Loading…
Cancel
Save