parent
7d09cd1928
commit
fa675f8954
@ -1,103 +0,0 @@
|
||||
Global:
|
||||
use_gpu: false
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/mv3_none_none_ctc/
|
||||
save_epoch_step: 500
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: 2000
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: True
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
max_text_length: 25
|
||||
character_dict_path:
|
||||
character_type: 'en'
|
||||
use_space_char: False
|
||||
infer_mode: False
|
||||
use_tps: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
learning_rate:
|
||||
lr: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00001
|
||||
|
||||
Architecture:
|
||||
type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
small_stride: [ 1, 2, 2, 2 ]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: reshape
|
||||
Head:
|
||||
name: CTC
|
||||
fc_decay: 0.00001
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
TRAIN:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
file_list:
|
||||
- ./rec/train # dataset1
|
||||
ratio_list: [ 0.4,0.6 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecAug:
|
||||
- RecResizeImg:
|
||||
image_shape: [ 3,32,100 ]
|
||||
- keepKeys:
|
||||
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
|
||||
loader:
|
||||
batch_size: 256
|
||||
shuffle: True
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
EVAL:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
file_list:
|
||||
- ./rec/val/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [ 3,32,100 ]
|
||||
- keepKeys:
|
||||
keep_keys: [ 'image','label','length' ] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size: 256
|
||||
num_workers: 8
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,131 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import time
|
||||
import lmdb
|
||||
import cv2
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
from ppocr.utils.logging import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
class LMDBDateSet(Dataset):
|
||||
def __init__(self, config, mode):
|
||||
super(LMDBDateSet, self).__init__()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % data_dir)
|
||||
self.data_idx_order_list = self.dataset_traversal()
|
||||
if self.do_shuffle:
|
||||
np.random.shuffle(self.data_idx_order_list)
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
|
||||
# # for rec
|
||||
# character = ''
|
||||
# for op in self.ops:
|
||||
# if hasattr(op, 'character'):
|
||||
# character = getattr(op, 'character')
|
||||
|
||||
# self.info_dict = {'character': character}
|
||||
|
||||
def load_hierarchical_lmdb_dataset(self, data_dir):
|
||||
lmdb_sets = {}
|
||||
dataset_idx = 0
|
||||
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
|
||||
if not dirnames:
|
||||
env = lmdb.open(
|
||||
dirpath,
|
||||
max_readers=32,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
meminit=False)
|
||||
txn = env.begin(write=False)
|
||||
num_samples = int(txn.get('num-samples'.encode()))
|
||||
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
|
||||
"txn":txn, "num_samples":num_samples}
|
||||
dataset_idx += 1
|
||||
return lmdb_sets
|
||||
|
||||
def dataset_traversal(self):
|
||||
lmdb_num = len(self.lmdb_sets)
|
||||
total_sample_num = 0
|
||||
for lno in range(lmdb_num):
|
||||
total_sample_num += self.lmdb_sets[lno]['num_samples']
|
||||
data_idx_order_list = np.zeros((total_sample_num, 2))
|
||||
beg_idx = 0
|
||||
for lno in range(lmdb_num):
|
||||
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
|
||||
end_idx = beg_idx + tmp_sample_num
|
||||
data_idx_order_list[beg_idx:end_idx, 0] = lno
|
||||
data_idx_order_list[beg_idx:end_idx, 1] \
|
||||
= list(range(tmp_sample_num))
|
||||
data_idx_order_list[beg_idx:end_idx, 1] += 1
|
||||
beg_idx = beg_idx + tmp_sample_num
|
||||
return data_idx_order_list
|
||||
|
||||
def get_img_data(self, value):
|
||||
"""get_img_data"""
|
||||
if not value:
|
||||
return None
|
||||
imgdata = np.frombuffer(value, dtype='uint8')
|
||||
if imgdata is None:
|
||||
return None
|
||||
imgori = cv2.imdecode(imgdata, 1)
|
||||
if imgori is None:
|
||||
return None
|
||||
return imgori
|
||||
|
||||
def get_lmdb_sample_info(self, txn, index):
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key)
|
||||
if label is None:
|
||||
return None
|
||||
label = label.decode('utf-8')
|
||||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
return imgbuf, label
|
||||
|
||||
def __getitem__(self, idx):
|
||||
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
||||
lmdb_idx = int(lmdb_idx)
|
||||
file_idx = int(file_idx)
|
||||
sample_info = self.get_lmdb_sample_info(
|
||||
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
|
||||
if sample_info is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
img, label = sample_info
|
||||
data = {'image': img, 'label': label}
|
||||
outs = transform(data, self.ops)
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return self.data_idx_order_list.shape[0]
|
||||
|
@ -0,0 +1,122 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import time
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
from ppocr.utils.logging import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode):
|
||||
super(SimpleDataSet, self).__init__()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
if data_source_num == 1:
|
||||
ratio_list = [1.0]
|
||||
else:
|
||||
ratio_list = dataset_config.pop('ratio_list')
|
||||
|
||||
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
|
||||
assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines_list, data_num_list = self.get_image_info_list(
|
||||
label_file_list)
|
||||
self.data_idx_order_list = self.dataset_traversal(
|
||||
data_num_list, ratio_list, batch_size)
|
||||
self.shuffle_data_random()
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def get_image_info_list(self, file_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines_list = []
|
||||
data_num_list = []
|
||||
for file in file_list:
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
data_lines_list.append(lines)
|
||||
data_num_list.append(len(lines))
|
||||
return data_lines_list, data_num_list
|
||||
|
||||
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
|
||||
select_num_list = []
|
||||
dataset_num = len(data_num_list)
|
||||
for dno in range(dataset_num):
|
||||
select_num = round(batch_size * ratio_list[dno])
|
||||
select_num = max(select_num, 1)
|
||||
select_num_list.append(select_num)
|
||||
data_idx_order_list = []
|
||||
cur_index_sets = [0] * dataset_num
|
||||
while True:
|
||||
finish_read_num = 0
|
||||
for dataset_idx in range(dataset_num):
|
||||
cur_index = cur_index_sets[dataset_idx]
|
||||
if cur_index >= data_num_list[dataset_idx]:
|
||||
finish_read_num += 1
|
||||
else:
|
||||
select_num = select_num_list[dataset_idx]
|
||||
for sno in range(select_num):
|
||||
cur_index = cur_index_sets[dataset_idx]
|
||||
if cur_index >= data_num_list[dataset_idx]:
|
||||
break
|
||||
data_idx_order_list.append((
|
||||
dataset_idx, cur_index))
|
||||
cur_index_sets[dataset_idx] += 1
|
||||
if finish_read_num == dataset_num:
|
||||
break
|
||||
return data_idx_order_list
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
for dno in range(len(self.data_lines_list)):
|
||||
random.shuffle(self.data_lines_list[dno])
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines_list[dataset_idx][file_idx]
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
|
@ -1,26 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from .losses import build_loss
|
||||
|
||||
__all__ = ['build_model', 'build_loss']
|
||||
|
||||
|
||||
def build_model(config):
|
||||
from .architectures import Model
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_class = Model(config)
|
||||
return module_class
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue