add gpu resent benchmark

pull/7909/head
VectorSL 4 years ago
parent 66dd2730b5
commit 4d3d9c1b85

@ -133,17 +133,20 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs)
├── run_eval_gpu.sh # launch gpu evaluation
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
├── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
└── run_gpu_resnet_benchmark.sh # GPU benchmark for resnet50 with imagenet2012(1 pcs)
├── src
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset
├── lr_generator.py # generate learning rate for each step
└── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50
├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50
└── resnet_gpu_benchmark.py # resnet50 for GPU benchmark
├── export.py # export model for inference
├── mindspore_hub_conf.py # mindspore hub interface
├── eval.py # eval net
└── train.py # train net
├── train.py # train net
└── gpu_resent_benchmark.py # GPU benchmark for resnet50
```
## [Script Parameters](#contents)
@ -272,6 +275,9 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA
# infer example
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
# gpu benchmark example
sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional)
```
#### Running parameter server mode training
@ -335,7 +341,22 @@ epoch: 4 step: 5004, loss is 3.5011306
epoch: 5 step: 5004, loss is 3.3501816
...
```
- GPU Benchmark of ResNet50 with ImageNet2012 dataset
```
# ========START RESNET50 GPU BENCHMARK========
step time: 22549.130 ms, fps: 11 img/sec. epoch: 1 step: 1, loss is 6.940182
step time: 182.485 ms, fps: 1402 img/sec. epoch: 1 step: 2, loss is 7.078993
step time: 175.263 ms, fps: 1460 img/sec. epoch: 1 step: 3, loss is 7.559594
step time: 174.775 ms, fps: 1464 img/sec. epoch: 1 step: 4, loss is 8.020937
step time: 175.564 ms, fps: 1458 img/sec. epoch: 1 step: 5, loss is 8.140132
step time: 175.438 ms, fps: 1459 img/sec. epoch: 1 step: 6, loss is 8.021118
step time: 175.760 ms, fps: 1456 img/sec. epoch: 1 step: 7, loss is 7.910158
step time: 176.033 ms, fps: 1454 img/sec. epoch: 1 step: 8, loss is 7.940162
step time: 175.995 ms, fps: 1454 img/sec. epoch: 1 step: 9, loss is 7.740654
step time: 175.313 ms, fps: 1460 img/sec. epoch: 1 step: 10, loss is 7.956182
...
```
## [Evaluation Process](#contents)
### Usage

@ -0,0 +1,160 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train resnet."""
import argparse
import time
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.train.callback import Callback, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from src.resnet_gpu_benchmark import resnet50 as resnet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.')
parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2')
parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path')
args_opt = parser.parse_args()
set_seed(1)
class MyTimeMonitor(Callback):
def __init__(self, batch_size):
super(MyTimeMonitor, self).__init__()
self.batch_size = batch_size
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
fps = self.batch_size / step_mseconds *1000
print("step time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ")
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU"):
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
ds = ds.map(operations=C2.PadEnd(pad_shape=[224, 224, 4], pad_value=0), input_columns="image",
num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr_ = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr_ = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr_)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
if __name__ == '__main__':
dev = "GPU"
epoch_size = int(args_opt.epoch_size)
total_batch = int(args_opt.batch_size)
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=dev, save_graphs=False)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=total_batch, target=dev)
step_size = dataset.get_dataset_size()
# define net
net = resnet(class_num=1001)
# init weight
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
# init lr
lr = get_liner_lr(lr_init=0, lr_end=0, lr_max=0.8, warmup_epochs=0, total_epochs=epoch_size,
steps_per_epoch=step_size)
lr = Tensor(lr)
# define opt
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': 1e-4},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False)
# Mixed precision
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks
time_cb = MyTimeMonitor(total_batch)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
# train model
print("========START RESNET50 GPU BENCHMARK========")
model.train(epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size())

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ]
then
echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATAPATH=$(get_real_path $1)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
if [ $# == 1 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH
fi
if [ $# == 2 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2
fi

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save