!3421 Add WarpCTC GPU script

Merge pull request !3421 from yangyongjie/master
pull/3421/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 669a8969c7

@ -31,7 +31,8 @@ These is an example of training Warpctc with self-generated captcha image datase
└──warpct └──warpct
├── README.md ├── README.md
├── script ├── script
├── run_distribute_train.sh # launch distributed training(8 pcs) ├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
├── run_eval.sh # launch evaluation ├── run_eval.sh # launch evaluation
├── run_process_data.sh # launch dataset generation ├── run_process_data.sh # launch dataset generation
└── run_standalone_train.sh # launch standalone training(1 pcs) └── run_standalone_train.sh # launch standalone training(1 pcs)
@ -75,22 +76,31 @@ Parameters for both training and evaluation can be set in config.py.
#### Usage #### Usage
``` ```
# distributed training # distributed training in Ascend
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# distributed training in GPU
Usage: sh run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
# standalone training # standalone training
Usage: sh run_standalone_train.sh [DATASET_PATH] Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]
``` ```
#### Launch #### Launch
``` ```
# distribute training example # distribute training example in Ascend
sh run_distribute_train.sh rank_table.json ../data/train sh run_distribute_train.sh rank_table.json ../data/train
# standalone training example # distribute training example in GPU
sh run_standalone_train.sh ../data/train sh run_distribute_train.sh 8 ../data/train
# standalone training example in Ascend
sh run_standalone_train.sh ../data/train Ascend
# standalone training example in GPU
sh run_standalone_train.sh ../data/train GPU
``` ```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
@ -116,14 +126,17 @@ Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
``` ```
# evaluation # evaluation
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
``` ```
#### Launch #### Launch
``` ```
# evaluation example # evaluation example in Ascend
sh run_eval.sh ../data/test warpctc-30-98.ckpt sh run_eval.sh ../data/test warpctc-30-98.ckpt Ascend
# evaluation example in GPU
sh run_eval.sh ../data/test warpctc-30-98.ckpt GPU
``` ```
> checkpoint can be produced in training process. > checkpoint can be produced in training process.

@ -23,10 +23,10 @@ from mindspore import dataset as de
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf from src.config import config as cf
from src.dataset import create_dataset from src.dataset import create_dataset
from src.warpctc import StackedRNN from src.warpctc import StackedRNN, StackedRNNForGPU
from src.metric import WarpCTCAccuracy from src.metric import WarpCTCAccuracy
random.seed(1) random.seed(1)
@ -36,30 +36,38 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training") parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, context.set_context(device_id=device_id)
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if __name__ == '__main__': if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset # create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) dataset = create_dataset(dataset_path=args_opt.dataset_path,
batch_size=cf.batch_size,
device_target=args_opt.platform)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# define loss if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) loss = CTCLoss(max_sequence_length=cf.captcha_width,
# define net max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint # load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
# define model # define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
# start evaluation # start evaluation
res = model.eval(dataset) res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
print("result:", res, flush=True) print("result:", res, flush=True)

