diff --git a/model_zoo/official/cv/shufflenetv1/train.py b/model_zoo/official/cv/shufflenetv1/train.py index 0c34de6b87..85c7cd45fb 100644 --- a/model_zoo/official/cv/shufflenetv1/train.py +++ b/model_zoo/official/cv/shufflenetv1/train.py @@ -16,13 +16,12 @@ import os import time import argparse -import numpy as np from mindspore import context from mindspore import Tensor from mindspore.common import set_seed from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model, ParallelMode -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.loss_scale_manager import FixedLossScaleManager @@ -34,53 +33,6 @@ from src.crossentropysmooth import CrossEntropySmooth set_seed(1) - -class Monitor(Callback): - """ - Monitor loss and time. - - Args: - lr_init (numpy array): train lr - - Returns: - None - - Examples: - >>> Monitor(lr_init=Tensor([0.05]*100).asnumpy()) - """ - - def __init__(self, lr_init=None): - super(Monitor, self).__init__() - self.lr_init = lr_init - self.lr_init_len = len(lr_init) - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, per_step_mseconds, - np.mean(self.losses))) - - def step_begin(self, run_context): - self.step_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - step_loss = cb_params.net_outputs - - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) - - self.losses.append(step_loss) - - if __name__ == '__main__': parser = argparse.ArgumentParser(description='image classification training') parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') @@ -138,7 +90,7 @@ if __name__ == '__main__': loss_scale_manager=loss_scale_manager) # define callbacks - cb = [Monitor(lr_init=lr.asnumpy())] + cb = [TimeMonitor(), LossMonitor()] if config.save_checkpoint: save_ckpt_path = config.ckpt_path config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * batches_per_epoch, diff --git a/model_zoo/official/cv/xception/train.py b/model_zoo/official/cv/xception/train.py index 77595aafbf..ab92e716d0 100644 --- a/model_zoo/official/cv/xception/train.py +++ b/model_zoo/official/cv/xception/train.py @@ -88,7 +88,7 @@ class Monitor(Callback): print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( cb_params.cur_epoch_num - 1 + config.finish_epoch, cb_params.epoch_num + config.finish_epoch, cur_step_in_epoch, cb_params.batch_num, step_loss, - np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]), flush=True) if __name__ == '__main__': parser = argparse.ArgumentParser(description='image classification training') @@ -164,10 +164,10 @@ if __name__ == '__main__': if args_opt.is_distributed: if rank == 0: cb += [ckpt_cb] - model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False) + model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) else: cb += [ckpt_cb] - model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False) + model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) print("train success") else: raise ValueError("Unsupported device_target.")