From 4d3d9c1b851db1454b49fbe9d11265f4b63be843 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 28 Oct 2020 11:14:32 +0800 Subject: [PATCH] add gpu resent benchmark --- model_zoo/official/cv/resnet/README.md | 27 +- .../cv/resnet/gpu_resnet_benchmark.py | 160 +++++++++++ .../scripts/run_gpu_resnet_benchmark.sh | 42 +++ .../cv/resnet/src/resnet_gpu_benchmark.py | 258 ++++++++++++++++++ 4 files changed, 484 insertions(+), 3 deletions(-) create mode 100644 model_zoo/official/cv/resnet/gpu_resnet_benchmark.py create mode 100644 model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh create mode 100644 model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py diff --git a/model_zoo/official/cv/resnet/README.md b/model_zoo/official/cv/resnet/README.md index 246f8b390c..caa585ac1e 100644 --- a/model_zoo/official/cv/resnet/README.md +++ b/model_zoo/official/cv/resnet/README.md @@ -133,17 +133,20 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C ├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs) ├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs) ├── run_eval_gpu.sh # launch gpu evaluation - └── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) + ├── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) + └── run_gpu_resnet_benchmark.sh # GPU benchmark for resnet50 with imagenet2012(1 pcs) ├── src ├── config.py # parameter configuration ├── dataset.py # data preprocessing ├── CrossEntropySmooth.py # loss definition for ImageNet2012 dataset ├── lr_generator.py # generate learning rate for each step - └── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 + ├── resnet.py # resnet backbone, including resnet50 and resnet101 and se-resnet50 + └── resnet_gpu_benchmark.py # resnet50 for GPU benchmark ├── export.py # export model for inference ├── mindspore_hub_conf.py # mindspore hub interface ├── eval.py # eval net - └── train.py # train net + ├── train.py # train net + └── gpu_resent_benchmark.py # GPU benchmark for resnet50 ``` ## [Script Parameters](#contents) @@ -272,6 +275,9 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA # infer example sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] + +# gpu benchmark example +sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional) ``` #### Running parameter server mode training @@ -335,7 +341,22 @@ epoch: 4 step: 5004, loss is 3.5011306 epoch: 5 step: 5004, loss is 3.3501816 ... ``` +- GPU Benchmark of ResNet50 with ImageNet2012 dataset +``` +# ========START RESNET50 GPU BENCHMARK======== +step time: 22549.130 ms, fps: 11 img/sec. epoch: 1 step: 1, loss is 6.940182 +step time: 182.485 ms, fps: 1402 img/sec. epoch: 1 step: 2, loss is 7.078993 +step time: 175.263 ms, fps: 1460 img/sec. epoch: 1 step: 3, loss is 7.559594 +step time: 174.775 ms, fps: 1464 img/sec. epoch: 1 step: 4, loss is 8.020937 +step time: 175.564 ms, fps: 1458 img/sec. epoch: 1 step: 5, loss is 8.140132 +step time: 175.438 ms, fps: 1459 img/sec. epoch: 1 step: 6, loss is 8.021118 +step time: 175.760 ms, fps: 1456 img/sec. epoch: 1 step: 7, loss is 7.910158 +step time: 176.033 ms, fps: 1454 img/sec. epoch: 1 step: 8, loss is 7.940162 +step time: 175.995 ms, fps: 1454 img/sec. epoch: 1 step: 9, loss is 7.740654 +step time: 175.313 ms, fps: 1460 img/sec. epoch: 1 step: 10, loss is 7.956182 +... +``` ## [Evaluation Process](#contents) ### Usage diff --git a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py new file mode 100644 index 0000000000..f22d744d1a --- /dev/null +++ b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py @@ -0,0 +1,160 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnet.""" +import argparse +import time +import numpy as np +from mindspore import context +from mindspore import Tensor +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.train.callback import Callback, LossMonitor +from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.common import set_seed +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from src.resnet_gpu_benchmark import resnet50 as resnet + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.') +parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2') +parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path') +args_opt = parser.parse_args() + +set_seed(1) + +class MyTimeMonitor(Callback): + def __init__(self, batch_size): + super(MyTimeMonitor, self).__init__() + self.batch_size = batch_size + def step_begin(self, run_context): + self.step_time = time.time() + def step_end(self, run_context): + step_mseconds = (time.time() - self.step_time) * 1000 + fps = self.batch_size / step_mseconds *1000 + print("step time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ") + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU"): + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) + ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + ds = ds.map(operations=C2.PadEnd(pad_shape=[224, 224, 4], pad_value=0), input_columns="image", + num_parallel_workers=8) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + +def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch): + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + + for i in range(total_steps): + if i < warmup_steps: + lr_ = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr_ = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr_) + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + +if __name__ == '__main__': + dev = "GPU" + epoch_size = int(args_opt.epoch_size) + total_batch = int(args_opt.batch_size) + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=dev, save_graphs=False) + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, + batch_size=total_batch, target=dev) + step_size = dataset.get_dataset_size() + + # define net + net = resnet(class_num=1001) + + # init weight + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + + # init lr + lr = get_liner_lr(lr_init=0, lr_end=0, lr_max=0.8, warmup_epochs=0, total_epochs=epoch_size, + steps_per_epoch=step_size) + lr = Tensor(lr) + + # define opt + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': 1e-4}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + # define loss, model + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) + loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False) + # Mixed precision + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False) + + # define callbacks + time_cb = MyTimeMonitor(total_batch) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + + # train model + print("========START RESNET50 GPU BENCHMARK========") + model.train(epoch_size, dataset, callbacks=cb, sink_size=dataset.get_dataset_size()) diff --git a/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh b/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh new file mode 100644 index 0000000000..229f3134a3 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ] && [ $# != 2 ] +then + echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional)" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATAPATH=$(get_real_path $1) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +if [ $# == 1 ] +then + python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH +fi + +if [ $# == 2 ] +then + python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2 +fi diff --git a/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py b/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py new file mode 100644 index 0000000000..4fb343ddcd --- /dev/null +++ b/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py @@ -0,0 +1,258 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +from scipy.stats import truncnorm + +def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): + fan_in = in_channel * kernel_size * kernel_size + scale = 1.0 + scale /= max(1., fan_in) + stddev = (scale ** 0.5) / .87962566103423978 + mu, sigma = 0, stddev + weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) + weight = np.reshape(weight, (out_channel, kernel_size, kernel_size, in_channel)) + return Tensor(weight, dtype=mstype.float32) + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, 3, 3, in_channel) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, + padding=1, pad_mode='pad', weight_init=weight, data_format="NHWC") + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, 1, 1, in_channel) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, + padding=0, pad_mode='pad', weight_init=weight, data_format="NHWC") + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, 7, 7, in_channel) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride, + padding=3, pad_mode='pad', weight_init=weight, data_format="NHWC") + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, + moving_mean_init=0, moving_var_init=1, data_format="NHWC") + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=0, beta_init=0, + moving_mean_init=0, moving_var_init=1, data_format="NHWC") + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + self.stride = stride + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.bn1 = _bn(channel) + self.conv2 = _conv3x3(channel, channel, stride=stride) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1) + self.bn3 = _bn_last(out_channel) + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.conv1 = _conv7x7(4, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same", data_format="NHWC") + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.avg_pool = P.AvgPool(7, 1, data_format="NHWC") + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.avg_pool(c5) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=1001): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num)