cnn direction model

pull/9330/head
avakh 4 years ago
parent f6450a614b
commit 830b8f3e93

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train resnet."""
import argparse
import os
import random
import numpy as np
from src.cnn_direction_model import CNNDirectionModel
from src.config import config1 as config
from src.dataset import create_dataset_eval
from mindspore import context
from mindspore import dataset as de
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
if __name__ == '__main__':
# init context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset_eval(args_opt.dataset_path + "/ocr_eval_pos.mindrecord", config=config)
step_size = dataset.get_dataset_size()
print("step_size ", step_size)
# define net
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
# eval model
res = model.eval(dataset, dataset_sink_mode=False)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

@ -0,0 +1,5 @@
mindspore
numpy
Pillow
python-opencv
scikit-image

@ -0,0 +1,88 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 3 ]
then
PATH3=$(get_real_path $3)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 3 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 2 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
fi
if [ $# == 3 ]
then
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
fi
cd ..
done

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH]"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
echo "start evaluation for device $DEVICE_ID"
env > env.log
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log &
cd ..

@ -0,0 +1,72 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=3
export RANK_ID=0
export RANK_SIZE=1
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 2 ]
then
PATH2=$(get_real_path $2)
fi
if [ $# == 2 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 1 ]
then
python train.py --dataset_path=$PATH1 &> log &
fi
if [ $# == 2 ]
then
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi
cd ..

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config1 = ed({
"batch_size": 8,
"epoch_size": 5,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr": 1e-4,
"work_nums": 4,
"im_size_w": 512,
"im_size_h": 64,
"pos_samples_size": 100,
"augment_severity": 0.1,
"augment_prob": 0.3
})

@ -0,0 +1,246 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
from src.dataset_utils import lucky, noise_blur, noise_speckle, noise_gamma, noise_gaussian, noise_salt_pepper, \
shift_color, enhance_brightness, enhance_sharpness, enhance_contrast, enhance_color, gaussian_blur, \
randcrop, resize, rdistort, rgeometry, rotate_about_center, whole_rdistort, warp_perspective, random_contrast, \
unify_img_label
import cv2
import numpy as np
cv2.setNumThreads(0)
image_height = None
image_width = None
class Augmentor():
"""
Augment image with random noise and transformation
Controlled by severity level [0, 1]
Usage:
augmentor = Augmentor(severity=0.3,
prob=0.5,
enable_transform=True,
enable_crop=False)
image_new = augmentor.process(image)
"""
def __init__(self, severity, prob, enable_transform=True, enable_crop=False):
"""
severity: in [0, 1], from min to max level of noise/transformation
prob: in [0, 1], probability to apply each operator
enable_transform: enable all transformation operators
enable_crop: enable crop operator
"""
self.severity = np.clip(severity, 0, 1)
self.prob = np.clip(prob, 0, 1)
self.enable_transform = enable_transform
self.enable_crop = enable_crop
def add_noise(self, im):
"""randomly add noise to image"""
severity = self.severity
prob = self.prob
if lucky(prob):
im = noise_gamma(im, severity=severity)
if lucky(prob):
im = noise_blur(im, severity=severity)
if lucky(prob):
im = noise_gaussian(im, severity=severity)
if lucky(prob):
im = noise_salt_pepper(im, severity=severity)
if lucky(prob):
im = shift_color(im, severity=severity)
if lucky(prob):
im = gaussian_blur(im, severity=severity)
if lucky(prob):
im = noise_speckle(im, severity=severity)
if lucky(prob):
im = enhance_sharpness(im, severity=severity)
if lucky(prob):
im = enhance_contrast(im, severity=severity)
if lucky(prob):
im = enhance_brightness(im, severity=severity)
if lucky(prob):
im = enhance_color(im, severity=severity)
if lucky(prob):
im = random_contrast(im)
return im
def convert_color(self, im, cval):
if cval in ['median', 'md']:
cval = np.median(im, axis=(0, 1)).astype(int)
elif cval == 'mean':
cval = np.mean(im, axis=(0, 1)).astype(int)
if hasattr(cval, '__iter__'):
cval = [int(i) for i in cval]
else:
cval = int(cval)
return cval
def transform(self, im, cval=255, **kw):
"""According to the parameters initialized by the class, deform the incoming image"""
severity = self.severity
prob = self.prob
cval = self.convert_color(im, cval)
if lucky(prob):
# affine transform
im = rgeometry(im, severity=severity, cval=cval)
if lucky(prob):
im = rdistort(im, severity=severity, cval=cval)
if lucky(prob):
im = warp_perspective(im, severity=severity, cval=cval)
if lucky(prob):
im = resize(im, fx=kw.get('fx'), fy=kw.get('fy'), severity=severity)
if lucky(prob):
im = rotate_about_center(im, severity=severity, cval=cval)
if lucky(prob):
# the overall distortion of the image.
im = whole_rdistort(im, severity=severity)
if lucky(prob) and self.enable_crop:
# random crop
im = randcrop(im, severity=severity)
return im
def process(self, im, cval='median', **kw):
""" Execute code according to the effect of initial setting, and support variable parameters"""
if self.enable_transform:
im = self.transform(im, cval=cval, **kw)
im = self.add_noise(im)
return im
def rotate_and_set_neg(img, label):
label = label - 1
img_rotate = np.rot90(img)
img_rotate = np.rot90(img_rotate)
# return img_rotate, label
return img_rotate, np.array(label).astype(np.int32)
def rotate(img, label):
img_rotate = np.rot90(img)
img_rotate = np.rot90(img_rotate)
return img_rotate, label
def random_neg_with_rotate(img, label):
if lucky(0.5):
##50% of samples set to negative samples
label = label - 1
# rotate by 180 debgress
img_rotate = np.rot90(img)
img = np.rot90(img_rotate)
return img, np.array(label).astype(np.int32)
def transform_image(img, label):
data = np.array([img[...]], np.float32)
data = data / 127.5 - 1
return data.transpose((0, 3, 1, 2))[0], label
def create_dataset_train(mindrecord_file_pos, config):
"""
create a train dataset
Args:
mindrecord_file_pos(string): mindrecord file for positive samples.
config(dict): config of dataset.
Returns:
dataset
"""
rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0'))
decode = C.Decode()
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=4,
num_shards=rank_size, shard_id=rank_id, shuffle=True)
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
augmentor = Augmentor(config.augment_severity, config.augment_prob)
operation = augmentor.process
ds = ds.map(operations=operation, input_columns=["image"],
num_parallel_workers=1, python_multiprocessing=True)
##randomly augment half of samples to be negative samples
ds = ds.map(operations=[random_neg_with_rotate, unify_img_label, transform_image], input_columns=["image", "label"],
num_parallel_workers=8, python_multiprocessing=True)
##for training double the dataset to accoun for positive and negative
ds = ds.repeat(2)
# apply batch operations
ds = ds.batch(config.batch_size, drop_remainder=True)
return ds
def resize_image(img, label):
color_fill = 255
scale = image_height / img.shape[0]
img = cv2.resize(img, None, fx=scale, fy=scale)
if img.shape[1] > image_width:
img = img[:, 0:image_width]
else:
blank_img = np.zeros((image_height, image_width, 3), np.uint8)
# fill the image with white
blank_img.fill(color_fill)
blank_img[:image_height, :img.shape[1]] = img
img = blank_img
data = np.array([img[...]], np.float32)
data = data / 127.5 - 1
return data.transpose((0, 3, 1, 2))[0], label
def create_dataset_eval(mindrecord_file_pos, config):
"""
create an eval dataset
Args:
mindrecord_file_pos(string): mindrecord file for positive samples.
config(dict): config of dataset.
Returns:
dataset
"""
rank_size = int(os.getenv("RANK_SIZE", '1'))
rank_id = int(os.getenv("RANK_ID", '0'))
decode = C.Decode()
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=1,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
global image_height
global image_width
image_height = config.im_size_h
image_width = config.im_size_w
ds = ds.map(operations=resize_image, input_columns=["image", "label"], num_parallel_workers=config.work_nums,
python_multiprocessing=False)
# apply batch operations
ds = ds.batch(1, drop_remainder=True)
return ds

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train CNN direction model."""
import argparse
import os
import random
from src.cnn_direction_model import CNNDirectionModel
from src.config import config1 as config
from src.dataset import create_dataset_train
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore import dataset as de
from mindspore.communication.management import init
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.metrics import Accuracy
from mindspore.nn.optim.adam import Adam
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
args_opt = parser.parse_args()
random.seed(11)
np.random.seed(11)
de.config.set_seed(11)
ms.common.set_seed(11)
if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
# init context
device_id = int(os.getenv('DEVICE_ID', '0'))
rank_id = int(os.getenv('RANK_ID', '0'))
rank_size = int(os.getenv('RANK_SIZE', '1'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
print("train args: ", args_opt, "\ncfg: ", config,
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
init()
# create dataset
dataset = create_dataset_train(args_opt.dataset_path + "/ocr_pos.mindrecord0", config=config)
step_size = dataset.get_dataset_size()
# define net
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
lr = config.lr
lr = Tensor(lr, ms.float32)
# define opt
opt = Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
# define loss, model
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
model = Model(net, loss_fn=loss, optimizer=opt, metrics={"Accuracy": Accuracy()})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=2500,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
Loading…
Cancel
Save