!9976 GPU update benchmark

From: @VectorSL
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
pull/9976/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 56e17168e2

File diff suppressed because it is too large Load Diff

@ -22,23 +22,28 @@ from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback, LossMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_group_size
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
from src.resnet_gpu_benchmark import resnet50 as resnet
from src.CrossEntropySmooth import CrossEntropySmooth
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--batch_size', type=str, default="256", help='Batch_size: default 256.')
parser.add_argument('--epoch_size', type=str, default="2", help='Epoch_size: default 2')
parser.add_argument('--print_per_steps', type=str, default="20", help='Print loss and time per steps: default 20')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--save_ckpt', type=ast.literal_eval, default=False, help='Save ckpt or not: default False')
parser.add_argument('--eval', type=ast.literal_eval, default=False, help='Eval ckpt : default False')
parser.add_argument('--dataset_path', type=str, default=None, help='Imagenet dataset path')
parser.add_argument('--ckpt_path', type=str, default="./", help='The path to save ckpt if save_ckpt is True;\
Or the ckpt model file when eval is True')
parser.add_argument('--mode', type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"], help='Execute mode')
parser.add_argument('--dtype', type=str, choices=["fp32", "fp16", "FP16", "FP32"], default="fp16",\
help='Compute data type fp32 or fp16: default fp16')
@ -107,14 +112,16 @@ def get_liner_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
if __name__ == '__main__':
def train():
# set args
dev = "GPU"
epoch_size = int(args_opt.epoch_size)
total_batch = int(args_opt.batch_size)
print_per_steps = int(args_opt.print_per_steps)
compute_type = str(args_opt.dtype).lower()
ckpt_save_dir = str(args_opt.ckpt_path)
save_ckpt = bool(args_opt.save_ckpt)
device_num = 1
# init context
if args_opt.mode == "GRAPH":
mode = context.GRAPH_MODE
@ -123,12 +130,14 @@ if __name__ == '__main__':
context.set_context(mode=mode, device_target=dev, save_graphs=False)
if args_opt.run_distribute:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
device_num = get_group_size()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[85, 160])
ckpt_save_dir = ckpt_save_dir + "ckpt_" + str(get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=total_batch, target=dev, dtype=compute_type)
batch_size=total_batch, target=dev, dtype=compute_type, device_num=device_num)
step_size = dataset.get_dataset_size()
if (print_per_steps > step_size or print_per_steps < 1):
print("Arg: print_per_steps should lessequal to dataset_size ", step_size)
@ -162,16 +171,14 @@ if __name__ == '__main__':
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': 1e-4},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
loss = CrossEntropySmooth(sparse=True, reduction='mean', smooth_factor=0.1, num_classes=1001)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4)
loss_scale = FixedLossScaleManager(1024, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Mixed precision
if compute_type == "fp16":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 1e-4, 1024)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks
@ -180,10 +187,49 @@ if __name__ == '__main__':
time_cb = MyTimeMonitor(total_batch, print_per_steps)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if save_ckpt:
config_ck = CheckpointConfig(save_checkpoint_steps=5 * step_size, keep_checkpoint_max=5)
ckpt_cb = ModelCheckpoint(prefix="resnet_benchmark", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
print("========START RESNET50 GPU BENCHMARK========")
if mode == context.GRAPH_MODE:
model.train(int(epoch_size * step_size / print_per_steps), dataset, callbacks=cb, sink_size=print_per_steps)
else:
model.train(epoch_size, dataset, callbacks=cb)
def eval_():
# set args
dev = "GPU"
compute_type = str(args_opt.dtype).lower()
ckpt_dir = str(args_opt.ckpt_path)
total_batch = int(args_opt.batch_size)
# init context
if args_opt.mode == "GRAPH":
mode = context.GRAPH_MODE
else:
mode = context.PYNATIVE_MODE
context.set_context(mode=mode, device_target=dev, save_graphs=False)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, repeat_num=1,
batch_size=total_batch, target=dev, dtype=compute_type)
# define net
net = resnet(class_num=1001, dtype=compute_type)
# load checkpoint
param_dict = load_checkpoint(ckpt_dir)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
loss = CrossEntropySmooth(sparse=True, reduction='mean', smooth_factor=0.1, num_classes=1001)
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
print("========START EVAL RESNET50 ON GPU ========")
res = model.eval(dataset)
print("result:", res, "ckpt=", ckpt_dir)
if __name__ == '__main__':
if not args_opt.eval:
train()
else:
eval_()

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_eval_gpu_resnet_benchmark.sh [DATASET_PATH] [CKPT_PATH] [BATCH_SIZE](optional) \
[DTYPE](optional)"
echo "Example: sh run_eval_gpu_resnet_benchmark.sh /path/imagenet/train /path/ckpt 256 FP16"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATAPATH=$(get_real_path $1)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
if [ $# == 2 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --eval=True --ckpt_path=$2
fi
if [ $# == 3 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --eval=True --ckpt_path=$2 \
--batch_size=$3
fi
if [ $# == 4 ]
then
python ${self_path}/../gpu_resnet_benchmark.py--dataset_path=$DATAPATH --eval=True --ckpt_path=$2 \
--batch_size=$3 --dtype=$4
fi

@ -14,11 +14,11 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ]
if [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_gpu_resnet_benchmark.sh [DATASET_PATH] [BATCH_SIZE](optional) [DTYPE](optional)\
[DEVICE_NUM](optional)"
echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 FP16 8"
[DEVICE_NUM](optional) [SAVE_CKPT](optional) [SAVE_PATH](optional)"
echo "Example: sh run_gpu_resnet_benchmark.sh /path/imagenet/train 256 FP16 8 true /path/ckpt"
exit 1
fi
@ -45,12 +45,23 @@ fi
if [ $# == 3 ]
then
python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True --dtype=$3 \
--dataset_path=$DATAPATH --batch_size=$2
python ${self_path}/../gpu_resnet_benchmark.py --dataset_path=$DATAPATH --batch_size=$2 --dtype=$3
fi
if [ $# == 4 ]
then
mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \
--dataset_path=$DATAPATH --batch_size=$2 --dtype=$3
fi
fi
if [ $# == 5 ]
then
mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \
--dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 --save_ckpt=$5
fi
if [ $# == 6 ]
then
mpirun --allow-run-as-root -n $4 python ${self_path}/../gpu_resnet_benchmark.py --run_distribute=True \
--dataset_path=$DATAPATH --batch_size=$2 --dtype=$3 --save_ckpt=$5 --ckpt_path=$6
fi

Loading…
Cancel
Save