add ctpn and crnn eval process while training

pull/14736/head
qujianwei 4 years ago
parent 7ffcf05809
commit c1f637c107

@ -131,6 +131,7 @@ crnn
│   ├── crnn.py # crnn network definition │   ├── crnn.py # crnn network definition
│   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip │   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
│   ├── dataset.py # Data preprocessing for training and evaluation │   ├── dataset.py # Data preprocessing for training and evaluation
│   ├── eval_callback.py
│   ├── ic03_dataset.py # Data preprocessing for IC03 │   ├── ic03_dataset.py # Data preprocessing for IC03
│   ├── ic13_dataset.py # Data preprocessing for IC13 │   ├── ic13_dataset.py # Data preprocessing for IC13
│   ├── iiit5k_dataset.py # Data preprocessing for IIIT5K │   ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
@ -225,6 +226,10 @@ Check the `eval/log.txt` and you will get outputs as following:
result: {'CRNNAccuracy': (0.806)} result: {'CRNNAccuracy': (0.806)}
``` ```
### Evaluation while training
You can add `run_eval` to start shell and set it True.You need also add `eval_dataset` to select which dataset to eval, and add eval_dataset_path to start shell if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
## [Inference Process](#contents) ## [Inference Process](#contents)
### [Export MindIR](#contents) ### [Export MindIR](#contents)

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

@ -15,6 +15,7 @@
"""crnn training""" """crnn training"""
import os import os
import argparse import argparse
import ast
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common import set_seed from mindspore.common import set_seed
@ -28,7 +29,8 @@ from src.loss import CTCLoss
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crnn import crnn from src.crnn import crnn
from src.crnn_for_train import TrainOneStepCellWithGradClip from src.crnn_for_train import TrainOneStepCellWithGradClip
from src.metric import CRNNAccuracy
from src.eval_callback import EvalCallBack
set_seed(1) set_seed(1)
parser = argparse.ArgumentParser(description="crnn training") parser = argparse.ArgumentParser(description="crnn training")
@ -38,6 +40,16 @@ parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend']
help='Running platform, only support Ascend now. Default is Ascend.') help='Running platform, only support Ascend now. Default is Ascend.')
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument("--run_eval", type=ast.literal_eval, default=False,
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_start_epoch", type=int, default=5,
help="Evaluation start epoch when run_eval is True, default is 5.")
parser.add_argument("--eval_interval", type=int, default=5,
help="Evaluation interval when run_eval is True, default is 5.")
parser.set_defaults(run_distribute=False) parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args() args_opt = parser.parse_args()
@ -50,6 +62,12 @@ if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
def apply_eval(eval_param):
evaluation_model = eval_param["model"]
eval_ds = eval_param["dataset"]
metrics_name = eval_param["metrics_name"]
res = evaluation_model.eval(eval_ds)
return res[metrics_name]
if __name__ == '__main__': if __name__ == '__main__':
lr_scale = 1 lr_scale = 1
@ -86,16 +104,31 @@ if __name__ == '__main__':
net = crnn(config) net = crnn(config)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
net = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train() net_with_grads = TrainOneStepCellWithGradClip(net_with_loss, opt).set_train()
# define model # define model
model = Model(net) model = Model(net_with_grads)
# define callbacks # define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
if args_opt.run_eval:
if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
eval_dataset = create_dataset(name=args_opt.eval_dataset,
dataset_path=args_opt.eval_dataset_path,
batch_size=config.batch_size,
is_training=False,
config=config)
eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt",
metrics_name="acc")
callbacks += [eval_cb]
if config.save_checkpoint and rank == 0: if config.save_checkpoint and rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb) callbacks.append(ckpt_cb)
model.train(config.epoch_size, dataset, callbacks=callbacks) model.train(config.epoch_size, dataset, callbacks=callbacks)

@ -96,6 +96,8 @@ Here we used 6 datasets for training, and 1 datasets for Evaluation.
│   ├── create_dataset.py # create mindrecord dataset │   ├── create_dataset.py # create mindrecord dataset
│   ├── ctpn.py # ctpn network definition │   ├── ctpn.py # ctpn network definition
│   ├── dataset.py # data proprocessing │   ├── dataset.py # data proprocessing
│   ├── eval_callback.py # evaluation callback while training
│   ├── eval_utils.py # evaluation function
│   ├── lr_schedule.py # learning rate scheduler │   ├── lr_schedule.py # learning rate scheduler
│   ├── network_define.py # network definition │   ├── network_define.py # network definition
│   └── text_connector │   └── text_connector
@ -235,6 +237,10 @@ Then you can run the scripts/eval_res.sh to calculate the evalulation result.
bash eval_res.sh bash eval_res.sh
``` ```
### Evaluation while training
You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True.
### Result ### Result
Evaluation result will be stored in the example path, you can find result like the followings in `log`. Evaluation result will be stored in the example path, you can find result like the followings in `log`.

