!10530 Add simple-pose-net to model_zoo
From: @rmdyh Reviewed-by: @linqingke,@oacjiewen Signed-off-by:pull/10530/MERGE
commit
693fbf0dcf
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,180 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
from mindspore import Tensor, float32, context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import flip_pairs, keypoint_dataset
|
||||
from src.evaluate.coco_eval import evaluate
|
||||
from src.model import get_pose_net
|
||||
from src.utils.transform import flip_back
|
||||
from src.predict import get_final_preds
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train keypoints network')
|
||||
parser.add_argument("--train_url", type=str, default="", help="")
|
||||
parser.add_argument("--data_url", type=str, default="", help="data")
|
||||
# output
|
||||
parser.add_argument('--output-url',
|
||||
help='output dir',
|
||||
type=str)
|
||||
# training
|
||||
parser.add_argument('--workers',
|
||||
help='num of dataloader workers',
|
||||
default=8,
|
||||
type=int)
|
||||
parser.add_argument('--model-file',
|
||||
help='model state file',
|
||||
type=str)
|
||||
parser.add_argument('--use-detect-bbox',
|
||||
help='use detect bbox',
|
||||
action='store_true')
|
||||
parser.add_argument('--flip-test',
|
||||
help='use flip test',
|
||||
default=True,
|
||||
action='store_true')
|
||||
parser.add_argument('--post-process',
|
||||
help='use post process',
|
||||
action='store_true')
|
||||
parser.add_argument('--shift-heatmap',
|
||||
help='shift heatmap',
|
||||
action='store_true')
|
||||
parser.add_argument('--coco-bbox-file',
|
||||
help='coco detection bbox file',
|
||||
type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def reset_config(cfg, args):
|
||||
if args.use_detect_bbox:
|
||||
cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
|
||||
if args.flip_test:
|
||||
cfg.TEST.FLIP_TEST = args.flip_test
|
||||
print('use flip test:', cfg.TEST.FLIP_TEST)
|
||||
if args.post_process:
|
||||
cfg.TEST.POST_PROCESS = args.post_process
|
||||
if args.shift_heatmap:
|
||||
cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
|
||||
if args.model_file:
|
||||
cfg.TEST.MODEL_FILE = args.model_file
|
||||
if args.coco_bbox_file:
|
||||
cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
|
||||
|
||||
|
||||
def validate(cfg, val_dataset, model, output_dir):
|
||||
# switch to evaluate mode
|
||||
model.set_train(False)
|
||||
|
||||
# init record
|
||||
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
|
||||
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
|
||||
dtype=np.float32)
|
||||
all_boxes = np.zeros((num_samples, 2))
|
||||
image_id = []
|
||||
idx = 0
|
||||
|
||||
# start eval
|
||||
start = time.time()
|
||||
for item in val_dataset.create_dict_iterator():
|
||||
# input data
|
||||
inputs = item['image'].asnumpy()
|
||||
# compute output
|
||||
output = model(Tensor(inputs, float32)).asnumpy()
|
||||
if cfg.TEST.FLIP_TEST:
|
||||
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
|
||||
output_flipped = model(inputs_flipped)
|
||||
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
|
||||
|
||||
# feature is not aligned, shift flipped heatmap for higher accuracy
|
||||
if cfg.TEST.SHIFT_HEATMAP:
|
||||
output_flipped[:, :, :, 1:] = \
|
||||
output_flipped.copy()[:, :, :, 0:-1]
|
||||
# output_flipped[:, :, :, 0] = 0
|
||||
|
||||
output = (output + output_flipped) * 0.5
|
||||
|
||||
# meta data
|
||||
c = item['center'].asnumpy()
|
||||
s = item['scale'].asnumpy()
|
||||
score = item['score'].asnumpy()
|
||||
file_id = list(item['id'].asnumpy())
|
||||
|
||||
# pred by heatmaps
|
||||
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
|
||||
num_images, _ = preds.shape[:2]
|
||||
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
|
||||
all_preds[idx:idx + num_images, :, 2:3] = maxvals
|
||||
# double check this all_boxes parts
|
||||
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
|
||||
all_boxes[idx:idx + num_images, 1] = score
|
||||
image_id.extend(file_id)
|
||||
idx += num_images
|
||||
if idx % 1024 == 0:
|
||||
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
|
||||
start = time.time()
|
||||
|
||||
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
|
||||
_, perf_indicator = evaluate(
|
||||
cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id)
|
||||
print("AP:", perf_indicator)
|
||||
return perf_indicator
|
||||
|
||||
|
||||
def main():
|
||||
# init seed
|
||||
set_seed(1)
|
||||
|
||||
# set context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
args = parse_args()
|
||||
# update config
|
||||
reset_config(config, args)
|
||||
|
||||
# init model
|
||||
model = get_pose_net(config, is_train=False)
|
||||
|
||||
# load parameters
|
||||
ckpt_name = config.TEST.MODEL_FILE
|
||||
print('loading model ckpt from {}'.format(ckpt_name))
|
||||
load_param_into_net(model, load_checkpoint(ckpt_name))
|
||||
|
||||
# Data loading code
|
||||
valid_dataset, _ = keypoint_dataset(
|
||||
config,
|
||||
bbox_file=config.TEST.COCO_BBOX_FILE,
|
||||
train_mode=False,
|
||||
num_parallel_workers=args.workers,
|
||||
)
|
||||
|
||||
# evaluate on validation set
|
||||
validate(config, valid_dataset, model, ckpt_name.split('.')[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
|
||||
python eval.py > eval_log$1.txt 2>&1 &
|
@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [SAVE_CKPT_PATH] [RANK_SIZE]
|
||||
|
||||
export RANK_TABLE_FILE=$1
|
||||
echo "RANK_TABLE_FILE=$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=$3
|
||||
SAVE_PATH=$2
|
||||
|
||||
device=(0 1 2 3)
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=${device[$i]}
|
||||
export RANK_ID=$i
|
||||
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
cd ../
|
||||
python train.py \
|
||||
--run-distribute \
|
||||
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &
|
||||
|
||||
echo "python train.py \
|
||||
--run-distribute \
|
||||
--ckpt-path=$SAVE_PATH > train_parallel$i/log.txt 2>&1 &"
|
||||
|
||||
done
|
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# Usage: train_standalone.sh [DEVICE_ID] [SAVE_CKPT_PATH]
|
||||
export DEVICE_ID=$1
|
||||
|
||||
python train.py \
|
||||
--ckpt-path=$2 --batch-size=128\
|
||||
> train_log$1.txt 2>&1 &
|
||||
echo " python train.py --ckpt-path=$2 --batch-size=128 > train_log$1.txt 2>&1 &"
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
|
||||
# pose_resnet related params
|
||||
POSE_RESNET = edict()
|
||||
POSE_RESNET.NUM_LAYERS = 50
|
||||
POSE_RESNET.DECONV_WITH_BIAS = False
|
||||
POSE_RESNET.NUM_DECONV_LAYERS = 3
|
||||
POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
|
||||
POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
|
||||
POSE_RESNET.FINAL_CONV_KERNEL = 1
|
||||
POSE_RESNET.TARGET_TYPE = 'gaussian'
|
||||
POSE_RESNET.HEATMAP_SIZE = [48, 64] # width * height, ex: 24 * 32
|
||||
POSE_RESNET.SIGMA = 2
|
||||
|
||||
MODEL_EXTRAS = {
|
||||
'pose_resnet': POSE_RESNET,
|
||||
}
|
||||
|
||||
# common params for NETWORK
|
||||
config.MODEL = edict()
|
||||
config.MODEL.NAME = 'pose_resnet'
|
||||
config.MODEL.INIT_WEIGHTS = True
|
||||
config.MODEL.PRETRAINED = './models/resnet50.ckpt'
|
||||
config.MODEL.NUM_JOINTS = 17
|
||||
config.MODEL.IMAGE_SIZE = [192, 256] # width * height, ex: 192 * 256
|
||||
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
|
||||
|
||||
# dataset
|
||||
config.DATASET = edict()
|
||||
config.DATASET.ROOT = '/data/coco2017/'
|
||||
config.DATASET.TEST_SET = 'val2017'
|
||||
config.DATASET.TRAIN_SET = 'train2017'
|
||||
# data augmentation
|
||||
config.DATASET.FLIP = True
|
||||
config.DATASET.ROT_FACTOR = 40
|
||||
config.DATASET.SCALE_FACTOR = 0.3
|
||||
|
||||
# for train
|
||||
config.TRAIN = edict()
|
||||
config.TRAIN.BATCH_SIZE = 64
|
||||
config.TRAIN.BEGIN_EPOCH = 0
|
||||
config.TRAIN.END_EPOCH = 140
|
||||
config.TRAIN.LR = 0.001
|
||||
config.TRAIN.LR_FACTOR = 0.1
|
||||
config.TRAIN.LR_STEP = [90, 120]
|
||||
|
||||
# test
|
||||
config.TEST = edict()
|
||||
config.TEST.BATCH_SIZE = 32
|
||||
config.TEST.FLIP_TEST = True
|
||||
config.TEST.POST_PROCESS = True
|
||||
config.TEST.SHIFT_HEATMAP = True
|
||||
config.TEST.USE_GT_BBOX = False
|
||||
config.TEST.MODEL_FILE = ''
|
||||
config.TEST.COCO_BBOX_FILE = 'experiments/COCO_val2017_detections_AP_H_56_person.json'
|
||||
# nms
|
||||
config.TEST.OKS_THRE = 0.9
|
||||
config.TEST.IN_VIS_THRE = 0.2
|
||||
config.TEST.BBOX_THRE = 1.0
|
||||
config.TEST.IMAGE_THRE = 0.0
|
||||
config.TEST.NMS_THRE = 1.0
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,132 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
has_coco = True
|
||||
except ImportError:
|
||||
has_coco = False
|
||||
|
||||
from src.utils.nms import oks_nms
|
||||
|
||||
|
||||
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
|
||||
results = []
|
||||
|
||||
for img, items in img_kpts.items():
|
||||
item_size = len(items)
|
||||
if not items:
|
||||
continue
|
||||
|
||||
# keypoints array at coco format
|
||||
kpts = np.array([items[k]['keypoints']
|
||||
for k in range(item_size)])
|
||||
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
|
||||
keypoints[:, 0::3] = kpts[:, :, 0]
|
||||
keypoints[:, 1::3] = kpts[:, :, 1]
|
||||
keypoints[:, 2::3] = kpts[:, :, 2]
|
||||
|
||||
result = [{'image_id': int(img),
|
||||
'keypoints': list(keypoints[k]),
|
||||
'score': items[k]['score'],
|
||||
'category_id': 1,
|
||||
} for k in range(item_size)]
|
||||
results.extend(result)
|
||||
|
||||
with open(res_file, 'w') as f:
|
||||
json.dump(results, f, sort_keys=True, indent=4)
|
||||
|
||||
|
||||
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
|
||||
coco = COCO(ann_path)
|
||||
coco_dt = coco.loadRes(res_file)
|
||||
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
|
||||
coco_eval.params.useSegm = None
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
|
||||
|
||||
info_str = []
|
||||
for ind, name in enumerate(stats_names):
|
||||
info_str.append((name, coco_eval.stats[ind]))
|
||||
|
||||
eval_file = os.path.join(
|
||||
res_folder, 'keypoints_results.pkl')
|
||||
|
||||
with open(eval_file, 'wb') as f:
|
||||
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
|
||||
print('coco eval results saved to %s' % eval_file)
|
||||
|
||||
return info_str
|
||||
|
||||
|
||||
# need double check this API and classes field
|
||||
def evaluate(cfg, preds, output_dir, all_boxes, img_id):
|
||||
res_folder = os.path.join(output_dir, 'results')
|
||||
if not os.path.exists(res_folder):
|
||||
os.makedirs(res_folder)
|
||||
res_file = os.path.join(res_folder, 'keypoints_results.json')
|
||||
# image -> list(keypoints/area/score)
|
||||
img_kpts_dict = defaultdict(list)
|
||||
for idx, file_id in enumerate(img_id):
|
||||
img_kpts_dict[file_id].append({
|
||||
'keypoints': preds[idx],
|
||||
'area': all_boxes[idx][0],
|
||||
'score': all_boxes[idx][1],
|
||||
})
|
||||
|
||||
# rescoring and oks nms
|
||||
num_joints = cfg.MODEL.NUM_JOINTS
|
||||
in_vis_thre = cfg.TEST.IN_VIS_THRE
|
||||
oks_thre = cfg.TEST.OKS_THRE
|
||||
oks_nmsed_kpts = {}
|
||||
for img, items in img_kpts_dict.items():
|
||||
for item in items:
|
||||
kpt_score = 0
|
||||
valid_num = 0
|
||||
for n_jt in range(num_joints):
|
||||
max_jt = item['keypoints'][n_jt][2]
|
||||
if max_jt > in_vis_thre:
|
||||
kpt_score = kpt_score + max_jt
|
||||
valid_num = valid_num + 1
|
||||
if valid_num != 0:
|
||||
kpt_score = kpt_score / valid_num
|
||||
# rescoring
|
||||
item['score'] = kpt_score * item['score']
|
||||
keep = oks_nms(items, oks_thre)
|
||||
if not keep:
|
||||
oks_nmsed_kpts[img] = items
|
||||
else:
|
||||
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
|
||||
|
||||
# evaluate and save
|
||||
image_set = cfg.DATASET.TEST_SET
|
||||
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
|
||||
if 'test' not in image_set and has_coco:
|
||||
ann_path = os.path.join(cfg.DATASET.ROOT, 'annotations',
|
||||
'person_keypoints_' + image_set + '.json')
|
||||
info_str = _do_python_keypoint_eval(
|
||||
res_file, res_folder, ann_path)
|
||||
name_value = OrderedDict(info_str)
|
||||
return name_value, name_value['AP']
|
||||
return {'Null': 0}, 0
|
@ -0,0 +1,225 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as F
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import ParameterTuple
|
||||
|
||||
BN_MOMENTUM = 0.1
|
||||
|
||||
|
||||
class MaxPool2dPytorch(nn.Cell):
|
||||
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
|
||||
super(MaxPool2dPytorch, self).__init__()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
|
||||
self.reverse = F.ReverseV2(axis=[2, 3])
|
||||
|
||||
def construct(self, x):
|
||||
x = self.reverse(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.reverse(x)
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Cell):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
pad_mode='pad', padding=1, has_bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
||||
has_bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
||||
momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU()
|
||||
self.down_sample_layer = downsample
|
||||
self.stride = stride
|
||||
|
||||
def construct(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.down_sample_layer is not None:
|
||||
residual = self.down_sample_layer(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class PoseResNet(nn.Cell):
|
||||
|
||||
def __init__(self, block, layers, cfg, pytorch_mode=True):
|
||||
self.inplanes = 64
|
||||
extra = cfg.MODEL.EXTRA
|
||||
self.deconv_with_bias = extra.DECONV_WITH_BIAS
|
||||
|
||||
super(PoseResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2,
|
||||
pad_mode='pad', padding=3, has_bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU()
|
||||
if pytorch_mode:
|
||||
self.maxpool = MaxPool2dPytorch(kernel_size=3, stride=2, pad_mode='same')
|
||||
print("use pytorch-style maxpool")
|
||||
else:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
print("use mindspore-style maxpool")
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
# used for deconv layers
|
||||
self.deconv_layers = self._make_deconv_layer(
|
||||
extra.NUM_DECONV_LAYERS,
|
||||
extra.NUM_DECONV_FILTERS,
|
||||
extra.NUM_DECONV_KERNELS,
|
||||
)
|
||||
|
||||
self.final_layer = nn.Conv2d(
|
||||
in_channels=extra.NUM_DECONV_FILTERS[-1],
|
||||
out_channels=cfg.MODEL.NUM_JOINTS,
|
||||
kernel_size=extra.FINAL_CONV_KERNEL,
|
||||
stride=1,
|
||||
pad_mode='pad',
|
||||
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0,
|
||||
has_bias=True,
|
||||
weight_init=Normal(0.001),
|
||||
)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.SequentialCell(OrderedDict([
|
||||
('0', nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, has_bias=False)),
|
||||
('1', nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)),
|
||||
]))
|
||||
|
||||
layers = OrderedDict()
|
||||
layers['0'] = block(self.inplanes, planes, stride, downsample)
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers['{}'.format(i)] = block(self.inplanes, planes)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def _get_deconv_cfg(self, deconv_kernel):
|
||||
assert deconv_kernel == 4, 'only support kernel_size = 4 for deconvolution layers'
|
||||
if deconv_kernel == 4:
|
||||
padding = 1
|
||||
output_padding = 0
|
||||
elif deconv_kernel == 3:
|
||||
padding = 1
|
||||
output_padding = 1
|
||||
elif deconv_kernel == 2:
|
||||
padding = 0
|
||||
output_padding = 0
|
||||
|
||||
return deconv_kernel, padding, output_padding
|
||||
|
||||
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
||||
assert num_layers == len(num_filters), \
|
||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||
assert num_layers == len(num_kernels), \
|
||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||
|
||||
layers = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
kernel, padding, _ = \
|
||||
self._get_deconv_cfg(num_kernels[i])
|
||||
|
||||
planes = num_filters[i]
|
||||
layers['deconv_{}'.format(i)] = nn.SequentialCell(OrderedDict([
|
||||
('deconv', nn.Conv2dTranspose(
|
||||
in_channels=self.inplanes,
|
||||
out_channels=planes,
|
||||
kernel_size=kernel,
|
||||
stride=2,
|
||||
pad_mode='pad',
|
||||
padding=padding,
|
||||
has_bias=self.deconv_with_bias,
|
||||
weight_init=Normal(0.001),
|
||||
)),
|
||||
('bn', nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)),
|
||||
('relu', nn.ReLU()),
|
||||
]))
|
||||
self.inplanes = planes
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.deconv_layers(x)
|
||||
x = self.final_layer(x)
|
||||
return x
|
||||
|
||||
def init_weights(self, pretrained=''):
|
||||
if os.path.isfile(pretrained):
|
||||
# load params from pretrained
|
||||
param_dict = load_checkpoint(pretrained)
|
||||
weight = ParameterTuple(self.trainable_params())
|
||||
for w in weight:
|
||||
if w.name.split('.')[0] not in ('deconv_layers', 'final_layer'):
|
||||
assert w.name in param_dict, "parameter %s not in checkpoint" % w.name
|
||||
load_param_into_net(self, param_dict)
|
||||
print('loading pretrained model {}'.format(pretrained))
|
||||
else:
|
||||
assert False, '{} is not a file'.format(pretrained)
|
||||
|
||||
|
||||
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
|
||||
101: (Bottleneck, [3, 4, 23, 3]),
|
||||
152: (Bottleneck, [3, 8, 36, 3])}
|
||||
|
||||
|
||||
def get_pose_net(cfg, is_train, ckpt_path=None, pytorch_mode=False):
|
||||
num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
|
||||
|
||||
block_class, layers = resnet_spec[num_layers]
|
||||
model = PoseResNet(block_class, layers, cfg, pytorch_mode=pytorch_mode)
|
||||
|
||||
if is_train and cfg.MODEL.INIT_WEIGHTS:
|
||||
model.init_weights(ckpt_path)
|
||||
|
||||
return model
|
@ -0,0 +1,85 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class JointsMSELoss(_Loss):
|
||||
def __init__(self, use_target_weight):
|
||||
super(JointsMSELoss, self).__init__()
|
||||
self.criterion = nn.MSELoss(reduction='mean')
|
||||
self.use_target_weight = use_target_weight
|
||||
self.reshape = P.Reshape()
|
||||
self.squeeze = P.Squeeze(1)
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, output, target, target_weight):
|
||||
batch_size = F.shape(output)[0]
|
||||
num_joints = F.shape(output)[1]
|
||||
|
||||
split = P.Split(1, num_joints)
|
||||
heatmaps_pred = self.reshape(output, (batch_size, num_joints, -1))
|
||||
heatmaps_pred = split(heatmaps_pred)
|
||||
|
||||
heatmaps_gt = self.reshape(target, (batch_size, num_joints, -1))
|
||||
heatmaps_gt = split(heatmaps_gt)
|
||||
loss = 0
|
||||
for idx in range(num_joints):
|
||||
heatmap_pred = self.squeeze(heatmaps_pred[idx])
|
||||
heatmap_gt = self.squeeze(heatmaps_gt[idx])
|
||||
if self.use_target_weight:
|
||||
loss += 0.5 * self.criterion(
|
||||
self.mul(heatmap_pred, target_weight[:, idx]),
|
||||
self.mul(heatmap_gt, target_weight[:, idx])
|
||||
)
|
||||
else:
|
||||
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
|
||||
return loss / num_joints
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""
|
||||
Wrap the network with loss function to compute loss.
|
||||
|
||||
Args:
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, image, target, weight, scale=None,
|
||||
center=None, score=None, idx=None):
|
||||
out = self._backbone(image)
|
||||
output = F.mixed_precision_cast(mstype.float32, out)
|
||||
target = F.mixed_precision_cast(mstype.float32, target)
|
||||
weight = F.mixed_precision_cast(mstype.float32, weight)
|
||||
return self._loss_fn(output, target, weight)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
"""
|
||||
Get the backbone network.
|
||||
|
||||
Returns:
|
||||
Cell, return backbone network.
|
||||
"""
|
||||
return self._backbone
|
@ -0,0 +1,78 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.utils.transform import transform_preds
|
||||
|
||||
|
||||
def get_max_preds(batch_heatmaps):
|
||||
'''
|
||||
get predictions from score maps
|
||||
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
|
||||
'''
|
||||
assert isinstance(batch_heatmaps, np.ndarray), \
|
||||
'batch_heatmaps should be numpy.ndarray'
|
||||
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
|
||||
|
||||
batch_size = batch_heatmaps.shape[0]
|
||||
num_joints = batch_heatmaps.shape[1]
|
||||
width = batch_heatmaps.shape[3]
|
||||
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
|
||||
idx = np.argmax(heatmaps_reshaped, 2)
|
||||
maxvals = np.amax(heatmaps_reshaped, 2)
|
||||
|
||||
maxvals = maxvals.reshape((batch_size, num_joints, 1))
|
||||
idx = idx.reshape((batch_size, num_joints, 1))
|
||||
|
||||
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
||||
|
||||
preds[:, :, 0] = (preds[:, :, 0]) % width
|
||||
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
|
||||
|
||||
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
|
||||
pred_mask = pred_mask.astype(np.float32)
|
||||
|
||||
preds *= pred_mask
|
||||
return preds, maxvals
|
||||
|
||||
|
||||
def get_final_preds(config, batch_heatmaps, center, scale):
|
||||
coords, maxvals = get_max_preds(batch_heatmaps)
|
||||
|
||||
heatmap_height = batch_heatmaps.shape[2]
|
||||
heatmap_width = batch_heatmaps.shape[3]
|
||||
|
||||
# post-processing
|
||||
if config.TEST.POST_PROCESS:
|
||||
for n in range(coords.shape[0]):
|
||||
for p in range(coords.shape[1]):
|
||||
hm = batch_heatmaps[n][p]
|
||||
px = int(math.floor(coords[n][p][0] + 0.5))
|
||||
py = int(math.floor(coords[n][p][1] + 0.5))
|
||||
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
|
||||
diff = np.array([hm[py][px + 1] - hm[py][px - 1],
|
||||
hm[py + 1][px] - hm[py - 1][px]])
|
||||
coords[n][p] += np.sign(diff) * .25
|
||||
|
||||
preds = coords.copy()
|
||||
|
||||
# Transform back
|
||||
for i in range(coords.shape[0]):
|
||||
preds[i] = transform_preds(coords[i], center[i], scale[i],
|
||||
[heatmap_width, heatmap_height])
|
||||
|
||||
return preds, maxvals
|
@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
|
||||
if not isinstance(sigmas, np.ndarray):
|
||||
sigmas = np.array(
|
||||
[.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
vas = (sigmas * 2) ** 2
|
||||
xg = g[0::3]
|
||||
yg = g[1::3]
|
||||
vg = g[2::3]
|
||||
ious = np.zeros((d.shape[0]))
|
||||
for n_d in range(0, d.shape[0]):
|
||||
xd = d[n_d, 0::3]
|
||||
yd = d[n_d, 1::3]
|
||||
vd = d[n_d, 2::3]
|
||||
dx = xd - xg
|
||||
dy = yd - yg
|
||||
e = (dx ** 2 + dy ** 2) / vas / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
|
||||
if in_vis_thre is not None:
|
||||
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
|
||||
e = e[ind]
|
||||
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
|
||||
return ious
|
||||
|
||||
|
||||
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
|
||||
"""
|
||||
greedily select boxes with high confidence and overlap with current maximum <= thresh
|
||||
rule out overlap >= thresh, overlap = oks
|
||||
:param kpts_db
|
||||
:param thresh: retain overlap < thresh
|
||||
:return: indexes to keep
|
||||
"""
|
||||
kpts_size = len(kpts_db)
|
||||
if kpts_size == 0:
|
||||
return []
|
||||
|
||||
scores = np.array([kpts_db[i]['score'] for i in range(kpts_size)])
|
||||
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(kpts_size)])
|
||||
areas = np.array([kpts_db[i]['area'] for i in range(kpts_size)])
|
||||
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
|
||||
|
||||
inds = np.where(oks_ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
@ -0,0 +1,116 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
||||
"""
|
||||
flip coords
|
||||
"""
|
||||
# Flip horizontal
|
||||
joints[:, 0] = width - joints[:, 0] - 1
|
||||
|
||||
# Change left-right parts
|
||||
for pair in matched_parts:
|
||||
joints[pair[0], :], joints[pair[1], :] = \
|
||||
joints[pair[1], :], joints[pair[0], :].copy()
|
||||
joints_vis[pair[0]], joints_vis[pair[1]] = \
|
||||
joints_vis[pair[1]], joints_vis[pair[0]].copy()
|
||||
|
||||
return joints * joints_vis, joints_vis
|
||||
|
||||
|
||||
def flip_back(output_flipped, matched_parts):
|
||||
'''
|
||||
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
||||
'''
|
||||
assert output_flipped.ndim == 4, \
|
||||
'output_flipped should be [batch_size, num_joints, height, width]'
|
||||
|
||||
output_flipped = output_flipped[:, :, :, ::-1]
|
||||
|
||||
for pair in matched_parts:
|
||||
tmp = output_flipped[:, pair[0], :, :].copy()
|
||||
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
||||
output_flipped[:, pair[1], :, :] = tmp
|
||||
|
||||
return output_flipped
|
||||
|
||||
|
||||
def transform_preds(coords, center, scale, output_size):
|
||||
target_coords = np.zeros(coords.shape)
|
||||
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
|
||||
for p in range(coords.shape[0]):
|
||||
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||
return target_coords
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
scale,
|
||||
rot,
|
||||
output_size,
|
||||
shift=np.array([0, 0], dtype=np.float32),
|
||||
inv=0):
|
||||
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||
print(scale)
|
||||
scale = np.array([scale, scale])
|
||||
|
||||
scale_tmp = scale * 200.0
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = _get_dir([0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
|
||||
src[2:, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
dst[2:, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
def affine_transform(pt, t):
|
||||
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2]
|
||||
|
||||
|
||||
def _get_3rd_point(a, b):
|
||||
direct = a - b
|
||||
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||
|
||||
|
||||
def _get_dir(src_point, rot_rad):
|
||||
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
|
||||
src_result = [0, 0]
|
||||
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||
|
||||
return src_result
|
@ -0,0 +1,148 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config
|
||||
from src.model import get_pose_net
|
||||
from src.network_define import JointsMSELoss, WithLossCell
|
||||
from src.dataset import keypoint_dataset
|
||||
|
||||
set_seed(1)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
|
||||
def get_lr(begin_epoch,
|
||||
total_epochs,
|
||||
steps_per_epoch,
|
||||
lr_init=0.1,
|
||||
factor=0.1,
|
||||
epoch_number_to_drop=(90, 120)
|
||||
):
|
||||
"""
|
||||
Generate learning rate array.
|
||||
|
||||
Args:
|
||||
begin_epoch (int): Initial epoch of training.
|
||||
total_epochs (int): Total epoch of training.
|
||||
steps_per_epoch (float): Steps of one epoch.
|
||||
lr_init (float): Initial learning rate. Default: 0.316.
|
||||
factor:Factor of lr to drop.
|
||||
epoch_number_to_drop:Learing rate will drop after these epochs.
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
|
||||
for i in range(int(total_steps)):
|
||||
if i in step_number_to_drop:
|
||||
lr_init = lr_init * factor
|
||||
lr_each_step.append(lr_init)
|
||||
current_step = steps_per_epoch * begin_epoch
|
||||
lr_each_step = np.array(lr_each_step, dtype=np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simpleposenet training")
|
||||
parser.add_argument("--run-distribute",
|
||||
help="Run distribute, default is false.",
|
||||
action='store_true')
|
||||
parser.add_argument('--ckpt-path', type=str, help='ckpt path to save')
|
||||
parser.add_argument('--batch-size', type=int, help='training batch size')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
# load parse and config
|
||||
print("loading parse...")
|
||||
args = parse_args()
|
||||
if args.batch_size:
|
||||
config.TRAIN.BATCH_SIZE = args.batch_size
|
||||
print('batch size :{}'.format(config.TRAIN.BATCH_SIZE))
|
||||
|
||||
# distribution and context
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
if args.run_distribute:
|
||||
init()
|
||||
rank = get_rank()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# only rank = 0 can write
|
||||
rank_save_flag = False
|
||||
if rank == 0 or device_num == 1:
|
||||
rank_save_flag = True
|
||||
|
||||
# create dataset
|
||||
dataset, _ = keypoint_dataset(config,
|
||||
rank=rank,
|
||||
group_size=device_num,
|
||||
train_mode=True,
|
||||
num_parallel_workers=8)
|
||||
|
||||
# network
|
||||
net = get_pose_net(config, True, ckpt_path=config.MODEL.PRETRAINED)
|
||||
loss = JointsMSELoss(use_target_weight=True)
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
|
||||
# lr schedule and optim
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
|
||||
config.TRAIN.END_EPOCH,
|
||||
dataset_size,
|
||||
lr_init=config.TRAIN.LR,
|
||||
factor=config.TRAIN.LR_FACTOR,
|
||||
epoch_number_to_drop=config.TRAIN.LR_STEP))
|
||||
opt = Adam(net.trainable_params(), learning_rate=lr)
|
||||
|
||||
# callback
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if args.ckpt_path and rank_save_flag:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="simplepose", directory=args.ckpt_path, config=config_ck)
|
||||
cb.append(ckpoint_cb)
|
||||
# train model
|
||||
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
|
||||
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
|
||||
print('start training, epoch size = %d' % epoch_size)
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue