commit
665e8c58a7
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def init_args():
|
||||||
|
parser = argparse.ArgumentParser('')
|
||||||
|
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||||
|
help='path to original images')
|
||||||
|
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||||
|
help='Directory where character dictionaries for the dataset were stored')
|
||||||
|
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||||
|
help='Directory where ord map dictionaries for the dataset were stored')
|
||||||
|
|
||||||
|
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||||
|
help='Directory where ord map dictionaries for the dataset were stored')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def xml_to_dict(xml_file, save_file=False):
|
||||||
|
tree = ET.parse(xml_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
imgs_labels = []
|
||||||
|
|
||||||
|
for ch in root:
|
||||||
|
im_label = {}
|
||||||
|
for ch01 in ch:
|
||||||
|
if ch01.tag in 'taggedRectangles':
|
||||||
|
# multiple children
|
||||||
|
rect_list = []
|
||||||
|
for ch02 in ch01:
|
||||||
|
rect = {}
|
||||||
|
rect['location'] = ch02.attrib
|
||||||
|
rect['label'] = ch02[0].text
|
||||||
|
rect_list.append(rect)
|
||||||
|
im_label['rect'] = rect_list
|
||||||
|
else:
|
||||||
|
im_label[ch01.tag] = ch01.text
|
||||||
|
imgs_labels.append(im_label)
|
||||||
|
|
||||||
|
if save_file:
|
||||||
|
np.save("annotation_train.npy", imgs_labels)
|
||||||
|
|
||||||
|
return imgs_labels
|
||||||
|
|
||||||
|
|
||||||
|
def image_crop_save(image, location, output_dir):
|
||||||
|
"""
|
||||||
|
crop image with location (h,w,x,y)
|
||||||
|
save cropped image to output directory
|
||||||
|
"""
|
||||||
|
# avoid negative value of coordinates in annotation
|
||||||
|
start_x = np.maximum(location[2], 0)
|
||||||
|
end_x = start_x + location[1]
|
||||||
|
start_y = np.maximum(location[3], 0)
|
||||||
|
end_y = start_y + location[0]
|
||||||
|
print("image array shape :{}".format(image.shape))
|
||||||
|
print("crop region ", start_x, end_x, start_y, end_y)
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||||
|
else:
|
||||||
|
cropped = image[start_y:end_y, start_x:end_x]
|
||||||
|
im = Image.fromarray(np.uint8(cropped))
|
||||||
|
im.save(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def convert():
|
||||||
|
args = init_args()
|
||||||
|
if not os.path.exists(args.dataset_dir):
|
||||||
|
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||||
|
|
||||||
|
if not os.path.exists(args.xml_file):
|
||||||
|
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||||
|
num_images = len(ims_labels_dict)
|
||||||
|
annotation_list = []
|
||||||
|
print("Converting annotation, {} images in total ".format(num_images))
|
||||||
|
for i in range(num_images):
|
||||||
|
img_label = ims_labels_dict[i]
|
||||||
|
image_name = img_label['imageName']
|
||||||
|
rects = img_label['rect']
|
||||||
|
ext = image_name.split('.')[-1]
|
||||||
|
name = image_name[:-(len(ext)+1)]
|
||||||
|
|
||||||
|
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||||
|
im_array = np.asarray(Image.open(fullpath))
|
||||||
|
print("processing image: {}".format(image_name))
|
||||||
|
for j, rect in enumerate(rects):
|
||||||
|
location = rect['location']
|
||||||
|
h = int(float(location['height']))
|
||||||
|
w = int(float(location['width']))
|
||||||
|
x = int(float(location['x']))
|
||||||
|
y = int(float(location['y']))
|
||||||
|
label = rect['label']
|
||||||
|
loc = [h, w, x, y]
|
||||||
|
output_name = name.replace("/", "_") + "_" + str(j) + "_" + label + '.' + ext
|
||||||
|
output_name = output_name.replace(",", "")
|
||||||
|
output_file = os.path.join(args.output_dir, output_name)
|
||||||
|
|
||||||
|
image_crop_save(im_array, loc, output_file)
|
||||||
|
ann = output_name + "," + label + ','
|
||||||
|
annotation_list.append(ann)
|
||||||
|
|
||||||
|
ann_file = args.output_annotation
|
||||||
|
|
||||||
|
with open(ann_file, 'w') as f:
|
||||||
|
for line in annotation_list:
|
||||||
|
txt = line + '\n'
|
||||||
|
f.write(txt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert()
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from scipy import io
|
||||||
|
|
||||||
|
###############################################
|
||||||
|
# load testdata
|
||||||
|
# testdata.mat structure
|
||||||
|
# test[:][0] : image name
|
||||||
|
# test[:][1] : label
|
||||||
|
# test[:][2] : 50 lexicon
|
||||||
|
# test[:][3] : 1000 lexicon
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
def init_args():
|
||||||
|
parser = argparse.ArgumentParser('')
|
||||||
|
parser.add_argument('-m', '--mat_file', type=str, default='testdata.mat',
|
||||||
|
help='Directory where character dictionaries for the dataset were stored')
|
||||||
|
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||||
|
help='Directory where ord map dictionaries for the dataset were stored')
|
||||||
|
|
||||||
|
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||||
|
help='Directory where ord map dictionaries for the dataset were stored')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def mat_to_list(mat_file):
|
||||||
|
ann_ori = io.loadmat(mat_file)
|
||||||
|
testdata = ann_ori['testdata'][0]
|
||||||
|
|
||||||
|
ann_output = []
|
||||||
|
for elem in testdata:
|
||||||
|
img_name = elem[0]
|
||||||
|
label = elem[1]
|
||||||
|
ann = img_name+',' +label
|
||||||
|
ann_output.append(ann)
|
||||||
|
return ann_output
|
||||||
|
|
||||||
|
|
||||||
|
def convert():
|
||||||
|
args = init_args()
|
||||||
|
|
||||||
|
ann_list = mat_to_list(args.mat_file)
|
||||||
|
|
||||||
|
ann_file = args.output_annotation
|
||||||
|
with open(ann_file, 'w') as f:
|
||||||
|
for line in ann_list:
|
||||||
|
txt = line + '\n'
|
||||||
|
f.write(txt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert()
|
@ -0,0 +1,140 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def init_args():
|
||||||
|
parser = argparse.ArgumentParser('')
|
||||||
|
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||||
|
help='Directory containing test_features.tfrecords')
|
||||||
|
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||||
|
help='Directory where character dictionaries for the dataset were stored')
|
||||||
|
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||||
|
help='Directory where ord map dictionaries for the dataset were stored')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def xml_to_dict(xml_file, save_file=False):
|
||||||
|
tree = ET.parse(xml_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
imgs_labels = []
|
||||||
|
|
||||||
|
for ch in root:
|
||||||
|
im_label = {}
|
||||||
|
for ch01 in ch:
|
||||||
|
if ch01.tag in "address":
|
||||||
|
continue
|
||||||
|
elif ch01.tag in 'taggedRectangles':
|
||||||
|
# multiple children
|
||||||
|
rect_list = []
|
||||||
|
for ch02 in ch01:
|
||||||
|
rect = {}
|
||||||
|
rect['location'] = ch02.attrib
|
||||||
|
rect['label'] = ch02[0].text
|
||||||
|
rect_list.append(rect)
|
||||||
|
im_label['rect'] = rect_list
|
||||||
|
else:
|
||||||
|
im_label[ch01.tag] = ch01.text
|
||||||
|
imgs_labels.append(im_label)
|
||||||
|
|
||||||
|
if save_file:
|
||||||
|
np.save("annotation_train.npy", imgs_labels)
|
||||||
|
|
||||||
|
return imgs_labels
|
||||||
|
|
||||||
|
|
||||||
|
def image_crop_save(image, location, output_dir):
|
||||||
|
"""
|
||||||
|
crop image with location (h,w,x,y)
|
||||||
|
save cropped image to output directory
|
||||||
|
"""
|
||||||
|
start_x = location[2]
|
||||||
|
end_x = start_x + location[1]
|
||||||
|
start_y = location[3]
|
||||||
|
if start_y < 0:
|
||||||
|
start_y = 0
|
||||||
|
end_y = start_y + location[0]
|
||||||
|
print("image array shape :{}".format(image.shape))
|
||||||
|
print("crop region ", start_x, end_x, start_y, end_y)
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||||
|
else:
|
||||||
|
cropped = image[start_y:end_y, start_x:end_x]
|
||||||
|
im = Image.fromarray(np.uint8(cropped))
|
||||||
|
im.save(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def convert():
|
||||||
|
args = init_args()
|
||||||
|
if not os.path.exists(args.dataset_dir):
|
||||||
|
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||||
|
|
||||||
|
if not os.path.exists(args.xml_file):
|
||||||
|
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||||
|
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||||
|
num_images = len(ims_labels_dict)
|
||||||
|
lexicon_list = []
|
||||||
|
annotation_list = []
|
||||||
|
print("Converting annotation, {} images in total ".format(num_images))
|
||||||
|
for i in range(num_images):
|
||||||
|
img_label = ims_labels_dict[i]
|
||||||
|
image_name = img_label['imageName']
|
||||||
|
lex = img_label['lex']
|
||||||
|
rects = img_label['rect']
|
||||||
|
name, ext = image_name.split('.')
|
||||||
|
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||||
|
im_array = np.asarray(Image.open(fullpath))
|
||||||
|
lexicon_list.append(lex)
|
||||||
|
print("processing image: {}".format(image_name))
|
||||||
|
for j, rect in enumerate(rects):
|
||||||
|
rect = rects[j]
|
||||||
|
location = rect['location']
|
||||||
|
h = int(location['height'])
|
||||||
|
w = int(location['width'])
|
||||||
|
x = int(location['x'])
|
||||||
|
y = int(location['y'])
|
||||||
|
label = rect['label']
|
||||||
|
loc = [h, w, x, y]
|
||||||
|
output_name = name + "_" + str(j) + "_" + label + '.' + ext
|
||||||
|
output_file = os.path.join(args.output_dir, output_name)
|
||||||
|
image_crop_save(im_array, loc, output_file)
|
||||||
|
ann = output_name + "," + label + ',' + str(i)
|
||||||
|
annotation_list.append(ann)
|
||||||
|
|
||||||
|
lex_file = './lexicon_ann_train.txt'
|
||||||
|
ann_file = './annotation_train.txt'
|
||||||
|
with open(lex_file, 'w') as f:
|
||||||
|
for line in lexicon_list:
|
||||||
|
txt = line + '\n'
|
||||||
|
f.write(txt)
|
||||||
|
|
||||||
|
with open(ann_file, 'w') as f:
|
||||||
|
for line in annotation_list:
|
||||||
|
txt = line + '\n'
|
||||||
|
f.write(txt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert()
|
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Warpctc evaluation"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.loss import CTCLoss
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.crnn import CRNN
|
||||||
|
from src.metric import CRNNAccuracy
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="CRNN eval")
|
||||||
|
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||||
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||||
|
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
|
||||||
|
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.model == 'lowcase':
|
||||||
|
from src.config import config1 as config
|
||||||
|
else:
|
||||||
|
from src.config import config2 as config
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
|
if args_opt.platform == 'Ascend':
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config.batch_size = 1
|
||||||
|
max_text_length = config.max_text_length
|
||||||
|
input_size = config.input_size
|
||||||
|
# create dataset
|
||||||
|
dataset = create_dataset(name=args_opt.dataset,
|
||||||
|
dataset_path=args_opt.dataset_path,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
is_training=False,
|
||||||
|
config=config)
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||||
|
max_label_length=max_text_length,
|
||||||
|
batch_size=config.batch_size)
|
||||||
|
net = CRNN(config)
|
||||||
|
# load checkpoint
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
# define model
|
||||||
|
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||||
|
# start evaluation
|
||||||
|
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||||
|
print("result:", res, flush=True)
|
@ -0,0 +1 @@
|
|||||||
|
python-Levenshtein
|
@ -0,0 +1,62 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]; then
|
||||||
|
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path() {
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_NAME=$1
|
||||||
|
PATH1=$(get_real_path $2)
|
||||||
|
PATH2=$(get_real_path $3)
|
||||||
|
|
||||||
|
if [ ! -f $PATH1 ]; then
|
||||||
|
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d $PATH2 ]; then
|
||||||
|
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
|
||||||
|
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp ../*.py ./train_parallel$i
|
||||||
|
cp *.sh ./train_parallel$i
|
||||||
|
cp -r ../src ./train_parallel$i
|
||||||
|
cd ./train_parallel$i || exit
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
env >env.log
|
||||||
|
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
@ -0,0 +1,89 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 4 ]; then
|
||||||
|
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path() {
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_NAME=$1
|
||||||
|
PATH1=$(get_real_path $2)
|
||||||
|
PATH2=$(get_real_path $3)
|
||||||
|
PLATFORM=$4
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]; then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PATH2 ]; then
|
||||||
|
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
run_ascend() {
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ]; then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env >env.log
|
||||||
|
echo "start evaluation for device $DEVICE_ID"
|
||||||
|
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
}
|
||||||
|
|
||||||
|
run_gpu() {
|
||||||
|
if [ -d "eval" ]; then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env >env.log
|
||||||
|
python eval.py --dataset=$DATASET_NAME \
|
||||||
|
--dataset_path=$1 \
|
||||||
|
--checkpoint_path=$2 \
|
||||||
|
--platform=GPU \
|
||||||
|
--dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ "Ascend" == $PLATFORM ]; then
|
||||||
|
run_ascend $PATH1 $PATH2
|
||||||
|
elif [ "GPU" == $PLATFORM ]; then
|
||||||
|
run_gpu $PATH1 $PATH2
|
||||||
|
else
|
||||||
|
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||||
|
fi
|
||||||
|
|
@ -0,0 +1,73 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]; then
|
||||||
|
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path() {
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_NAME=$1
|
||||||
|
PATH1=$(get_real_path $2)
|
||||||
|
PLATFORM=$3
|
||||||
|
|
||||||
|
if [ ! -d $PATH1 ]; then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_ID=0
|
||||||
|
run_ascend() {
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env >env.log
|
||||||
|
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
}
|
||||||
|
|
||||||
|
run_gpu() {
|
||||||
|
env >env.log
|
||||||
|
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -d "train" ]; then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
WORKDIR=./train$(DEVICE_ID)
|
||||||
|
mkdir $WORKDIR
|
||||||
|
cp ../*.py $WORKDIR
|
||||||
|
cp -r ../src $WORKDIR
|
||||||
|
cd $WORKDIR || exit
|
||||||
|
|
||||||
|
if [ "Ascend" == $PLATFORM ]; then
|
||||||
|
run_ascend $PATH1
|
||||||
|
elif [ "GPU" == $PLATFORM ]; then
|
||||||
|
run_gpu $PATH1
|
||||||
|
else
|
||||||
|
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||||
|
fi
|
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Network parameters."""
|
||||||
|
from easydict import EasyDict
|
||||||
|
|
||||||
|
|
||||||
|
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
|
||||||
|
|
||||||
|
# use for low case number
|
||||||
|
config1 = EasyDict({
|
||||||
|
"max_text_length": 23,
|
||||||
|
"image_width": 100,
|
||||||
|
"image_height": 32,
|
||||||
|
"batch_size": 64,
|
||||||
|
"epoch_size": 10,
|
||||||
|
"hidden_size": 256,
|
||||||
|
"learning_rate": 0.02,
|
||||||
|
"momentum": 0.95,
|
||||||
|
"nesterov": True,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_steps": 1000,
|
||||||
|
"keep_checkpoint_max": 30,
|
||||||
|
"save_checkpoint_path": "./",
|
||||||
|
"class_num": 37,
|
||||||
|
"input_size": 512,
|
||||||
|
"num_step": 24,
|
||||||
|
"use_dropout": True,
|
||||||
|
"blank": 36
|
||||||
|
})
|
@ -0,0 +1,171 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Warpctc network definition."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.common.initializer import TruncatedNormal
|
||||||
|
|
||||||
|
def _bn(channel):
|
||||||
|
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
|
||||||
|
moving_var_init=1)
|
||||||
|
|
||||||
|
class Conv(nn.Cell):
|
||||||
|
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):
|
||||||
|
super(Conv, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||||
|
padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))
|
||||||
|
self.bn = _bn(out_channel)
|
||||||
|
self.Relu = nn.ReLU()
|
||||||
|
self.use_bn = use_bn
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv(x)
|
||||||
|
if self.use_bn:
|
||||||
|
out = self.bn(out)
|
||||||
|
out = self.Relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class VGG(nn.Cell):
|
||||||
|
"""VGG Network structure"""
|
||||||
|
def __init__(self, is_training=True):
|
||||||
|
super(VGG, self).__init__()
|
||||||
|
self.conv1 = Conv(3, 64, use_bn=True)
|
||||||
|
self.conv2 = Conv(64, 128, use_bn=True)
|
||||||
|
self.conv3 = Conv(128, 256, use_bn=True)
|
||||||
|
self.conv4 = Conv(256, 256, use_bn=True)
|
||||||
|
self.conv5 = Conv(256, 512, use_bn=True)
|
||||||
|
self.conv6 = Conv(512, 512, use_bn=True)
|
||||||
|
self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)
|
||||||
|
self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
|
||||||
|
self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')
|
||||||
|
self.bn1 = _bn(512)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.maxpool2d1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.maxpool2d1(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.conv4(x)
|
||||||
|
x = self.maxpool2d2(x)
|
||||||
|
x = self.conv5(x)
|
||||||
|
x = self.conv6(x)
|
||||||
|
x = self.maxpool2d2(x)
|
||||||
|
x = self.conv7(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CRNN(nn.Cell):
|
||||||
|
"""
|
||||||
|
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||||
|
text images.
|
||||||
|
batch_size(int): batch size of input data, default is 64
|
||||||
|
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(CRNN, self).__init__()
|
||||||
|
self.batch_size = config.batch_size
|
||||||
|
self.input_size = config.input_size
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_classes = config.class_num
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
k = (1 / self.hidden_size) ** 0.5
|
||||||
|
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||||
|
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
|
||||||
|
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
|
||||||
|
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
|
||||||
|
|
||||||
|
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||||
|
self.w1 = Parameter(w1.astype(np.float16), name="w1")
|
||||||
|
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||||
|
self.w2 = Parameter(w2.astype(np.float16), name="w2")
|
||||||
|
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||||
|
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw")
|
||||||
|
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||||
|
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw")
|
||||||
|
|
||||||
|
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1")
|
||||||
|
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2")
|
||||||
|
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw")
|
||||||
|
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw")
|
||||||
|
|
||||||
|
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
|
||||||
|
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||||
|
|
||||||
|
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
|
||||||
|
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
|
||||||
|
|
||||||
|
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
|
||||||
|
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
|
||||||
|
self.fc.to_float(mstype.float32)
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.concat = P.Concat()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.squeeze = P.Squeeze(axis=0)
|
||||||
|
self.vgg = VGG()
|
||||||
|
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||||
|
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||||
|
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||||
|
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||||
|
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
|
||||||
|
self.concat1 = P.Concat(axis=2)
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.rnn_dropout = nn.Dropout(0.9)
|
||||||
|
self.use_dropout = config.use_dropout
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.vgg(x)
|
||||||
|
x = self.cast(x, mstype.float16)
|
||||||
|
|
||||||
|
x = self.reshape(x, (self.batch_size, self.input_size, -1))
|
||||||
|
x = self.transpose(x, (2, 0, 1))
|
||||||
|
bw_x = self.reverse_seq1(x, self.seq_length)
|
||||||
|
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||||
|
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||||
|
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
|
||||||
|
y1_out = self.concat1((y1, y1_bw))
|
||||||
|
if self.use_dropout:
|
||||||
|
y1_out = self.rnn_dropout(y1_out)
|
||||||
|
|
||||||
|
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
|
||||||
|
bw_y = self.reverse_seq3(y1_out, self.seq_length)
|
||||||
|
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
|
||||||
|
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
|
||||||
|
y2_out = self.concat1((y2, y2_bw))
|
||||||
|
if self.use_dropout:
|
||||||
|
y2_out = self.dropout(y2_out)
|
||||||
|
|
||||||
|
output = ()
|
||||||
|
for i in range(F.shape(y2_out)[0]):
|
||||||
|
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
|
||||||
|
y2_after_fc = self.expand_dims(y2_after_fc, 0)
|
||||||
|
output += (y2_after_fc,)
|
||||||
|
output = self.concat(output)
|
||||||
|
return output
|
@ -0,0 +1,114 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Automatic differentiation with grad clip."""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||||
|
_get_parallel_mode)
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn.cell import Cell
|
||||||
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
|
||||||
|
compute_norm = C.MultitypeFuncGraph("compute_norm")
|
||||||
|
|
||||||
|
|
||||||
|
@compute_norm.register("Tensor")
|
||||||
|
def _compute_norm(grad):
|
||||||
|
norm = nn.Norm()
|
||||||
|
norm = norm(F.cast(grad, mstype.float32))
|
||||||
|
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
grad_div = C.MultitypeFuncGraph("grad_div")
|
||||||
|
|
||||||
|
|
||||||
|
@grad_div.register("Tensor", "Tensor")
|
||||||
|
def _grad_div(val, grad):
|
||||||
|
div = P.RealDiv()
|
||||||
|
mul = P.Mul()
|
||||||
|
scale = div(10.0, val)
|
||||||
|
ret = mul(grad, scale)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TrainOneStepCellWithGradClip(Cell):
|
||||||
|
"""
|
||||||
|
Network training package class.
|
||||||
|
|
||||||
|
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
|
||||||
|
Backward graph with grad clip will be created in the construct function to do parameter updating.
|
||||||
|
Different parallel modes are available to run the training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network.
|
||||||
|
optimizer (Cell): Optimizer for updating the weights.
|
||||||
|
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- data (Tensor) - Tensor of shape :(N, ...).
|
||||||
|
- label (Tensor) - Tensor of shape :(N, ...).
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a scalar Tensor with shape :math:`()`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, optimizer, sens=1.0):
|
||||||
|
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.network.set_grad()
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
self.weights = optimizer.parameters
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||||
|
self.sens = sens
|
||||||
|
self.reducer_flag = False
|
||||||
|
self.grad_reducer = None
|
||||||
|
self.hyper_map = C.HyperMap()
|
||||||
|
self.greater = P.Greater()
|
||||||
|
self.select = P.Select()
|
||||||
|
self.norm = nn.Norm(keep_dims=True)
|
||||||
|
self.dtype = P.DType()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.concat = P.Concat(axis=0)
|
||||||
|
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||||
|
parallel_mode = _get_parallel_mode()
|
||||||
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||||
|
self.reducer_flag = True
|
||||||
|
if self.reducer_flag:
|
||||||
|
mean = _get_gradients_mean()
|
||||||
|
degree = _get_device_num()
|
||||||
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||||
|
|
||||||
|
def construct(self, data, label):
|
||||||
|
weights = self.weights
|
||||||
|
loss = self.network(data, label)
|
||||||
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||||
|
grads = self.grad(self.network, weights)(data, label, sens)
|
||||||
|
norm = self.hyper_map(F.partial(compute_norm), grads)
|
||||||
|
norm = self.concat(norm)
|
||||||
|
norm = self.norm(norm)
|
||||||
|
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
|
||||||
|
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
|
||||||
|
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
|
||||||
|
if self.reducer_flag:
|
||||||
|
# apply grad reducer on grads
|
||||||
|
grads = self.grad_reducer(grads)
|
||||||
|
return F.depend(loss, self.optimizer(grads))
|
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset preprocessing."""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C
|
||||||
|
import mindspore.dataset.vision.c_transforms as vc
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from src.config import config1, label_dict
|
||||||
|
from src.ic03_dataset import IC03Dataset
|
||||||
|
from src.ic13_dataset import IC13Dataset
|
||||||
|
from src.iiit5k_dataset import IIIT5KDataset
|
||||||
|
from src.svt_dataset import SVTDataset
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
|
class CaptchaDataset:
|
||||||
|
"""
|
||||||
|
create train or evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_root_dir(str): root path of images
|
||||||
|
max_text_length(int): max number of digits in images.
|
||||||
|
device_target(str): platform of training, support Ascend and GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_root_dir, is_training=True, config=config1):
|
||||||
|
if not os.path.exists(img_root_dir):
|
||||||
|
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||||
|
self.img_root_dir = img_root_dir
|
||||||
|
if is_training:
|
||||||
|
self.imgslist = os.path.join(self.img_root_dir, 'annotation_train.txt')
|
||||||
|
else:
|
||||||
|
self.imgslist = os.path.join(self.img_root_dir, 'annotation_test.txt')
|
||||||
|
self.lexicon_file = os.path.join(self.img_root_dir, 'lexicon.txt')
|
||||||
|
with open(self.lexicon_file, 'r') as f:
|
||||||
|
self.lexicons = [line.strip('\n') for line in f]
|
||||||
|
f.close()
|
||||||
|
self.img_names = {}
|
||||||
|
self.img_list = []
|
||||||
|
with open(self.imgslist, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
img_name, label_index = line.strip('\n').split(" ")
|
||||||
|
self.img_list.append(img_name)
|
||||||
|
self.img_names[img_name] = self.lexicons[int(label_index)]
|
||||||
|
f.close()
|
||||||
|
self.max_text_length = config.max_text_length
|
||||||
|
self.blank = config.blank
|
||||||
|
self.class_num = config.class_num
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_name = self.img_list[item]
|
||||||
|
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||||
|
im = im.convert("RGB")
|
||||||
|
r, g, b = im.split()
|
||||||
|
im = Image.merge("RGB", (b, g, r))
|
||||||
|
image = np.array(im)
|
||||||
|
label_str = self.img_names[img_name]
|
||||||
|
label = []
|
||||||
|
for c in label_str:
|
||||||
|
if c in label_dict:
|
||||||
|
label.append(label_dict.index(c))
|
||||||
|
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||||
|
label = np.array(label)
|
||||||
|
return image, label
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(name, dataset_path, batch_size=1, num_shards=1, shard_id=0, is_training=True, config=config1):
|
||||||
|
"""
|
||||||
|
create train or evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(int): dataset path
|
||||||
|
batch_size(int): batch size of generated dataset, default is 1
|
||||||
|
num_shards(int): number of devices
|
||||||
|
shard_id(int): rank id
|
||||||
|
device_target(str): platform of training, support Ascend and GPU
|
||||||
|
"""
|
||||||
|
if name == 'synth':
|
||||||
|
dataset = CaptchaDataset(dataset_path, is_training, config)
|
||||||
|
elif name == 'ic03':
|
||||||
|
dataset = IC03Dataset(dataset_path, "annotation.txt", config, True, 3)
|
||||||
|
elif name == 'ic13':
|
||||||
|
dataset = IC13Dataset(dataset_path, "Challenge2_Test_Task3_GT.txt", config)
|
||||||
|
elif name == 'svt':
|
||||||
|
dataset = SVTDataset(dataset_path, config)
|
||||||
|
elif name == 'iiit5k':
|
||||||
|
dataset = IIIT5KDataset(dataset_path, "annotation.txt", config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported dataset name: {name}")
|
||||||
|
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||||
|
image_trans = [
|
||||||
|
vc.Resize((config.image_height, config.image_width)),
|
||||||
|
vc.Normalize([127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||||
|
vc.HWC2CHW()
|
||||||
|
]
|
||||||
|
label_trans = [
|
||||||
|
C.TypeCast(mstype.int32)
|
||||||
|
]
|
||||||
|
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
||||||
|
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
||||||
|
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
return ds
|
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Dataset adaptor for SVT"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from src.config import config1, label_dict
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
|
class IC03Dataset:
|
||||||
|
"""
|
||||||
|
create train or evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_root_dir(str): root path of images
|
||||||
|
max_text_length(int): max number of digits in images.
|
||||||
|
device_target(str): platform of training, support Ascend and GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1, filter_by_dict=True, filter_length=3):
|
||||||
|
if not os.path.exists(img_root_dir):
|
||||||
|
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||||
|
self.img_root_dir = img_root_dir
|
||||||
|
anno_file = os.path.join(img_root_dir, anno_file)
|
||||||
|
|
||||||
|
self.img_names = {}
|
||||||
|
self.img_list = []
|
||||||
|
with open(anno_file, 'r') as f:
|
||||||
|
for lines in f:
|
||||||
|
img_name = lines.split(",")[0]
|
||||||
|
label = lines.split(",")[1].lower()
|
||||||
|
if len(label) < filter_length:
|
||||||
|
continue
|
||||||
|
if filter_by_dict:
|
||||||
|
flag = True
|
||||||
|
for c in label:
|
||||||
|
if c not in label_dict:
|
||||||
|
flag = False
|
||||||
|
break
|
||||||
|
if not flag:
|
||||||
|
continue
|
||||||
|
self.img_names[img_name] = label
|
||||||
|
self.img_list.append(img_name)
|
||||||
|
|
||||||
|
self.max_text_length = config.max_text_length
|
||||||
|
self.blank = config.blank
|
||||||
|
self.class_num = config.class_num
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_name = self.img_list[item]
|
||||||
|
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||||
|
im = im.convert("RGB")
|
||||||
|
r, g, b = im.split()
|
||||||
|
im = Image.merge("RGB", (b, g, r))
|
||||||
|
image = np.array(im)
|
||||||
|
label_str = self.img_names[img_name]
|
||||||
|
label = []
|
||||||
|
for c in label_str:
|
||||||
|
if c in label_dict:
|
||||||
|
label.append(label_dict.index(c))
|
||||||
|
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||||
|
label = np.array(label)
|
||||||
|
return image, label
|
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Dataset adaptor for SVT"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from src.config import config1, label_dict
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
|
class IC13Dataset:
|
||||||
|
"""
|
||||||
|
create evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_root_dir(str): root path of images
|
||||||
|
max_text_length(int): max number of digits in images
|
||||||
|
device_target(str): platform of training, support Ascend and GPU
|
||||||
|
"""
|
||||||
|
def __init__(self, img_root_dir, label_file="", config=config1, filter_by_dict=True, filter_length=3):
|
||||||
|
if not os.path.exists(img_root_dir):
|
||||||
|
raise RuntimeError("the input image dir {} is invalid".format(img_root_dir))
|
||||||
|
self.img_root_dir = img_root_dir
|
||||||
|
self.label_file = os.path.join(img_root_dir, label_file)
|
||||||
|
self.img_names = {}
|
||||||
|
self.img_list = []
|
||||||
|
self.config = config
|
||||||
|
with open(self.label_file, 'r') as f:
|
||||||
|
for lines in f:
|
||||||
|
img_name = lines.split(",")[0]
|
||||||
|
label = lines.split("\"")[1].lower()
|
||||||
|
if len(label) < filter_length:
|
||||||
|
continue
|
||||||
|
if filter_by_dict:
|
||||||
|
flag = True
|
||||||
|
for c in label:
|
||||||
|
if c not in label_dict:
|
||||||
|
flag = False
|
||||||
|
break
|
||||||
|
if not flag:
|
||||||
|
continue
|
||||||
|
self.img_names[img_name] = label
|
||||||
|
self.img_list.append(img_name)
|
||||||
|
f.close()
|
||||||
|
self.max_text_length = config.max_text_length
|
||||||
|
self.blank = config.blank
|
||||||
|
self.class_num = config.class_num
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_name = self.img_list[item]
|
||||||
|
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||||
|
im = im.convert("RGB")
|
||||||
|
r, g, b = im.split()
|
||||||
|
im = Image.merge("RGB", (b, g, r))
|
||||||
|
image = np.array(im)
|
||||||
|
label_str = self.img_names[img_name]
|
||||||
|
label = []
|
||||||
|
for c in label_str:
|
||||||
|
if c in label_dict:
|
||||||
|
label.append(label_dict.index(c))
|
||||||
|
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||||
|
label = np.array(label)
|
||||||
|
return image, label
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Dataset adaptor for SVT"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from src.config import config1, label_dict
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
class IIIT5KDataset:
|
||||||
|
"""
|
||||||
|
create train or evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_root_dir(str): root path of images
|
||||||
|
max_text_length(int): max number of digits in images.
|
||||||
|
device_target(str): platform of training, support Ascend and GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1):
|
||||||
|
if not os.path.exists(img_root_dir):
|
||||||
|
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||||
|
self.img_root_dir = img_root_dir
|
||||||
|
anno_file = os.path.join(img_root_dir, anno_file)
|
||||||
|
|
||||||
|
self.img_names = {}
|
||||||
|
self.img_list = []
|
||||||
|
with open(anno_file, 'r') as f:
|
||||||
|
for lines in f:
|
||||||
|
img_name = lines.split(",")[0]
|
||||||
|
label = lines.split(",")[1].lower()
|
||||||
|
self.img_names[img_name] = label
|
||||||
|
self.img_list.append(img_name)
|
||||||
|
|
||||||
|
self.max_text_length = config.max_text_length
|
||||||
|
self.blank = config.blank
|
||||||
|
self.class_num = config.class_num
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_name = self.img_list[item]
|
||||||
|
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||||
|
im = im.convert("RGB")
|
||||||
|
r, g, b = im.split()
|
||||||
|
im = Image.merge("RGB", (b, g, r))
|
||||||
|
image = np.array(im)
|
||||||
|
label_str = self.img_names[img_name]
|
||||||
|
label = []
|
||||||
|
for c in label_str:
|
||||||
|
if c in label_dict:
|
||||||
|
label.append(label_dict.index(c))
|
||||||
|
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||||
|
label = np.array(label)
|
||||||
|
return image, label
|
@ -0,0 +1,49 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""CTC Loss."""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class CTCLoss(_Loss):
|
||||||
|
"""
|
||||||
|
CTCLoss definition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_sequence_length(int): max number of sequence length. For text images, the value is equal to image
|
||||||
|
width
|
||||||
|
max_label_length(int): max number of label length for each input.
|
||||||
|
batch_size(int): batch size of input logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_sequence_length, max_label_length, batch_size):
|
||||||
|
super(CTCLoss, self).__init__()
|
||||||
|
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
|
||||||
|
name="sequence_length")
|
||||||
|
labels_indices = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(max_label_length):
|
||||||
|
labels_indices.append([i, j])
|
||||||
|
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
|
||||||
|
|
||||||
|
def construct(self, logit, label):
|
||||||
|
labels_values = self.reshape(label, (-1,))
|
||||||
|
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
|
||||||
|
return loss
|
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Metric for accuracy evaluation."""
|
||||||
|
from mindspore import nn
|
||||||
|
import Levenshtein
|
||||||
|
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
|
||||||
|
class CRNNAccuracy(nn.Metric):
|
||||||
|
"""
|
||||||
|
Define accuracy metric for warpctc network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(CRNNAccuracy).__init__()
|
||||||
|
self.config = config
|
||||||
|
self._correct_num = 0
|
||||||
|
self._total_num = 0
|
||||||
|
self.blank = config.blank
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._correct_num = 0
|
||||||
|
self._total_num = 0
|
||||||
|
|
||||||
|
def update(self, *inputs):
|
||||||
|
if len(inputs) != 2:
|
||||||
|
raise ValueError('CRNNAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||||
|
y_pred = self._convert_data(inputs[0])
|
||||||
|
y = self._convert_data(inputs[1])
|
||||||
|
str_pred = self._ctc_greedy_decoder(y_pred)
|
||||||
|
str_label = self._convert_labels(y)
|
||||||
|
|
||||||
|
for pred, label in zip(str_pred, str_label):
|
||||||
|
print(pred, " :: ", label)
|
||||||
|
edit_distance = Levenshtein.distance(pred, label)
|
||||||
|
self._total_num += 1
|
||||||
|
if edit_distance == 0:
|
||||||
|
self._correct_num += 1
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
if self._total_num == 0:
|
||||||
|
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
|
||||||
|
print('correct num: ', self._correct_num, ', total num: ', self._total_num)
|
||||||
|
sequence_accurancy = self._correct_num / self._total_num
|
||||||
|
return sequence_accurancy
|
||||||
|
|
||||||
|
def _arr2char(self, inputs):
|
||||||
|
string = ""
|
||||||
|
for i in inputs:
|
||||||
|
if i < self.blank:
|
||||||
|
string += label_dict[i]
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _convert_labels(self, inputs):
|
||||||
|
str_list = []
|
||||||
|
for label in inputs:
|
||||||
|
str_temp = self._arr2char(label)
|
||||||
|
str_list.append(str_temp)
|
||||||
|
return str_list
|
||||||
|
|
||||||
|
def _ctc_greedy_decoder(self, y_pred):
|
||||||
|
"""
|
||||||
|
parse predict result to labels
|
||||||
|
"""
|
||||||
|
indices = []
|
||||||
|
seq_len, batch_size, _ = y_pred.shape
|
||||||
|
indices = y_pred.argmax(axis=2)
|
||||||
|
lens = [seq_len] * batch_size
|
||||||
|
pred_labels = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = indices[:, i]
|
||||||
|
last_idx = self.blank
|
||||||
|
pred_label = []
|
||||||
|
for j in range(lens[i]):
|
||||||
|
cur_idx = idx[j]
|
||||||
|
if cur_idx not in [last_idx, self.blank]:
|
||||||
|
pred_label.append(cur_idx)
|
||||||
|
last_idx = cur_idx
|
||||||
|
pred_labels.append(pred_label)
|
||||||
|
str_results = []
|
||||||
|
for i in pred_labels:
|
||||||
|
str_results.append(self._arr2char(i))
|
||||||
|
return str_results
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Dataset adaptor for SVT"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from src.config import config1, label_dict
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
class SVTDataset:
|
||||||
|
"""
|
||||||
|
create train or evaluation dataset for crnn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_root_dir(str): root path of images
|
||||||
|
max_text_length(int): max number of digits in images.
|
||||||
|
device_target(str): platform of training, support Ascend and GPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_root_dir, config=config1):
|
||||||
|
if not os.path.exists(img_root_dir):
|
||||||
|
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||||
|
self.img_root_dir = img_root_dir
|
||||||
|
file_list = os.listdir(img_root_dir)
|
||||||
|
self.img_names = {}
|
||||||
|
self.img_list = []
|
||||||
|
for f in file_list:
|
||||||
|
label = f.split(".jpg")[0]
|
||||||
|
label = label.split("_")[-1].lower()
|
||||||
|
self.img_names[f] = label
|
||||||
|
self.img_list.append(f)
|
||||||
|
|
||||||
|
self.max_text_length = config.max_text_length
|
||||||
|
self.blank = config.blank
|
||||||
|
self.class_num = config.class_num
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_name = self.img_list[item]
|
||||||
|
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||||
|
im = im.convert("RGB")
|
||||||
|
r, g, b = im.split()
|
||||||
|
im = Image.merge("RGB", (b, g, r))
|
||||||
|
image = np.array(im)
|
||||||
|
label_str = self.img_names[img_name]
|
||||||
|
label = []
|
||||||
|
for c in label_str:
|
||||||
|
if c in label_dict:
|
||||||
|
label.append(label_dict.index(c))
|
||||||
|
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||||
|
label = np.array(label)
|
||||||
|
return image, label
|
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""crnn training"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.nn.wrap import WithLossCell
|
||||||
|
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||||
|
from mindspore.communication.management import init, get_group_size, get_rank
|
||||||
|
|
||||||
|
from src.loss import CTCLoss
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.crnn import CRNN
|
||||||
|
from src.crnn_for_train import TrainOneStepCellWithGradClip
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="crnn training")
|
||||||
|
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||||
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||||
|
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||||
|
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||||
|
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||||
|
parser.set_defaults(run_distribute=False)
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.model == 'lowercase':
|
||||||
|
from src.config import config1 as config
|
||||||
|
else:
|
||||||
|
from src.config import config2 as config
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
|
if args_opt.platform == 'Ascend':
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
lr_scale = 1
|
||||||
|
if args_opt.run_distribute:
|
||||||
|
if args_opt.platform == 'Ascend':
|
||||||
|
init()
|
||||||
|
lr_scale = 1
|
||||||
|
device_num = int(os.environ.get("RANK_SIZE"))
|
||||||
|
rank = int(os.environ.get("RANK_ID"))
|
||||||
|
else:
|
||||||
|
init()
|
||||||
|
lr_scale = 1
|
||||||
|
device_num = get_group_size()
|
||||||
|
rank = get_rank()
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
|
gradients_mean=True)
|
||||||
|
else:
|
||||||
|
device_num = 1
|
||||||
|
rank = 0
|
||||||
|
|
||||||
|
max_text_length = config.max_text_length
|
||||||
|
# create dataset
|
||||||
|
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
|
||||||
|
num_shards=device_num, shard_id=rank, config=config)
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
# define lr
|
||||||
|
lr_init = config.learning_rate
|
||||||
|
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
|
||||||
|
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||||
|
max_label_length=max_text_length,
|
||||||
|
batch_size=config.batch_size)
|
||||||
|
net = CRNN(config)
|
||||||
|
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||||
|
|
||||||
|
net = WithLossCell(net, loss)
|
||||||
|
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||||
|
# define model
|
||||||
|
model = Model(net)
|
||||||
|
# define callbacks
|
||||||
|
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||||
|
if config.save_checkpoint:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||||
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||||
|
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||||
|
callbacks.append(ckpt_cb)
|
||||||
|
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
Loading…
Reference in new issue