From 5db704dfe35381aa48fed5f2204ebfd32f2f156d Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 15 Dec 2020 17:21:14 +0800 Subject: [PATCH] update gpu benchmark --- model_zoo/official/cv/resnet/README.md | 126 +++++++++++------- .../cv/resnet/gpu_resnet_benchmark.py | 74 ++++++++-- .../scripts/run_eval_gpu_resnet_benckmark.sh | 51 +++++++ .../scripts/run_gpu_resnet_benchmark.sh | 23 +++- 4 files changed, 206 insertions(+), 68 deletions(-) create mode 100644 model_zoo/official/cv/resnet/scripts/run_eval_gpu_resnet_benckmark.sh diff --git a/model_zoo/official/cv/resnet/README.md b/model_zoo/official/cv/resnet/README.md index 7b0f3047eb..407586ab55 100644 --- a/model_zoo/official/cv/resnet/README.md +++ b/model_zoo/official/cv/resnet/README.md @@ -19,14 +19,16 @@ - [Description of Random Situation](#description-of-random-situation) - [ModelZoo Homepage](#modelzoo-homepage) - # [ResNet Description](#contents) + ## Description + ResNet (residual neural network) was proposed by Kaiming He and other four Chinese of Microsoft Research Institute. Through the use of ResNet unit, it successfully trained 152 layers of neural network, and won the championship in ilsvrc2015. The error rate on top 5 was 3.57%, and the parameter quantity was lower than vggnet, so the effect was very outstanding. Traditional convolution network or full connection network will have more or less information loss. At the same time, it will lead to the disappearance or explosion of gradient, which leads to the failure of deep network training. ResNet solves this problem to a certain extent. By passing the input information to the output, the integrity of the information is protected. The whole network only needs to learn the part of the difference between input and output, which simplifies the learning objectives and difficulties.The structure of ResNet can accelerate the training of neural network very quickly, and the accuracy of the model is also greatly improved. At the same time, ResNet is very popular, even can be directly used in the concept net network. These are examples of training ResNet50/ResNet101/SE-ResNet50 with CIFAR-10/ImageNet2012 dataset in MindSpore.ResNet50 and ResNet101 can reference [paper 1](https://arxiv.org/pdf/1512.03385.pdf) below, and SE-ResNet50 is a variant of ResNet50 which reference [paper 2](https://arxiv.org/abs/1709.01507) and [paper 3](https://arxiv.org/abs/1812.01187) below, Training SE-ResNet50 for just 24 epochs using 8 Ascend 910, we can reach top-1 accuracy of 75.9%.(Training ResNet101 with dataset CIFAR-10 and SE-ResNet50 with CIFAR-10 is not supported yet.) ## Paper + 1.[paper](https://arxiv.org/pdf/1512.03385.pdf):Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition" 2.[paper](https://arxiv.org/abs/1709.01507):Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu. "Squeeze-and-Excitation Networks" @@ -41,14 +43,15 @@ The overall network architecture of ResNet is show below: # [Dataset](#contents) Dataset used: [CIFAR-10]() + - Dataset size:60,000 32*32 colorful images in 10 classes - - Train:50,000 images - - Test: 10,000 images + - Train:50,000 images + - Test: 10,000 images - Data format:binary files - - Note:Data will be processed in dataset.py + - Note:Data will be processed in dataset.py - Download the dataset, the directory structure is as follows: -``` +```bash ├─cifar-10-batches-bin │ └─cifar-10-verify-bin @@ -57,13 +60,13 @@ Dataset used: [CIFAR-10]() Dataset used: [ImageNet2012](http://www.image-net.org/) - Dataset size 224*224 colorful images in 1000 classes - - Train:1,281,167 images - - Test: 50,000 images + - Train:1,281,167 images + - Test: 50,000 images - Data format:jpeg - - Note:Data will be processed in dataset.py + - Note:Data will be processed in dataset.py - Download the dataset, the directory structure is as follows: - ``` + ```bash └─dataset ├─ilsvrc # train dataset └─validation_preprocess # evaluate dataset @@ -79,21 +82,20 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil # [Environment Requirements](#contents) - Hardware(Ascend/GPU) - - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - - [MindSpore](https://www.mindspore.cn/install/en) + - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: - - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - - + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) # [Quick Start](#contents) After installing MindSpore via the official website, you can start training and evaluation as follows: - Running on Ascend -``` + +```bash # distributed training Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) @@ -106,7 +108,8 @@ Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [D ``` - Running on GPU -``` + +```bash # distributed training example sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) @@ -115,6 +118,9 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA # infer example sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] + +# gpu benchmark example +sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional) ``` # [Script Description](#contents) @@ -134,7 +140,8 @@ sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [C ├── run_parameter_server_train_gpu.sh # launch gpu parameter server training(8 pcs) ├── run_eval_gpu.sh # launch gpu evaluation ├── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs) - └── run_gpu_resnet_benchmark.sh # GPU benchmark for resnet50 with imagenet2012(1 pcs) + ├── run_gpu_resnet_benchmark.sh # launch gpu benchmark train for resnet50 with imagenet2012 + └── run_eval_gpu_resnet_benckmark.sh # launch gpu benchmark eval for resnet50 with imagenet2012 ├── src ├── config.py # parameter configuration ├── dataset.py # data preprocessing @@ -155,13 +162,13 @@ Parameters for both training and evaluation can be set in config.py. - Config for ResNet50, CIFAR-10 dataset -``` +```bash "class_num": 10, # dataset class num "batch_size": 32, # batch size of input tensor "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum -"weight_decay": 1e-4, # weight decay -"epoch_size": 90, # only valid for taining, which is always 1 for inference +"weight_decay": 1e-4, # weight decay +"epoch_size": 90, # only valid for taining, which is always 1 for inference "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size "save_checkpoint": True, # whether save checkpoint or not "save_checkpoint_epochs": 5, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last step @@ -176,13 +183,13 @@ Parameters for both training and evaluation can be set in config.py. - Config for ResNet50, ImageNet2012 dataset -``` +```bash "class_num": 1001, # dataset class number "batch_size": 256, # batch size of input tensor "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum optimizer -"weight_decay": 1e-4, # weight decay -"epoch_size": 90, # only valid for taining, which is always 1 for inference +"weight_decay": 1e-4, # weight decay +"epoch_size": 90, # only valid for taining, which is always 1 for inference "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size "save_checkpoint": True, # whether save checkpoint or not "save_checkpoint_epochs": 5, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch @@ -199,7 +206,7 @@ Parameters for both training and evaluation can be set in config.py. - Config for ResNet101, ImageNet2012 dataset -``` +```bash "class_num": 1001, # dataset class number "batch_size": 32, # batch size of input tensor "loss_scale": 1024, # loss scale @@ -220,7 +227,7 @@ Parameters for both training and evaluation can be set in config.py. - Config for SE-ResNet50, ImageNet2012 dataset -``` +```bash "class_num": 1001, # dataset class number "batch_size": 32, # batch size of input tensor "loss_scale": 1024, # loss scale @@ -245,8 +252,10 @@ Parameters for both training and evaluation can be set in config.py. ## [Training Process](#contents) ### Usage + #### Running on Ascend -``` + +```bash # distributed training Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) @@ -258,6 +267,7 @@ Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imag Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] ``` + For distributed training, a hccl configuration file with JSON format needs to be created in advance. Please follow the instructions in the link [hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). @@ -266,7 +276,7 @@ Training result will be stored in the example path, whose folder name begins wit #### Running on GPU -``` +```bash # distributed training example sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) @@ -276,20 +286,28 @@ sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATA # infer example sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] -# gpu benchmark example -sh run_gpu_resnet_benchmark.sh [IMAGENET_DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) +# gpu benchmark training example +sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional) [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional) + +# gpu benckmark infer example +sh run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) [DTYPE](optional) ``` +For distributed training, a hostfile configuration needs to be created in advance. + +Please follow the instructions in the link [GPU-Multi-Host](https://www.mindspore.cn/tutorial/training/zh-CN/r1.0/advanced_use/distributed_training_gpu.html). + #### Running parameter server mode training - Parameter server training Ascend example -``` +```bash sh run_parameter_server_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) ``` - Parameter server training GPU example -``` + +```bash sh run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) ``` @@ -297,7 +315,7 @@ sh run_parameter_server_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] - Training ResNet50 with CIFAR-10 dataset -``` +```bash # distribute training result(8 pcs) epoch: 1 step: 195, loss is 1.9601055 epoch: 2 step: 195, loss is 1.8555021 @@ -309,7 +327,7 @@ epoch: 5 step: 195, loss is 1.393667 - Training ResNet50 with ImageNet2012 dataset -``` +```bash # distribute training result(8 pcs) epoch: 1 step: 5004, loss is 4.8995576 epoch: 2 step: 5004, loss is 3.9235563 @@ -321,7 +339,7 @@ epoch: 5 step: 5004, loss is 3.1978393 - Training ResNet101 with ImageNet2012 dataset -``` +```bash # distribute training result(8 pcs) epoch: 1 step: 5004, loss is 4.805483 epoch: 2 step: 5004, loss is 3.2121816 @@ -330,9 +348,10 @@ epoch: 4 step: 5004, loss is 3.3667371 epoch: 5 step: 5004, loss is 3.1718972 ... ``` + - Training SE-ResNet50 with ImageNet2012 dataset -``` +```bash # distribute training result(8 pcs) epoch: 1 step: 5004, loss is 5.1779146 epoch: 2 step: 5004, loss is 4.139395 @@ -341,9 +360,10 @@ epoch: 4 step: 5004, loss is 3.5011306 epoch: 5 step: 5004, loss is 3.3501816 ... ``` + - GPU Benchmark of ResNet50 with ImageNet2012 dataset -``` +```bash # ========START RESNET50 GPU BENCHMARK======== Epoch time: 12416.098 ms, fps: 412 img/sec. epoch: 1 step: 20, loss is 6.940182 Epoch time: 3472.037 ms, fps: 1474 img/sec. epoch: 2 step: 20, loss is 7.078993 @@ -352,17 +372,19 @@ Epoch time: 3460.311 ms, fps: 1479 img/sec. epoch: 4 step: 20, loss is 6.920937 Epoch time: 3460.543 ms, fps: 1479 img/sec. epoch: 5 step: 20, loss is 6.814013 ... ``` + ## [Evaluation Process](#contents) ### Usage #### Running on Ascend -``` + +```bash # evaluation Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] ``` -``` +```bash # evaluation example sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt ``` @@ -370,7 +392,8 @@ sh run_eval.sh resnet50 cifar10 ~/cifar10-10-verify-bin ~/resnet50_cifar10/train > checkpoint can be produced in training process. #### Running on GPU -``` + +```bash sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH] ``` @@ -380,35 +403,37 @@ Evaluation result will be stored in the example path, whose folder name is "eval - Evaluating ResNet50 with CIFAR-10 dataset -``` +```bash result: {'acc': 0.91446314102564111} ckpt=~/resnet50_cifar10/train_parallel0/resnet-90_195.ckpt ``` - Evaluating ResNet50 with ImageNet2012 dataset -``` +```bash result: {'acc': 0.7671054737516005} ckpt=train_parallel0/resnet-90_5004.ckpt ``` - Evaluating ResNet101 with ImageNet2012 dataset -``` +```bash result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199744} ckpt=train_parallel0/resnet-120_5004.ckpt ``` - Evaluating SE-ResNet50 with ImageNet2012 dataset -``` +```bash result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.768065781049936} ckpt=train_parallel0/resnet-24_5004.ckpt ``` # [Model Description](#contents) + ## [Performance](#contents) -### Evaluation Performance +### Evaluation Performance #### ResNet50 on CIFAR-10 + | Parameters | Ascend 910 | GPU | | -------------------------- | -------------------------------------- |---------------------------------- | | Model Version | ResNet50-v1.5 |ResNet50-v1.5| @@ -428,6 +453,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) | #### ResNet50 on ImageNet2012 + | Parameters | Ascend 910 | GPU | | -------------------------- | -------------------------------------- |---------------------------------- | | Model Version | ResNet50-v1.5 |ResNet50-v1.5| @@ -447,6 +473,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnet) | #### ResNet101 on ImageNet2012 + | Parameters | Ascend 910 | GPU | | -------------------------- | -------------------------------------- |---------------------------------- | | Model Version | ResNet101 |ResNet101| @@ -488,6 +515,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 ### Inference Performance #### ResNet50 on CIFAR-10 + | Parameters | Ascend | GPU | | ------------------- | --------------------------- | --------------------------- | | Model Version | ResNet50-v1.5 | ResNet50-v1.5 | @@ -501,6 +529,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Model for inference | 91M (.air file) | | #### ResNet50 on ImageNet2012 + | Parameters | Ascend | GPU | | ------------------- | --------------------------- | --------------------------- | | Model Version | ResNet50-v1.5 | ResNet50-v1.5 | @@ -514,6 +543,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Model for inference | 98M (.air file) | | #### ResNet101 on ImageNet2012 + | Parameters | Ascend | GPU | | ------------------- | --------------------------- | --------------------------- | | Model Version | ResNet101 | ResNet101 | @@ -527,6 +557,7 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Model for inference | 171M (.air file) | | #### SE-ResNet50 on ImageNet2012 + | Parameters | Ascend | | ------------------- | --------------------------- | | Model Version | SE-ResNet50 | @@ -539,11 +570,10 @@ result: {'top_5_accuracy': 0.9342589628681178, 'top_1_accuracy': 0.7680657810499 | Accuracy | 76.80% | | Model for inference | 109M (.air file) | - # [Description of Random Situation](#contents) In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. - # [ModelZoo Homepage](#contents) - Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file + + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py index 757b9564bb..f678476ed9 100644 --- a/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py +++ b/model_zoo/official/cv/resnet/gpu_resnet_benchmark.py @@ -22,23 +22,28 @@ from mindspore import Tensor from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model from mindspore.context import ParallelMode -from mindspore.train.callback import Callback, LossMonitor -from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits +from mindspore.train.callback import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.communication.management import init, get_group_size +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed import mindspore.nn as nn import mindspore.common.initializer as weight_init import mindspore.dataset.engine as de import mindspore.dataset.vision.c_transforms as C from src.resnet_gpu_benchmark import resnet50 as resnet +from src.CrossEntropySmooth import CrossEntropySmooth parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.') parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2') parser.add_argument('--print_per_steps', type=str, default="20", help='Print loss and time per steps: default 20') parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute') +parser.add_argument('--save_ckpt', type=ast.literal_eval, default=False, help='Save ckpt or not: default False') +parser.add_argument('--eval', type=ast.literal_eval, default=False, help='Eval ckpt : default False') parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path') +parser.add_argument('--ckpt_path', type=str, default="./", help='The path to save ckpt if save_ckpt is True;\ + Or the ckpt model file when eval is True') parser.add_argument('--mode', type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"], help='Execute mode') parser.add_argument('--dtype', type=str, choices=["fp32", "fp16", "FP16", "FP32"], default="fp16",\ help='Compute data type fp32 or fp16: default fp16') @@ -107,14 +112,16 @@ def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per lr_each_step = np.array(lr_each_step).astype(np.float32) return lr_each_step -if __name__ == '__main__': +def train(): # set args dev = "GPU" epoch_size = int(args_opt.epoch_size) total_batch = int(args_opt.batch_size) print_per_steps = int(args_opt.print_per_steps) compute_type = str(args_opt.dtype).lower() - + ckpt_save_dir = str(args_opt.ckpt_path) + save_ckpt = bool(args_opt.save_ckpt) + device_num = 1 # init context if args_opt.mode == "GRAPH": mode = context.GRAPH_MODE @@ -123,12 +130,14 @@ if __name__ == '__main__': context.set_context(mode=mode, device_target=dev, save_graphs=False) if args_opt.run_distribute: init() - context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, + device_num = get_group_size() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, all_reduce_fusion_config=[85, 160]) + ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/" # create dataset dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, - batch_size=total_batch, target=dev, dtype=compute_type) + batch_size=total_batch, target=dev, dtype=compute_type, device_num=device_num) step_size = dataset.get_dataset_size() if (print_per_steps > step_size or print_per_steps < 1): print("Arg: print_per_steps should lessequal to dataset_size ", step_size) @@ -162,16 +171,14 @@ if __name__ == '__main__': else: no_decayed_params.append(param) - group_params = [{'params': decayed_params, 'weight_decay': 1e-4}, - {'params': no_decayed_params}, - {'order_params': net.trainable_params()}] # define loss, model - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) + loss = CrossEntropySmooth(sparse=True, reduction='mean', smooth_factor=0.1, num_classes=1001) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4) loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False) - model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) # Mixed precision if compute_type == "fp16": + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False) # define callbacks @@ -180,10 +187,49 @@ if __name__ == '__main__': time_cb = MyTimeMonitor(total_batch, print_per_steps) loss_cb = LossMonitor() cb = [time_cb, loss_cb] - + if save_ckpt: + config_ck = CheckpointConfig(save_checkpoint_steps=5 * step_size, keep_checkpoint_max=5) + ckpt_cb = ModelCheckpoint(prefix="resnet_benchmark", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] # train model print("========START RESNET50 GPU BENCHMARK========") if mode == context.GRAPH_MODE: model.train(int(epoch_size * step_size / print_per_steps), dataset, callbacks=cb, sink_size=print_per_steps) else: model.train(epoch_size, dataset, callbacks=cb) + +def eval_(): + # set args + dev = "GPU" + compute_type = str(args_opt.dtype).lower() + ckpt_dir = str(args_opt.ckpt_path) + total_batch = int(args_opt.batch_size) + # init context + if args_opt.mode == "GRAPH": + mode = context.GRAPH_MODE + else: + mode = context.PYNATIVE_MODE + context.set_context(mode=mode, device_target=dev, save_graphs=False) + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, repeat_num=1, + batch_size=total_batch, target=dev, dtype=compute_type) + # define net + net = resnet(class_num=1001, dtype=compute_type) + # load checkpoint + param_dict = load_checkpoint(ckpt_dir) + load_param_into_net(net, param_dict) + net.set_train(False) + # define loss, model + loss = CrossEntropySmooth(sparse=True, reduction='mean', smooth_factor=0.1, num_classes=1001) + # define model + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + # eval model + print("========START EVAL RESNET50 ON GPU ========") + res = model.eval(dataset) + print("result:", res, "ckpt=", ckpt_dir) + +if __name__ == '__main__': + if not args_opt.eval: + train() + else: + eval_() diff --git a/model_zoo/official/cv/resnet/scripts/run_eval_gpu_resnet_benckmark.sh b/model_zoo/official/cv/resnet/scripts/run_eval_gpu_resnet_benckmark.sh new file mode 100644 index 0000000000..71e6f39eb6 --- /dev/null +++ b/model_zoo/official/cv/resnet/scripts/run_eval_gpu_resnet_benckmark.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ] +then + echo "Usage: sh run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) \ + [DTYPE](optional)" + echo "Example: sh run_eval_gpu_resnet_benchmark.sh /path/imagenet/train /path/ckpt 256 FP16" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATAPATH=$(get_real_path $1) +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +if [ $# == 2 ] +then + python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --eval=True --ckpt_path=$2 +fi + +if [ $# == 3 ] +then + python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --eval=True --ckpt_path=$2 \ + --batch_size=$3 +fi + +if [ $# == 4 ] +then + python ${self_path}/../gpu_resnet_benchmark.py--dataset_path=$DATAPATH --eval=True --ckpt_path=$2 \ + --batch_size=$3 --dtype=$4 +fi diff --git a/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh b/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh index c138900c98..705604988d 100644 --- a/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh +++ b/model_zoo/official/cv/resnet/scripts/run_gpu_resnet_benchmark.sh @@ -14,11 +14,11 @@ # limitations under the License. # ============================================================================ -if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] +if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ] then echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional)\ - [DEVICE_NUM](optional)" - echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 FP16 8" + [DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)" + echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 FP16 8 true /path/ckpt" exit 1 fi @@ -45,12 +45,23 @@ fi if [ $# == 3 ] then - python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True --dtype=$3 \ - --dataset_path=$DATAPATH --batch_size=$2 + python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 fi if [ $# == 4 ] then mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \ --dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 -fi \ No newline at end of file +fi + +if [ $# == 5 ] +then + mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \ + --dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 --save_ckpt=$5 +fi + +if [ $# == 6 ] +then + mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \ + --dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 --save_ckpt=$5 --ckpt_path=$6 +fi