Add modelzoo CNNCTC Network.

pull/6383/head
linqingke 5 years ago
parent aa9fb70f3c
commit ca90924fa4

File diff suppressed because it is too large Load Diff

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnnctc eval"""
import argparse
import time
import numpy as np
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.dataset import GeneratorDataset
from src.util import CTCLabelConverter, AverageMeter
from src.config import Config_CNNCTC
from src.dataset import IIIT_Generator_batch
from src.cnn_ctc import CNNCTC_Model
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
save_graphs_path=".", enable_auto_mixed_precision=False)
def test_dataset_creator():
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
return ds
def test(config):
ds = test_dataset_creator()
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
ckpt_path = config.CKPT_PATH
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
print('parameters loaded! from: ', ckpt_path)
converter = CTCLabelConverter(config.CHARACTER)
model_run_time = AverageMeter()
npu_to_cpu_time = AverageMeter()
postprocess_time = AverageMeter()
count = 0
correct_count = 0
for data in ds.create_tuple_iterator():
img, _, text, _, length = data
img_tensor = Tensor(img, mstype.float32)
model_run_begin = time.time()
model_predict = net(img_tensor)
model_run_end = time.time()
model_run_time.update(model_run_end - model_run_begin)
npu_to_cpu_begin = time.time()
model_predict = np.squeeze(model_predict.asnumpy())
npu_to_cpu_end = time.time()
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
postprocess_begin = time.time()
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
preds_index = np.argmax(model_predict, 2)
preds_index = np.reshape(preds_index, [-1])
preds_str = converter.decode(preds_index, preds_size)
postprocess_end = time.time()
postprocess_time.update(postprocess_end - postprocess_begin)
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
if count == 0:
model_run_time.reset()
npu_to_cpu_time.reset()
postprocess_time.reset()
else:
print('---------model run time--------', model_run_time.avg)
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
print('---------postprocess run time--------', postprocess_time.avg)
print("Prediction samples: \n", preds_str[:5])
print("Ground truth: \n", label_str[:5])
for pred, label in zip(preds_str, label_str):
if pred == label:
correct_count += 1
count += 1
print('accuracy: ', correct_count / count)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="FasterRcnn training")
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
args_opt = parser.parse_args()
cfg = Config_CNNCTC()
if args_opt.ckpt_path != "":
cfg.CKPT_PATH = args_opt.ckpt_path
test(cfg)

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
echo ${current_exec_path}
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
PATH2=$(get_real_path $2)
echo $PATH2
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
export RANK_TABLE_FILE=$PATH1
export RANK_SIZE=8
ulimit -u unlimited
for((i=0;i<$RANK_SIZE;i++));
do
rm ./train_parallel_$i/ -rf
mkdir ./train_parallel_$i
cp ./*.py ./train_parallel_$i
cp ./scripts/*.sh ./train_parallel_$i
cp -r ./src ./train_parallel_$i
cd ./train_parallel_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
if [ -f $PATH2 ]
then
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
else
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
fi
cd .. || exit
done

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]
then
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: TRAINED_CKPT=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ./*.py ./eval
cp ./scripts/*.sh ./eval
cp -r ./src ./eval
cd ./eval || exit
echo "start infering for device $DEVICE_ID"
env > env.log
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
cd .. || exit

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
ulimit -u unlimited
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./*.py ./train
cp ./scripts/*.sh ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ -f $PATH1 ]
then
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
else
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
fi
cd .. || exit

@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""src init file"""

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss callback"""
import time
from mindspore.train.callback import Callback
from .util import AverageMeter
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.loss_avg = AverageMeter()
self.timer = AverageMeter()
self.start_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if cur_step_in_epoch % 2000 == 1:
self.loss_avg = AverageMeter()
self.timer = AverageMeter()
self.start_time = time.time()
else:
self.timer.update(time.time() - self.start_time)
self.start_time = time.time()
self.loss_avg.update(loss)
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_file = open("./loss.log", "a+")
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
cb_params.cur_epoch_num, cur_step_in_epoch,
self.loss_avg.avg, self.timer.avg))
loss_file.write("\n")
loss_file.close()
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
cb_params.cur_epoch_num, cur_step_in_epoch,
self.loss_avg.avg, self.timer.avg))

File diff suppressed because it is too large Load Diff

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network config setting, will be used in train.py and eval.py"""
class Config_CNNCTC():
# model config
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
NUM_CLASS = len(CHARACTER) + 1
HIDDEN_SIZE = 512
FINAL_FEATURE_WIDTH = 26
# dataset config
IMG_H = 32
IMG_W = 100
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
TRAIN_BATCH_SIZE = 192
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
TEST_BATCH_SIZE = 256
TEST_DATASET_SIZE = 2976
TRAIN_EPOCHS = 3
# training config
CKPT_PATH = ''
SAVE_PATH = './'
LR = 1e-4
LR_PARA = 5e-4
MOMENTUM = 0.8
LOSS_SCALE = 8096
SAVE_CKPT_PER_N_STEP = 2000
KEEP_CKPT_MAX_NUM = 5

