parent
aa9fb70f3c
commit
ca90924fa4
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,109 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnnctc eval"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
|
||||
from src.util import CTCLabelConverter, AverageMeter
|
||||
from src.config import Config_CNNCTC
|
||||
from src.dataset import IIIT_Generator_batch
|
||||
from src.cnn_ctc import CNNCTC_Model
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
def test_dataset_creator():
|
||||
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
|
||||
return ds
|
||||
|
||||
|
||||
def test(config):
|
||||
ds = test_dataset_creator()
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
|
||||
ckpt_path = config.CKPT_PATH
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded! from: ', ckpt_path)
|
||||
|
||||
converter = CTCLabelConverter(config.CHARACTER)
|
||||
|
||||
model_run_time = AverageMeter()
|
||||
npu_to_cpu_time = AverageMeter()
|
||||
postprocess_time = AverageMeter()
|
||||
|
||||
count = 0
|
||||
correct_count = 0
|
||||
for data in ds.create_tuple_iterator():
|
||||
img, _, text, _, length = data
|
||||
|
||||
img_tensor = Tensor(img, mstype.float32)
|
||||
|
||||
model_run_begin = time.time()
|
||||
model_predict = net(img_tensor)
|
||||
model_run_end = time.time()
|
||||
model_run_time.update(model_run_end - model_run_begin)
|
||||
|
||||
npu_to_cpu_begin = time.time()
|
||||
model_predict = np.squeeze(model_predict.asnumpy())
|
||||
npu_to_cpu_end = time.time()
|
||||
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
|
||||
|
||||
postprocess_begin = time.time()
|
||||
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
|
||||
preds_index = np.argmax(model_predict, 2)
|
||||
preds_index = np.reshape(preds_index, [-1])
|
||||
preds_str = converter.decode(preds_index, preds_size)
|
||||
postprocess_end = time.time()
|
||||
postprocess_time.update(postprocess_end - postprocess_begin)
|
||||
|
||||
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
|
||||
|
||||
if count == 0:
|
||||
model_run_time.reset()
|
||||
npu_to_cpu_time.reset()
|
||||
postprocess_time.reset()
|
||||
else:
|
||||
print('---------model run time--------', model_run_time.avg)
|
||||
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
|
||||
print('---------postprocess run time--------', postprocess_time.avg)
|
||||
|
||||
print("Prediction samples: \n", preds_str[:5])
|
||||
print("Ground truth: \n", label_str[:5])
|
||||
for pred, label in zip(preds_str, label_str):
|
||||
if pred == label:
|
||||
correct_count += 1
|
||||
count += 1
|
||||
|
||||
print('accuracy: ', correct_count / count)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_opt.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_opt.ckpt_path
|
||||
test(cfg)
|
@ -0,0 +1,57 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
|
||||
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
export RANK_SIZE=8
|
||||
ulimit -u unlimited
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm ./train_parallel_$i/ -rf
|
||||
mkdir ./train_parallel_$i
|
||||
cp ./*.py ./train_parallel_$i
|
||||
cp ./scripts/*.sh ./train_parallel_$i
|
||||
cp -r ./src ./train_parallel_$i
|
||||
cd ./train_parallel_$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
if [ -f $PATH2 ]
|
||||
then
|
||||
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||
else
|
||||
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
|
||||
fi
|
||||
cd .. || exit
|
||||
done
|
||||
|
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: TRAINED_CKPT=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp ./scripts/*.sh ./eval
|
||||
cp -r ./src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
|
||||
cd .. || exit
|
@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ./*.py ./train
|
||||
cp ./scripts/*.sh ./train
|
||||
cp -r ./src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ -f $PATH1 ]
|
||||
then
|
||||
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
|
||||
else
|
||||
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
|
||||
fi
|
||||
cd .. || exit
|
@ -0,0 +1,15 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""src init file"""
|
@ -0,0 +1,71 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""loss callback"""
|
||||
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from .util import AverageMeter
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.loss_avg = AverageMeter()
|
||||
self.timer = AverageMeter()
|
||||
self.start_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
|
||||
if cur_step_in_epoch % 2000 == 1:
|
||||
self.loss_avg = AverageMeter()
|
||||
self.timer = AverageMeter()
|
||||
self.start_time = time.time()
|
||||
else:
|
||||
self.timer.update(time.time() - self.start_time)
|
||||
self.start_time = time.time()
|
||||
|
||||
self.loss_avg.update(loss)
|
||||
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
loss_file = open("./loss.log", "a+")
|
||||
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
self.loss_avg.avg, self.timer.avg))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
self.loss_avg.avg, self.timer.avg))
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network config setting, will be used in train.py and eval.py"""
|
||||
|
||||
class Config_CNNCTC():
|
||||
# model config
|
||||
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
NUM_CLASS = len(CHARACTER) + 1
|
||||
HIDDEN_SIZE = 512
|
||||
FINAL_FEATURE_WIDTH = 26
|
||||
|
||||
# dataset config
|
||||
IMG_H = 32
|
||||
IMG_W = 100
|
||||
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
|
||||
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
|
||||
TRAIN_BATCH_SIZE = 192
|
||||
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
|
||||
TEST_BATCH_SIZE = 256
|
||||
TEST_DATASET_SIZE = 2976
|
||||
TRAIN_EPOCHS = 3
|
||||
|
||||
# training config
|
||||
CKPT_PATH = ''
|
||||
SAVE_PATH = './'
|
||||
LR = 1e-4
|
||||
LR_PARA = 5e-4
|
||||
MOMENTUM = 0.8
|
||||
LOSS_SCALE = 8096
|
||||
SAVE_CKPT_PER_N_STEP = 2000
|
||||
KEEP_CKPT_MAX_NUM = 5
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,88 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""generate ascend rank file"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ascend distribute rank.")
|
||||
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
|
||||
|
||||
def main(rank_table_file):
|
||||
nproc_per_node = 8
|
||||
|
||||
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
|
||||
|
||||
server_id = socket.gethostbyname(socket.gethostname())
|
||||
|
||||
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
|
||||
device_ips = {}
|
||||
for hccn_item in hccn_configs:
|
||||
hccn_item = hccn_item.strip()
|
||||
if hccn_item.startswith('address_'):
|
||||
device_id, device_ip = hccn_item.split('=')
|
||||
device_id = device_id.split('_')[1]
|
||||
device_ips[device_id] = device_ip
|
||||
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
|
||||
|
||||
hccn_table = {}
|
||||
hccn_table['board_id'] = '0x002f' # A+K
|
||||
# hccn_table['board_id'] = '0x0000' # A+X
|
||||
|
||||
hccn_table['chip_info'] = '910'
|
||||
hccn_table['deploy_mode'] = 'lab'
|
||||
hccn_table['group_count'] = '1'
|
||||
hccn_table['group_list'] = []
|
||||
instance_list = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
instance = {}
|
||||
instance['devices'] = []
|
||||
device_id = visible_devices[instance_id]
|
||||
device_ip = device_ips[device_id]
|
||||
instance['devices'].append({
|
||||
'device_id': device_id,
|
||||
'device_ip': device_ip,
|
||||
})
|
||||
instance['rank_id'] = str(instance_id)
|
||||
instance['server_id'] = server_id
|
||||
instance_list.append(instance)
|
||||
hccn_table['group_list'].append({
|
||||
'device_num': str(nproc_per_node),
|
||||
'server_num': '1',
|
||||
'group_name': '',
|
||||
'instance_count': str(nproc_per_node),
|
||||
'instance_list': instance_list,
|
||||
})
|
||||
hccn_table['para_plane_nic_location'] = 'device'
|
||||
hccn_table['para_plane_nic_name'] = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
eth_id = visible_devices[instance_id]
|
||||
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
|
||||
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
|
||||
hccn_table['status'] = 'completed'
|
||||
import json
|
||||
with open(rank_table_file, 'w') as table_fp:
|
||||
json.dump(hccn_table, table_fp, indent=4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = parser.parse_args()
|
||||
rank_table = args_opt.rank_file
|
||||
if os.path.exists(rank_table):
|
||||
print('Rank table file exists.')
|
||||
else:
|
||||
print('Generating rank table file.')
|
||||
main(rank_table)
|
||||
print('Rank table file generated')
|
@ -0,0 +1,171 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""preprocess dataset"""
|
||||
|
||||
import random
|
||||
import pickle
|
||||
import numpy as np
|
||||
import lmdb
|
||||
from tqdm import tqdm
|
||||
|
||||
def combine_lmdbs(lmdb_paths, lmdb_save_path):
|
||||
max_len = int((26 + 1) // 2)
|
||||
character = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
env_save = lmdb.open(
|
||||
lmdb_save_path,
|
||||
map_size=1099511627776)
|
||||
|
||||
cnt = 0
|
||||
for lmdb_path in lmdb_paths:
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
# Filtering
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
if len(label) > max_len:
|
||||
continue
|
||||
|
||||
illegal_sample = False
|
||||
for char_item in label.lower():
|
||||
if char_item not in character:
|
||||
illegal_sample = True
|
||||
break
|
||||
if illegal_sample:
|
||||
continue
|
||||
|
||||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
|
||||
with env_save.begin(write=True) as txn_save:
|
||||
cnt += 1
|
||||
|
||||
label_key_save = 'label-%09d'.encode() % cnt
|
||||
label_save = label.encode()
|
||||
image_key_save = 'image-%09d'.encode() % cnt
|
||||
image_save = imgbuf
|
||||
|
||||
txn_save.put(label_key_save, label_save)
|
||||
txn_save.put(image_key_save, image_save)
|
||||
|
||||
nSamples = cnt
|
||||
with env_save.begin(write=True) as txn_save:
|
||||
txn_save.put('num-samples'.encode(), str(nSamples).encode())
|
||||
|
||||
|
||||
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
|
||||
label_length_dict = {}
|
||||
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
label_length = len(label)
|
||||
if label_length in label_length_dict:
|
||||
label_length_dict[label_length] += 1
|
||||
else:
|
||||
label_length_dict[label_length] = 1
|
||||
|
||||
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
label_length_sum = 0
|
||||
label_num = 0
|
||||
lengths = []
|
||||
p = []
|
||||
for l, num in sorted_label_length:
|
||||
label_length_sum += l * num
|
||||
label_num += num
|
||||
p.append(num)
|
||||
lengths.append(l)
|
||||
for i, _ in enumerate(p):
|
||||
p[i] /= label_num
|
||||
|
||||
average_overall_length = int(label_length_sum / label_num * batch_size)
|
||||
|
||||
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
|
||||
ret = []
|
||||
cur_sum = 0
|
||||
ret = np.random.choice(items, batch_size - 1, True, p)
|
||||
cur_sum = sum(ret)
|
||||
ret = list(ret)
|
||||
if fix_length - cur_sum in items:
|
||||
ret.append(fix_length - cur_sum)
|
||||
else:
|
||||
return None
|
||||
return ret
|
||||
|
||||
result = []
|
||||
while len(result) < num_of_combinations:
|
||||
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
|
||||
if ret is not None:
|
||||
result.append(ret)
|
||||
return result
|
||||
|
||||
|
||||
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
|
||||
length_index_dict = {}
|
||||
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
label_length = len(label)
|
||||
if label_length in length_index_dict:
|
||||
length_index_dict[label_length].append(index)
|
||||
else:
|
||||
length_index_dict[label_length] = [index]
|
||||
|
||||
ret = []
|
||||
for _ in range(num_of_iters):
|
||||
comb = random.choice(combinations)
|
||||
for l in comb:
|
||||
ret.append(random.choice(length_index_dict[l]))
|
||||
|
||||
with open(pkl_save_path, 'wb') as f:
|
||||
pickle.dump(ret, f, -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
|
||||
print('Begin to combine multiple lmdb datasets')
|
||||
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||
|
||||
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
|
||||
print('Begin to generate the index order of input data')
|
||||
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
|
||||
|
||||
print('Done')
|
@ -0,0 +1,102 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""util file"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class CTCLabelConverter():
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character):
|
||||
# character (str): set of the possible characters.
|
||||
dict_character = list(character)
|
||||
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
|
||||
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
|
||||
self.dict['[blank]'] = len(dict_character)
|
||||
|
||||
def encode(self, text):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
||||
output:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
length = [len(s) for s in text]
|
||||
text = ''.join(text)
|
||||
text = [self.dict[char] for char in text]
|
||||
|
||||
return np.array(text), np.array(length)
|
||||
|
||||
def decode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
texts = []
|
||||
index = 0
|
||||
for l in length:
|
||||
t = text_index[index:index + l]
|
||||
|
||||
char_list = []
|
||||
for i in range(l):
|
||||
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||
if t[i] != self.dict['[blank]'] and (
|
||||
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||
char_list.append(self.character[t[i]])
|
||||
text = ''.join(char_list)
|
||||
|
||||
texts.append(text)
|
||||
index += l
|
||||
return texts
|
||||
|
||||
def reverse_encode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
texts = []
|
||||
index = 0
|
||||
for l in length:
|
||||
t = text_index[index:index + l]
|
||||
|
||||
char_list = []
|
||||
for i in range(l):
|
||||
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
|
||||
char_list.append(self.character[t[i]])
|
||||
text = ''.join(char_list)
|
||||
|
||||
texts.append(text)
|
||||
index += l
|
||||
return texts
|
@ -0,0 +1,100 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnnctc train"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import Config_CNNCTC
|
||||
from src.callback import LossCallBack
|
||||
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
|
||||
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
|
||||
|
||||
set_seed(1)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
|
||||
def dataset_creator(run_distribute):
|
||||
if run_distribute:
|
||||
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
|
||||
else:
|
||||
st_dataset = ST_MJ_Generator_batch_fixed_length()
|
||||
|
||||
ds = GeneratorDataset(st_dataset,
|
||||
['img', 'label_indices', 'text', 'sequence_length'],
|
||||
num_parallel_workers=8)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def train(args_opt, config):
|
||||
if args_opt.run_distribute:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel")
|
||||
|
||||
ds = dataset_creator(args_opt.run_distribute)
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
net.set_train(True)
|
||||
|
||||
if config.CKPT_PATH != '':
|
||||
param_dict = load_checkpoint(config.CKPT_PATH)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded!')
|
||||
else:
|
||||
print('train from scratch...')
|
||||
|
||||
criterion = ctc_loss()
|
||||
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
|
||||
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
|
||||
|
||||
net = WithLossCell(net, criterion)
|
||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
|
||||
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
|
||||
|
||||
callback = LossCallBack()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
|
||||
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
|
||||
|
||||
if args_opt.device_id == 0:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
|
||||
else:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='CNNCTC arg')
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
args_cfg = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_cfg.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_cfg.ckpt_path
|
||||
train(args_cfg, cfg)
|
Loading…
Reference in new issue