!11788 Adding CRNN-Seq2Seq-OCR model to MindSpore model zoo
From: @alashkari Reviewed-by: Signed-off-by:pull/11788/MERGE
commit
3952a57d85
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
echo $PATH2
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
cd ..
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log &
|
||||
cd ..
|
@ -0,0 +1,178 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
CRNN-Seq2Seq-OCR model.
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
from src.seq2seq import Encoder, Decoder
|
||||
|
||||
|
||||
class NLLLoss(_Loss):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, label):
|
||||
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
|
||||
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
|
||||
return self.get_loss(loss)
|
||||
|
||||
|
||||
class AttentionOCRInfer(nn.Cell):
|
||||
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
|
||||
decoder_output_size, max_length, dropout_p=0.1):
|
||||
super(AttentionOCRInfer, self).__init__()
|
||||
|
||||
self.encoder = Encoder(batch_size=batch_size,
|
||||
conv_out_dim=conv_out_dim,
|
||||
hidden_size=encoder_hidden_size)
|
||||
|
||||
self.decoder = Decoder(hidden_size=decoder_hidden_size,
|
||||
output_size=decoder_output_size,
|
||||
max_length=max_length,
|
||||
dropout_p=dropout_p)
|
||||
|
||||
def construct(self, img, decoder_input, decoder_hidden):
|
||||
'''
|
||||
get token output
|
||||
'''
|
||||
encoder_outputs = self.encoder(img)
|
||||
decoder_output, decoder_hidden, decoder_attention = self.decoder(
|
||||
decoder_input, decoder_hidden, encoder_outputs)
|
||||
return decoder_output, decoder_hidden, decoder_attention
|
||||
|
||||
|
||||
class AttentionOCR(nn.Cell):
|
||||
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
|
||||
decoder_output_size, max_length, dropout_p=0.1):
|
||||
super(AttentionOCR, self).__init__()
|
||||
self.encoder = Encoder(batch_size=batch_size,
|
||||
conv_out_dim=conv_out_dim,
|
||||
hidden_size=encoder_hidden_size)
|
||||
self.decoder = Decoder(hidden_size=decoder_hidden_size,
|
||||
output_size=decoder_output_size,
|
||||
max_length=max_length,
|
||||
dropout_p=dropout_p)
|
||||
self.init_decoder_hidden = Tensor(np.zeros((1, batch_size, decoder_hidden_size),
|
||||
dtype=np.float16), mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.split = P.Split(axis=1, output_num=max_length)
|
||||
self.concat = P.Concat()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.argmax = P.Argmax()
|
||||
self.select = P.Select()
|
||||
|
||||
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
|
||||
encoder_outputs = self.encoder(img)
|
||||
_, text_len = self.shape(decoder_inputs)
|
||||
decoder_outputs = ()
|
||||
decoder_input_tuple = self.split(decoder_inputs)
|
||||
decoder_target_tuple = self.split(decoder_targets)
|
||||
decoder_input = decoder_input_tuple[0]
|
||||
decoder_hidden = self.init_decoder_hidden
|
||||
|
||||
for i in range(text_len):
|
||||
decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = self.argmax(decoder_output)
|
||||
decoder_input_top = self.expand_dims(topi, 1)
|
||||
decoder_input = self.select(teacher_force, decoder_target_tuple[i], decoder_input_top)
|
||||
decoder_output = self.expand_dims(decoder_output, 0)
|
||||
decoder_outputs += (decoder_output,)
|
||||
outputs = self.concat(decoder_outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class AttentionOCRWithLossCell(nn.Cell):
|
||||
"""AttentionOCR with Loss"""
|
||||
def __init__(self, network, max_length):
|
||||
super(AttentionOCRWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = NLLLoss()
|
||||
self.shape = P.Shape()
|
||||
self.add = P.AddN()
|
||||
self.mean = P.ReduceMean()
|
||||
self.split = P.Split(axis=0, output_num=max_length)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
|
||||
decoder_outputs = self.network(img, decoder_inputs, decoder_targets, teacher_force)
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
|
||||
_, text_len = self.shape(decoder_targets)
|
||||
loss_total = ()
|
||||
decoder_output_tuple = self.split(decoder_outputs)
|
||||
for i in range(text_len):
|
||||
loss = self.loss(self.squeeze(decoder_output_tuple[i]), decoder_targets[:, i])
|
||||
loss = self.mean(loss)
|
||||
loss_total += (loss,)
|
||||
loss_output = self.add(loss_total)
|
||||
return loss_output
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * P.Reciprocal()(scale)
|
||||
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
|
||||
# Set parallel_mode
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
@ -0,0 +1,195 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
CRN-Seq2Seq-OCR CNN model.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2:
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, gain_param=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, gain_param)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
class ConvRelu(nn.Cell):
|
||||
"""
|
||||
Convolution Layer followed by Relu Layer
|
||||
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1):
|
||||
super(ConvRelu, self).__init__()
|
||||
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
weight_init=Tensor(kaiming_normal(shape)))
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNRelu(nn.Cell):
|
||||
"""
|
||||
Convolution Layer followed by Batch Normalization and Relu Layer
|
||||
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same'):
|
||||
super(ConvBNRelu, self).__init__()
|
||||
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size, stride,
|
||||
pad_mode=pad_mode,
|
||||
weight_init=Tensor(kaiming_normal(shape)))
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class CNN(nn.Cell):
|
||||
"""
|
||||
CNN Class for OCR
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conv_out_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.convRelu1 = ConvRelu(3, 64, (3, 3))
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
self.convRelu2 = ConvRelu(64, 128, (3, 3))
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu1 = ConvBNRelu(128, 256, (3, 3))
|
||||
self.convRelu3 = ConvRelu(256, 256, (3, 3))
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu2 = ConvBNRelu(256, 384, (3, 3))
|
||||
self.convRelu4 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu3 = ConvBNRelu(384, 384, (3, 3))
|
||||
self.convRelu5 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool5 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu4 = ConvBNRelu(384, 384, (3, 3))
|
||||
self.convRelu6 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool6 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 1)))
|
||||
self.convBNRelu5 = ConvBNRelu(384, conv_out_dim, (2, 2), pad_mode='valid')
|
||||
self.dropout = nn.Dropout(keep_prob=0.5)
|
||||
|
||||
self.squeeze = P.Squeeze(2)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.convRelu1(x)
|
||||
x = self.maxpool1(x)
|
||||
|
||||
x = self.convRelu2(x)
|
||||
x = self.maxpool2(x)
|
||||
|
||||
x = self.convBNRelu1(x)
|
||||
x = self.convRelu3(x)
|
||||
x = self.maxpool3(x)
|
||||
|
||||
x = self.convBNRelu2(x)
|
||||
x = self.convRelu4(x)
|
||||
x = self.maxpool4(x)
|
||||
|
||||
x = self.convBNRelu3(x)
|
||||
x = self.convRelu5(x)
|
||||
x = self.maxpool5(x)
|
||||
|
||||
x = self.convBNRelu4(x)
|
||||
x = self.convRelu6(x)
|
||||
x = self.maxpool6(x)
|
||||
|
||||
x = self.pad(x)
|
||||
x = self.convBNRelu5(x)
|
||||
x = self.dropout(x)
|
||||
x = self.squeeze(x)
|
||||
|
||||
return x
|
@ -0,0 +1,245 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create FSNS MindRecord files."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from src.config import config
|
||||
from src.utils import initialize_vocabulary
|
||||
|
||||
|
||||
def serialize_annotation(img_path, lex, vocab):
|
||||
|
||||
go_id = config.characters_dictionary.get("go_id")
|
||||
eos_id = config.characters_dictionary.get("eos_id")
|
||||
|
||||
word = [go_id]
|
||||
for special_label in config.labels_not_use:
|
||||
if lex == special_label:
|
||||
if config.print_no_train_label:
|
||||
print("label in for image: %s is special label, related label is: %s, skip ..." % (img_path, lex))
|
||||
return None
|
||||
|
||||
for c in lex:
|
||||
if c not in vocab:
|
||||
return None
|
||||
|
||||
c_idx = vocab.get(c)
|
||||
word.append(c_idx)
|
||||
|
||||
word.append(eos_id)
|
||||
word = np.array(word, dtype=np.int32)
|
||||
return word
|
||||
|
||||
def create_fsns_label(image_dir, anno_file_dirs):
|
||||
"""Get image path and annotation."""
|
||||
|
||||
if not os.path.isdir(image_dir):
|
||||
raise ValueError(f'Cannot find {image_dir} dataset path.')
|
||||
|
||||
image_files_dict = {}
|
||||
image_anno_dict = {}
|
||||
images = []
|
||||
img_id = 0
|
||||
|
||||
for anno_file_dir in anno_file_dirs:
|
||||
|
||||
anno_file = open(anno_file_dir, 'r').readlines()
|
||||
|
||||
for line in anno_file:
|
||||
|
||||
file_name = line.split('\t')[0]
|
||||
labels = line.split('\t')[1].split('\n')[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
|
||||
if not os.path.isfile(image_path):
|
||||
print(f'Cannot find image {image_path} according to annotations.')
|
||||
continue
|
||||
|
||||
if labels:
|
||||
images.append(img_id)
|
||||
image_files_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = labels
|
||||
img_id += 1
|
||||
|
||||
return images, image_files_dict, image_anno_dict
|
||||
|
||||
|
||||
def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
|
||||
|
||||
anno_file_dirs = [config.train_annotation_file]
|
||||
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
|
||||
anno_file_dirs=anno_file_dirs)
|
||||
vocab, _ = initialize_vocabulary(config.vocab_path)
|
||||
|
||||
data_schema = {"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]},
|
||||
"decoder_input": {"type": "int32", "shape": [-1]},
|
||||
"decoder_mask": {"type": "int32", "shape": [-1]},
|
||||
"decoder_target": {"type": "int32", "shape": [-1]},
|
||||
"annotation": {"type": "string"}}
|
||||
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
writer.add_schema(data_schema, "ocr")
|
||||
|
||||
for img_id in images:
|
||||
|
||||
image_path = image_path_dict[img_id]
|
||||
annotation = image_anno_dict[img_id]
|
||||
|
||||
label_max_len = config.max_text_len
|
||||
text_max_len = config.max_text_len - 2
|
||||
|
||||
if len(annotation) > text_max_len:
|
||||
continue
|
||||
label = serialize_annotation(image_path, annotation, vocab)
|
||||
|
||||
if label is None:
|
||||
continue
|
||||
|
||||
label_len = len(label)
|
||||
decoder_input_len = label_max_len
|
||||
|
||||
if label_len <= decoder_input_len:
|
||||
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
|
||||
one_mask_len = label_len - config.go_shift
|
||||
target_weight = np.concatenate((np.ones(one_mask_len, dtype=np.float32),
|
||||
np.zeros(decoder_input_len - one_mask_len, dtype=np.float32)))
|
||||
else:
|
||||
continue
|
||||
|
||||
decoder_input = (np.array(label).T).astype(np.int32)
|
||||
target_weight = (np.array(target_weight).T).astype(np.int32)
|
||||
|
||||
if not len(decoder_input) == len(target_weight):
|
||||
continue
|
||||
|
||||
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
|
||||
target = (np.array(target)).astype(np.int32)
|
||||
|
||||
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
row = {"image": img,
|
||||
"label": label,
|
||||
"decoder_input": decoder_input,
|
||||
"decoder_mask": target_weight,
|
||||
"decoder_target": target,
|
||||
"annotation": str(annotation)}
|
||||
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
|
||||
|
||||
anno_file_dirs = [config.train_annotation_file]
|
||||
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
|
||||
anno_file_dirs=anno_file_dirs)
|
||||
vocab, _ = initialize_vocabulary(config.vocab_path)
|
||||
|
||||
data_schema = {"image": {"type": "bytes"},
|
||||
"decoder_input": {"type": "int32", "shape": [-1]},
|
||||
"decoder_target": {"type": "int32", "shape": [-1]},
|
||||
"annotation": {"type": "string"}}
|
||||
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
writer.add_schema(data_schema, "ocr")
|
||||
|
||||
for img_id in images:
|
||||
|
||||
image_path = image_path_dict[img_id]
|
||||
annotation = image_anno_dict[img_id]
|
||||
|
||||
label_max_len = config.max_text_len
|
||||
text_max_len = config.max_text_len - 2
|
||||
|
||||
if len(annotation) > text_max_len:
|
||||
continue
|
||||
label = serialize_annotation(image_path, annotation, vocab)
|
||||
|
||||
if label is None:
|
||||
continue
|
||||
|
||||
label_len = len(label)
|
||||
decoder_input_len = label_max_len
|
||||
|
||||
if label_len <= decoder_input_len:
|
||||
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
|
||||
else:
|
||||
continue
|
||||
|
||||
decoder_input = (np.array(label).T).astype(np.int32)
|
||||
|
||||
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
|
||||
target = (np.array(target)).astype(np.int32)
|
||||
|
||||
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
row = {"image": img,
|
||||
"decoder_input": decoder_input,
|
||||
"decoder_target": target,
|
||||
"annotation": str(annotation)}
|
||||
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True):
|
||||
print("Start creating dataset!")
|
||||
if is_training:
|
||||
mindrecord_dir = os.path.join(config.mindrecord_dir, "train")
|
||||
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
|
||||
|
||||
if not os.path.exists(mindrecord_files[0]):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "fsns":
|
||||
if os.path.isdir(config.data_root):
|
||||
print("Create FSNS Mindrecord files for train pipeline.")
|
||||
fsns_train_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix, file_num=8)
|
||||
print("Create FSNS Mindrecord files for train pipeline Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("{} not exits!".format(config.data_root))
|
||||
else:
|
||||
print("{} dataset is not defined!".format(dataset))
|
||||
|
||||
if not is_training:
|
||||
mindrecord_dir = os.path.join(config.mindrecord_dir, "val")
|
||||
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
|
||||
|
||||
if not os.path.exists(mindrecord_files[0]):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "fsns":
|
||||
if os.path.isdir(config.val_data_root):
|
||||
print("Create FSNS Mindrecord files for val pipeline.")
|
||||
fsns_val_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix)
|
||||
print("Create FSNS Mindrecord files for val pipeline Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("{} not exits!".format(config.val_data_root))
|
||||
else:
|
||||
print("{} dataset is not defined!".format(dataset))
|
||||
|
||||
return mindrecord_files
|
@ -0,0 +1,144 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FSNS dataset"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.vision.py_transforms as P
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.config import config
|
||||
|
||||
|
||||
class AugmentationOps():
|
||||
def __init__(self, min_area_ratio=0.8, aspect_ratio_range=(0.8, 1.2), brightness=32./255.,
|
||||
contrast=0.5, saturation=0.5, hue=0.2, img_tile_shape=(150, 150)):
|
||||
self.min_area_ratio = min_area_ratio
|
||||
self.aspect_ratio_range = aspect_ratio_range
|
||||
self.img_tile_shape = img_tile_shape
|
||||
self.random_image_distortion_ops = P.RandomColorAdjust(brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue)
|
||||
|
||||
def __call__(self, img):
|
||||
img_h = self.img_tile_shape[0]
|
||||
img_w = self.img_tile_shape[1]
|
||||
img_new = np.zeros([128, 512, 3])
|
||||
|
||||
for i in range(4):
|
||||
img_tile = img[:, (i*150):((i+1)*150), :]
|
||||
# Random crop cut from the street sign image, resized to the same size.
|
||||
# Assures that the crop covers at least 0.8 area of the input image.
|
||||
# Aspect ratio of cropped image is within [0.8,1.2] range.
|
||||
h = img_h + 1
|
||||
w = img_w + 1
|
||||
|
||||
while (w >= img_w or h >= img_h):
|
||||
aspect_ratio = np.random.uniform(self.aspect_ratio_range[0],
|
||||
self.aspect_ratio_range[1])
|
||||
h_low = np.ceil(np.sqrt(self.min_area_ratio * img_h * img_w / aspect_ratio))
|
||||
h_high = np.floor(np.sqrt(img_h * img_w / aspect_ratio))
|
||||
h = np.random.randint(h_low, h_high)
|
||||
w = int(h * aspect_ratio)
|
||||
|
||||
y = np.random.randint(img_w - w)
|
||||
x = np.random.randint(img_h - h)
|
||||
img_tile = img_tile[x:(x+h), y:(y+w), :]
|
||||
# Randomly chooses one of the 4 interpolation resize methods.
|
||||
interpolation = np.random.choice([cv2.INTER_LINEAR,
|
||||
cv2.INTER_CUBIC,
|
||||
cv2.INTER_AREA,
|
||||
cv2.INTER_NEAREST])
|
||||
img_tile = cv2.resize(img_tile, (128, 128), interpolation=interpolation)
|
||||
# Random color distortion ops.
|
||||
img_tile_pil = Image.fromarray(img_tile)
|
||||
img_tile_pil = self.random_image_distortion_ops(img_tile_pil)
|
||||
img_tile = np.array(img_tile_pil)
|
||||
img_new[:, (i*128):((i+1)*128), :] = img_tile
|
||||
|
||||
img_new = 2 * (img_new / 255.) - 1
|
||||
return img_new
|
||||
|
||||
|
||||
class ImageResizeWithRescale():
|
||||
def __init__(self, standard_img_height, standard_img_width, channel_size=3):
|
||||
self.standard_img_height = standard_img_height
|
||||
self.standard_img_width = standard_img_width
|
||||
self.channel_size = channel_size
|
||||
|
||||
def __call__(self, img):
|
||||
img = cv2.resize(img, (self.standard_img_width, self.standard_img_height))
|
||||
img = 2 * (img / 255.) - 1
|
||||
return img
|
||||
|
||||
|
||||
def random_teacher_force(images, source_ids, target_ids):
|
||||
teacher_force = np.random.random() < config.teacher_force_ratio
|
||||
teacher_force_array = np.array([teacher_force], dtype=bool)
|
||||
return images, source_ids, target_ids, teacher_force_array
|
||||
|
||||
|
||||
def create_ocr_train_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=4, use_multiprocessing=True):
|
||||
ds = de.MindDataset(mindrecord_file,
|
||||
columns_list=["image", "decoder_input", "decoder_target"],
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=is_training)
|
||||
aug_ops = AugmentationOps()
|
||||
transforms = [C.Decode(),
|
||||
aug_ops,
|
||||
C.HWC2CHW()]
|
||||
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"])
|
||||
ds = ds.map(operations=random_teacher_force, input_columns=["image", "decoder_input", "decoder_target"],
|
||||
output_columns=["image", "decoder_input", "decoder_target", "teacher_force"],
|
||||
column_order=["image", "decoder_input", "decoder_target", "teacher_force"])
|
||||
type_cast_op_bool = ops.TypeCast(mstype.bool_)
|
||||
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
|
||||
print("Train dataset size= %s" % (int(ds.get_dataset_size())))
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def create_ocr_val_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
|
||||
num_parallel_workers=4, use_multiprocessing=True):
|
||||
ds = de.MindDataset(mindrecord_file,
|
||||
columns_list=["image", "annotation", "decoder_input", "decoder_target"],
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=False)
|
||||
resize_rescale_op = ImageResizeWithRescale(standard_img_height=128, standard_img_width=512)
|
||||
transforms = [C.Decode(),
|
||||
resize_rescale_op,
|
||||
C.HWC2CHW()]
|
||||
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"],
|
||||
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_input"],
|
||||
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
print("Val dataset size= %s" % (str(int(ds.get_dataset_size())*batch_size)))
|
||||
return ds
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
GRU cell
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.weight_init import gru_default_state
|
||||
|
||||
|
||||
class GRU(nn.Cell):
|
||||
'''
|
||||
GRU model
|
||||
|
||||
Args:
|
||||
input_size: The number of expected features in the input
|
||||
hidden_size: The number of features in the hidden state
|
||||
'''
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(GRU, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h = gru_default_state(self.input_size, self.hidden_size)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x, h):
|
||||
'''
|
||||
GRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): GRU input
|
||||
h(Tensor): GRU hidden state
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
h = self.cast(h, mstype.float16)
|
||||
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, h)
|
||||
return y1, h1
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER('crnn-seq2seq-ocr', rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,196 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""lstm"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import nn, context, Tensor, Parameter, ParameterTuple
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@constexpr
|
||||
def _create_sequence_length(shape):
|
||||
num_step, batch_size, _ = shape
|
||||
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
|
||||
return sequence_length
|
||||
|
||||
class LSTM(nn.Cell):
|
||||
"""
|
||||
Stacked LSTM (Long Short-Term Memory) layers.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
num_layers (int): Number of layers of stacked LSTM . Default: 1.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
batch_first (bool): Specifies whether the first dimension of input is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
(batch_size, seq_len, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `input`.
|
||||
|
||||
Outputs:
|
||||
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
|
||||
|
||||
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
|
||||
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
|
||||
(num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
super(LSTM, self).__init__()
|
||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
self.batch_first = batch_first
|
||||
self.transpose = P.Transpose()
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.dropout = dropout
|
||||
self.lstm = P.LSTM(input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
has_bias=has_bias,
|
||||
bidirectional=bidirectional,
|
||||
dropout=float(dropout))
|
||||
|
||||
weight_size = 0
|
||||
gate_size = 4 * hidden_size
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if self.is_ascend:
|
||||
self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.concat_2dim = P.Concat(axis=2)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
if dropout < 0 or dropout > 1:
|
||||
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
|
||||
if dropout == 1:
|
||||
self.dropout_op = P.ZerosLike()
|
||||
else:
|
||||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
b0 = np.zeros(gate_size, dtype=np.float32)
|
||||
self.w_list = []
|
||||
self.b_list = []
|
||||
self.rnns_fw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnns_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
for layer in range(num_layers):
|
||||
w_shape = input_size if layer == 0 else (num_directions * hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
|
||||
self.w_list.append(Parameter(
|
||||
initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer)))
|
||||
if has_bias:
|
||||
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float32)
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer)))
|
||||
else:
|
||||
self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer)))
|
||||
if bidirectional:
|
||||
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
|
||||
self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]),
|
||||
name='weight_bw' + str(layer)))
|
||||
b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float32) if has_bias else b0
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]),
|
||||
name='bias_bw' + str(layer)))
|
||||
self.w_list = ParameterTuple(self.w_list)
|
||||
self.b_list = ParameterTuple(self.b_list)
|
||||
else:
|
||||
for layer in range(num_layers):
|
||||
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
|
||||
increment_size = gate_size * input_layer_size
|
||||
increment_size += gate_size * hidden_size
|
||||
if has_bias:
|
||||
increment_size += 2 * gate_size
|
||||
weight_size += increment_size * num_directions
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
|
||||
|
||||
def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked bidirectional dynamic_rnn"""
|
||||
x_shape = self.shape(x)
|
||||
sequence_length = _create_sequence_length(x_shape)
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
output = x
|
||||
for i in range(self.num_layers):
|
||||
offset = i * 2
|
||||
weight_fw, weight_bw = weight[offset], weight[offset + 1]
|
||||
bias_fw, bias_bw = bias[offset], bias[offset + 1]
|
||||
init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :]
|
||||
init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :]
|
||||
bw_x = self.reverse_seq(pre_layer, sequence_length)
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw)
|
||||
y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw)
|
||||
y_bw = self.reverse_seq(y_bw, sequence_length)
|
||||
output = self.concat_2dim((y, y_bw))
|
||||
pre_layer = self.dropout_op(output) if self.dropout else output
|
||||
hn += (h[-1:, :, :],)
|
||||
hn += (h_bw[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
cn += (c_bw[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return output, status_h, status_c
|
||||
|
||||
def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked mutil_layer dynamic_rnn"""
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
y = 0
|
||||
for i in range(self.num_layers):
|
||||
weight_fw, bias_bw = weight[i], bias[i]
|
||||
init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :]
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw)
|
||||
pre_layer = self.dropout_op(y) if self.dropout else y
|
||||
hn += (h[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return y, status_h, status_c
|
||||
|
||||
def construct(self, x, hx):
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
h, c = hx
|
||||
if self.is_ascend:
|
||||
x = self.cast(x, mstype.float16)
|
||||
h = self.cast(h, mstype.float16)
|
||||
c = self.cast(c, mstype.float16)
|
||||
if self.bidirectional:
|
||||
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
else:
|
||||
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
else:
|
||||
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
return x, (h, c)
|
@ -0,0 +1,165 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
Seq2Seq_OCR model.
|
||||
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.cnn import CNN
|
||||
from src.gru import GRU
|
||||
from src.lstm import LSTM
|
||||
from src.weight_init import lstm_default_state
|
||||
|
||||
|
||||
class BidirectionalLSTM(nn.Cell):
|
||||
"""Bidirectional LSTM with a Dense layer
|
||||
|
||||
Args:
|
||||
batch_size(int): batch size of input data
|
||||
input_size(int): Size of time sequence
|
||||
hidden_size(int): the hidden size of LSTM layers
|
||||
output_size(int): the output size of the dense layer
|
||||
"""
|
||||
def __init__(self, batch_size, input_size, hidden_size, output_size):
|
||||
super(BidirectionalLSTM, self).__init__()
|
||||
self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True).to_float(mstype.float16)
|
||||
self.h, self.c = lstm_default_state(batch_size, hidden_size, bidirectional=True)
|
||||
self.embedding = nn.Dense(hidden_size * 2, output_size).to_float(mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, inputs):
|
||||
inputs = self.cast(inputs, mstype.float16)
|
||||
recurrent, _ = self.rnn(inputs, (self.h, self.c))
|
||||
T, b, h = self.shape(recurrent)
|
||||
t_rec = self.reshape(recurrent, (T * b, h))
|
||||
output = self.embedding(t_rec)
|
||||
output = self.reshape(output, (T, b, -1))
|
||||
return output
|
||||
|
||||
|
||||
class AttnDecoderRNN(nn.Cell):
|
||||
"""Attention Decoder Structure with a one-layer GRU
|
||||
|
||||
Args:
|
||||
hidden_size(int): the hidden size
|
||||
output_size(int): the output size
|
||||
max_length(iht): max time step of the decoder
|
||||
dropout_p(float): dropout probability, default is 0.1
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
|
||||
super(AttnDecoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.dropout_p = dropout_p
|
||||
self.max_length = max_length
|
||||
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
|
||||
self.attn = nn.Dense(in_channels=self.hidden_size * 2, out_channels=self.max_length).to_float(mstype.float16)
|
||||
self.attn_combine = nn.Dense(in_channels=self.hidden_size * 2,
|
||||
out_channels=self.hidden_size).to_float(mstype.float16)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p)
|
||||
self.gru = GRU(hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.concat1 = P.Concat(axis=1)
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.relu = P.ReLU()
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
self.bmm = P.BatchMatMul()
|
||||
self.unsqueeze = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(1)
|
||||
self.squeeze1 = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
embedded = self.embedding(inputs)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.dropout(embedded)
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
|
||||
embedded_concat = self.concat((embedded, hidden))
|
||||
embedded_concat = self.squeeze1(embedded_concat)
|
||||
attn_weights = self.softmax(self.attn(embedded_concat))
|
||||
attn_weights = self.unsqueeze(attn_weights, 1)
|
||||
perm_encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
attn_applied = self.bmm(attn_weights, perm_encoder_outputs)
|
||||
attn_applied = self.squeeze(attn_applied)
|
||||
embedded_squeeze = self.squeeze1(embedded)
|
||||
|
||||
output = self.concat1((embedded_squeeze, attn_applied))
|
||||
output = self.attn_combine(output)
|
||||
output = self.unsqueeze(output, 0)
|
||||
output = self.relu(output)
|
||||
|
||||
gru_hidden = self.squeeze1(hidden)
|
||||
output, hidden, _, _, _, _ = self.gru(output, gru_hidden)
|
||||
output = self.squeeze1(output)
|
||||
output = self.log_softmax(self.out(output))
|
||||
|
||||
return output, hidden, attn_weights
|
||||
|
||||
|
||||
class Encoder(nn.Cell):
|
||||
"""Encoder with a CNN and two BidirectionalLSTM layers
|
||||
|
||||
Args:
|
||||
batch_size(int): batch size of input data
|
||||
conv_out_dim(int): the output dimension of the cnn layer
|
||||
hidden_size(int): the hidden size of LSTM layers
|
||||
"""
|
||||
def __init__(self, batch_size, conv_out_dim, hidden_size):
|
||||
super(Encoder, self).__init__()
|
||||
self.cnn = CNN(int(conv_out_dim/4))
|
||||
self.lstm1 = BidirectionalLSTM(batch_size, conv_out_dim, hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.lstm2 = BidirectionalLSTM(batch_size, hidden_size, hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.split = P.Split(axis=3, output_num=4)
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, inputs):
|
||||
inputs = self.cast(inputs, mstype.float32)
|
||||
(x1, x2, x3, x4) = self.split(inputs)
|
||||
conv1 = self.cnn(x1)
|
||||
conv2 = self.cnn(x2)
|
||||
conv3 = self.cnn(x3)
|
||||
conv4 = self.cnn(x4)
|
||||
conv = self.concat((conv1, conv2, conv3, conv4))
|
||||
conv = self.transpose(conv, (2, 0, 1))
|
||||
output = self.lstm1(conv)
|
||||
output = self.lstm2(output)
|
||||
return output
|
||||
|
||||
|
||||
class Decoder(nn.Cell):
|
||||
"""Decoder
|
||||
|
||||
Args:
|
||||
hidden_size(int): the hidden size
|
||||
output_size(int): the output size
|
||||
max_length(iht): max time step of the decoder
|
||||
dropout_p(float): dropout probability, default is 0.1
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
|
||||
super(Decoder, self).__init__()
|
||||
self.decoder = AttnDecoderRNN(hidden_size, output_size, max_length, dropout_p)
|
||||
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
return self.decoder(inputs, hidden, encoder_outputs)
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import logging
|
||||
|
||||
|
||||
def initialize_vocabulary(vocabulary_path):
|
||||
"""
|
||||
initialize vocabulary from file.
|
||||
assume the vocabulary is stored one-item-per-line
|
||||
"""
|
||||
characters_class = 9999
|
||||
|
||||
if os.path.exists(vocabulary_path):
|
||||
rev_vocab = []
|
||||
with codecs.open(vocabulary_path, 'r', encoding='utf-8') as voc_file:
|
||||
rev_vocab = [line.strip() for line in voc_file]
|
||||
|
||||
vocab = {x: y for (y, x) in enumerate(rev_vocab)}
|
||||
|
||||
reserved_char_size = characters_class - len(rev_vocab)
|
||||
if reserved_char_size < 0:
|
||||
raise ValueError("Number of characters in vocabulary is equal or larger than config.characters_class")
|
||||
|
||||
for _ in range(reserved_char_size):
|
||||
rev_vocab.append('')
|
||||
|
||||
# put space at the last position
|
||||
vocab[' '] = len(rev_vocab)
|
||||
rev_vocab.append(' ')
|
||||
logging.info("Initializing vocabulary ends: %s", vocabulary_path)
|
||||
return vocab, rev_vocab
|
||||
|
||||
raise ValueError("Initializing vocabulary ends: %s" % vocabulary_path)
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
weights initialization
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
|
||||
def lstm_default_state(batch_size, hidden_size, bidirectional, num_layers=1):
|
||||
"""init default input."""
|
||||
num_directions = 2 if bidirectional else 1
|
||||
h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
return h, c
|
||||
|
||||
|
||||
def gru_default_state(input_size, hidden_size):
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
weight_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
|
||||
name='weight_i')
|
||||
weight_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
|
||||
name='weight_h')
|
||||
bias_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
|
||||
name='bias_i')
|
||||
bias_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
|
||||
name='bias_h')
|
||||
return weight_i, weight_h, bias_i, bias_h
|
@ -0,0 +1,158 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
CRNN-Seq2Seq-OCR train.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import create_ocr_train_dataset
|
||||
from src.logger import get_logger
|
||||
from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse train arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training')
|
||||
|
||||
# device related
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0,
|
||||
help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
|
||||
parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1')
|
||||
|
||||
#dataset related
|
||||
parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
|
||||
|
||||
parser.add_argument('--is_save_on_master', type=int, default=0,
|
||||
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
if args.is_distributed:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# Logger
|
||||
args.logger = get_logger(args.outputs_dir, rank)
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# DATASET
|
||||
dataset = create_ocr_train_dataset(args.mindrecord_file,
|
||||
config.batch_size,
|
||||
rank_size=device_num,
|
||||
rank_id=rank)
|
||||
args.steps_per_epoch = dataset.get_dataset_size()
|
||||
args.logger.info('Finish loading dataset')
|
||||
|
||||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
args.logger.save_args(args)
|
||||
|
||||
network = AttentionOCR(config.batch_size,
|
||||
int(config.img_width / 4),
|
||||
config.encoder_hidden_size,
|
||||
config.decoder_hidden_size,
|
||||
config.decoder_output_size,
|
||||
config.max_length,
|
||||
config.dropout_p)
|
||||
|
||||
if args.pre_checkpoint_path:
|
||||
param_dict = load_checkpoint(args.pre_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
network = AttentionOCRWithLossCell(network, config.max_length)
|
||||
|
||||
lr = Tensor(config.lr, mstype.float32)
|
||||
opt = nn.Adam(network.trainable_params(), lr, beta1=config.adam_beta1, beta2=config.adam_beta2,
|
||||
loss_scale=config.loss_scale)
|
||||
|
||||
network = TrainingWrapper(network, opt, sens=config.loss_scale)
|
||||
|
||||
args.logger.info('Finished get network')
|
||||
|
||||
callback = [TimeMonitor(data_size=1), LossMonitor()]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix="crnn_seq2seq_ocr")
|
||||
callback.append(ckpt_cb)
|
||||
|
||||
model = Model(network)
|
||||
model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False)
|
||||
|
||||
args.logger.info('==========Training Done===============')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in new issue