File diff suppressed because it is too large Load Diff

@ -0,0 +1,88 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate ascend rank file"""
import os
import socket
import argparse
parser = argparse.ArgumentParser(description="ascend distribute rank.")
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
def main(rank_table_file):
nproc_per_node = 8
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
server_id = socket.gethostbyname(socket.gethostname())
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
hccn_table['board_id'] = '0x002f' # A+K
# hccn_table['board_id'] = '0x0000' # A+X
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
for instance_id in range(nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
hccn_table['status'] = 'completed'
import json
with open(rank_table_file, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
if __name__ == '__main__':
args_opt = parser.parse_args()
rank_table = args_opt.rank_file
if os.path.exists(rank_table):
print('Rank table file exists.')
else:
print('Generating rank table file.')
main(rank_table)
print('Rank table file generated')

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""preprocess dataset"""
import random
import pickle
import numpy as np
import lmdb
from tqdm import tqdm
def combine_lmdbs(lmdb_paths, lmdb_save_path):
max_len = int((26 + 1) // 2)
character = '0123456789abcdefghijklmnopqrstuvwxyz'
env_save = lmdb.open(
lmdb_save_path,
map_size=1099511627776)
cnt = 0
for lmdb_path in lmdb_paths:
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
# Filtering
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_len:
continue
illegal_sample = False
for char_item in label.lower():
if char_item not in character:
illegal_sample = True
break
if illegal_sample:
continue
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
with env_save.begin(write=True) as txn_save:
cnt += 1
label_key_save = 'label-%09d'.encode() % cnt
label_save = label.encode()
image_key_save = 'image-%09d'.encode() % cnt
image_save = imgbuf
txn_save.put(label_key_save, label_save)
txn_save.put(image_key_save, image_save)
nSamples = cnt
with env_save.begin(write=True) as txn_save:
txn_save.put('num-samples'.encode(), str(nSamples).encode())
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
label_length_dict = {}
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
label_length = len(label)
if label_length in label_length_dict:
label_length_dict[label_length] += 1
else:
label_length_dict[label_length] = 1
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
label_length_sum = 0
label_num = 0
lengths = []
p = []
for l, num in sorted_label_length:
label_length_sum += l * num
label_num += num
p.append(num)
lengths.append(l)
for i, _ in enumerate(p):
p[i] /= label_num
average_overall_length = int(label_length_sum / label_num * batch_size)
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
ret = []
cur_sum = 0
ret = np.random.choice(items, batch_size - 1, True, p)
cur_sum = sum(ret)
ret = list(ret)
if fix_length - cur_sum in items:
ret.append(fix_length - cur_sum)
else:
return None
return ret
result = []
while len(result) < num_of_combinations:
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
if ret is not None:
result.append(ret)
return result
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
length_index_dict = {}
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
label_length = len(label)
if label_length in length_index_dict:
length_index_dict[label_length].append(index)
else:
length_index_dict[label_length] = [index]
ret = []
for _ in range(num_of_iters):
comb = random.choice(combinations)
for l in comb:
ret.append(random.choice(length_index_dict[l]))
with open(pkl_save_path, 'wb') as f:
pickle.dump(ret, f, -1)
if __name__ == '__main__':
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
print('Begin to combine multiple lmdb datasets')
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
print('Begin to generate the index order of input data')
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
print('Done')

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""util file"""
import numpy as np
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CTCLabelConverter():
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
self.dict['[blank]'] = len(dict_character)
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return np.array(text), np.array(length)
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
if t[i] != self.dict['[blank]'] and (
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
def reverse_encode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts

@ -0,0 +1,100 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnnctc train"""
import argparse
import ast
import mindspore
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.dataset import GeneratorDataset
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.model import Model
from mindspore.communication.management import init
from mindspore.common import set_seed
from src.config import Config_CNNCTC
from src.callback import LossCallBack
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
save_graphs_path=".", enable_auto_mixed_precision=False)
def dataset_creator(run_distribute):
if run_distribute:
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
else:
st_dataset = ST_MJ_Generator_batch_fixed_length()
ds = GeneratorDataset(st_dataset,
['img', 'label_indices', 'text', 'sequence_length'],
num_parallel_workers=8)
return ds
def train(args_opt, config):
if args_opt.run_distribute:
init()
context.set_auto_parallel_context(parallel_mode="data_parallel")
ds = dataset_creator(args_opt.run_distribute)
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
net.set_train(True)
if config.CKPT_PATH != '':
param_dict = load_checkpoint(config.CKPT_PATH)
load_param_into_net(net, param_dict)
print('parameters loaded!')
else:
print('train from scratch...')
criterion = ctc_loss()
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
net = WithLossCell(net, criterion)
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
callback = LossCallBack()
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
if args_opt.device_id == 0:
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
else:
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CNNCTC arg')
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is false.")
args_cfg = parser.parse_args()
cfg = Config_CNNCTC()
if args_cfg.ckpt_path != "":
cfg.CKPT_PATH = args_cfg.ckpt_path
train(args_cfg, cfg)

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
# ===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""

Loading…
Cancel
Save