@ -14,17 +14,14 @@
# ============================================================================ # ============================================================================
"""Evaluation for CTPN""" """Evaluation for CTPN"""
import os
import argparse import argparse
import time
import numpy as np
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed from mindspore.common import set_seed
from src.ctpn import CTPN from src.ctpn import CTPN
from src.config import config from src.config import config
from src.dataset import create_ctpn_dataset from src.dataset import create_ctpn_dataset
from src.text_connector.detector import detect from src.eval_utils import eval_for_ctpn
set_seed(1) set_seed(1)
parser = argparse.ArgumentParser(description="CTPN evaluation") parser = argparse.ArgumentParser(description="CTPN evaluation")
@ -39,80 +36,13 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''):
"""ctpn infer.""" """ctpn infer."""
print("ckpt path is {}".format(ckpt_path)) print("ckpt path is {}".format(ckpt_path))
ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False) ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False)
config.batch_size = config.test_batch_size
total = ds.get_dataset_size() total = ds.get_dataset_size()
print("*************total dataset size is {}".format(total)) print("eval dataset size is {}".format(total))
net = CTPN(config, is_training=False) net = CTPN(config, batch_size=config.test_batch_size, is_training=False)
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
eval_iter = 0 eval_for_ctpn(net, ds, img_dir)
print("\n========================================\n")
print("Processing, please wait a moment.")
img_basenames = []
output_dir = os.path.join(os.getcwd(), "submit")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for file in os.listdir(img_dir):
img_basenames.append(os.path.basename(file))
for data in ds.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
start = time.time()
# run net
output = net(img_data, gt_bboxes, gt_labels, gt_num)
gt_bboxes = gt_bboxes.asnumpy()
gt_labels = gt_labels.asnumpy()
gt_num = gt_num.asnumpy().astype(bool)
end = time.time()
proposal = output[0]
proposal_mask = output[1]
print("start to draw pic")
for j in range(config.test_batch_size):
img = img_basenames[config.test_batch_size * eval_iter + j]
all_box_tmp = proposal[j].asnumpy()
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
using_boxes_mask = all_box_tmp * all_mask_tmp
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
scores = using_boxes_mask[:, 4].astype(np.float32)
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
from PIL import Image, ImageDraw
im = Image.open(img_dir + '/' + img)
draw = ImageDraw.Draw(im)
image_h = img_metas.asnumpy()[j][2]
image_w = img_metas.asnumpy()[j][3]
gt_boxs = gt_bboxes[j][gt_num[j], :]
for gt_box in gt_boxs:
gt_x1 = gt_box[0] / image_w
gt_y1 = gt_box[1] / image_h
gt_x2 = gt_box[2] / image_w
gt_y2 = gt_box[3] / image_h
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
fill='green', width=2)
file_name = "res_" + img.replace("jpg", "txt")
output_file = os.path.join(output_dir, file_name)
f = open(output_file, 'w')
for bbox in bboxes:
x1 = bbox[0] / image_w
y1 = bbox[1] / image_h
x2 = bbox[2] / image_w
y2 = bbox[3] / image_h
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
f.write(str_tmp)
f.write("\n")
f.close()
im.save(img)
percent = round(eval_iter / total * 100, 2)
eval_iter = eval_iter + 1
print("Iter {} cost time {}".format(eval_iter, end - start))
print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r')
if __name__ == '__main__': if __name__ == '__main__':
ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path) ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path)

@ -36,7 +36,7 @@ if args.device_target == "Ascend":
context.set_context(device_id=args.device_id) context.set_context(device_id=args.device_id)
if __name__ == '__main__': if __name__ == '__main__':
net = CTPN_Infer(config=config) net = CTPN_Infer(config=config, batch_size=config.test_batch_size)
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)

