From c1f637c1072546873c38a5283777a34b07994acd Mon Sep 17 00:00:00 2001 From: qujianwei Date: Tue, 6 Apr 2021 14:49:09 +0800 Subject: [PATCH] add ctpn and crnn eval process while training --- model_zoo/official/cv/crnn/README.md | 5 + .../official/cv/crnn/src/eval_callback.py | 91 ++++++++++++++++++ model_zoo/official/cv/crnn/train.py | 43 ++++++++- model_zoo/official/cv/ctpn/README.md | 6 ++ model_zoo/official/cv/ctpn/eval.py | 78 +-------------- model_zoo/official/cv/ctpn/export.py | 2 +- .../scripts/run_distribute_train_ascend.sh | 2 + model_zoo/official/cv/ctpn/src/config.py | 17 ++-- .../official/cv/ctpn/src/create_dataset.py | 2 +- model_zoo/official/cv/ctpn/src/ctpn.py | 34 +++---- model_zoo/official/cv/ctpn/src/dataset.py | 4 +- .../official/cv/ctpn/src/eval_callback.py | 91 ++++++++++++++++++ model_zoo/official/cv/ctpn/src/eval_utils.py | 96 +++++++++++++++++++ model_zoo/official/cv/ctpn/train.py | 50 ++++++++-- 14 files changed, 406 insertions(+), 115 deletions(-) create mode 100644 model_zoo/official/cv/crnn/src/eval_callback.py create mode 100644 model_zoo/official/cv/ctpn/src/eval_callback.py create mode 100644 model_zoo/official/cv/ctpn/src/eval_utils.py diff --git a/model_zoo/official/cv/crnn/README.md b/model_zoo/official/cv/crnn/README.md index e110c022fa..c30c9be405 100644 --- a/model_zoo/official/cv/crnn/README.md +++ b/model_zoo/official/cv/crnn/README.md @@ -131,6 +131,7 @@ crnn │   ├── crnn.py # crnn network definition │   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip │   ├── dataset.py # Data preprocessing for training and evaluation +│   ├── eval_callback.py │   ├── ic03_dataset.py # Data preprocessing for IC03 │   ├── ic13_dataset.py # Data preprocessing for IC13 │   ├── iiit5k_dataset.py # Data preprocessing for IIIT5K @@ -225,6 +226,10 @@ Check the `eval/log.txt` and you will get outputs as following: result: {'CRNNAccuracy': (0.806)} ``` +### Evaluation while training + +You can add `run_eval` to start shell and set it True.You need also add `eval_dataset` to select which dataset to eval, and add eval_dataset_path to start shell if you want evaluation while training. And you can set argument option: `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. + ## [Inference Process](#contents) ### [Export MindIR](#contents) diff --git a/model_zoo/official/cv/crnn/src/eval_callback.py b/model_zoo/official/cv/crnn/src/eval_callback.py new file mode 100644 index 0000000000..f7badff7f8 --- /dev/null +++ b/model_zoo/official/cv/crnn/src/eval_callback.py @@ -0,0 +1,91 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) + \ No newline at end of file diff --git a/model_zoo/official/cv/crnn/train.py b/model_zoo/official/cv/crnn/train.py index 83159fa909..1974ce28a0 100644 --- a/model_zoo/official/cv/crnn/train.py +++ b/model_zoo/official/cv/crnn/train.py @@ -15,6 +15,7 @@ """crnn training""" import os import argparse +import ast import mindspore.nn as nn from mindspore import context from mindspore.common import set_seed @@ -28,7 +29,8 @@ from src.loss import CTCLoss from src.dataset import create_dataset from src.crnn import crnn from src.crnn_for_train import TrainOneStepCellWithGradClip - +from src.metric import CRNNAccuracy +from src.eval_callback import EvalCallBack set_seed(1) parser = argparse.ArgumentParser(description="crnn training") @@ -38,6 +40,16 @@ parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'] help='Running platform, only support Ascend now. Default is Ascend.') parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) +parser.add_argument('--eval_dataset', type=str, default='svt', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) +parser.add_argument('--eval_dataset_path', type=str, default=None, help='Dataset path, default is None') +parser.add_argument("--run_eval", type=ast.literal_eval, default=False, + help="Run evaluation when training, default is False.") +parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, + help="Save best checkpoint when run_eval is True, default is True.") +parser.add_argument("--eval_start_epoch", type=int, default=5, + help="Evaluation start epoch when run_eval is True, default is 5.") +parser.add_argument("--eval_interval", type=int, default=5, + help="Evaluation interval when run_eval is True, default is 5.") parser.set_defaults(run_distribute=False) args_opt = parser.parse_args() @@ -50,6 +62,12 @@ if args_opt.platform == 'Ascend': device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id) +def apply_eval(eval_param): + evaluation_model = eval_param["model"] + eval_ds = eval_param["dataset"] + metrics_name = eval_param["metrics_name"] + res = evaluation_model.eval(eval_ds) + return res[metrics_name] if __name__ == '__main__': lr_scale = 1 @@ -86,16 +104,31 @@ if __name__ == '__main__': net = crnn(config) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov) - net = WithLossCell(net, loss) - net = TrainOneStepCellWithGradClip(net, opt).set_train() + net_with_loss = WithLossCell(net, loss) + net_with_grads = TrainOneStepCellWithGradClip(net_with_loss, opt).set_train() # define model - model = Model(net) + model = Model(net_with_grads) # define callbacks callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] + save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') + if args_opt.run_eval: + if args_opt.eval_dataset_path is None or (not os.path.isdir(args_opt.eval_dataset_path)): + raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) + eval_dataset = create_dataset(name=args_opt.eval_dataset, + dataset_path=args_opt.eval_dataset_path, + batch_size=config.batch_size, + is_training=False, + config=config) + eval_model = Model(net, loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)}) + eval_param_dict = {"model": eval_model, "dataset": eval_dataset, "metrics_name": "CRNNAccuracy"} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, + eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=save_ckpt_path, besk_ckpt_name="best_acc.ckpt", + metrics_name="acc") + callbacks += [eval_cb] if config.save_checkpoint and rank == 0: config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, keep_checkpoint_max=config.keep_checkpoint_max) - save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck) callbacks.append(ckpt_cb) model.train(config.epoch_size, dataset, callbacks=callbacks) diff --git a/model_zoo/official/cv/ctpn/README.md b/model_zoo/official/cv/ctpn/README.md index 69bc893fb6..0bc6d5d7d5 100644 --- a/model_zoo/official/cv/ctpn/README.md +++ b/model_zoo/official/cv/ctpn/README.md @@ -96,6 +96,8 @@ Here we used 6 datasets for training, and 1 datasets for Evaluation. │   ├── create_dataset.py # create mindrecord dataset │   ├── ctpn.py # ctpn network definition │   ├── dataset.py # data proprocessing + │   ├── eval_callback.py # evaluation callback while training + │   ├── eval_utils.py # evaluation function │   ├── lr_schedule.py # learning rate scheduler │   ├── network_define.py # network definition │   └── text_connector @@ -235,6 +237,10 @@ Then you can run the scripts/eval_res.sh to calculate the evalulation result. bash eval_res.sh ``` +### Evaluation while training + +You can add `run_eval` to start shell and set it True, if you want evaluation while training. And you can set argument option: `eval_dataset_path`, `save_best_ckpt`, `eval_start_epoch`, `eval_interval` when `run_eval` is True. + ### Result Evaluation result will be stored in the example path, you can find result like the followings in `log`. diff --git a/model_zoo/official/cv/ctpn/eval.py b/model_zoo/official/cv/ctpn/eval.py index 17e3bfa075..77636d2e56 100644 --- a/model_zoo/official/cv/ctpn/eval.py +++ b/model_zoo/official/cv/ctpn/eval.py @@ -14,17 +14,14 @@ # ============================================================================ """Evaluation for CTPN""" -import os import argparse -import time -import numpy as np from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed from src.ctpn import CTPN from src.config import config from src.dataset import create_ctpn_dataset -from src.text_connector.detector import detect +from src.eval_utils import eval_for_ctpn set_seed(1) parser = argparse.ArgumentParser(description="CTPN evaluation") @@ -39,80 +36,13 @@ def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''): """ctpn infer.""" print("ckpt path is {}".format(ckpt_path)) ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False) - config.batch_size = config.test_batch_size total = ds.get_dataset_size() - print("*************total dataset size is {}".format(total)) - net = CTPN(config, is_training=False) + print("eval dataset size is {}".format(total)) + net = CTPN(config, batch_size=config.test_batch_size, is_training=False) param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) net.set_train(False) - eval_iter = 0 - - print("\n========================================\n") - print("Processing, please wait a moment.") - img_basenames = [] - output_dir = os.path.join(os.getcwd(), "submit") - if not os.path.exists(output_dir): - os.mkdir(output_dir) - for file in os.listdir(img_dir): - img_basenames.append(os.path.basename(file)) - for data in ds.create_dict_iterator(): - img_data = data['image'] - img_metas = data['image_shape'] - gt_bboxes = data['box'] - gt_labels = data['label'] - gt_num = data['valid_num'] - - start = time.time() - # run net - output = net(img_data, gt_bboxes, gt_labels, gt_num) - gt_bboxes = gt_bboxes.asnumpy() - gt_labels = gt_labels.asnumpy() - gt_num = gt_num.asnumpy().astype(bool) - end = time.time() - proposal = output[0] - proposal_mask = output[1] - print("start to draw pic") - for j in range(config.test_batch_size): - img = img_basenames[config.test_batch_size * eval_iter + j] - all_box_tmp = proposal[j].asnumpy() - all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) - using_boxes_mask = all_box_tmp * all_mask_tmp - textsegs = using_boxes_mask[:, 0:4].astype(np.float32) - scores = using_boxes_mask[:, 4].astype(np.float32) - shape = img_metas.asnumpy()[0][:2].astype(np.int32) - bboxes = detect(textsegs, scores[:, np.newaxis], shape) - from PIL import Image, ImageDraw - im = Image.open(img_dir + '/' + img) - draw = ImageDraw.Draw(im) - image_h = img_metas.asnumpy()[j][2] - image_w = img_metas.asnumpy()[j][3] - gt_boxs = gt_bboxes[j][gt_num[j], :] - for gt_box in gt_boxs: - gt_x1 = gt_box[0] / image_w - gt_y1 = gt_box[1] / image_h - gt_x2 = gt_box[2] / image_w - gt_y2 = gt_box[3] / image_h - draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ - fill='green', width=2) - file_name = "res_" + img.replace("jpg", "txt") - output_file = os.path.join(output_dir, file_name) - f = open(output_file, 'w') - for bbox in bboxes: - x1 = bbox[0] / image_w - y1 = bbox[1] / image_h - x2 = bbox[2] / image_w - y2 = bbox[3] / image_h - draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) - str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) - f.write(str_tmp) - f.write("\n") - f.close() - im.save(img) - percent = round(eval_iter / total * 100, 2) - eval_iter = eval_iter + 1 - print("Iter {} cost time {}".format(eval_iter, end - start)) - print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r') + eval_for_ctpn(net, ds, img_dir) if __name__ == '__main__': ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path) diff --git a/model_zoo/official/cv/ctpn/export.py b/model_zoo/official/cv/ctpn/export.py index 9a365bfa29..c886ea1b62 100644 --- a/model_zoo/official/cv/ctpn/export.py +++ b/model_zoo/official/cv/ctpn/export.py @@ -36,7 +36,7 @@ if args.device_target == "Ascend": context.set_context(device_id=args.device_id) if __name__ == '__main__': - net = CTPN_Infer(config=config) + net = CTPN_Infer(config=config, batch_size=config.test_batch_size) param_dict = load_checkpoint(args.ckpt_file) diff --git a/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh b/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh index 98b5cee57d..7669552e36 100644 --- a/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh +++ b/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh @@ -56,6 +56,8 @@ do export RANK_ID=$i rm -rf ./train_parallel$i mkdir ./train_parallel$i + cp ./*.py ./train_parallel$i + cp ./*.zip ./train_parallel$i cp ../*.py ./train_parallel$i cp *.sh ./train_parallel$i cp -r ../src ./train_parallel$i diff --git a/model_zoo/official/cv/ctpn/src/config.py b/model_zoo/official/cv/ctpn/src/config.py index de114b72f5..45f3aed903 100644 --- a/model_zoo/official/cv/ctpn/src/config.py +++ b/model_zoo/official/cv/ctpn/src/config.py @@ -29,8 +29,7 @@ finetune_config = EasyDict({ "total_epoch": 50, }) -# use for low case number -config = EasyDict({ +config_default = EasyDict({ "img_width": 960, "img_height": 576, "keep_ratio": False, @@ -39,7 +38,6 @@ config = EasyDict({ "expand_ratio": 1.0, # anchor - "feature_shapes": (36, 60), "num_anchors": 14, "anchor_base": 16, "anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406], @@ -56,7 +54,6 @@ config = EasyDict({ "neg_iou_thr": 0.5, "pos_iou_thr": 0.7, "min_pos_iou": 0.001, - "num_bboxes": 30240, "num_gts": 256, "num_expected_neg": 512, "num_expected_pos": 256, @@ -75,12 +72,11 @@ config = EasyDict({ # rnn structure "input_size": 512, - "num_step": 60, - "rnn_batch_size": 36, "hidden_size": 128, # training "warmup_mode": "linear", + # batch_size only support 1 "batch_size": 1, "momentum": 0.9, "save_checkpoint": True, @@ -131,3 +127,12 @@ config = EasyDict({ "pretraining_dataset_file": "", "finetune_dataset_file": "" }) + +config_add = { + "feature_shapes": (config_default["img_height"] // 16, config_default["img_width"] // 16), + "num_bboxes": (config_default["img_height"] // 16) * \ + (config_default["img_width"] // 16) *config_default["num_anchors"], + "num_step": config_default["img_width"] // 16, + "rnn_batch_size": config_default["img_height"] // 16 +} +config = EasyDict({**config_default, **config_add}) diff --git a/model_zoo/official/cv/ctpn/src/create_dataset.py b/model_zoo/official/cv/ctpn/src/create_dataset.py index ef9a8faf2c..d0066dd87f 100644 --- a/model_zoo/official/cv/ctpn/src/create_dataset.py +++ b/model_zoo/official/cv/ctpn/src/create_dataset.py @@ -145,7 +145,7 @@ def create_train_dataset(dataset_type): # test: icdar2013 test icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ config.icdar13_test_path[1], "") - image_files = icdar_test_image_files + image_files = sorted(icdar_test_image_files) image_anno_dict = icdar_test_anno_dict data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ prefix="ctpn_test.mindrecord", file_num=1) diff --git a/model_zoo/official/cv/ctpn/src/ctpn.py b/model_zoo/official/cv/ctpn/src/ctpn.py index 2bc0125e48..f764a5e4b6 100644 --- a/model_zoo/official/cv/ctpn/src/ctpn.py +++ b/model_zoo/official/cv/ctpn/src/ctpn.py @@ -29,16 +29,13 @@ class BiLSTM(nn.Cell): Define a BiLSTM network which contains two LSTM layers Args: - input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for - captcha images. - batch_size(int): batch size of input data, default is 64 - hidden_size(int): the hidden size in LSTM layers, default is 512 + config(EasyDict): config for ctpn network + batch_size(int): batch size of input data, only support 1 """ - def __init__(self, config, is_training=True): + def __init__(self, config, batch_size): super(BiLSTM, self).__init__() - self.is_training = is_training - self.batch_size = config.batch_size * config.rnn_batch_size - print("batch size is {} ".format(self.batch_size)) + self.batch_size = batch_size + self.batch_size = self.batch_size * config.rnn_batch_size self.input_size = config.input_size self.hidden_size = config.hidden_size self.num_step = config.num_step @@ -84,25 +81,24 @@ class CTPN(nn.Cell): Define CTPN network Args: - input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for - captcha images. - batch_size(int): batch size of input data, default is 64 - hidden_size(int): the hidden size in LSTM layers, default is 512 + config(EasyDict): config for ctpn network + batch_size(int): batch size of input data, only support 1 + is_training(bool): whether training, default is True """ - def __init__(self, config, is_training=True): + def __init__(self, config, batch_size, is_training=True): super(CTPN, self).__init__() self.config = config - self.is_training = is_training + self.batch_size = batch_size self.num_step = config.num_step self.input_size = config.input_size - self.batch_size = config.batch_size self.hidden_size = config.hidden_size self.vgg16_feature_extractor = VGG16FeatureExtraction() self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same') - self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16) + self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16) self.reshape = P.Reshape() self.transpose = P.Transpose() self.cast = P.Cast() + self.is_training = is_training # rpn block self.rpn_with_loss = RPN(config, @@ -115,7 +111,7 @@ class CTPN(nn.Cell): self.featmap_size = config.feature_shapes self.anchor_list = self.get_anchors(self.featmap_size) self.proposal_generator_test = Proposal(config, - config.test_batch_size, + self.batch_size, config.activate_num_classes, config.use_sigmoid_cls) self.proposal_generator_test.set_train_local(config, False) @@ -143,9 +139,9 @@ class CTPN(nn.Cell): return Tensor(anchors, mstype.float16) class CTPN_Infer(nn.Cell): - def __init__(self, config): + def __init__(self, config, batch_size): super(CTPN_Infer, self).__init__() - self.network = CTPN(config, is_training=False) + self.network = CTPN(config, batch_size=batch_size, is_training=False) self.network.set_train(False) def construct(self, img_data): diff --git a/model_zoo/official/cv/ctpn/src/dataset.py b/model_zoo/official/cv/ctpn/src/dataset.py index 79b0db7feb..a7246e1439 100644 --- a/model_zoo/official/cv/ctpn/src/dataset.py +++ b/model_zoo/official/cv/ctpn/src/dataset.py @@ -289,11 +289,11 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num= input_columns=["image", "annotation"], output_columns=["image", "box", "label", "valid_num", "image_shape"], column_order=["image", "box", "label", "valid_num", "image_shape"], - num_parallel_workers=num_parallel_workers, + num_parallel_workers=8, python_multiprocessing=True) ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], - num_parallel_workers=24) + num_parallel_workers=8) # transpose_column from python to c ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) ds = ds.map(operations=[type_cast1], input_columns=["box"]) diff --git a/model_zoo/official/cv/ctpn/src/eval_callback.py b/model_zoo/official/cv/ctpn/src/eval_callback.py new file mode 100644 index 0000000000..f7badff7f8 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/eval_callback.py @@ -0,0 +1,91 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./", besk_ckpt_name="best.ckpt", metrics_name="acc"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) + \ No newline at end of file diff --git a/model_zoo/official/cv/ctpn/src/eval_utils.py b/model_zoo/official/cv/ctpn/src/eval_utils.py new file mode 100644 index 0000000000..56f64969b4 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/eval_utils.py @@ -0,0 +1,96 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation utils for CTPN""" +import os +import subprocess +import numpy as np +from src.config import config +from src.text_connector.detector import detect + +def exec_shell_cmd(cmd): + sub = subprocess.Popen(args="{}".format(cmd), shell=True, stdin=subprocess.PIPE, \ + stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) + stdout_data, _ = sub.communicate() + if sub.returncode != 0: + raise ValueError("{} is not a executable command, please check.".format(cmd)) + return stdout_data.strip() + +def get_eval_result(): + create_eval_bbox = 'rm -rf submit*.zip;cd ./submit/;zip -r ../submit.zip *.txt;cd ../;bash eval_res.sh' + os.system(create_eval_bbox) + get_eval_output = "grep hmean log | awk '{print $NF}' | awk -F} '{print $1}' |tail -n 1" + hmean = exec_shell_cmd(get_eval_output) + return float(hmean) + +def eval_for_ctpn(network, dataset, eval_image_path): + network.set_train(False) + eval_iter = 0 + img_basenames = [] + output_dir = os.path.join(os.getcwd(), "submit") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + for file in os.listdir(eval_image_path): + img_basenames.append(os.path.basename(file)) + img_basenames = sorted(img_basenames) + for data in dataset.create_dict_iterator(): + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + # run net + output = network(img_data, gt_bboxes, gt_labels, gt_num) + gt_bboxes = gt_bboxes.asnumpy() + gt_labels = gt_labels.asnumpy() + gt_num = gt_num.asnumpy().astype(bool) + proposal = output[0] + proposal_mask = output[1] + for j in range(config.test_batch_size): + img = img_basenames[config.test_batch_size * eval_iter + j] + all_box_tmp = proposal[j].asnumpy() + all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) + using_boxes_mask = all_box_tmp * all_mask_tmp + textsegs = using_boxes_mask[:, 0:4].astype(np.float32) + scores = using_boxes_mask[:, 4].astype(np.float32) + shape = img_metas.asnumpy()[0][:2].astype(np.int32) + bboxes = detect(textsegs, scores[:, np.newaxis], shape) + from PIL import Image, ImageDraw + im = Image.open(eval_image_path + '/' + img) + draw = ImageDraw.Draw(im) + image_h = img_metas.asnumpy()[j][2] + image_w = img_metas.asnumpy()[j][3] + gt_boxs = gt_bboxes[j][gt_num[j], :] + for gt_box in gt_boxs: + gt_x1 = gt_box[0] / image_w + gt_y1 = gt_box[1] / image_h + gt_x2 = gt_box[2] / image_w + gt_y2 = gt_box[3] / image_h + draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ + fill='green', width=2) + file_name = "res_" + img.replace("jpg", "txt") + output_file = os.path.join(output_dir, file_name) + f = open(output_file, 'w') + for bbox in bboxes: + x1 = bbox[0] / image_w + y1 = bbox[1] / image_h + x2 = bbox[2] / image_w + y2 = bbox[3] / image_h + draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) + str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) + f.write(str_tmp) + f.write("\n") + f.close() + im.save(img) + eval_iter = eval_iter + 1 diff --git a/model_zoo/official/cv/ctpn/train.py b/model_zoo/official/cv/ctpn/train.py index f90d7f9a40..f42d147752 100644 --- a/model_zoo/official/cv/ctpn/train.py +++ b/model_zoo/official/cv/ctpn/train.py @@ -32,6 +32,8 @@ from src.config import config, pretrain_config, finetune_config from src.dataset import create_ctpn_dataset from src.lr_schedule import dynamic_lr from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell +from src.eval_utils import eval_for_ctpn, get_eval_result +from src.eval_callback import EvalCallBack set_seed(1) @@ -43,10 +45,30 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") parser.add_argument("--task_type", type=str, default="Pretraining",\ choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining") +parser.add_argument("--run_eval", type=ast.literal_eval, default=False, \ + help="Run evaluation when training, default is False.") +parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True, \ + help="Save best checkpoint when run_eval is True, default is True.") +parser.add_argument("--eval_image_path", type=str, default="", \ + help="eval image path, when run_eval is True, eval_image_path should be set.") +parser.add_argument("--eval_dataset_path", type=str, default="", \ + help="eval dataset path, when run_eval is True, eval_dataset_path should be set.") +parser.add_argument("--eval_start_epoch", type=int, default=10, \ + help="Evaluation start epoch when run_eval is True, default is 10.") +parser.add_argument("--eval_interval", type=int, default=10, \ + help="Evaluation interval when run_eval is True, default is 10.") args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True) +def apply_eval(eval_param): + network = eval_param["eval_network"] + eval_ds = eval_param["eval_dataset"] + eval_image_path = eval_param["eval_image_path"] + eval_for_ctpn(network, eval_ds, eval_image_path) + hmean = get_eval_result() + return hmean + if __name__ == '__main__': if args_opt.run_distribute: rank = args_opt.rank_id @@ -78,7 +100,7 @@ if __name__ == '__main__': dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\ batch_size=config.batch_size, device_num=device_num, rank_id=rank) dataset_size = dataset.get_dataset_size() - net = CTPN(config=config, is_training=True) + net = CTPN(config=config, batch_size=config.batch_size) net = net.set_train() load_path = args_opt.pre_trained @@ -100,20 +122,34 @@ if __name__ == '__main__': weight_decay=config.weight_decay, loss_scale=config.loss_scale) net_with_loss = WithLossCell(net, loss) if args_opt.run_distribute: - net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, - mean=True, degree=device_num) + net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, \ + mean=True, degree=device_num) else: - net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) + net_with_grads = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) time_cb = TimeMonitor(data_size=dataset_size) loss_cb = LossCallBack(rank_id=rank) cb = [time_cb, loss_cb] + save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") if config.save_checkpoint: ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size, keep_checkpoint_max=config.keep_checkpoint_max) - save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) cb += [ckpoint_cb] - - model = Model(net) + if args_opt.run_eval: + if args_opt.eval_dataset_path is None or (not os.path.isfile(args_opt.eval_dataset_path)): + raise ValueError("{} is not a existing path.".format(args_opt.eval_dataset_path)) + if args_opt.eval_image_path is None or (not os.path.isdir(args_opt.eval_image_path)): + raise ValueError("{} is not a existing path.".format(args_opt.eval_image_path)) + eval_dataset = create_ctpn_dataset(args_opt.eval_dataset_path, \ + batch_size=config.batch_size, repeat_num=1, is_training=False) + eval_net = net + eval_param_dict = {"eval_network": eval_net, "eval_dataset": eval_dataset, \ + "eval_image_path": args_opt.eval_image_path} + eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=args_opt.eval_interval, + eval_start_epoch=args_opt.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=save_checkpoint_path, besk_ckpt_name="best_acc.ckpt", + metrics_name="hmean") + cb += [eval_cb] + model = Model(net_with_grads) model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)