pull/1637/head
unknown 6 years ago
parent 91adbf7e2c
commit 2d433e6408

@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -14,21 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""evaluation.""" """evaluation."""
import os, time
import argparse import argparse
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import Model
from mindspore.communication.management import init
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model, ParallelMode
import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset from src.md_dataset import create_dataset
from src.losses import OhemLoss from src.losses import OhemLoss
from src.miou_precision import MiouPrecision from src.miou_precision import MiouPrecision
from src.deeplabv3 import deeplabv3_resnet50 from src.deeplabv3 import deeplabv3_resnet50
from src.config import config from src.config import config
parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") parser = argparse.ArgumentParser(description="Deeplabv3 evaluation")
@ -44,15 +35,16 @@ print(args_opt)
if __name__ == "__main__": if __name__ == "__main__":
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval") eval_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="eval")
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
param_dict = load_checkpoint(args_opt.checkpoint_url) param_dict = load_checkpoint(args_opt.checkpoint_url)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
mIou = MiouPrecision(config.seg_num_classes) mIou = MiouPrecision(config.seg_num_classes)
metrics={'mIou':mIou} metrics = {'mIou': mIou}
loss = OhemLoss(config.seg_num_classes, config.ignore_label) loss = OhemLoss(config.seg_num_classes, config.ignore_label)
model = Model(net, loss, metrics=metrics) model = Model(net, loss, metrics=metrics)
model.eval(eval_dataset) model.eval(eval_dataset)

@ -1,4 +1,3 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -14,18 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train.""" """train."""
import os, time
import argparse import argparse
from mindspore import context from mindspore import context
from mindspore import log as logger
from mindspore.communication.management import init from mindspore.communication.management import init
import mindspore.nn as nn
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model, ParallelMode from mindspore import Model, ParallelMode
import argparse
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback,CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset from src.md_dataset import create_dataset
from src.losses import OhemLoss from src.losses import OhemLoss
from src.deeplabv3 import deeplabv3_resnet50 from src.deeplabv3 import deeplabv3_resnet50
@ -40,8 +34,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id, default
parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path')
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.') parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.')
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
"default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
args_opt = parser.parse_args() args_opt = parser.parse_args()
print(args_opt) print(args_opt)
@ -63,22 +56,22 @@ class LossCallBack(Callback):
cb_params = run_context.original_args() cb_params = run_context.original_args()
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs))) str(cb_params.net_outputs)))
def model_fine_tune(flags, net, fix_weight_layer): def model_fine_tune(flags, train_net, fix_weight_layer):
checkpoint_path = flags.checkpoint_url checkpoint_path = flags.checkpoint_url
if checkpoint_path is None: if checkpoint_path is None:
return return
param_dict = load_checkpoint(checkpoint_path) param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(train_net, param_dict)
for para in net.trainable_params(): for para in train_net.trainable_params():
if fix_weight_layer in para.name: if fix_weight_layer in para.name:
para.requires_grad=False para.requires_grad = False
if __name__ == "__main__": if __name__ == "__main__":
if args_opt.distribute == "true": if args_opt.distribute == "true":
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init() init()
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train") train_dataset = create_dataset(args_opt, args_opt.data_url, args_opt.epoch_size, args_opt.batch_size, usage="train")
dataset_size = train_dataset.get_dataset_size() dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size) time_cb = TimeMonitor(data_size=dataset_size)
callback = [time_cb, LossCallBack()] callback = [time_cb, LossCallBack()]
@ -87,13 +80,14 @@ if __name__ == "__main__":
keep_checkpoint_max=args_opt.save_checkpoint_num) keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck)
callback.append(ckpoint_cb) callback.append(ckpoint_cb)
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size],
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates,
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride,
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid)
net.set_train() net.set_train()
model_fine_tune(args_opt, net, 'layer') model_fine_tune(args_opt, net, 'layer')
loss = OhemLoss(config.seg_num_classes, config.ignore_label) loss = OhemLoss(config.seg_num_classes, config.ignore_label)
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay)
model = Model(net, loss, opt) model = Model(net, loss, opt)
model.train(args_opt.epoch_size, train_dataset, callback) model.train(args_opt.epoch_size, train_dataset, callback)
Loading…
Cancel
Save