@ -56,6 +56,8 @@ do
export RANK_ID=$i export RANK_ID=$i
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp ./*.py ./train_parallel$i
cp ./*.zip ./train_parallel$i
cp ../*.py ./train_parallel$i cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i cp -r ../src ./train_parallel$i

@ -29,8 +29,7 @@ finetune_config = EasyDict({
"total_epoch": 50, "total_epoch": 50,
}) })
# use for low case number config_default = EasyDict({
config = EasyDict({
"img_width": 960, "img_width": 960,
"img_height": 576, "img_height": 576,
"keep_ratio": False, "keep_ratio": False,
@ -39,7 +38,6 @@ config = EasyDict({
"expand_ratio": 1.0, "expand_ratio": 1.0,
# anchor # anchor
"feature_shapes": (36, 60),
"num_anchors": 14, "num_anchors": 14,
"anchor_base": 16, "anchor_base": 16,
"anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406], "anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406],
@ -56,7 +54,6 @@ config = EasyDict({
"neg_iou_thr": 0.5, "neg_iou_thr": 0.5,
"pos_iou_thr": 0.7, "pos_iou_thr": 0.7,
"min_pos_iou": 0.001, "min_pos_iou": 0.001,
"num_bboxes": 30240,
"num_gts": 256, "num_gts": 256,
"num_expected_neg": 512, "num_expected_neg": 512,
"num_expected_pos": 256, "num_expected_pos": 256,
@ -75,12 +72,11 @@ config = EasyDict({
# rnn structure # rnn structure
"input_size": 512, "input_size": 512,
"num_step": 60,
"rnn_batch_size": 36,
"hidden_size": 128, "hidden_size": 128,
# training # training
"warmup_mode": "linear", "warmup_mode": "linear",
# batch_size only support 1
"batch_size": 1, "batch_size": 1,
"momentum": 0.9, "momentum": 0.9,
"save_checkpoint": True, "save_checkpoint": True,
@ -131,3 +127,12 @@ config = EasyDict({
"pretraining_dataset_file": "", "pretraining_dataset_file": "",
"finetune_dataset_file": "" "finetune_dataset_file": ""
}) })
config_add = {
"feature_shapes": (config_default["img_height"] // 16, config_default["img_width"] // 16),
"num_bboxes": (config_default["img_height"] // 16) * \
(config_default["img_width"] // 16) *config_default["num_anchors"],
"num_step": config_default["img_width"] // 16,
"rnn_batch_size": config_default["img_height"] // 16
}
config = EasyDict({**config_default, **config_add})

@ -145,7 +145,7 @@ def create_train_dataset(dataset_type):
# test: icdar2013 test # test: icdar2013 test
icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\
config.icdar13_test_path[1], "") config.icdar13_test_path[1], "")
image_files = icdar_test_image_files image_files = sorted(icdar_test_image_files)
image_anno_dict = icdar_test_anno_dict image_anno_dict = icdar_test_anno_dict
data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \
prefix="ctpn_test.mindrecord", file_num=1) prefix="ctpn_test.mindrecord", file_num=1)

@ -29,16 +29,13 @@ class BiLSTM(nn.Cell):
Define a BiLSTM network which contains two LSTM layers Define a BiLSTM network which contains two LSTM layers
Args: Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for config(EasyDict): config for ctpn network
captcha images. batch_size(int): batch size of input data, only support 1
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
""" """
def __init__(self, config, is_training=True): def __init__(self, config, batch_size):
super(BiLSTM, self).__init__() super(BiLSTM, self).__init__()
self.is_training = is_training self.batch_size = batch_size
self.batch_size = config.batch_size * config.rnn_batch_size self.batch_size = self.batch_size * config.rnn_batch_size
print("batch size is {} ".format(self.batch_size))
self.input_size = config.input_size self.input_size = config.input_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_step = config.num_step self.num_step = config.num_step
@ -84,25 +81,24 @@ class CTPN(nn.Cell):
Define CTPN network Define CTPN network
Args: Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for config(EasyDict): config for ctpn network
captcha images. batch_size(int): batch size of input data, only support 1
batch_size(int): batch size of input data, default is 64 is_training(bool): whether training, default is True
hidden_size(int): the hidden size in LSTM layers, default is 512
""" """
def __init__(self, config, is_training=True): def __init__(self, config, batch_size, is_training=True):
super(CTPN, self).__init__() super(CTPN, self).__init__()
self.config = config self.config = config
self.is_training = is_training self.batch_size = batch_size
self.num_step = config.num_step self.num_step = config.num_step
self.input_size = config.input_size self.input_size = config.input_size
self.batch_size = config.batch_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.vgg16_feature_extractor = VGG16FeatureExtraction() self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same') self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16) self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.transpose = P.Transpose() self.transpose = P.Transpose()
self.cast = P.Cast() self.cast = P.Cast()
self.is_training = is_training
# rpn block # rpn block
self.rpn_with_loss = RPN(config, self.rpn_with_loss = RPN(config,
@ -115,7 +111,7 @@ class CTPN(nn.Cell):
self.featmap_size = config.feature_shapes self.featmap_size = config.feature_shapes
self.anchor_list = self.get_anchors(self.featmap_size) self.anchor_list = self.get_anchors(self.featmap_size)
self.proposal_generator_test = Proposal(config, self.proposal_generator_test = Proposal(config,
config.test_batch_size, self.batch_size,
config.activate_num_classes, config.activate_num_classes,
config.use_sigmoid_cls) config.use_sigmoid_cls)
self.proposal_generator_test.set_train_local(config, False) self.proposal_generator_test.set_train_local(config, False)
@ -143,9 +139,9 @@ class CTPN(nn.Cell):
return Tensor(anchors, mstype.float16) return Tensor(anchors, mstype.float16)
class CTPN_Infer(nn.Cell): class CTPN_Infer(nn.Cell):
def __init__(self, config): def __init__(self, config, batch_size):
super(CTPN_Infer, self).__init__() super(CTPN_Infer, self).__init__()
self.network = CTPN(config, is_training=False) self.network = CTPN(config, batch_size=batch_size, is_training=False)
self.network.set_train(False) self.network.set_train(False)
def construct(self, img_data): def construct(self, img_data):

@ -289,11 +289,11 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=
input_columns=["image", "annotation"], input_columns=["image", "annotation"],
output_columns=["image", "box", "label", "valid_num", "image_shape"], output_columns=["image", "box", "label", "valid_num", "image_shape"],
column_order=["image", "box", "label", "valid_num", "image_shape"], column_order=["image", "box", "label", "valid_num", "image_shape"],
num_parallel_workers=num_parallel_workers, num_parallel_workers=8,
python_multiprocessing=True) python_multiprocessing=True)
ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
num_parallel_workers=24) num_parallel_workers=8)
# transpose_column from python to c # transpose_column from python to c
ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
ds = ds.map(operations=[type_cast1], input_columns=["box"]) ds = ds.map(operations=[type_cast1], input_columns=["box"])

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

@ -0,0 +1,96 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation utils for CTPN"""
import os
import subprocess
import numpy as np
from src.config import config
from src.text_connector.detector import detect
def exec_shell_cmd(cmd):
sub = subprocess.Popen(args="{}".format(cmd), shell=True, stdin=subprocess.PIPE, \
stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
stdout_data, _ = sub.communicate()
if sub.returncode != 0:
raise ValueError("{} is not a executable command, please check.".format(cmd))
return stdout_data.strip()
def get_eval_result():
create_eval_bbox = 'rm -rf submit*.zip;cd ./submit/;zip -r ../submit.zip *.txt;cd ../;bash eval_res.sh'
os.system(create_eval_bbox)
get_eval_output = "grep hmean log | awk '{print $NF}' | awk -F} '{print $1}' |tail -n 1"
hmean = exec_shell_cmd(get_eval_output)
return float(hmean)
def eval_for_ctpn(network, dataset, eval_image_path):
network.set_train(False)
eval_iter = 0
img_basenames = []
output_dir = os.path.join(os.getcwd(), "submit")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for file in os.listdir(eval_image_path):
img_basenames.append(os.path.basename(file))
img_basenames = sorted(img_basenames)
for data in dataset.create_dict_iterator():
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
# run net
output = network(img_data, gt_bboxes, gt_labels, gt_num)
gt_bboxes = gt_bboxes.asnumpy()
gt_labels = gt_labels.asnumpy()
gt_num = gt_num.asnumpy().astype(bool)
proposal = output[0]
proposal_mask = output[1]
for j in range(config.test_batch_size):
img = img_basenames[config.test_batch_size * eval_iter + j]
all_box_tmp = proposal[j].asnumpy()
all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1)
using_boxes_mask = all_box_tmp * all_mask_tmp
textsegs = using_boxes_mask[:, 0:4].astype(np.float32)
scores = using_boxes_mask[:, 4].astype(np.float32)
shape = img_metas.asnumpy()[0][:2].astype(np.int32)
bboxes = detect(textsegs, scores[:, np.newaxis], shape)
from PIL import Image, ImageDraw
im = Image.open(eval_image_path + '/' + img)
draw = ImageDraw.Draw(im)
image_h = img_metas.asnumpy()[j][2]
image_w = img_metas.asnumpy()[j][3]
gt_boxs = gt_bboxes[j][gt_num[j], :]
for gt_box in gt_boxs:
gt_x1 = gt_box[0] / image_w
gt_y1 = gt_box[1] / image_h
gt_x2 = gt_box[2] / image_w
gt_y2 = gt_box[3] / image_h
draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\
fill='green', width=2)
file_name = "res_" + img.replace("jpg", "txt")
output_file = os.path.join(output_dir, file_name)
f = open(output_file, 'w')
for bbox in bboxes:
x1 = bbox[0] / image_w
y1 = bbox[1] / image_h
x2 = bbox[2] / image_w
y2 = bbox[3] / image_h
draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2)
str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2))
f.write(str_tmp)
f.write("\n")
f.close()
im.save(img)
eval_iter = eval_iter + 1

@ -32,6 +32,8 @@ from src.config import config, pretrain_config, finetune_config
from src.dataset import create_ctpn_dataset from src.dataset import create_ctpn_dataset
from src.lr_schedule import dynamic_lr from src.lr_schedule import dynamic_lr
from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell
from src.eval_utils import eval_for_ctpn, get_eval_result
from src.eval_callback import EvalCallBack
set_seed(1) set_seed(1)
@ -43,10 +45,30 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
parser.add_argument("--task_type", type=str, default="Pretraining",\ parser.add_argument("--task_type", type=str, default="Pretraining",\
choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining") choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining")
parser.add_argument("--run_eval", type=ast.literal_eval, default=False, \
help="Run evaluation when training, default is False.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, \
help="Save best checkpoint when run_eval is True, default is True.")
parser.add_argument("--eval_image_path", type=str, default="", \
help="eval image path, when run_eval is True, eval_image_path should be set.")
parser.add_argument("--eval_dataset_path", type=str, default="", \
help="eval dataset path, when run_eval is True, eval_dataset_path should be set.")
parser.add_argument("--eval_start_epoch", type=int, default=10, \
help="Evaluation start epoch when run_eval is True, default is 10.")
parser.add_argument("--eval_interval", type=int, default=10, \
help="Evaluation interval when run_eval is True, default is 10.")
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True)
def apply_eval(eval_param):
network = eval_param["eval_network"]
eval_ds = eval_param["eval_dataset"]
eval_image_path = eval_param["eval_image_path"]
eval_for_ctpn(network, eval_ds, eval_image_path)
hmean = get_eval_result()
return hmean
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.run_distribute: if args_opt.run_distribute:
rank = args_opt.rank_id rank = args_opt.rank_id
@ -78,7 +100,7 @@ if __name__ == '__main__':
dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\ dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\
batch_size=config.batch_size, device_num=device_num, rank_id=rank) batch_size=config.batch_size, device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
net = CTPN(config=config, is_training=True) net = CTPN(config=config, batch_size=config.batch_size)
net = net.set_train() net = net.set_train()
load_path = args_opt.pre_trained load_path = args_opt.pre_trained
@ -100,20 +122,34 @@ if __name__ == '__main__':
weight_decay=config.weight_decay, loss_scale=config.loss_scale) weight_decay=config.weight_decay, loss_scale=config.loss_scale)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
if args_opt.run_distribute: if args_opt.run_distribute:
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, \
mean=True, degree=device_num) mean=True, degree=device_num)
else: else:
net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale)
time_cb = TimeMonitor(data_size=dataset_size) time_cb = TimeMonitor(data_size=dataset_size)
loss_cb = LossCallBack(rank_id=rank) loss_cb = LossCallBack(rank_id=rank)
cb = [time_cb, loss_cb] cb = [time_cb, loss_cb]
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
if config.save_checkpoint: if config.save_checkpoint:
ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size, ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/")
ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb] cb += [ckpoint_cb]
if args_opt.run_eval:
model = Model(net) if args_opt.eval_dataset_path is None or (not os.path.isfile(args_opt.eval_dataset_path)):
raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path))
if args_opt.eval_image_path is None or (not os.path.isdir(args_opt.eval_image_path)):
raise ValueError("{} is not a existing path.".format(args_opt.eval_image_path))
eval_dataset = create_ctpn_dataset(args_opt.eval_dataset_path, \
batch_size=config.batch_size, repeat_num=1, is_training=False)
eval_net = net
eval_param_dict = {"eval_network": eval_net, "eval_dataset": eval_dataset, \
"eval_image_path": args_opt.eval_image_path}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval,
eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=save_checkpoint_path, besk_ckpt_name="best_acc.ckpt",
metrics_name="hmean")
cb += [eval_cb]
model = Model(net_with_grads)
model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True) model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)

Loading…
Cancel
Save