commit
665e8c58a7
@ -0,0 +1,134 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='path to original images')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
|
||||
def image_crop_save(image, location, output_dir):
|
||||
"""
|
||||
crop image with location (h,w,x,y)
|
||||
save cropped image to output directory
|
||||
"""
|
||||
# avoid negative value of coordinates in annotation
|
||||
start_x = np.maximum(location[2], 0)
|
||||
end_x = start_x + location[1]
|
||||
start_y = np.maximum(location[3], 0)
|
||||
end_y = start_y + location[0]
|
||||
print("image array shape :{}".format(image.shape))
|
||||
print("crop region ", start_x, end_x, start_y, end_y)
|
||||
if len(image.shape) == 3:
|
||||
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||
else:
|
||||
cropped = image[start_y:end_y, start_x:end_x]
|
||||
im = Image.fromarray(np.uint8(cropped))
|
||||
im.save(output_dir)
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
annotation_list = []
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
rects = img_label['rect']
|
||||
ext = image_name.split('.')[-1]
|
||||
name = image_name[:-(len(ext)+1)]
|
||||
|
||||
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||
im_array = np.asarray(Image.open(fullpath))
|
||||
print("processing image: {}".format(image_name))
|
||||
for j, rect in enumerate(rects):
|
||||
location = rect['location']
|
||||
h = int(float(location['height']))
|
||||
w = int(float(location['width']))
|
||||
x = int(float(location['x']))
|
||||
y = int(float(location['y']))
|
||||
label = rect['label']
|
||||
loc = [h, w, x, y]
|
||||
output_name = name.replace("/", "_") + "_" + str(j) + "_" + label + '.' + ext
|
||||
output_name = output_name.replace(",", "")
|
||||
output_file = os.path.join(args.output_dir, output_name)
|
||||
|
||||
image_crop_save(im_array, loc, output_file)
|
||||
ann = output_name + "," + label + ','
|
||||
annotation_list.append(ann)
|
||||
|
||||
ann_file = args.output_annotation
|
||||
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in annotation_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
from scipy import io
|
||||
|
||||
###############################################
|
||||
# load testdata
|
||||
# testdata.mat structure
|
||||
# test[:][0] : image name
|
||||
# test[:][1] : label
|
||||
# test[:][2] : 50 lexicon
|
||||
# test[:][3] : 1000 lexicon
|
||||
##############################################
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-m', '--mat_file', type=str, default='testdata.mat',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def mat_to_list(mat_file):
|
||||
ann_ori = io.loadmat(mat_file)
|
||||
testdata = ann_ori['testdata'][0]
|
||||
|
||||
ann_output = []
|
||||
for elem in testdata:
|
||||
img_name = elem[0]
|
||||
label = elem[1]
|
||||
ann = img_name+',' +label
|
||||
ann_output.append(ann)
|
||||
return ann_output
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
|
||||
ann_list = mat_to_list(args.mat_file)
|
||||
|
||||
ann_file = args.output_annotation
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in ann_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
@ -0,0 +1,140 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='Directory containing test_features.tfrecords')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in "address":
|
||||
continue
|
||||
elif ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
|
||||
def image_crop_save(image, location, output_dir):
|
||||
"""
|
||||
crop image with location (h,w,x,y)
|
||||
save cropped image to output directory
|
||||
"""
|
||||
start_x = location[2]
|
||||
end_x = start_x + location[1]
|
||||
start_y = location[3]
|
||||
if start_y < 0:
|
||||
start_y = 0
|
||||
end_y = start_y + location[0]
|
||||
print("image array shape :{}".format(image.shape))
|
||||
print("crop region ", start_x, end_x, start_y, end_y)
|
||||
if len(image.shape) == 3:
|
||||
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||
else:
|
||||
cropped = image[start_y:end_y, start_x:end_x]
|
||||
im = Image.fromarray(np.uint8(cropped))
|
||||
im.save(output_dir)
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
lexicon_list = []
|
||||
annotation_list = []
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
lex = img_label['lex']
|
||||
rects = img_label['rect']
|
||||
name, ext = image_name.split('.')
|
||||
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||
im_array = np.asarray(Image.open(fullpath))
|
||||
lexicon_list.append(lex)
|
||||
print("processing image: {}".format(image_name))
|
||||
for j, rect in enumerate(rects):
|
||||
rect = rects[j]
|
||||
location = rect['location']
|
||||
h = int(location['height'])
|
||||
w = int(location['width'])
|
||||
x = int(location['x'])
|
||||
y = int(location['y'])
|
||||
label = rect['label']
|
||||
loc = [h, w, x, y]
|
||||
output_name = name + "_" + str(j) + "_" + label + '.' + ext
|
||||
output_file = os.path.join(args.output_dir, output_name)
|
||||
image_crop_save(im_array, loc, output_file)
|
||||
ann = output_name + "," + label + ',' + str(i)
|
||||
annotation_list.append(ann)
|
||||
|
||||
lex_file = './lexicon_ann_train.txt'
|
||||
ann_file = './annotation_train.txt'
|
||||
with open(lex_file, 'w') as f:
|
||||
for line in lexicon_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in annotation_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
@ -0,0 +1,72 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import CRNN
|
||||
from src.metric import CRNNAccuracy
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CRNN eval")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.model == 'lowcase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.batch_size = 1
|
||||
max_text_length = config.max_text_length
|
||||
input_size = config.input_size
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset,
|
||||
dataset_path=args_opt.dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = CRNN(config)
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||
print("result:", res, flush=True)
|
@ -0,0 +1 @@
|
||||
python-Levenshtein
|
@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,89 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
PLATFORM=$4
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --dataset=$DATASET_NAME \
|
||||
--dataset_path=$1 \
|
||||
--checkpoint_path=$2 \
|
||||
--platform=GPU \
|
||||
--dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1 $PATH2
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1 $PATH2
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
||||
|
@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PLATFORM=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=0
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train$(DEVICE_ID)
|
||||
mkdir $WORKDIR
|
||||
cp ../*.py $WORKDIR
|
||||
cp -r ../src $WORKDIR
|
||||
cd $WORKDIR || exit
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
@ -0,0 +1,42 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
|
||||
|
||||
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
|
||||
# use for low case number
|
||||
config1 = EasyDict({
|
||||
"max_text_length": 23,
|
||||
"image_width": 100,
|
||||
"image_height": 32,
|
||||
"batch_size": 64,
|
||||
"epoch_size": 10,
|
||||
"hidden_size": 256,
|
||||
"learning_rate": 0.02,
|
||||
"momentum": 0.95,
|
||||
"nesterov": True,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 1000,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./",
|
||||
"class_num": 37,
|
||||
"input_size": 512,
|
||||
"num_step": 24,
|
||||
"use_dropout": True,
|
||||
"blank": 36
|
||||
})
|
@ -0,0 +1,171 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
|
||||
moving_var_init=1)
|
||||
|
||||
class Conv(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||
padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))
|
||||
self.bn = _bn(out_channel)
|
||||
self.Relu = nn.ReLU()
|
||||
self.use_bn = use_bn
|
||||
def construct(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_bn:
|
||||
out = self.bn(out)
|
||||
out = self.Relu(out)
|
||||
return out
|
||||
|
||||
class VGG(nn.Cell):
|
||||
"""VGG Network structure"""
|
||||
def __init__(self, is_training=True):
|
||||
super(VGG, self).__init__()
|
||||
self.conv1 = Conv(3, 64, use_bn=True)
|
||||
self.conv2 = Conv(64, 128, use_bn=True)
|
||||
self.conv3 = Conv(128, 256, use_bn=True)
|
||||
self.conv4 = Conv(256, 256, use_bn=True)
|
||||
self.conv5 = Conv(256, 512, use_bn=True)
|
||||
self.conv6 = Conv(512, 512, use_bn=True)
|
||||
self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)
|
||||
self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
|
||||
self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')
|
||||
self.bn1 = _bn(512)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool2d1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.maxpool2d1(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.maxpool2d2(x)
|
||||
x = self.conv5(x)
|
||||
x = self.conv6(x)
|
||||
x = self.maxpool2d2(x)
|
||||
x = self.conv7(x)
|
||||
return x
|
||||
|
||||
|
||||
class CRNN(nn.Cell):
|
||||
"""
|
||||
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
text images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CRNN, self).__init__()
|
||||
self.batch_size = config.batch_size
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_classes = config.class_num
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1 = Parameter(w1.astype(np.float16), name="w1")
|
||||
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2 = Parameter(w2.astype(np.float16), name="w2")
|
||||
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw")
|
||||
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1")
|
||||
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw")
|
||||
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
|
||||
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
|
||||
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
|
||||
|
||||
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
|
||||
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
|
||||
self.fc.to_float(mstype.float32)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.vgg = VGG()
|
||||
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.rnn_dropout = nn.Dropout(0.9)
|
||||
self.use_dropout = config.use_dropout
|
||||
|
||||
def construct(self, x):
|
||||
x = self.vgg(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
|
||||
x = self.reshape(x, (self.batch_size, self.input_size, -1))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
bw_x = self.reverse_seq1(x, self.seq_length)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
|
||||
y1_out = self.concat1((y1, y1_bw))
|
||||
if self.use_dropout:
|
||||
y1_out = self.rnn_dropout(y1_out)
|
||||
|
||||
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
|
||||
bw_y = self.reverse_seq3(y1_out, self.seq_length)
|
||||
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
|
||||
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
|
||||
y2_out = self.concat1((y2, y2_bw))
|
||||
if self.use_dropout:
|
||||
y2_out = self.dropout(y2_out)
|
||||
|
||||
output = ()
|
||||
for i in range(F.shape(y2_out)[0]):
|
||||
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
|
||||
y2_after_fc = self.expand_dims(y2_after_fc, 0)
|
||||
output += (y2_after_fc,)
|
||||
output = self.concat(output)
|
||||
return output
|
@ -0,0 +1,114 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Automatic differentiation with grad clip."""
|
||||
import numpy as np
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
compute_norm = C.MultitypeFuncGraph("compute_norm")
|
||||
|
||||
|
||||
@compute_norm.register("Tensor")
|
||||
def _compute_norm(grad):
|
||||
norm = nn.Norm()
|
||||
norm = norm(F.cast(grad, mstype.float32))
|
||||
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return ret
|
||||
|
||||
|
||||
grad_div = C.MultitypeFuncGraph("grad_div")
|
||||
|
||||
|
||||
@grad_div.register("Tensor", "Tensor")
|
||||
def _grad_div(val, grad):
|
||||
div = P.RealDiv()
|
||||
mul = P.Mul()
|
||||
scale = div(10.0, val)
|
||||
ret = mul(grad, scale)
|
||||
return ret
|
||||
|
||||
|
||||
class TrainOneStepCellWithGradClip(Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
|
||||
Backward graph with grad clip will be created in the construct function to do parameter updating.
|
||||
Different parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Inputs:
|
||||
- data (Tensor) - Tensor of shape :(N, ...).
|
||||
- label (Tensor) - Tensor of shape :(N, ...).
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.norm = nn.Norm(keep_dims=True)
|
||||
self.dtype = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, data, label):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(data, label, sens)
|
||||
norm = self.hyper_map(F.partial(compute_norm), grads)
|
||||
norm = self.concat(norm)
|
||||
norm = self.norm(norm)
|
||||
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
|
||||
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
|
||||
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
@ -0,0 +1,121 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset preprocessing."""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vc
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.ic03_dataset import IC03Dataset
|
||||
from src.ic13_dataset import IC13Dataset
|
||||
from src.iiit5k_dataset import IIIT5KDataset
|
||||
from src.svt_dataset import SVTDataset
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class CaptchaDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, is_training=True, config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
if is_training:
|
||||
self.imgslist = os.path.join(self.img_root_dir, 'annotation_train.txt')
|
||||
else:
|
||||
self.imgslist = os.path.join(self.img_root_dir, 'annotation_test.txt')
|
||||
self.lexicon_file = os.path.join(self.img_root_dir, 'lexicon.txt')
|
||||
with open(self.lexicon_file, 'r') as f:
|
||||
self.lexicons = [line.strip('\n') for line in f]
|
||||
f.close()
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(self.imgslist, 'r') as f:
|
||||
for line in f:
|
||||
img_name, label_index = line.strip('\n').split(" ")
|
||||
self.img_list.append(img_name)
|
||||
self.img_names[img_name] = self.lexicons[int(label_index)]
|
||||
f.close()
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
||||
|
||||
def create_dataset(name, dataset_path, batch_size=1, num_shards=1, shard_id=0, is_training=True, config=config1):
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
dataset_path(int): dataset path
|
||||
batch_size(int): batch size of generated dataset, default is 1
|
||||
num_shards(int): number of devices
|
||||
shard_id(int): rank id
|
||||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
if name == 'synth':
|
||||
dataset = CaptchaDataset(dataset_path, is_training, config)
|
||||
elif name == 'ic03':
|
||||
dataset = IC03Dataset(dataset_path, "annotation.txt", config, True, 3)
|
||||
elif name == 'ic13':
|
||||
dataset = IC13Dataset(dataset_path, "Challenge2_Test_Task3_GT.txt", config)
|
||||
elif name == 'svt':
|
||||
dataset = SVTDataset(dataset_path, config)
|
||||
elif name == 'iiit5k':
|
||||
dataset = IIIT5KDataset(dataset_path, "annotation.txt", config)
|
||||
else:
|
||||
raise ValueError(f"unsupported dataset name: {name}")
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
image_trans = [
|
||||
vc.Resize((config.image_height, config.image_width)),
|
||||
vc.Normalize([127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||
vc.HWC2CHW()
|
||||
]
|
||||
label_trans = [
|
||||
C.TypeCast(mstype.int32)
|
||||
]
|
||||
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
||||
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
||||
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class IC03Dataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1, filter_by_dict=True, filter_length=3):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
anno_file = os.path.join(img_root_dir, anno_file)
|
||||
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(anno_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split(",")[1].lower()
|
||||
if len(label) < filter_length:
|
||||
continue
|
||||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
continue
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class IC13Dataset:
|
||||
"""
|
||||
create evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images
|
||||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
def __init__(self, img_root_dir, label_file="", config=config1, filter_by_dict=True, filter_length=3):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
self.label_file = os.path.join(img_root_dir, label_file)
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
self.config = config
|
||||
with open(self.label_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split("\"")[1].lower()
|
||||
if len(label) < filter_length:
|
||||
continue
|
||||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
continue
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
f.close()
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class IIIT5KDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
anno_file = os.path.join(img_root_dir, anno_file)
|
||||
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(anno_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split(",")[1].lower()
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
@ -0,0 +1,49 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CTC Loss."""
|
||||
import numpy as np
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CTCLoss(_Loss):
|
||||
"""
|
||||
CTCLoss definition
|
||||
|
||||
Args:
|
||||
max_sequence_length(int): max number of sequence length. For text images, the value is equal to image
|
||||
width
|
||||
max_label_length(int): max number of label length for each input.
|
||||
batch_size(int): batch size of input logits
|
||||
"""
|
||||
|
||||
def __init__(self, max_sequence_length, max_label_length, batch_size):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
|
||||
name="sequence_length")
|
||||
labels_indices = []
|
||||
for i in range(batch_size):
|
||||
for j in range(max_label_length):
|
||||
labels_indices.append([i, j])
|
||||
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
|
||||
self.reshape = P.Reshape()
|
||||
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
|
||||
|
||||
def construct(self, logit, label):
|
||||
labels_values = self.reshape(label, (-1,))
|
||||
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
|
||||
return loss
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Metric for accuracy evaluation."""
|
||||
from mindspore import nn
|
||||
import Levenshtein
|
||||
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
class CRNNAccuracy(nn.Metric):
|
||||
"""
|
||||
Define accuracy metric for warpctc network.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CRNNAccuracy).__init__()
|
||||
self.config = config
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
self.blank = config.blank
|
||||
|
||||
def clear(self):
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('CRNNAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
str_pred = self._ctc_greedy_decoder(y_pred)
|
||||
str_label = self._convert_labels(y)
|
||||
|
||||
for pred, label in zip(str_pred, str_label):
|
||||
print(pred, " :: ", label)
|
||||
edit_distance = Levenshtein.distance(pred, label)
|
||||
self._total_num += 1
|
||||
if edit_distance == 0:
|
||||
self._correct_num += 1
|
||||
|
||||
def eval(self):
|
||||
if self._total_num == 0:
|
||||
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
|
||||
print('correct num: ', self._correct_num, ', total num: ', self._total_num)
|
||||
sequence_accurancy = self._correct_num / self._total_num
|
||||
return sequence_accurancy
|
||||
|
||||
def _arr2char(self, inputs):
|
||||
string = ""
|
||||
for i in inputs:
|
||||
if i < self.blank:
|
||||
string += label_dict[i]
|
||||
return string
|
||||
|
||||
def _convert_labels(self, inputs):
|
||||
str_list = []
|
||||
for label in inputs:
|
||||
str_temp = self._arr2char(label)
|
||||
str_list.append(str_temp)
|
||||
return str_list
|
||||
|
||||
def _ctc_greedy_decoder(self, y_pred):
|
||||
"""
|
||||
parse predict result to labels
|
||||
"""
|
||||
indices = []
|
||||
seq_len, batch_size, _ = y_pred.shape
|
||||
indices = y_pred.argmax(axis=2)
|
||||
lens = [seq_len] * batch_size
|
||||
pred_labels = []
|
||||
for i in range(batch_size):
|
||||
idx = indices[:, i]
|
||||
last_idx = self.blank
|
||||
pred_label = []
|
||||
for j in range(lens[i]):
|
||||
cur_idx = idx[j]
|
||||
if cur_idx not in [last_idx, self.blank]:
|
||||
pred_label.append(cur_idx)
|
||||
last_idx = cur_idx
|
||||
pred_labels.append(pred_label)
|
||||
str_results = []
|
||||
for i in pred_labels:
|
||||
str_results.append(self._arr2char(i))
|
||||
return str_results
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class SVTDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
file_list = os.listdir(img_root_dir)
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
for f in file_list:
|
||||
label = f.split(".jpg")[0]
|
||||
label = label.split("_")[-1].lower()
|
||||
self.img_names[f] = label
|
||||
self.img_list.append(f)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
@ -0,0 +1,101 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""crnn training"""
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap import WithLossCell
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import CRNN
|
||||
from src.crnn_for_train import TrainOneStepCellWithGradClip
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="crnn training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.model == 'lowercase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr_scale = 1
|
||||
if args_opt.run_distribute:
|
||||
if args_opt.platform == 'Ascend':
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
|
||||
max_text_length = config.max_text_length
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
|
||||
num_shards=device_num, shard_id=rank, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = config.learning_rate
|
||||
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = CRNN(config)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
# define model
|
||||
model = Model(net)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
Loading…
Reference in new issue