|
|
|
@ -37,6 +37,7 @@ from mindspore.train.model import Model, ParallelMode
|
|
|
|
|
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
|
|
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
import mindspore.dataset.engine as de
|
|
|
|
|
from mindspore.communication.management import init
|
|
|
|
|
|
|
|
|
@ -46,6 +47,7 @@ de.config.set_seed(1)
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Image classification')
|
|
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
|
|
|
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
device_id = int(os.getenv('DEVICE_ID'))
|
|
|
|
@ -166,6 +168,9 @@ if __name__ == '__main__':
|
|
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
|
|
|
|
|
repeat_num=epoch_size, batch_size=config.batch_size)
|
|
|
|
|
step_size = dataset.get_dataset_size()
|
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
|
|
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
|
|
|
|
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr,
|
|
|
|
|