diff --git a/PPOCRLabel/README.md b/PPOCRLabel/README.md index 41a7ab4..a83b770 100644 --- a/PPOCRLabel/README.md +++ b/PPOCRLabel/README.md @@ -9,7 +9,7 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w ### Recent Update - 2021.1.11: Optimize the labeling experience (by [edencfc](https://github.com/edencfc)), - - Users can choose whether to pop up the label input dialog after drawing the detection box in "View - Pop-up Label Input Dialog". + - Users can choose whether to pop up the label input dialog after drawing the detection box in "View - Pop-up Label Input Dialog". - The recognition result scrolls synchronously when users click related detection box. - Click to modify the recognition result.(If you can't change the result, please switch to the system default input method, or switch back to the original input method again) - 2020.12.18: Support re-recognition of a single label box (by [ninetailskim](https://github.com/ninetailskim) ), perfect shortcut keys. @@ -49,7 +49,7 @@ python3 PPOCRLabel.py ``` pip3 install pyqt5 pip3 uninstall opencv-python # Uninstall opencv manually as it conflicts with pyqt -pip3 install opencv-contrib-python-headless # Install the headless version of opencv +pip3 install opencv-contrib-python-headless==4.2.0.32 # Install the headless version of opencv cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder python3 PPOCRLabel.py ``` @@ -127,7 +127,7 @@ Therefore, if the recognition result has been manually changed before, it may ch - Default model: PPOCRLabel uses the Chinese and English ultra-lightweight OCR model in PaddleOCR by default, supports Chinese, English and number recognition, and multiple language detection. -- Model language switching: Changing the built-in model language is supportable by clicking "PaddleOCR"-"Choose OCR Model" in the menu bar. Currently supported languages​include French, German, Korean, and Japanese. +- Model language switching: Changing the built-in model language is supportable by clicking "PaddleOCR"-"Choose OCR Model" in the menu bar. Currently supported languages​include French, German, Korean, and Japanese. For specific model download links, please refer to [PaddleOCR Model List](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md#multilingual-recognition-modelupdating) - Custom model: The model trained by users can be replaced by modifying PPOCRLabel.py in [PaddleOCR class instantiation](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/PPOCRLabel.py#L110) referring [Custom Model Code](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/whl_en.md#use-custom-model) @@ -160,11 +160,11 @@ For some data that are difficult to recognize, the recognition results will not ``` pyrcc5 -o libs/resources.py resources.qrc ``` -- If you get an error ``` module 'cv2' has no attribute 'INTER_NEAREST'```, you need to delete all opencv related packages first, and then reinstall the headless version of opencv +- If you get an error ``` module 'cv2' has no attribute 'INTER_NEAREST'```, you need to delete all opencv related packages first, and then reinstall the 4.2.0.32 version of headless opencv ``` - pip install opencv-contrib-python-headless + pip install opencv-contrib-python-headless==4.2.0.32 ``` - + ### Related 1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg) diff --git a/PPOCRLabel/README_ch.md b/PPOCRLabel/README_ch.md index df4f7df..b9bfc9e 100644 --- a/PPOCRLabel/README_ch.md +++ b/PPOCRLabel/README_ch.md @@ -49,7 +49,7 @@ python3 PPOCRLabel.py --lang ch ``` pip3 install pyqt5 pip3 uninstall opencv-python # 由于mac版本的opencv与pyqt有冲突,需先手动卸载opencv -pip3 install opencv-contrib-python-headless # 安装headless版本的open-cv +pip3 install opencv-contrib-python-headless==4.2.0.32 # 安装headless版本的open-cv cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下 python3 PPOCRLabel.py --lang ch ``` @@ -132,22 +132,22 @@ PPOCRLabel支持三种保存方式: ### 错误提示 - 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。 - + - PPOCRLabel**不支持对中文文件名**的图片进行自动标注。 - 针对Linux用户:如果您在打开软件过程中出现**objc[XXXXX]**开头的错误,证明您的opencv版本太高,建议安装4.2版本: ``` pip install opencv-python==4.2.0.32 ``` - + - 如果出现 ```Missing string id``` 开头的错误,需要重新编译资源: ``` pyrcc5 -o libs/resources.py resources.qrc ``` - -- 如果出现``` module 'cv2' has no attribute 'INTER_NEAREST'```错误,需要首先删除所有opencv相关包,然后重新安装headless版本的opencv + +- 如果出现``` module 'cv2' has no attribute 'INTER_NEAREST'```错误,需要首先删除所有opencv相关包,然后重新安装4.2.0.32版本的headless opencv ``` - pip install opencv-contrib-python-headless + pip install opencv-contrib-python-headless==4.2.0.32 ``` ### 参考资料 diff --git a/README.md b/README.md index 67d65e9..fb88ef0 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm - Static graph: develop branch **Recent updates** +- 2021.2.8 Release PaddleOCRv2.0(branch release/2.0) and set as default branch. Check release note here: https://github.com/PaddlePaddle/PaddleOCR/releases/tag/v2.0.0 - 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048). - 2020.12.15 update Data synthesis tool, i.e., [Style-Text](./StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image. - 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](./PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly. diff --git a/README_ch.md b/README_ch.md index 543315c..3119752 100755 --- a/README_ch.md +++ b/README_ch.md @@ -8,6 +8,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式 - 静态图版本:develop分支 **近期更新** +- 2021.2.8 正式发布PaddleOCRv2.0(branch release/2.0)并设置为推荐用户使用的默认分支. 发布的详细内容,请参考: https://github.com/PaddlePaddle/PaddleOCR/releases/tag/v2.0.0 - 2021.2.8 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数167个,每周一都会更新,欢迎大家持续关注。 - 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802) - 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048) diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml index 0ce0673..3cf1f7a 100644 --- a/configs/rec/rec_mv3_tps_bilstm_att.yml +++ b/configs/rec/rec_mv3_tps_bilstm_att.yml @@ -66,7 +66,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: ../training/ + data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image img_mode: BGR @@ -85,7 +85,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ../validation/ + data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image img_mode: BGR diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml index 02aeb8c..659a172 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_att.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml @@ -65,7 +65,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: ../training/ + data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image img_mode: BGR @@ -84,7 +84,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ../validation/ + data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image img_mode: BGR diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml index ec7f170..6b38616 100644 --- a/configs/rec/rec_r50_fpn_srn.yml +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -59,7 +59,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: ./train_data/srn_train_data_duiqi + data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image img_mode: BGR @@ -84,7 +84,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ./train_data/data_lmdb_release/evaluation + data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image img_mode: BGR diff --git a/doc/joinus.PNG b/doc/joinus.PNG index 22258be..064159c 100644 Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 55870a5..7a32d87 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -215,7 +215,7 @@ class AttnLabelEncode(BaseRecLabelEncode): return None data['length'] = np.array(len(text)) text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len - - len(text) - 1) + - len(text) - 2) data['label'] = np.array(text) return data @@ -261,7 +261,7 @@ class SRNLabelEncode(BaseRecLabelEncode): if len(text) > self.max_text_len: return None data['length'] = np.array(len(text)) - text = text + [char_num] * (self.max_text_len - len(text)) + text = text + [char_num - 1] * (self.max_text_len - len(text)) data['label'] = np.array(text) return data diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index a7cfe12..0d22271 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -57,6 +57,9 @@ class AttentionHead(nn.Layer): else: targets = paddle.zeros(shape=[batch_size], dtype="int32") probs = None + char_onehots = None + outputs = None + alpha = None for i in range(num_steps): char_onehots = self._char_to_onehot( diff --git a/requirements.txt b/requirements.txt index 1321896..2401d52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ shapely -imgaug +scikit-image==0.17.2 +imgaug==0.4.0 pyclipper lmdb opencv-python==4.2.0.32 diff --git a/tools/eval.py b/tools/eval.py index 16cfe53..4afed46 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -47,6 +47,7 @@ def main(): config['Architecture']["Head"]['out_channels'] = len( getattr(post_process_class, 'character')) model = build_model(config['Architecture']) + use_srn = config['Architecture']['algorithm'] == "SRN" best_model_dict = init_model(config, model, logger) if len(best_model_dict): @@ -59,7 +60,7 @@ def main(): # start eval metirc = program.eval(model, valid_dataloader, post_process_class, - eval_class) + eval_class, use_srn) logger.info('metric eval ***************') for k, v in metirc.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index fd895e5..b3d9d49 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -54,6 +54,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "RARE": + postprocess_params = { + 'name': 'AttnLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors = \ utility.create_predictor(args, 'rec', logger) diff --git a/tools/program.py b/tools/program.py index 34d484d..6277d74 100755 --- a/tools/program.py +++ b/tools/program.py @@ -182,6 +182,8 @@ def train(config, model_average = False model.train() + use_srn = config['Architecture']['algorithm'] == "SRN" + if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -200,7 +202,7 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - if config['Architecture']['algorithm'] == "SRN": + if use_srn: others = batch[-4:] preds = model(images, others) model_average = True @@ -256,8 +258,12 @@ def train(config, min_average_window=10000, max_average_window=15625) Model_Average.apply() - cur_metric = eval(model, valid_dataloader, post_process_class, - eval_class) + cur_metric = eval( + model, + valid_dataloader, + post_process_class, + eval_class, + use_srn=use_srn) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -321,7 +327,8 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class): +def eval(model, valid_dataloader, post_process_class, eval_class, + use_srn=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -332,7 +339,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class): break images = batch[0] start = time.time() - if "SRN" in str(model.head): + + if use_srn: others = batch[-4:] preds = model(images, others) else: