From 2d433e640864d77a1612df9126df0ecd79735851 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 29 May 2020 10:11:23 +0800 Subject: [PATCH] modify --- model_zoo/deeplabv3/evaluation.py | 28 ++++++++++----------------- model_zoo/deeplabv3/train.py | 32 +++++++++++++------------------ 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/model_zoo/deeplabv3/evaluation.py b/model_zoo/deeplabv3/evaluation.py index f4e3e38d9f..2c03467587 100644 --- a/model_zoo/deeplabv3/evaluation.py +++ b/model_zoo/deeplabv3/evaluation.py @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +13,13 @@ # limitations under the License. # ============================================================================ """evaluation.""" -import os, time import argparse from mindspore import context -from mindspore import log as logger -from mindspore.communication.management import init -import mindspore.nn as nn -from mindspore.nn.optim.momentum import Momentum -from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore import Model, ParallelMode -import argparse +from mindspore import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor from src.md_dataset import create_dataset from src.losses import OhemLoss -from src.miou_precision import MiouPrecision +from src.miou_precision import MiouPrecision from src.deeplabv3 import deeplabv3_resnet50 from src.config import config parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") @@ -44,15 +35,16 @@ print(args_opt) if __name__ == "__main__": args_opt.crop_size = config.crop_size args_opt.base_size = config.crop_size - eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval") - net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], - infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, - decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, - fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) + eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval") + net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, + decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, + fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) param_dict = load_checkpoint(args_opt.checkpoint_url) load_param_into_net(net, param_dict) mIou = MiouPrecision(config.seg_num_classes) - metrics={'mIou':mIou} + metrics = {'mIou': mIou} loss = OhemLoss(config.seg_num_classes, config.ignore_label) model = Model(net, loss, metrics=metrics) - model.eval(eval_dataset) \ No newline at end of file + model.eval(eval_dataset) + \ No newline at end of file diff --git a/model_zoo/deeplabv3/train.py b/model_zoo/deeplabv3/train.py index ed625ede6b..237d2f4800 100644 --- a/model_zoo/deeplabv3/train.py +++ b/model_zoo/deeplabv3/train.py @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,18 +13,13 @@ # limitations under the License. # ============================================================================ """train.""" -import os, time import argparse from mindspore import context -from mindspore import log as logger from mindspore.communication.management import init -import mindspore.nn as nn from mindspore.nn.optim.momentum import Momentum -from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore import Model, ParallelMode -import argparse from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor from src.md_dataset import create_dataset from src.losses import OhemLoss from src.deeplabv3 import deeplabv3_resnet50 @@ -40,8 +34,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id, default parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.') -parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " - "default is 1000.") +parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") args_opt = parser.parse_args() print(args_opt) @@ -63,22 +56,22 @@ class LossCallBack(Callback): cb_params = run_context.original_args() print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, str(cb_params.net_outputs))) -def model_fine_tune(flags, net, fix_weight_layer): +def model_fine_tune(flags, train_net, fix_weight_layer): checkpoint_path = flags.checkpoint_url if checkpoint_path is None: return param_dict = load_checkpoint(checkpoint_path) - load_param_into_net(net, param_dict) - for para in net.trainable_params(): + load_param_into_net(train_net, param_dict) + for para in train_net.trainable_params(): if fix_weight_layer in para.name: - para.requires_grad=False + para.requires_grad = False if __name__ == "__main__": if args_opt.distribute == "true": context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) init() args_opt.base_size = config.crop_size args_opt.crop_size = config.crop_size - train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train") + train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train") dataset_size = train_dataset.get_dataset_size() time_cb = TimeMonitor(data_size=dataset_size) callback = [time_cb, LossCallBack()] @@ -87,13 +80,14 @@ if __name__ == "__main__": keep_checkpoint_max=args_opt.save_checkpoint_num) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) callback.append(ckpoint_cb) - net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], - infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, - decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, - fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) + net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], + infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, + decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, + fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) net.set_train() model_fine_tune(args_opt, net, 'layer') loss = OhemLoss(config.seg_num_classes, config.ignore_label) opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) model = Model(net, loss, opt) - model.train(args_opt.epoch_size, train_dataset, callback) \ No newline at end of file + model.train(args_opt.epoch_size, train_dataset, callback) + \ No newline at end of file