!8899 Add OpenPose network to modelzoo
From: @zhanghuiyao Reviewed-by: Signed-off-by:pull/8899/MERGE
commit
f4c126ddeb
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,38 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.openposenet import OpenPoseNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
|
||||
# define net
|
||||
net = OpenPoseNet()
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
inputs = np.random.uniform(0.0, 1.0, size=[1, 3, 368, 368]).astype(np.float32)
|
||||
export(net, Tensor(inputs), file_name="openpose.air", file_format='AIR')
|
@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
RANK_TABLE_FILE=$(get_real_path $1)
|
||||
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--train_dir train2017 \
|
||||
--group_size 8 \
|
||||
--train_ann person_keypoints_train2017.json > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
python eval.py \
|
||||
--model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt \
|
||||
--imgpath_val /data0/zhy/dataset/coco/val2017 \
|
||||
--ann /data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json \
|
||||
> eval.log 2>&1 &
|
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
cd ..
|
||||
python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 &
|
@ -0,0 +1,171 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from enum import IntEnum
|
||||
|
||||
class JointType(IntEnum):
|
||||
Nose = 0
|
||||
|
||||
Neck = 1
|
||||
|
||||
RightShoulder = 2
|
||||
|
||||
RightElbow = 3
|
||||
|
||||
RightHand = 4
|
||||
|
||||
LeftShoulder = 5
|
||||
|
||||
LeftElbow = 6
|
||||
|
||||
LeftHand = 7
|
||||
|
||||
RightWaist = 8
|
||||
|
||||
RightKnee = 9
|
||||
|
||||
RightFoot = 10
|
||||
|
||||
LeftWaist = 11
|
||||
|
||||
LeftKnee = 12
|
||||
|
||||
LeftFoot = 13
|
||||
|
||||
RightEye = 14
|
||||
|
||||
LeftEye = 15
|
||||
|
||||
RightEar = 16
|
||||
|
||||
LeftEar = 17
|
||||
|
||||
params = {
|
||||
# paths
|
||||
'data_dir': '/data0/zhy/dataset/coco',
|
||||
'vgg_path': '/data0/zhy/dataset/coco/vgg19-0-97_5004.ckpt',
|
||||
'save_model_path': './checkpoints/',
|
||||
'load_pretrain': False,
|
||||
'pretrained_model_path': "",
|
||||
# training params
|
||||
'batch_size': 10,
|
||||
|
||||
'lr': 1e-4,
|
||||
'lr_gamma': 0.1,
|
||||
'lr_steps': '100000,200000,250000',
|
||||
'lr_steps_NP': '250000',
|
||||
|
||||
'loss_scale': 16386,
|
||||
'max_epoch_train': 60,
|
||||
'min_keypoints': 5,
|
||||
'min_area': 32 * 32,
|
||||
'insize': 368,
|
||||
'downscale': 8,
|
||||
'paf_sigma': 8,
|
||||
'heatmap_sigma': 7,
|
||||
'eva_num': 100,
|
||||
'keep_checkpoint_max': 5,
|
||||
'log_interval': 100,
|
||||
'ckpt_interval': 663, # 5000,
|
||||
|
||||
'min_box_size': 64,
|
||||
'max_box_size': 512,
|
||||
'min_scale': 0.5,
|
||||
'max_scale': 2.0,
|
||||
'max_rotate_degree': 40,
|
||||
'center_perterb_max': 40,
|
||||
|
||||
# inference params
|
||||
'inference_img_size': 368,
|
||||
'inference_scales': [0.5, 1, 1.5, 2],
|
||||
# 'inference_scales': [1.0],
|
||||
'heatmap_size': 320,
|
||||
'gaussian_sigma': 2.5,
|
||||
'ksize': 17,
|
||||
'n_integ_points': 10,
|
||||
'n_integ_points_thresh': 8,
|
||||
'heatmap_peak_thresh': 0.05,
|
||||
'inner_product_thresh': 0.05,
|
||||
'limb_length_ratio': 1.0,
|
||||
'length_penalty_value': 1,
|
||||
'n_subset_limbs_thresh': 3,
|
||||
'subset_score_thresh': 0.2,
|
||||
'limbs_point': [
|
||||
[JointType.Neck, JointType.RightWaist],
|
||||
[JointType.RightWaist, JointType.RightKnee],
|
||||
[JointType.RightKnee, JointType.RightFoot],
|
||||
[JointType.Neck, JointType.LeftWaist],
|
||||
[JointType.LeftWaist, JointType.LeftKnee],
|
||||
[JointType.LeftKnee, JointType.LeftFoot],
|
||||
[JointType.Neck, JointType.RightShoulder],
|
||||
[JointType.RightShoulder, JointType.RightElbow],
|
||||
[JointType.RightElbow, JointType.RightHand],
|
||||
[JointType.RightShoulder, JointType.RightEar],
|
||||
[JointType.Neck, JointType.LeftShoulder],
|
||||
[JointType.LeftShoulder, JointType.LeftElbow],
|
||||
[JointType.LeftElbow, JointType.LeftHand],
|
||||
[JointType.LeftShoulder, JointType.LeftEar],
|
||||
[JointType.Neck, JointType.Nose],
|
||||
[JointType.Nose, JointType.RightEye],
|
||||
[JointType.Nose, JointType.LeftEye],
|
||||
[JointType.RightEye, JointType.RightEar],
|
||||
[JointType.LeftEye, JointType.LeftEar]
|
||||
],
|
||||
'joint_indices': [
|
||||
JointType.Nose,
|
||||
JointType.LeftEye,
|
||||
JointType.RightEye,
|
||||
JointType.LeftEar,
|
||||
JointType.RightEar,
|
||||
JointType.LeftShoulder,
|
||||
JointType.RightShoulder,
|
||||
JointType.LeftElbow,
|
||||
JointType.RightElbow,
|
||||
JointType.LeftHand,
|
||||
JointType.RightHand,
|
||||
JointType.LeftWaist,
|
||||
JointType.RightWaist,
|
||||
JointType.LeftKnee,
|
||||
JointType.RightKnee,
|
||||
JointType.LeftFoot,
|
||||
JointType.RightFoot
|
||||
],
|
||||
|
||||
# face params
|
||||
'face_inference_img_size': 368,
|
||||
'face_heatmap_peak_thresh': 0.1,
|
||||
'face_crop_scale': 1.5,
|
||||
'face_line_indices': [
|
||||
[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], # 轮廓
|
||||
[17, 18], [18, 19], [19, 20], [20, 21],
|
||||
[22, 23], [23, 24], [24, 25], [25, 26],
|
||||
[27, 28], [28, 29], [29, 30],
|
||||
[31, 32], [32, 33], [33, 34], [34, 35],
|
||||
[36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36],
|
||||
[42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42],
|
||||
[48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], # 唇外廓
|
||||
[60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], [66, 67], [67, 60]
|
||||
],
|
||||
|
||||
# hand params
|
||||
'hand_inference_img_size': 368,
|
||||
'hand_heatmap_peak_thresh': 0.1,
|
||||
'fingers_indices': [
|
||||
[[0, 1], [1, 2], [2, 3], [3, 4]],
|
||||
[[0, 5], [5, 6], [6, 7], [7, 8]],
|
||||
[[0, 9], [9, 10], [10, 11], [11, 12]],
|
||||
[[0, 13], [13, 14], [14, 15], [15, 16]],
|
||||
[[0, 17], [17, 18], [18, 19], [19, 20]],
|
||||
],
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,133 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pycocotools.coco import COCO as ReadJson
|
||||
|
||||
from config import params
|
||||
|
||||
class DataLoader():
|
||||
def __init__(self, coco, dir_name, data_mode='train'):
|
||||
self.train = coco
|
||||
self.dir_name = dir_name
|
||||
assert data_mode in ['train', 'val'], 'Data loading mode is invalid.'
|
||||
self.mode = data_mode
|
||||
self.catIds = coco.getCatIds() # catNms=['person']
|
||||
self.imgIds = sorted(coco.getImgIds(catIds=self.catIds))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgIds)
|
||||
|
||||
def gen_masks(self, image, anns):
|
||||
_mask_all = np.zeros(image.shape[:2], 'bool')
|
||||
_mask_miss = np.zeros(image.shape[:2], 'bool')
|
||||
for ann in anns:
|
||||
mask = self.train.annToMask(ann).astype('bool')
|
||||
if ann['iscrowd'] == 1:
|
||||
intxn = _mask_all & mask
|
||||
_mask_miss = np.bitwise_or(_mask_miss.astype(int), np.subtract(mask, intxn, dtype=np.int32))
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']:
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
_mask_miss = np.bitwise_or(_mask_miss.astype(int), mask.astype(int))
|
||||
else:
|
||||
_mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int))
|
||||
return _mask_all, _mask_miss
|
||||
|
||||
def dwaw_gen_masks(self, image, mask, color=(0, 0, 1)):
|
||||
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
||||
mskd = image * bimsk.astype(np.int32)
|
||||
clmsk = np.ones(bimsk.shape) * bimsk
|
||||
for ind in range(3):
|
||||
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
|
||||
image = image + 0.7 * clmsk - 0.7 * mskd
|
||||
return image.astype(np.uint8)
|
||||
|
||||
def draw_masks_and_keypoints(self, image, anns):
|
||||
for ann in anns:
|
||||
# masks
|
||||
mask = self.train.annToMask(ann).astype(np.uint8)
|
||||
if ann['iscrowd'] == 1:
|
||||
color = (0, 0, 1)
|
||||
elif ann['num_keypoints'] == 0:
|
||||
color = (0, 1, 0)
|
||||
else:
|
||||
color = (1, 0, 0)
|
||||
bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
||||
mskd = image * bimsk.astype(np.int32)
|
||||
clmsk = np.ones(bimsk.shape) * bimsk
|
||||
for ind in range(3):
|
||||
clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255
|
||||
image = image + 0.7 * clmsk - 0.7 * mskd
|
||||
|
||||
# keypoints
|
||||
for x, y, v in np.array(ann['keypoints']).reshape(-1, 3):
|
||||
if v == 1:
|
||||
cv2.circle(image, (x, y), 3, (255, 255, 0), -1)
|
||||
elif v == 2:
|
||||
cv2.circle(image, (x, y), 3, (255, 0, 255), -1)
|
||||
return image.astype(np.uint8)
|
||||
|
||||
def get_img_annotation(self, ind=None, image_id=None):
|
||||
if ind is not None:
|
||||
image_id = self.imgIds[ind]
|
||||
|
||||
anno_ids = self.train.getAnnIds(imgIds=[image_id])
|
||||
_annotations = self.train.loadAnns(anno_ids)
|
||||
|
||||
img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([image_id])[0]['file_name'])
|
||||
_image = cv2.imread(img_file)
|
||||
return _image, _annotations, image_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--vis', action='store_true', help='visualize annotations and ignore masks')
|
||||
parser.add_argument('--train_ann', type=str, help='train annotations json')
|
||||
parser.add_argument('--val_ann', type=str, help='val annotations json')
|
||||
parser.add_argument('--train_dir', type=str, help='name of train dir')
|
||||
parser.add_argument('--val_dir', type=str, help='name of val dir')
|
||||
args = parser.parse_args()
|
||||
path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir]
|
||||
for index, mode in enumerate(['train', 'val']):
|
||||
train = ReadJson(path_list[index])
|
||||
data_loader = DataLoader(train, path_list[index+2], mode=mode)
|
||||
|
||||
save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
for i in tqdm(range(len(data_loader))):
|
||||
img, annotations, img_id = data_loader.get_img_annotation(ind=i)
|
||||
mask_all, mask_miss = data_loader.gen_masks(img, annotations)
|
||||
|
||||
if args.vis:
|
||||
ann_img = data_loader.draw_masks_and_keypoints(img, annotations)
|
||||
msk_img = data_loader.dwaw_gen_masks(img, mask_miss)
|
||||
cv2.imshow('image', np.hstack((ann_img, msk_img)))
|
||||
k = cv2.waitKey()
|
||||
if k == ord('q'):
|
||||
break
|
||||
elif k == ord('s'):
|
||||
cv2.imwrite('aaa.png', np.hstack((ann_img, msk_img)))
|
||||
|
||||
if np.any(mask_miss) and not args.vis:
|
||||
mask_miss = mask_miss.astype(np.uint8) * 255
|
||||
save_path = os.path.join(save_dir, '{:012d}.png'.format(img_id))
|
||||
cv2.imwrite(save_path, mask_miss)
|
@ -0,0 +1,207 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import time
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * F.cast(reciprocal(scale), F.dtype(grad))
|
||||
|
||||
@grad_scale.register("Tensor", "RowTensor")
|
||||
def tensor_grad_scale_row_tensor(scale, grad):
|
||||
return RowTensor(grad.indices,
|
||||
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
|
||||
grad.dense_shape)
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
class openpose_loss(_Loss):
|
||||
def __init__(self):
|
||||
super(openpose_loss, self).__init__()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.tile = P.Tile()
|
||||
self.mul = P.Mul()
|
||||
self.l2_loss = P.L2Loss()
|
||||
self.square = P.Square()
|
||||
self.reduceMean = P.ReduceMean()
|
||||
self.reduceSum = P.ReduceSum()
|
||||
self.print = P.Print()
|
||||
self.shape = P.Shape()
|
||||
self.maxoftensor = P.ArgMaxWithValue(-1)
|
||||
|
||||
def mean_square_error(self, map1, map2, mask=None):
|
||||
# print("mask", mask)
|
||||
# import pdb; pdb.set_trace()
|
||||
if mask is None:
|
||||
mse = self.reduceMean((map1 - map2) ** 2)
|
||||
return mse
|
||||
|
||||
squareMap = self.square(map1 - map2)
|
||||
squareMap_mask = self.mul(squareMap, mask)
|
||||
mse = self.reduceMean(squareMap_mask)
|
||||
return mse
|
||||
|
||||
def construct(self, logit_paf, logit_heatmap, gt_paf, gt_heatmap, ignore_mask):
|
||||
# Input
|
||||
# ignore_mask, make sure the ignore_mask the 0-1 array instead of the bool-false array
|
||||
heatmaps_loss = []
|
||||
pafs_loss = []
|
||||
total_loss = 0
|
||||
|
||||
paf_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_paf)[1], 1, 1))
|
||||
heatmap_masks = self.tile(self.expand_dims(ignore_mask, 1), (1, self.shape(gt_heatmap)[1], 1, 1))
|
||||
|
||||
paf_masks = F.stop_gradient(paf_masks)
|
||||
heatmap_masks = F.stop_gradient(heatmap_masks)
|
||||
for logit_paf_t, logit_heatmap_t in zip(logit_paf, logit_heatmap):
|
||||
# TEST
|
||||
# tensor1 -- tuple
|
||||
# tensor1 = self.maxoftensor(logit_paf_t)[1]
|
||||
# tensor2 = self.maxoftensor(logit_heatmap_t)[1]
|
||||
# tensor3 = self.maxoftensor(tensor1)[1]
|
||||
# tensor4 = self.maxoftensor(tensor2)[1]
|
||||
# self.print("paf",tensor3)
|
||||
# self.print("heatmaps",tensor2)
|
||||
pafs_loss_t = self.mean_square_error(logit_paf_t, gt_paf, paf_masks)
|
||||
heatmaps_loss_t = self.mean_square_error(logit_heatmap_t, gt_heatmap, heatmap_masks)
|
||||
|
||||
total_loss += pafs_loss_t + heatmaps_loss_t
|
||||
heatmaps_loss.append(heatmaps_loss_t)
|
||||
pafs_loss.append(pafs_loss_t)
|
||||
|
||||
return total_loss, heatmaps_loss, pafs_loss
|
||||
|
||||
class Depend_network(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Depend_network, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, *args):
|
||||
loss, _, _ = self.network(*args) # loss, heatmaps_loss, pafs_loss
|
||||
return loss
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.depend_network = Depend_network(network)
|
||||
# self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.print = P.Print()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
#if mean.get_device_num_is_set():
|
||||
# if mean:
|
||||
#degree = context.get_auto_parallel_context("device_num")
|
||||
# else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss, heatmaps_loss, pafs_loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
#grads = self.grad(self.network, weights)(*args, sens)
|
||||
grads = self.grad(self.depend_network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
#return F.depend(loss, self.optimizer(grads))
|
||||
# for grad in grads:
|
||||
# self.print(grad)
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
return loss, heatmaps_loss, pafs_loss
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
|
||||
def construct(self, input_data, gt_paf, gt_heatmap, mask):
|
||||
logit_pafs, logit_heatmap = self.network(input_data)
|
||||
loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask)
|
||||
return loss
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
||||
|
||||
global time_stamp_init, time_stamp_first
|
||||
if not time_stamp_init:
|
||||
time_stamp_first = time.time()
|
||||
time_stamp_init = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
|
||||
self.count += 1
|
||||
self.loss_sum += float(loss)
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if self.count >= 1:
|
||||
global time_stamp_first
|
||||
time_stamp_current = time.time()
|
||||
|
||||
loss = self.loss_sum/self.count
|
||||
|
||||
loss_file = open("./loss.log", "a+")
|
||||
loss_file.write("%lu epoch: %s step: %s ,loss: %.5f" %
|
||||
(time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
self.count = 0
|
||||
self.loss_sum = 0
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,157 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
from src.config import params
|
||||
|
||||
class MyLossMonitor(LossMonitor):
|
||||
def __init__(self, per_print_times=1):
|
||||
super(MyLossMonitor, self).__init__()
|
||||
self._per_print_times = per_print_times
|
||||
self._start_time = time.time()
|
||||
self._loss_list = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
# print("epoch: %s step: %s, loss is %s, step time: %.3f s." % (cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
# loss,
|
||||
# (time.time() - self._start_time)), flush=True)
|
||||
self._loss_list.append(loss)
|
||||
if cb_params.cur_step_num % 100 == 0:
|
||||
print("epoch: %s, steps: [%s] mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
np.array(self._loss_list).mean()), flush=True)
|
||||
self._loss_list = []
|
||||
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse train arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore openpose training')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir')
|
||||
parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json',
|
||||
help='train annotations json')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann)
|
||||
args.imgpath_train = os.path.join(params['data_dir'], args.train_dir)
|
||||
args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train')
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size):
|
||||
lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f')
|
||||
for step in lr_steps:
|
||||
step //= group_size
|
||||
lr_stage[step:] *= lr_gamma
|
||||
|
||||
lr_base = lr_stage.copy()
|
||||
lr_base = lr_base / 4
|
||||
|
||||
lr_vgg = lr_base.copy()
|
||||
vgg_freeze_step = 2000
|
||||
lr_vgg[:vgg_freeze_step] = 0
|
||||
return lr_stage, lr_base, lr_vgg
|
||||
|
||||
# zhang add
|
||||
def adjust_learning_rate(init_lr, lr_gamma, steps_per_epoch, max_epoch_train, stepvalues):
|
||||
lr_stage = np.array([init_lr] * steps_per_epoch * max_epoch_train).astype('f')
|
||||
for epoch in stepvalues:
|
||||
lr_stage[epoch * steps_per_epoch:] *= lr_gamma
|
||||
|
||||
lr_base = lr_stage.copy()
|
||||
lr_base = lr_base / 4
|
||||
|
||||
lr_vgg = lr_base.copy()
|
||||
vgg_freeze_step = 2000
|
||||
lr_vgg[:vgg_freeze_step] = 0
|
||||
return lr_stage, lr_base, lr_vgg
|
||||
|
||||
|
||||
def load_model(test_net, model_path):
|
||||
if model_path:
|
||||
param_dict = load_checkpoint(model_path)
|
||||
# print(type(param_dict))
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
# print('key:', key)
|
||||
if key.startswith('moment'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
|
||||
# else:
|
||||
# param_dict_new[key] = values
|
||||
load_param_into_net(test_net, param_dict_new)
|
||||
|
||||
|
||||
class show_loss_list():
|
||||
def __init__(self, name):
|
||||
self.loss_list = np.zeros(6).astype('f')
|
||||
self.sums = 0
|
||||
self.name = name
|
||||
|
||||
def add(self, list_of_tensor):
|
||||
self.sums += 1
|
||||
for i, loss_tensor in enumerate(list_of_tensor):
|
||||
self.loss_list[i] += loss_tensor.asnumpy()
|
||||
|
||||
def show(self):
|
||||
print(self.name + ' stage_loss:', self.loss_list / (self.sums + 1e-8), flush=True)
|
||||
self.loss_list = np.zeros(6).astype('f')
|
||||
self.sums = 0
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
def __init__(self):
|
||||
self.loss = 0
|
||||
self.sum = 0
|
||||
|
||||
def add(self, tensor):
|
||||
self.sum += 1
|
||||
self.loss += tensor.asnumpy()
|
||||
|
||||
def meter(self):
|
||||
avergeLoss = self.loss / (self.sum + 1e-8)
|
||||
self.loss = 0
|
||||
self.sum = 0
|
||||
return avergeLoss
|
@ -0,0 +1,124 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.nn.optim import Adam
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.openposenet import OpenPoseNet
|
||||
from src.loss import openpose_loss, BuildTrainNetwork
|
||||
from src.config import params
|
||||
from src.utils import parse_args, get_lr, load_model, MyLossMonitor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
|
||||
args.outputs_dir = params['save_model_path']
|
||||
|
||||
if args.group_size > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_{}/".format(str(get_rank())))
|
||||
args.rank = get_rank()
|
||||
else:
|
||||
args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/")
|
||||
args.rank = 0
|
||||
|
||||
# with out loss_scale
|
||||
if args.group_size > 1:
|
||||
args.loss_scale = params['loss_scale'] / 2
|
||||
args.lr_steps = list(map(int, params["lr_steps_NP"].split(',')))
|
||||
else:
|
||||
args.loss_scale = params['loss_scale']
|
||||
args.lr_steps = list(map(int, params["lr_steps"].split(',')))
|
||||
|
||||
# create network
|
||||
print('start create network')
|
||||
criterion = openpose_loss()
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
network = OpenPoseNet(vggpath=params['vgg_path'])
|
||||
# network.add_flags_recursive(fp32=True)
|
||||
|
||||
if params["load_pretrain"]:
|
||||
print("load pretrain model:", params["pretrained_model_path"])
|
||||
load_model(network, params["pretrained_model_path"])
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# create dataset
|
||||
if os.path.exists(args.jsonpath_train) and os.path.exists(args.imgpath_train) \
|
||||
and os.path.exists(args.maskpath_train):
|
||||
print('start create dataset')
|
||||
else:
|
||||
print('Error: wrong data path')
|
||||
|
||||
|
||||
num_worker = 20 if args.group_size > 1 else 48
|
||||
de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train,
|
||||
batch_size=params['batch_size'],
|
||||
rank=args.rank,
|
||||
group_size=args.group_size,
|
||||
num_worker=num_worker,
|
||||
multiprocessing=True,
|
||||
shuffle=True,
|
||||
repeat_num=1)
|
||||
steps_per_epoch = de_dataset_train.get_dataset_size()
|
||||
print("steps_per_epoch: ", steps_per_epoch)
|
||||
|
||||
# lr scheduler
|
||||
lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size,
|
||||
params['lr_gamma'],
|
||||
steps_per_epoch,
|
||||
params["max_epoch_train"],
|
||||
args.lr_steps,
|
||||
args.group_size)
|
||||
vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params()))
|
||||
base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params()))
|
||||
stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params()))
|
||||
|
||||
group_params = [{'params': vgg19_base_params, 'lr': lr_vgg},
|
||||
{'params': base_params, 'lr': lr_base},
|
||||
{'params': stages_params, 'lr': lr_stage}]
|
||||
|
||||
opt = Adam(group_params, loss_scale=args.loss_scale)
|
||||
|
||||
train_net.set_train(True)
|
||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager)
|
||||
|
||||
params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval'])
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'],
|
||||
keep_checkpoint_max=params["keep_checkpoint_max"])
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck)
|
||||
time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size())
|
||||
callback_list = [MyLossMonitor(), time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list,
|
||||
dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# mindspore.common.seed.set_seed(1)
|
||||
train()
|
Loading…
Reference in new issue