!6383 Add modelzoo CNNCTC Network.
Merge pull request !6383 from linqingke/fasterrcnnpull/6383/MERGE
commit
737e27d721
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnnctc eval"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.dataset import GeneratorDataset
|
||||||
|
|
||||||
|
from src.util import CTCLabelConverter, AverageMeter
|
||||||
|
from src.config import Config_CNNCTC
|
||||||
|
from src.dataset import IIIT_Generator_batch
|
||||||
|
from src.cnn_ctc import CNNCTC_Model
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||||
|
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||||
|
|
||||||
|
def test_dataset_creator():
|
||||||
|
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def test(config):
|
||||||
|
ds = test_dataset_creator()
|
||||||
|
|
||||||
|
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||||
|
|
||||||
|
ckpt_path = config.CKPT_PATH
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
print('parameters loaded! from: ', ckpt_path)
|
||||||
|
|
||||||
|
converter = CTCLabelConverter(config.CHARACTER)
|
||||||
|
|
||||||
|
model_run_time = AverageMeter()
|
||||||
|
npu_to_cpu_time = AverageMeter()
|
||||||
|
postprocess_time = AverageMeter()
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
correct_count = 0
|
||||||
|
for data in ds.create_tuple_iterator():
|
||||||
|
img, _, text, _, length = data
|
||||||
|
|
||||||
|
img_tensor = Tensor(img, mstype.float32)
|
||||||
|
|
||||||
|
model_run_begin = time.time()
|
||||||
|
model_predict = net(img_tensor)
|
||||||
|
model_run_end = time.time()
|
||||||
|
model_run_time.update(model_run_end - model_run_begin)
|
||||||
|
|
||||||
|
npu_to_cpu_begin = time.time()
|
||||||
|
model_predict = np.squeeze(model_predict.asnumpy())
|
||||||
|
npu_to_cpu_end = time.time()
|
||||||
|
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
|
||||||
|
|
||||||
|
postprocess_begin = time.time()
|
||||||
|
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
|
||||||
|
preds_index = np.argmax(model_predict, 2)
|
||||||
|
preds_index = np.reshape(preds_index, [-1])
|
||||||
|
preds_str = converter.decode(preds_index, preds_size)
|
||||||
|
postprocess_end = time.time()
|
||||||
|
postprocess_time.update(postprocess_end - postprocess_begin)
|
||||||
|
|
||||||
|
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
model_run_time.reset()
|
||||||
|
npu_to_cpu_time.reset()
|
||||||
|
postprocess_time.reset()
|
||||||
|
else:
|
||||||
|
print('---------model run time--------', model_run_time.avg)
|
||||||
|
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
|
||||||
|
print('---------postprocess run time--------', postprocess_time.avg)
|
||||||
|
|
||||||
|
print("Prediction samples: \n", preds_str[:5])
|
||||||
|
print("Ground truth: \n", label_str[:5])
|
||||||
|
for pred, label in zip(preds_str, label_str):
|
||||||
|
if pred == label:
|
||||||
|
correct_count += 1
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print('accuracy: ', correct_count / count)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
cfg = Config_CNNCTC()
|
||||||
|
if args_opt.ckpt_path != "":
|
||||||
|
cfg.CKPT_PATH = args_opt.ckpt_path
|
||||||
|
test(cfg)
|
@ -0,0 +1,57 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
echo $PATH2
|
||||||
|
|
||||||
|
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
export RANK_SIZE=8
|
||||||
|
ulimit -u unlimited
|
||||||
|
for((i=0;i<$RANK_SIZE;i++));
|
||||||
|
do
|
||||||
|
rm ./train_parallel_$i/ -rf
|
||||||
|
mkdir ./train_parallel_$i
|
||||||
|
cp ./*.py ./train_parallel_$i
|
||||||
|
cp ./scripts/*.sh ./train_parallel_$i
|
||||||
|
cp -r ./src ./train_parallel_$i
|
||||||
|
cd ./train_parallel_$i || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
if [ -f $PATH2 ]
|
||||||
|
then
|
||||||
|
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||||
|
else
|
||||||
|
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
|
||||||
|
fi
|
||||||
|
cd .. || exit
|
||||||
|
done
|
||||||
|
|
@ -0,0 +1,54 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# -ne 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: TRAINED_CKPT=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ./*.py ./eval
|
||||||
|
cp ./scripts/*.sh ./eval
|
||||||
|
cp -r ./src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
echo "start infering for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
|
||||||
|
cd .. || exit
|
@ -0,0 +1,45 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ./*.py ./train
|
||||||
|
cp ./scripts/*.sh ./train
|
||||||
|
cp -r ./src ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
if [ -f $PATH1 ]
|
||||||
|
then
|
||||||
|
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
|
||||||
|
else
|
||||||
|
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
|
||||||
|
fi
|
||||||
|
cd .. || exit
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""src init file"""
|
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""loss callback"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from .util import AverageMeter
|
||||||
|
|
||||||
|
class LossCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
|
||||||
|
If the loss is NAN or INF terminating training.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
per_print_times (int): Print loss every times. Default: 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, per_print_times=1):
|
||||||
|
super(LossCallBack, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0.")
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
self.loss_avg = AverageMeter()
|
||||||
|
self.timer = AverageMeter()
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
|
||||||
|
loss = cb_params.net_outputs.asnumpy()
|
||||||
|
|
||||||
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
|
cur_num = cb_params.cur_step_num
|
||||||
|
|
||||||
|
if cur_step_in_epoch % 2000 == 1:
|
||||||
|
self.loss_avg = AverageMeter()
|
||||||
|
self.timer = AverageMeter()
|
||||||
|
self.start_time = time.time()
|
||||||
|
else:
|
||||||
|
self.timer.update(time.time() - self.start_time)
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
self.loss_avg.update(loss)
|
||||||
|
|
||||||
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||||
|
loss_file = open("./loss.log", "a+")
|
||||||
|
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||||
|
self.loss_avg.avg, self.timer.avg))
|
||||||
|
loss_file.write("\n")
|
||||||
|
loss_file.close()
|
||||||
|
|
||||||
|
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||||
|
self.loss_avg.avg, self.timer.avg))
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""network config setting, will be used in train.py and eval.py"""
|
||||||
|
|
||||||
|
class Config_CNNCTC():
|
||||||
|
# model config
|
||||||
|
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
NUM_CLASS = len(CHARACTER) + 1
|
||||||
|
HIDDEN_SIZE = 512
|
||||||
|
FINAL_FEATURE_WIDTH = 26
|
||||||
|
|
||||||
|
# dataset config
|
||||||
|
IMG_H = 32
|
||||||
|
IMG_W = 100
|
||||||
|
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
|
||||||
|
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
|
||||||
|
TRAIN_BATCH_SIZE = 192
|
||||||
|
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
|
||||||
|
TEST_BATCH_SIZE = 256
|
||||||
|
TEST_DATASET_SIZE = 2976
|
||||||
|
TRAIN_EPOCHS = 3
|
||||||
|
|
||||||
|
# training config
|
||||||
|
CKPT_PATH = ''
|
||||||
|
SAVE_PATH = './'
|
||||||
|
LR = 1e-4
|
||||||
|
LR_PARA = 5e-4
|
||||||
|
MOMENTUM = 0.8
|
||||||
|
LOSS_SCALE = 8096
|
||||||
|
SAVE_CKPT_PER_N_STEP = 2000
|
||||||
|
KEEP_CKPT_MAX_NUM = 5
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,88 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""generate ascend rank file"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ascend distribute rank.")
|
||||||
|
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
|
||||||
|
|
||||||
|
def main(rank_table_file):
|
||||||
|
nproc_per_node = 8
|
||||||
|
|
||||||
|
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
|
||||||
|
|
||||||
|
server_id = socket.gethostbyname(socket.gethostname())
|
||||||
|
|
||||||
|
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
|
||||||
|
device_ips = {}
|
||||||
|
for hccn_item in hccn_configs:
|
||||||
|
hccn_item = hccn_item.strip()
|
||||||
|
if hccn_item.startswith('address_'):
|
||||||
|
device_id, device_ip = hccn_item.split('=')
|
||||||
|
device_id = device_id.split('_')[1]
|
||||||
|
device_ips[device_id] = device_ip
|
||||||
|
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
|
||||||
|
|
||||||
|
hccn_table = {}
|
||||||
|
hccn_table['board_id'] = '0x002f' # A+K
|
||||||
|
# hccn_table['board_id'] = '0x0000' # A+X
|
||||||
|
|
||||||
|
hccn_table['chip_info'] = '910'
|
||||||
|
hccn_table['deploy_mode'] = 'lab'
|
||||||
|
hccn_table['group_count'] = '1'
|
||||||
|
hccn_table['group_list'] = []
|
||||||
|
instance_list = []
|
||||||
|
for instance_id in range(nproc_per_node):
|
||||||
|
instance = {}
|
||||||
|
instance['devices'] = []
|
||||||
|
device_id = visible_devices[instance_id]
|
||||||
|
device_ip = device_ips[device_id]
|
||||||
|
instance['devices'].append({
|
||||||
|
'device_id': device_id,
|
||||||
|
'device_ip': device_ip,
|
||||||
|
})
|
||||||
|
instance['rank_id'] = str(instance_id)
|
||||||
|
instance['server_id'] = server_id
|
||||||
|
instance_list.append(instance)
|
||||||
|
hccn_table['group_list'].append({
|
||||||
|
'device_num': str(nproc_per_node),
|
||||||
|
'server_num': '1',
|
||||||
|
'group_name': '',
|
||||||
|
'instance_count': str(nproc_per_node),
|
||||||
|
'instance_list': instance_list,
|
||||||
|
})
|
||||||
|
hccn_table['para_plane_nic_location'] = 'device'
|
||||||
|
hccn_table['para_plane_nic_name'] = []
|
||||||
|
for instance_id in range(nproc_per_node):
|
||||||
|
eth_id = visible_devices[instance_id]
|
||||||
|
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
|
||||||
|
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
|
||||||
|
hccn_table['status'] = 'completed'
|
||||||
|
import json
|
||||||
|
with open(rank_table_file, 'w') as table_fp:
|
||||||
|
json.dump(hccn_table, table_fp, indent=4)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
rank_table = args_opt.rank_file
|
||||||
|
if os.path.exists(rank_table):
|
||||||
|
print('Rank table file exists.')
|
||||||
|
else:
|
||||||
|
print('Generating rank table file.')
|
||||||
|
main(rank_table)
|
||||||
|
print('Rank table file generated')
|
@ -0,0 +1,171 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""preprocess dataset"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import lmdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def combine_lmdbs(lmdb_paths, lmdb_save_path):
|
||||||
|
max_len = int((26 + 1) // 2)
|
||||||
|
character = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
|
||||||
|
env_save = lmdb.open(
|
||||||
|
lmdb_save_path,
|
||||||
|
map_size=1099511627776)
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
for lmdb_path in lmdb_paths:
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
# Filtering
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
if len(label) > max_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
illegal_sample = False
|
||||||
|
for char_item in label.lower():
|
||||||
|
if char_item not in character:
|
||||||
|
illegal_sample = True
|
||||||
|
break
|
||||||
|
if illegal_sample:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_key = 'image-%09d'.encode() % index
|
||||||
|
imgbuf = txn.get(img_key)
|
||||||
|
|
||||||
|
with env_save.begin(write=True) as txn_save:
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
label_key_save = 'label-%09d'.encode() % cnt
|
||||||
|
label_save = label.encode()
|
||||||
|
image_key_save = 'image-%09d'.encode() % cnt
|
||||||
|
image_save = imgbuf
|
||||||
|
|
||||||
|
txn_save.put(label_key_save, label_save)
|
||||||
|
txn_save.put(image_key_save, image_save)
|
||||||
|
|
||||||
|
nSamples = cnt
|
||||||
|
with env_save.begin(write=True) as txn_save:
|
||||||
|
txn_save.put('num-samples'.encode(), str(nSamples).encode())
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
|
||||||
|
label_length_dict = {}
|
||||||
|
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
label_length = len(label)
|
||||||
|
if label_length in label_length_dict:
|
||||||
|
label_length_dict[label_length] += 1
|
||||||
|
else:
|
||||||
|
label_length_dict[label_length] = 1
|
||||||
|
|
||||||
|
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
label_length_sum = 0
|
||||||
|
label_num = 0
|
||||||
|
lengths = []
|
||||||
|
p = []
|
||||||
|
for l, num in sorted_label_length:
|
||||||
|
label_length_sum += l * num
|
||||||
|
label_num += num
|
||||||
|
p.append(num)
|
||||||
|
lengths.append(l)
|
||||||
|
for i, _ in enumerate(p):
|
||||||
|
p[i] /= label_num
|
||||||
|
|
||||||
|
average_overall_length = int(label_length_sum / label_num * batch_size)
|
||||||
|
|
||||||
|
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
|
||||||
|
ret = []
|
||||||
|
cur_sum = 0
|
||||||
|
ret = np.random.choice(items, batch_size - 1, True, p)
|
||||||
|
cur_sum = sum(ret)
|
||||||
|
ret = list(ret)
|
||||||
|
if fix_length - cur_sum in items:
|
||||||
|
ret.append(fix_length - cur_sum)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
result = []
|
||||||
|
while len(result) < num_of_combinations:
|
||||||
|
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
|
||||||
|
if ret is not None:
|
||||||
|
result.append(ret)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
|
||||||
|
length_index_dict = {}
|
||||||
|
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
label_length = len(label)
|
||||||
|
if label_length in length_index_dict:
|
||||||
|
length_index_dict[label_length].append(index)
|
||||||
|
else:
|
||||||
|
length_index_dict[label_length] = [index]
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for _ in range(num_of_iters):
|
||||||
|
comb = random.choice(combinations)
|
||||||
|
for l in comb:
|
||||||
|
ret.append(random.choice(length_index_dict[l]))
|
||||||
|
|
||||||
|
with open(pkl_save_path, 'wb') as f:
|
||||||
|
pickle.dump(ret, f, -1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
|
||||||
|
print('Begin to combine multiple lmdb datasets')
|
||||||
|
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||||
|
|
||||||
|
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
|
||||||
|
print('Begin to generate the index order of input data')
|
||||||
|
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||||
|
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
|
||||||
|
|
||||||
|
print('Done')
|
@ -0,0 +1,102 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""util file"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class AverageMeter():
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
class CTCLabelConverter():
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character):
|
||||||
|
# character (str): set of the possible characters.
|
||||||
|
dict_character = list(character)
|
||||||
|
|
||||||
|
self.dict = {}
|
||||||
|
for i, char in enumerate(dict_character):
|
||||||
|
self.dict[char] = i
|
||||||
|
|
||||||
|
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
|
||||||
|
self.dict['[blank]'] = len(dict_character)
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
"""convert text-label into text-index.
|
||||||
|
input:
|
||||||
|
text: text labels of each image. [batch_size]
|
||||||
|
|
||||||
|
output:
|
||||||
|
text: concatenated text index for CTCLoss.
|
||||||
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||||
|
length: length of each text. [batch_size]
|
||||||
|
"""
|
||||||
|
length = [len(s) for s in text]
|
||||||
|
text = ''.join(text)
|
||||||
|
text = [self.dict[char] for char in text]
|
||||||
|
|
||||||
|
return np.array(text), np.array(length)
|
||||||
|
|
||||||
|
def decode(self, text_index, length):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
texts = []
|
||||||
|
index = 0
|
||||||
|
for l in length:
|
||||||
|
t = text_index[index:index + l]
|
||||||
|
|
||||||
|
char_list = []
|
||||||
|
for i in range(l):
|
||||||
|
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||||
|
if t[i] != self.dict['[blank]'] and (
|
||||||
|
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||||
|
char_list.append(self.character[t[i]])
|
||||||
|
text = ''.join(char_list)
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
index += l
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def reverse_encode(self, text_index, length):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
texts = []
|
||||||
|
index = 0
|
||||||
|
for l in length:
|
||||||
|
t = text_index[index:index + l]
|
||||||
|
|
||||||
|
char_list = []
|
||||||
|
for i in range(l):
|
||||||
|
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
|
||||||
|
char_list.append(self.character[t[i]])
|
||||||
|
text = ''.join(char_list)
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
index += l
|
||||||
|
return texts
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnnctc train"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.dataset import GeneratorDataset
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from src.config import Config_CNNCTC
|
||||||
|
from src.callback import LossCallBack
|
||||||
|
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
|
||||||
|
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||||
|
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_creator(run_distribute):
|
||||||
|
if run_distribute:
|
||||||
|
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
|
||||||
|
else:
|
||||||
|
st_dataset = ST_MJ_Generator_batch_fixed_length()
|
||||||
|
|
||||||
|
ds = GeneratorDataset(st_dataset,
|
||||||
|
['img', 'label_indices', 'text', 'sequence_length'],
|
||||||
|
num_parallel_workers=8)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def train(args_opt, config):
|
||||||
|
if args_opt.run_distribute:
|
||||||
|
init()
|
||||||
|
context.set_auto_parallel_context(parallel_mode="data_parallel")
|
||||||
|
|
||||||
|
ds = dataset_creator(args_opt.run_distribute)
|
||||||
|
|
||||||
|
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||||
|
net.set_train(True)
|
||||||
|
|
||||||
|
if config.CKPT_PATH != '':
|
||||||
|
param_dict = load_checkpoint(config.CKPT_PATH)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
print('parameters loaded!')
|
||||||
|
else:
|
||||||
|
print('train from scratch...')
|
||||||
|
|
||||||
|
criterion = ctc_loss()
|
||||||
|
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
|
||||||
|
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
|
||||||
|
|
||||||
|
net = WithLossCell(net, criterion)
|
||||||
|
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
|
||||||
|
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
|
||||||
|
|
||||||
|
callback = LossCallBack()
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
|
||||||
|
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
|
||||||
|
|
||||||
|
if args_opt.device_id == 0:
|
||||||
|
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
|
||||||
|
else:
|
||||||
|
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='CNNCTC arg')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
|
||||||
|
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||||
|
help="Run distribute, default is false.")
|
||||||
|
args_cfg = parser.parse_args()
|
||||||
|
|
||||||
|
cfg = Config_CNNCTC()
|
||||||
|
if args_cfg.ckpt_path != "":
|
||||||
|
cfg.CKPT_PATH = args_cfg.ckpt_path
|
||||||
|
train(args_cfg, cfg)
|
Loading…
Reference in new issue