parent
f6450a614b
commit
830b8f3e93
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
from src.config import config1 as config
|
||||
from src.dataset import create_dataset_eval
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset_eval(args_opt.dataset_path + "/ocr_eval_pos.mindrecord", config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
print("step_size ", step_size)
|
||||
|
||||
# define net
|
||||
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss, model
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
|
||||
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy'})
|
||||
|
||||
# eval model
|
||||
res = model.eval(dataset, dataset_sink_mode=False)
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
@ -0,0 +1,5 @@
|
||||
mindspore
|
||||
numpy
|
||||
Pillow
|
||||
python-opencv
|
||||
scikit-image
|
@ -0,0 +1,88 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PATH3=$(get_real_path $3)
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 3 ] && [ ! -f $PATH3 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
|
||||
fi
|
||||
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=4
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 #&> log &
|
||||
|
||||
cd ..
|
@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=3
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
PATH2=$(get_real_path $2)
|
||||
fi
|
||||
|
||||
if [ $# == 2 ] && [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 1 ]
|
||||
then
|
||||
python train.py --dataset_path=$PATH1 &> log &
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
python train.py --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
|
||||
fi
|
||||
|
||||
cd ..
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,37 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config1 = ed({
|
||||
"batch_size": 8,
|
||||
"epoch_size": 5,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 20,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 5,
|
||||
"lr_decay_mode": "poly",
|
||||
"lr": 1e-4,
|
||||
"work_nums": 4,
|
||||
"im_size_w": 512,
|
||||
"im_size_h": 64,
|
||||
"pos_samples_size": 100,
|
||||
"augment_severity": 0.1,
|
||||
"augment_prob": 0.3
|
||||
})
|
@ -0,0 +1,246 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from src.dataset_utils import lucky, noise_blur, noise_speckle, noise_gamma, noise_gaussian, noise_salt_pepper, \
|
||||
shift_color, enhance_brightness, enhance_sharpness, enhance_contrast, enhance_color, gaussian_blur, \
|
||||
randcrop, resize, rdistort, rgeometry, rotate_about_center, whole_rdistort, warp_perspective, random_contrast, \
|
||||
unify_img_label
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
image_height = None
|
||||
image_width = None
|
||||
|
||||
|
||||
class Augmentor():
|
||||
"""
|
||||
Augment image with random noise and transformation
|
||||
|
||||
Controlled by severity level [0, 1]
|
||||
|
||||
Usage:
|
||||
augmentor = Augmentor(severity=0.3,
|
||||
prob=0.5,
|
||||
enable_transform=True,
|
||||
enable_crop=False)
|
||||
image_new = augmentor.process(image)
|
||||
"""
|
||||
|
||||
def __init__(self, severity, prob, enable_transform=True, enable_crop=False):
|
||||
"""
|
||||
severity: in [0, 1], from min to max level of noise/transformation
|
||||
prob: in [0, 1], probability to apply each operator
|
||||
enable_transform: enable all transformation operators
|
||||
enable_crop: enable crop operator
|
||||
"""
|
||||
self.severity = np.clip(severity, 0, 1)
|
||||
self.prob = np.clip(prob, 0, 1)
|
||||
self.enable_transform = enable_transform
|
||||
self.enable_crop = enable_crop
|
||||
|
||||
def add_noise(self, im):
|
||||
"""randomly add noise to image"""
|
||||
|
||||
severity = self.severity
|
||||
prob = self.prob
|
||||
|
||||
if lucky(prob):
|
||||
im = noise_gamma(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = noise_blur(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = noise_gaussian(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = noise_salt_pepper(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = shift_color(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = gaussian_blur(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = noise_speckle(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = enhance_sharpness(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = enhance_contrast(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = enhance_brightness(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = enhance_color(im, severity=severity)
|
||||
if lucky(prob):
|
||||
im = random_contrast(im)
|
||||
|
||||
return im
|
||||
|
||||
def convert_color(self, im, cval):
|
||||
if cval in ['median', 'md']:
|
||||
cval = np.median(im, axis=(0, 1)).astype(int)
|
||||
elif cval == 'mean':
|
||||
cval = np.mean(im, axis=(0, 1)).astype(int)
|
||||
if hasattr(cval, '__iter__'):
|
||||
cval = [int(i) for i in cval]
|
||||
else:
|
||||
cval = int(cval)
|
||||
return cval
|
||||
|
||||
def transform(self, im, cval=255, **kw):
|
||||
"""According to the parameters initialized by the class, deform the incoming image"""
|
||||
severity = self.severity
|
||||
prob = self.prob
|
||||
cval = self.convert_color(im, cval)
|
||||
if lucky(prob):
|
||||
# affine transform
|
||||
im = rgeometry(im, severity=severity, cval=cval)
|
||||
if lucky(prob):
|
||||
im = rdistort(im, severity=severity, cval=cval)
|
||||
if lucky(prob):
|
||||
im = warp_perspective(im, severity=severity, cval=cval)
|
||||
if lucky(prob):
|
||||
im = resize(im, fx=kw.get('fx'), fy=kw.get('fy'), severity=severity)
|
||||
if lucky(prob):
|
||||
im = rotate_about_center(im, severity=severity, cval=cval)
|
||||
if lucky(prob):
|
||||
# the overall distortion of the image.
|
||||
im = whole_rdistort(im, severity=severity)
|
||||
if lucky(prob) and self.enable_crop:
|
||||
# random crop
|
||||
im = randcrop(im, severity=severity)
|
||||
return im
|
||||
|
||||
def process(self, im, cval='median', **kw):
|
||||
""" Execute code according to the effect of initial setting, and support variable parameters"""
|
||||
if self.enable_transform:
|
||||
im = self.transform(im, cval=cval, **kw)
|
||||
im = self.add_noise(im)
|
||||
return im
|
||||
|
||||
|
||||
def rotate_and_set_neg(img, label):
|
||||
label = label - 1
|
||||
img_rotate = np.rot90(img)
|
||||
img_rotate = np.rot90(img_rotate)
|
||||
# return img_rotate, label
|
||||
return img_rotate, np.array(label).astype(np.int32)
|
||||
|
||||
|
||||
def rotate(img, label):
|
||||
img_rotate = np.rot90(img)
|
||||
img_rotate = np.rot90(img_rotate)
|
||||
return img_rotate, label
|
||||
|
||||
|
||||
def random_neg_with_rotate(img, label):
|
||||
if lucky(0.5):
|
||||
##50% of samples set to negative samples
|
||||
label = label - 1
|
||||
# rotate by 180 debgress
|
||||
img_rotate = np.rot90(img)
|
||||
img = np.rot90(img_rotate)
|
||||
return img, np.array(label).astype(np.int32)
|
||||
|
||||
|
||||
def transform_image(img, label):
|
||||
data = np.array([img[...]], np.float32)
|
||||
data = data / 127.5 - 1
|
||||
return data.transpose((0, 3, 1, 2))[0], label
|
||||
|
||||
|
||||
def create_dataset_train(mindrecord_file_pos, config):
|
||||
"""
|
||||
create a train dataset
|
||||
|
||||
Args:
|
||||
mindrecord_file_pos(string): mindrecord file for positive samples.
|
||||
config(dict): config of dataset.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
rank_size = int(os.getenv("RANK_SIZE", '1'))
|
||||
rank_id = int(os.getenv("RANK_ID", '0'))
|
||||
decode = C.Decode()
|
||||
|
||||
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=4,
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=True)
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
|
||||
|
||||
augmentor = Augmentor(config.augment_severity, config.augment_prob)
|
||||
operation = augmentor.process
|
||||
ds = ds.map(operations=operation, input_columns=["image"],
|
||||
num_parallel_workers=1, python_multiprocessing=True)
|
||||
##randomly augment half of samples to be negative samples
|
||||
ds = ds.map(operations=[random_neg_with_rotate, unify_img_label, transform_image], input_columns=["image", "label"],
|
||||
num_parallel_workers=8, python_multiprocessing=True)
|
||||
##for training double the dataset to accoun for positive and negative
|
||||
ds = ds.repeat(2)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(config.batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def resize_image(img, label):
|
||||
color_fill = 255
|
||||
scale = image_height / img.shape[0]
|
||||
img = cv2.resize(img, None, fx=scale, fy=scale)
|
||||
if img.shape[1] > image_width:
|
||||
img = img[:, 0:image_width]
|
||||
else:
|
||||
blank_img = np.zeros((image_height, image_width, 3), np.uint8)
|
||||
# fill the image with white
|
||||
blank_img.fill(color_fill)
|
||||
blank_img[:image_height, :img.shape[1]] = img
|
||||
img = blank_img
|
||||
data = np.array([img[...]], np.float32)
|
||||
data = data / 127.5 - 1
|
||||
return data.transpose((0, 3, 1, 2))[0], label
|
||||
|
||||
|
||||
def create_dataset_eval(mindrecord_file_pos, config):
|
||||
"""
|
||||
create an eval dataset
|
||||
|
||||
Args:
|
||||
mindrecord_file_pos(string): mindrecord file for positive samples.
|
||||
config(dict): config of dataset.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
rank_size = int(os.getenv("RANK_SIZE", '1'))
|
||||
rank_id = int(os.getenv("RANK_ID", '0'))
|
||||
decode = C.Decode()
|
||||
|
||||
ds = de.MindDataset(mindrecord_file_pos, columns_list=["image", "label"], num_parallel_workers=1,
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=8)
|
||||
|
||||
global image_height
|
||||
global image_width
|
||||
image_height = config.im_size_h
|
||||
image_width = config.im_size_w
|
||||
ds = ds.map(operations=resize_image, input_columns=["image", "label"], num_parallel_workers=config.work_nums,
|
||||
python_multiprocessing=False)
|
||||
# apply batch operations
|
||||
ds = ds.batch(1, drop_remainder=True)
|
||||
|
||||
return ds
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,108 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train CNN direction model."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
from src.cnn_direction_model import CNNDirectionModel
|
||||
from src.config import config1 as config
|
||||
from src.dataset import create_dataset_train
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.nn.optim.adam import Adam
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(11)
|
||||
np.random.seed(11)
|
||||
de.config.set_seed(11)
|
||||
ms.common.set_seed(11)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
target = args_opt.device_target
|
||||
ckpt_save_dir = config.save_checkpoint_path
|
||||
|
||||
# init context
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
rank_id = int(os.getenv('RANK_ID', '0'))
|
||||
rank_size = int(os.getenv('RANK_SIZE', '1'))
|
||||
run_distribute = rank_size > 1
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id, save_graphs=False)
|
||||
|
||||
print("train args: ", args_opt, "\ncfg: ", config,
|
||||
"\nparallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
|
||||
|
||||
if run_distribute:
|
||||
context.set_auto_parallel_context(device_num=rank_size, parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
init()
|
||||
|
||||
# create dataset
|
||||
dataset = create_dataset_train(args_opt.dataset_path + "/ocr_pos.mindrecord0", config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = CNNDirectionModel([3, 64, 48, 48, 64], [64, 48, 48, 64, 64], [256, 64], [64, 512])
|
||||
|
||||
# init weight
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = config.lr
|
||||
lr = Tensor(lr, ms.float32)
|
||||
|
||||
# define opt
|
||||
opt = Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
|
||||
|
||||
# define loss, model
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="sum")
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=2500,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="cnn_direction_model", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
# train model
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
|
Loading…
Reference in new issue