@ -57,6 +57,6 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
cd ./train_parallel$i || exit cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log env >env.log
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
cd .. cd ..
done done

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
RANK_SIZE=$1
DATASET_PATH=$(get_real_path $2)
if [ ! -d $DATASET_PATH ]; then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ -d "distribute_train" ]; then
rm -rf ./distribute_train
fi
mkdir ./distribute_train
cp ../*.py ./distribute_train
cp -r ../src ./distribute_train
cd ./distribute_train || exit
mpirun --allow-run-as-root -n $RANK_SIZE \
python train.py \
--dataset_path=$DATASET_PATH \
--platform=GPU \
--run_distribute > log.txt 2>&1 &
cd ..

@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 2 ]; then if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
exit 1 exit 1
fi fi
@ -29,6 +29,7 @@ get_real_path() {
PATH1=$(get_real_path $1) PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2) PATH2=$(get_real_path $2)
PLATFORM=$3
if [ ! -d $PATH1 ]; then if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory" echo "error: DATASET_PATH=$PATH1 is not a directory"
@ -40,6 +41,7 @@ if [ ! -f $PATH2 ]; then
exit 1 exit 1
fi fi
run_ascend() {
ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=0 export DEVICE_ID=0
@ -51,10 +53,32 @@ if [ -d "eval" ]; then
fi fi
mkdir ./eval mkdir ./eval
cp ../*.py ./eval cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval cp -r ../src ./eval
cd ./eval || exit cd ./eval || exit
env >env.log env >env.log
echo "start evaluation for device $DEVICE_ID" echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
cd .. cd ..
}
run_gpu() {
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1 $PATH2
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1 $PATH2
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

@ -14,8 +14,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 1 ]; then if [ $# != 2 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
exit 1 exit 1
fi fi
@ -28,27 +28,44 @@ get_real_path() {
} }
PATH1=$(get_real_path $1) PATH1=$(get_real_path $1)
PLATFORM=$2
if [ ! -d $PATH1 ]; then if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory" echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1 exit 1
fi fi
run_ascend() {
ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=0 export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
env >env.log
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ -d "train" ]; then if [ -d "train" ]; then
rm -rf ./train rm -rf ./train
fi fi
mkdir ./train mkdir ./train
cp ../*.py ./train cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train cp -r ../src ./train
cd ./train || exit cd ./train || exit
echo "start training for device $DEVICE_ID"
env >env.log if [ "Ascend" == $PLATFORM ]; then
python train.py --dataset=$PATH1 &>log & run_ascend $PATH1
cd .. elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

@ -24,24 +24,25 @@ from PIL import Image
from src.config import config as cf from src.config import config as cf
class _CaptchaDataset(): class _CaptchaDataset:
""" """
create train or evaluation dataset for warpctc create train or evaluation dataset for warpctc
Args: Args:
img_root_dir(str): root path of images img_root_dir(str): root path of images
max_captcha_digits(int): max number of digits in images. max_captcha_digits(int): max number of digits in images.
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label device_target(str): platform of training, support Ascend and GPU.
length is less than max_captcha_digits, the remaining labels are padding with blank.
""" """
def __init__(self, img_root_dir, max_captcha_digits, blank=10): def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
if not os.path.exists(img_root_dir): if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir self.img_root_dir = img_root_dir
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
self.max_captcha_digits = max_captcha_digits self.max_captcha_digits = max_captcha_digits
self.blank = blank self.target = device_target
self.blank = 10 if self.target == 'Ascend' else 0
self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
def __len__(self): def __len__(self):
return len(self.img_names) return len(self.img_names)
@ -54,27 +55,33 @@ class _CaptchaDataset():
image = np.array(im) image = np.array(im)
label_str = os.path.splitext(img_name)[0] label_str = os.path.splitext(img_name)[0]
label_str = label_str[label_str.find('-') + 1:] label_str = label_str[label_str.find('-') + 1:]
if self.target == 'Ascend':
label = [int(i) for i in label_str] label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
else:
label = [int(i) + 1 for i in label_str]
length = len(label)
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label.append(length)
label = np.array(label) label = np.array(label)
return image, label return image, label
def create_dataset(dataset_path, repeat_num=1, batch_size=1): def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
""" """
create train or evaluation dataset for warpctc create train or evaluation dataset for warpctc
Args: Args:
dataset_path(int): dataset path dataset_path(int): dataset path
repeat_num(int): dataset repetition num, default is 1
batch_size(int): batch size of generated dataset, default is 1 batch_size(int): batch size of generated dataset, default is 1
num_shards(int): number of devices
shard_id(int): rank id
device_target(str): platform of training, support Ascend and GPU
""" """
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
image_trans = [ image_trans = [
vc.Rescale(1.0 / 255.0, 0.0), vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
@ -87,6 +94,5 @@ def create_dataset(dataset_path, repeat_num=1, batch_size=1):
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
ds = ds.batch(batch_size) ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
return ds return ds

@ -47,3 +47,25 @@ class CTCLoss(_Loss):
labels_values = self.reshape(label, (-1,)) labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss return loss
class CTCLossV2(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, batch_size):
super(CTCLossV2, self).__init__()
self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32)
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLossV2()
def construct(self, logit, label):
labels_values = label[:, :-1]
labels_length = label[:, -1]
loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length)
return loss

@ -15,19 +15,19 @@
"""Metric for accuracy evaluation.""" """Metric for accuracy evaluation."""
from mindspore import nn from mindspore import nn
BLANK_LABLE = 10
class WarpCTCAccuracy(nn.Metric): class WarpCTCAccuracy(nn.Metric):
""" """
Define accuracy metric for warpctc network. Define accuracy metric for warpctc network.
""" """
def __init__(self): def __init__(self, device_target='Ascend'):
super(WarpCTCAccuracy).__init__() super(WarpCTCAccuracy).__init__()
self._correct_num = 0 self._correct_num = 0
self._total_num = 0 self._total_num = 0
self._count = 0 self._count = 0
self.device_target = device_target
self.blank = 10 if device_target == 'Ascend' else 0
def clear(self): def clear(self):
self._correct_num = 0 self._correct_num = 0
@ -39,6 +39,8 @@ class WarpCTCAccuracy(nn.Metric):
y_pred = self._convert_data(inputs[0]) y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1]) y = self._convert_data(inputs[1])
if self.device_target == 'GPU':
y = y[:, :-1]
self._count += 1 self._count += 1
@ -54,8 +56,7 @@ class WarpCTCAccuracy(nn.Metric):
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num return self._correct_num / self._total_num
@staticmethod def _is_eq(self, pred_lbl, target):
def _is_eq(pred_lbl, target):
""" """
check whether predict label is equal to target label check whether predict label is equal to target label
""" """
@ -63,11 +64,10 @@ class WarpCTCAccuracy(nn.Metric):
pred_diff = len(target) - len(pred_lbl) pred_diff = len(target) - len(pred_lbl)
if pred_diff > 0: if pred_diff > 0:
# padding by BLANK_LABLE # padding by BLANK_LABLE
pred_lbl.extend([BLANK_LABLE] * pred_diff) pred_lbl.extend([self.blank] * pred_diff)
return pred_lbl == target return pred_lbl == target
@staticmethod def _get_prediction(self, y_pred):
def _get_prediction(y_pred):
""" """
parse predict result to labels parse predict result to labels
""" """
@ -78,11 +78,11 @@ class WarpCTCAccuracy(nn.Metric):
pred_lbls = [] pred_lbls = []
for i in range(batch_size): for i in range(batch_size):
idx = indices[:, i] idx = indices[:, i]
last_idx = BLANK_LABLE last_idx = self.blank
pred_lbl = [] pred_lbl = []
for j in range(lens[i]): for j in range(lens[i]):
cur_idx = idx[j] cur_idx = idx[j]
if cur_idx not in [last_idx, BLANK_LABLE]: if cur_idx not in [last_idx, self.blank]:
pred_lbl.append(cur_idx) pred_lbl.append(cur_idx)
last_idx = cur_idx last_idx = cur_idx
pred_lbls.append(pred_lbl) pred_lbls.append(pred_lbl)

@ -88,3 +88,52 @@ class StackedRNN(nn.Cell):
output = self.concat((output, h2_after_fc)) output = self.concat((output, h2_after_fc))
return output return output
class StackedRNNForGPU(nn.Cell):
"""
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
num_layer(int): the number of layer of LSTM.
"""
def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2):
super(StackedRNNForGPU, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.num_classes = 11
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / hidden_size) ** 0.5
weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4)
self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight')
self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
self.lstm.weight = self.weight
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
def construct(self, x):
x = self.transpose(x, (3, 0, 2, 1))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
output, _ = self.lstm(x, (self.h, self.c))
res = ()
for i in range(F.shape(x)[0]):
res += (self.expand_dims(self.fc(output[i]), 0),)
res = self.concat(res)
return res

@ -42,7 +42,7 @@ grad_div = C.MultitypeFuncGraph("grad_div")
@grad_div.register("Tensor", "Tensor") @grad_div.register("Tensor", "Tensor")
def _grad_div(val, grad): def _grad_div(val, grad):
div = P.Div() div = P.RealDiv()
mul = P.Mul() mul = P.Mul()
grad = mul(grad, 10.0) grad = mul(grad, 10.0)
ret = div(grad, val) ret = div(grad, val)

@ -24,12 +24,12 @@ from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.nn.wrap import WithLossCell from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf from src.config import config as cf
from src.dataset import create_dataset from src.dataset import create_dataset
from src.warpctc import StackedRNN from src.warpctc import StackedRNN, StackedRNNForGPU
from src.warpctc_for_train import TrainOneStepCellWithGradClip from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr from src.lr_schedule import get_lr
@ -38,38 +38,60 @@ np.random.seed(1)
de.config.set_seed(1) de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training") parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, context.set_context(device_id=device_id)
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if __name__ == '__main__': if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute: if args_opt.run_distribute:
if args_opt.platform == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
init('nccl')
lr_scale = 0.5
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args_opt.device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
init() else:
device_num = 1
rank = 0
max_captcha_digits = cf.max_captcha_digits max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset # create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# define lr # define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
lr = get_lr(cf.epoch_size, step_size, lr_init) lr = get_lr(cf.epoch_size, step_size, lr_init)
# define loss if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) loss = CTCLoss(max_sequence_length=cf.captcha_width,
# define net max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# define opt
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = WithLossCell(net, loss) net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train() net = TrainOneStepCellWithGradClip(net, opt).set_train()
# define model # define model
@ -79,6 +101,6 @@ if __name__ == '__main__':
if cf.save_checkpoint: if cf.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
keep_checkpoint_max=cf.keep_checkpoint_max) keep_checkpoint_max=cf.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path, config=config_ck)
callbacks.append(ckpt_cb) callbacks.append(ckpt_cb)
model.train(cf.epoch_size, dataset, callbacks=callbacks) model.train(cf.epoch_size, dataset, callbacks=callbacks)

Loading…
Cancel
Save