!4937 vgg16: modify readme format and replace callback

Merge pull request !4937 from ms_yan/vgg_format
pull/4937/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 38c366306c

File diff suppressed because it is too large Load Diff

@ -15,7 +15,7 @@
# ============================================================================ # ============================================================================
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the script as: "
echo "bash run_distribute_train_gpu.sh DATA_PATH" echo "bash run_distribute_train_gpu.sh DATA_PATH"
echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train" echo "for example: bash run_distribute_train_gpu.sh /path/ImageNet2012/train"
echo "==============================================================================================================" echo "=============================================================================================================="

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH DATASET_TYPE DEVICE_TYPE CHECKPOINT_PATH"
echo "for example: bash run_eval.sh /path/ImageNet2012/train cifar10 Ascend /path/a.ckpt "
echo "=============================================================================================================="
DATA_PATH=&1
DATASET_TYPE=$2
DEVICE_TYPE=$3
CHECKPOINT_PATH=$4
python eval.py \
--data_path=$DATA_PATH \
--dataset=$DATASET_TYPE \
--device_target=$DEVICE_TYPE \
--pre_trained=$CHECKPOINT_PATH > output.eval.log 2>&1 &

@ -18,7 +18,6 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
""" """
import argparse import argparse
import datetime import datetime
import time
import os import os
import random import random
@ -29,7 +28,7 @@ from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
@ -49,63 +48,6 @@ random.seed(1)
np.random.seed(1) np.random.seed(1)
class ProgressMonitor(Callback):
"""monitor loss and time"""
def __init__(self, args_param):
super(ProgressMonitor, self).__init__()
self.me_epoch_start_time = 0
self.me_epoch_start_step_num = 0
self.args = args_param
self.ckpt_history = []
def begin(self, run_context):
self.args.logger.info('start network train...')
def epoch_begin(self, run_context):
pass
def epoch_end(self, run_context):
"""
Called after each epoch finished.
Args:
run_context (RunContext): Include some information of the model.
"""
cb_params = run_context.original_args()
me_step = cb_params.cur_step_num - 1
real_epoch = me_step // self.args.steps_per_epoch
time_used = time.time() - self.me_epoch_start_time
fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}'
'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean))
if self.args.rank_save_ckpt_flag:
import glob
ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
for ckpt in ckpts:
ckpt_fn = os.path.basename(ckpt)
if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
continue
if ckpt in self.ckpt_history:
continue
self.ckpt_history.append(ckpt)
self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},'
'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn))
self.me_epoch_start_step_num = me_step
self.me_epoch_start_time = time.time()
def step_begin(self, run_context):
pass
def step_end(self, run_context, *me_args):
pass
def end(self, run_context):
self.args.logger.info('end network train...')
def parse_args(cloud_args=None): def parse_args(cloud_args=None):
"""parameters""" """parameters"""
parser = argparse.ArgumentParser('mindspore classification training') parser = argparse.ArgumentParser('mindspore classification training')
@ -279,9 +221,10 @@ if __name__ == '__main__':
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
# checkpoint save # define callbacks
progress_cb = ProgressMonitor(args) time_cb = TimeMonitor(data_size=batch_num)
callbacks = [progress_cb,] loss_cb = LossMonitor(per_print_times=batch_num)
callbacks = [time_cb, loss_cb]
if args.rank_save_ckpt_flag: if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max) keep_checkpoint_max=args.ckpt_save_max)

Loading…
Cancel
Save