Pre Merge pull request !14578 from 童志豪/master
commit
fdabd0deca
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,144 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
This file evaluates the model used.
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, float32, context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.pose_resnet import GetPoseResNet
|
||||
from src.dataset import flip_pairs
|
||||
from src.dataset import CreateDatasetCoco
|
||||
from src.utils.coco import evaluate
|
||||
from src.utils.transforms import flip_back
|
||||
from src.utils.inference import get_final_preds
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
import moxing as mox
|
||||
|
||||
set_seed(config.GENERAL.EVAL_SEED)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Evaluate')
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of evaluate outputs.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def validate(cfg, val_dataset, model, output_dir, ann_path):
|
||||
'''
|
||||
validate
|
||||
'''
|
||||
model.set_train(False)
|
||||
num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
|
||||
all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
|
||||
dtype=np.float32)
|
||||
all_boxes = np.zeros((num_samples, 2))
|
||||
image_id = []
|
||||
idx = 0
|
||||
|
||||
start = time.time()
|
||||
for item in val_dataset.create_dict_iterator():
|
||||
inputs = item['image'].asnumpy()
|
||||
output = model(Tensor(inputs, float32)).asnumpy()
|
||||
if cfg.TEST.FLIP_TEST:
|
||||
inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
|
||||
output_flipped = model(inputs_flipped)
|
||||
output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
|
||||
|
||||
if cfg.TEST.SHIFT_HEATMAP:
|
||||
output_flipped[:, :, :, 1:] = \
|
||||
output_flipped.copy()[:, :, :, 0:-1]
|
||||
|
||||
output = (output + output_flipped) * 0.5
|
||||
|
||||
c = item['center'].asnumpy()
|
||||
s = item['scale'].asnumpy()
|
||||
score = item['score'].asnumpy()
|
||||
file_id = list(item['id'].asnumpy())
|
||||
|
||||
preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
|
||||
num_images, _ = preds.shape[:2]
|
||||
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
|
||||
all_preds[idx:idx + num_images, :, 2:3] = maxvals
|
||||
all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
|
||||
all_boxes[idx:idx + num_images, 1] = score
|
||||
image_id.extend(file_id)
|
||||
idx += num_images
|
||||
if idx % 1024 == 0:
|
||||
print('{} samples validated in {} seconds'.format(idx, time.time() - start))
|
||||
start = time.time()
|
||||
|
||||
print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
|
||||
_, perf_indicator = evaluate(cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id, ann_path)
|
||||
print("AP:", perf_indicator)
|
||||
return perf_indicator
|
||||
|
||||
|
||||
def main():
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
|
||||
|
||||
model = GetPoseResNet(config)
|
||||
|
||||
ckpt_name = ''
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
ckpt_name = config.MODELARTS.CACHE_INPUT
|
||||
else:
|
||||
ckpt_name = config.DATASET.ROOT
|
||||
ckpt_name = ckpt_name + config.TEST.MODEL_FILE
|
||||
print('loading model ckpt from {}'.format(ckpt_name))
|
||||
load_param_into_net(model, load_checkpoint(ckpt_name))
|
||||
|
||||
valid_dataset = CreateDatasetCoco(
|
||||
train_mode=False,
|
||||
num_parallel_workers=config.TEST.NUM_PARALLEL_WORKERS,
|
||||
)
|
||||
|
||||
ckpt_name = ckpt_name.split('/')
|
||||
ckpt_name = ckpt_name[len(ckpt_name) - 1]
|
||||
ckpt_name = ckpt_name.split('.')[0]
|
||||
output_dir = ''
|
||||
ann_path = ''
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
output_dir = config.MODELARTS.CACHE_OUTPUT
|
||||
ann_path = config.MODELARTS.CACHE_INPUT
|
||||
else:
|
||||
output_dir = config.TEST.OUTPUT_DIR
|
||||
ann_path = config.DATASET.ROOT
|
||||
output_dir = output_dir + ckpt_name
|
||||
ann_path = ann_path + config.DATASET.TEST_JSON
|
||||
validate(config, valid_dataset, model, output_dir, ann_path)
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,51 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as ms
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import config
|
||||
from src.pose_resnet import GetPoseResNet
|
||||
|
||||
parser = argparse.ArgumentParser(description='simple_baselines')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_url", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="simple_baselines", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["MINDIR"], default='MINDIR', help='file format')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = config
|
||||
|
||||
net = GetPoseResNet(config)
|
||||
|
||||
assert cfg.checkpoint_dir is not None, "cfg.checkpoint_dir is None."
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([64, 3, 192, 224]), ms.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
||||
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
export DEVICE_ID=$1
|
||||
|
||||
python eval.py > eval_log$1.txt 2>&1 &
|
||||
@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# Usage: sh train_distributed.sh [MINDSPORE_HCCL_CONFIG_PATH] [SAVE_CKPT_PATH] [RANK_SIZE]
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH RANK_SIZE"
|
||||
echo "For example: bash run.sh /coco2017/train2017 2"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$1
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
RANK_SIZE=$2
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
for((i=0;i<${RANK_SIZE};i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cd ./device$i
|
||||
mkdir src
|
||||
mkdir src/utils
|
||||
cd ../
|
||||
cp ./train.py ./device$i
|
||||
cp ./src/config.py ./src/dataset.py ./src/pose_resnet.py ./src/network_with_loss.py ./device$i/src
|
||||
cp ./src/utils/transforms.py ./device$i/src/utils
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py > train$i.log 2>&1 &
|
||||
echo "$i finish"
|
||||
cd ../
|
||||
done
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
||||
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_standalone_train.sh DATA_PATH DEVICE_ID"
|
||||
echo "For example: bash run_standalone_train.sh /path/dataset 0"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$1
|
||||
DEVICE_ID=$2
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
cd ../
|
||||
export DEVICE_ID=$2
|
||||
export RANK_ID=$2
|
||||
env > env0.log
|
||||
python train.py --data_url $1 --device_id $2 > train.log 2>&1
|
||||
|
||||
if [ $? -eq 0 ];then
|
||||
echo "training success"
|
||||
else
|
||||
echo "training failed"
|
||||
exit 2
|
||||
fi
|
||||
echo "finish"
|
||||
cd ../
|
||||
|
||||
@ -0,0 +1,106 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
config
|
||||
'''
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
|
||||
#general
|
||||
config.GENERAL = edict()
|
||||
config.GENERAL.VERSION = 'commit'
|
||||
config.GENERAL.TRAIN_SEED = 1
|
||||
config.GENERAL.EVAL_SEED = 1
|
||||
config.GENERAL.DATASET_SEED = 1
|
||||
config.GENERAL.RUN_DISTRIBUTE = True
|
||||
|
||||
#model arts
|
||||
config.MODELARTS = edict()
|
||||
config.MODELARTS.IS_MODEL_ARTS = True
|
||||
config.MODELARTS.CACHE_INPUT = '/cache/data_tzh/'
|
||||
config.MODELARTS.CACHE_OUTPUT = '/cache/train_out/'
|
||||
|
||||
# model
|
||||
config.MODEL = edict()
|
||||
config.MODEL.IS_TRAINED = True
|
||||
config.MODEL.INIT_WEIGHTS = True
|
||||
config.MODEL.PRETRAINED = 'resnet50.ckpt'
|
||||
config.MODEL.NUM_JOINTS = 17
|
||||
config.MODEL.IMAGE_SIZE = [192, 256]
|
||||
|
||||
# network
|
||||
config.NETWORK = edict()
|
||||
config.NETWORK.NUM_LAYERS = 50
|
||||
config.NETWORK.DECONV_WITH_BIAS = False
|
||||
config.NETWORK.NUM_DECONV_LAYERS = 3
|
||||
config.NETWORK.NUM_DECONV_FILTERS = [256, 256, 256]
|
||||
config.NETWORK.NUM_DECONV_KERNELS = [4, 4, 4]
|
||||
config.NETWORK.FINAL_CONV_KERNEL = 1
|
||||
config.NETWORK.REVERSE = True
|
||||
|
||||
config.NETWORK.TARGET_TYPE = 'gaussian'
|
||||
config.NETWORK.HEATMAP_SIZE = [48, 64]
|
||||
config.NETWORK.SIGMA = 2
|
||||
|
||||
# loss
|
||||
config.LOSS = edict()
|
||||
config.LOSS.USE_TARGET_WEIGHT = True
|
||||
|
||||
# dataset
|
||||
config.DATASET = edict()
|
||||
config.DATASET.TYPE = 'COCO'
|
||||
config.DATASET.ROOT = 'coco2017/'
|
||||
config.DATASET.TRAIN_SET = 'train2017'
|
||||
config.DATASET.TRAIN_JSON = 'annotations/person_keypoints_train2017.json'
|
||||
config.DATASET.TEST_SET = 'val2017'
|
||||
config.DATASET.TEST_JSON = 'annotations/person_keypoints_val2017.json'
|
||||
|
||||
# training data augmentation
|
||||
config.DATASET.FLIP = True
|
||||
config.DATASET.SCALE_FACTOR = 0.3
|
||||
config.DATASET.ROT_FACTOR = 40
|
||||
|
||||
# train
|
||||
config.TRAIN = edict()
|
||||
config.TRAIN.SHUFFLE = True
|
||||
config.TRAIN.BATCH_SIZE = 64
|
||||
config.TRAIN.BEGIN_EPOCH = 0
|
||||
config.TRAIN.END_EPOCH = 140
|
||||
config.TRAIN.LR = 0.001
|
||||
config.TRAIN.LR_FACTOR = 0.1
|
||||
config.TRAIN.LR_STEP = [90, 120]
|
||||
config.TRAIN.NUM_PARALLEL_WORKERS = 8
|
||||
config.TRAIN.SAVE_CKPT = True
|
||||
config.TRAIN.CKPT_PATH = 'check_point/'
|
||||
|
||||
# valid
|
||||
config.TEST = edict()
|
||||
config.TEST.BATCH_SIZE = 32
|
||||
config.TEST.FLIP_TEST = True
|
||||
config.TEST.POST_PROCESS = True
|
||||
config.TEST.SHIFT_HEATMAP = True
|
||||
config.TEST.USE_GT_BBOX = False
|
||||
config.TEST.NUM_PARALLEL_WORKERS = 8
|
||||
config.TEST.MODEL_FILE = 'multi_train_poseresnet_v5_2-140_2340.ckpt'
|
||||
config.TEST.COCO_BBOX_FILE = 'annotations/COCO_val2017_detections_AP_H_56_person.json'
|
||||
config.TEST.OUTPUT_DIR = 'results/'
|
||||
|
||||
# nms
|
||||
config.TEST.OKS_THRE = 0.9
|
||||
config.TEST.IN_VIS_THRE = 0.2
|
||||
config.TEST.BBOX_THRE = 1.0
|
||||
config.TEST.IMAGE_THRE = 0.0
|
||||
config.TEST.NMS_THRE = 1.0
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,81 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
network_with_loss
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
class JointsMSELoss(_Loss):
|
||||
'''
|
||||
JointsMSELoss
|
||||
'''
|
||||
def __init__(self, use_target_weight):
|
||||
super(JointsMSELoss, self).__init__()
|
||||
self.criterion = nn.MSELoss(reduction='mean')
|
||||
self.use_target_weight = use_target_weight
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.squeeze = P.Squeeze(1)
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, output, target, target_weight):
|
||||
'''
|
||||
construct
|
||||
'''
|
||||
total_shape = self.shape(output)
|
||||
batch_size = total_shape[0]
|
||||
num_joints = total_shape[1]
|
||||
remained_size = 1
|
||||
for i in range(2, len(total_shape)):
|
||||
remained_size *= total_shape[i]
|
||||
|
||||
split = P.Split(1, num_joints)
|
||||
new_shape = (batch_size, num_joints, remained_size)
|
||||
heatmaps_pred = split(self.reshape(output, new_shape))
|
||||
heatmaps_gt = split(self.reshape(target, new_shape))
|
||||
loss = 0
|
||||
|
||||
for idx in range(num_joints):
|
||||
heatmap_pred_squeezed = self.squeeze(heatmaps_pred[idx])
|
||||
heatmap_gt_squeezed = self.squeeze(heatmaps_gt[idx])
|
||||
if self.use_target_weight:
|
||||
loss += 0.5 * self.criterion(self.mul(heatmap_pred_squeezed, target_weight[:, idx]),
|
||||
self.mul(heatmap_gt_squeezed, target_weight[:, idx]))
|
||||
else:
|
||||
loss += 0.5 * self.criterion(heatmap_pred_squeezed, heatmap_gt_squeezed)
|
||||
|
||||
return loss / num_joints
|
||||
|
||||
class PoseResNetWithLoss(nn.Cell):
|
||||
"""
|
||||
Pack the model network and loss function together to calculate the loss value.
|
||||
"""
|
||||
def __init__(self, network, loss):
|
||||
super(PoseResNetWithLoss, self).__init__()
|
||||
self.network = network
|
||||
self.loss = loss
|
||||
|
||||
def construct(self, image, target, weight, scale=None, center=None, score=None, idx=None):
|
||||
output = self.network(image)
|
||||
output = F.mixed_precision_cast(mstype.float32, output)
|
||||
target = F.mixed_precision_cast(mstype.float32, target)
|
||||
weight = F.mixed_precision_cast(mstype.float32, weight)
|
||||
return self.loss(output, target, weight)
|
||||
@ -0,0 +1,222 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
simple_baselines network
|
||||
'''
|
||||
from __future__ import division
|
||||
import os
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as init
|
||||
import mindspore.ops.operations as F
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
BN_MOMENTUM = 0.1
|
||||
|
||||
class MPReverse(nn.Cell):
|
||||
'''
|
||||
MPReverse
|
||||
'''
|
||||
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
|
||||
super(MPReverse, self).__init__()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, pad_mode=pad_mode)
|
||||
self.reverse = F.ReverseV2(axis=[2, 3])
|
||||
|
||||
def construct(self, x):
|
||||
x = self.reverse(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.reverse(x)
|
||||
return x
|
||||
|
||||
class Bottleneck(nn.Cell):
|
||||
'''
|
||||
model part of network
|
||||
'''
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
||||
stride=stride, padding=1, has_bias=False, pad_mode='pad')
|
||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, has_bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
construct
|
||||
'''
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class PoseResNet(nn.Cell):
|
||||
'''
|
||||
PoseResNet
|
||||
'''
|
||||
|
||||
def __init__(self, block, layers, cfg):
|
||||
self.inplanes = 64
|
||||
self.deconv_with_bias = cfg.NETWORK.DECONV_WITH_BIAS
|
||||
|
||||
super(PoseResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, has_bias=False, pad_mode='pad')
|
||||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = MPReverse(kernel_size=3, stride=2, pad_mode='same')
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
# used for deconv layers
|
||||
self.deconv_layers = self._make_deconv_layer(
|
||||
cfg.NETWORK.NUM_DECONV_LAYERS,
|
||||
cfg.NETWORK.NUM_DECONV_FILTERS,
|
||||
cfg.NETWORK.NUM_DECONV_KERNELS,
|
||||
)
|
||||
|
||||
self.final_layer = nn.Conv2d(
|
||||
in_channels=cfg.NETWORK.NUM_DECONV_FILTERS[-1],
|
||||
out_channels=cfg.MODEL.NUM_JOINTS,
|
||||
kernel_size=cfg.NETWORK.FINAL_CONV_KERNEL,
|
||||
stride=1,
|
||||
padding=1 if cfg.NETWORK.FINAL_CONV_KERNEL == 3 else 0,
|
||||
pad_mode='pad',
|
||||
has_bias=True,
|
||||
weight_init=init.Normal(0.001)
|
||||
)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
'''
|
||||
_make_layer
|
||||
'''
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.SequentialCell([nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, has_bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM)])
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
print(i)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
||||
'''
|
||||
_make_deconv_layer
|
||||
'''
|
||||
assert num_layers == len(num_filters), \
|
||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||
assert num_layers == len(num_kernels), \
|
||||
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
|
||||
layers = []
|
||||
for i in range(num_layers):
|
||||
kernel = num_kernels[i]
|
||||
padding = 1
|
||||
planes = num_filters[i]
|
||||
|
||||
layers.append(nn.Conv2dTranspose(
|
||||
in_channels=self.inplanes,
|
||||
out_channels=planes,
|
||||
kernel_size=kernel,
|
||||
stride=2,
|
||||
padding=padding,
|
||||
has_bias=self.deconv_with_bias,
|
||||
pad_mode='pad',
|
||||
weight_init=init.Normal(0.001)
|
||||
))
|
||||
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
|
||||
layers.append(nn.ReLU())
|
||||
self.inplanes = planes
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
'''
|
||||
construct
|
||||
'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.deconv_layers(x)
|
||||
x = self.final_layer(x)
|
||||
return x
|
||||
|
||||
def init_weights(self, pretrained=''):
|
||||
if os.path.isfile(pretrained):
|
||||
# load params from pretrained
|
||||
param_dict = load_checkpoint(pretrained)
|
||||
load_param_into_net(self, param_dict)
|
||||
print('=> loading pretrained model {}'.format(pretrained))
|
||||
else:
|
||||
print('=> imagenet pretrained model dose not exist')
|
||||
raise ValueError('{} is not a file'.format(pretrained))
|
||||
|
||||
|
||||
resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]),
|
||||
101: (Bottleneck, [3, 4, 23, 3]),
|
||||
152: (Bottleneck, [3, 8, 36, 3])}
|
||||
|
||||
|
||||
def GetPoseResNet(cfg):
|
||||
'''
|
||||
GetPoseResNet
|
||||
'''
|
||||
num_layers = cfg.NETWORK.NUM_LAYERS
|
||||
block_class, layers = resnet_spec[num_layers]
|
||||
network = PoseResNet(block_class, layers, cfg)
|
||||
|
||||
if cfg.MODEL.IS_TRAINED and cfg.MODEL.INIT_WEIGHTS:
|
||||
pretrained = ''
|
||||
if cfg.MODELARTS.IS_MODEL_ARTS:
|
||||
pretrained = cfg.MODELARTS.CACHE_INPUT + cfg.MODEL.PRETRAINED
|
||||
else:
|
||||
pretrained = cfg.TRAIN.CKPT_PATH + cfg.MODEL.PRETRAINED
|
||||
network.init_weights(pretrained)
|
||||
return network
|
||||
@ -0,0 +1,137 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
coco
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
has_coco = True
|
||||
except ImportError:
|
||||
has_coco = False
|
||||
|
||||
from src.utils.nms import oks_nms
|
||||
|
||||
def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
|
||||
'''
|
||||
_write_coco_keypoint_results
|
||||
'''
|
||||
results = []
|
||||
|
||||
for img, items in img_kpts.items():
|
||||
item_size = len(items)
|
||||
if not items:
|
||||
continue
|
||||
kpts = np.array([items[k]['keypoints']
|
||||
for k in range(item_size)])
|
||||
keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
|
||||
keypoints[:, 0::3] = kpts[:, :, 0]
|
||||
keypoints[:, 1::3] = kpts[:, :, 1]
|
||||
keypoints[:, 2::3] = kpts[:, :, 2]
|
||||
|
||||
result = [{'image_id': int(img),
|
||||
'keypoints': list(keypoints[k]),
|
||||
'score': items[k]['score'],
|
||||
'category_id': 1,
|
||||
} for k in range(item_size)]
|
||||
results.extend(result)
|
||||
|
||||
with open(res_file, 'w') as f:
|
||||
json.dump(results, f, sort_keys=True, indent=4)
|
||||
|
||||
|
||||
def _do_python_keypoint_eval(res_file, res_folder, ann_path):
|
||||
'''
|
||||
_do_python_keypoint_eval
|
||||
'''
|
||||
coco = COCO(ann_path)
|
||||
coco_dt = coco.loadRes(res_file)
|
||||
coco_eval = COCOeval(coco, coco_dt, 'keypoints')
|
||||
coco_eval.params.useSegm = None
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
|
||||
|
||||
info_str = []
|
||||
for ind, name in enumerate(stats_names):
|
||||
info_str.append((name, coco_eval.stats[ind]))
|
||||
|
||||
eval_file = os.path.join(
|
||||
res_folder, 'keypoints_results.pkl')
|
||||
|
||||
with open(eval_file, 'wb') as f:
|
||||
pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
|
||||
print('coco eval results saved to %s' % eval_file)
|
||||
|
||||
return info_str
|
||||
|
||||
def evaluate(cfg, preds, output_dir, all_boxes, img_id, ann_path):
|
||||
'''
|
||||
evaluate
|
||||
'''
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
res_file = os.path.join(output_dir, 'keypoints_results.json')
|
||||
img_kpts_dict = defaultdict(list)
|
||||
for idx, file_id in enumerate(img_id):
|
||||
img_kpts_dict[file_id].append({
|
||||
'keypoints': preds[idx],
|
||||
'area': all_boxes[idx][0],
|
||||
'score': all_boxes[idx][1],
|
||||
})
|
||||
|
||||
# rescoring and oks nms
|
||||
num_joints = cfg.MODEL.NUM_JOINTS
|
||||
in_vis_thre = cfg.TEST.IN_VIS_THRE
|
||||
oks_thre = cfg.TEST.OKS_THRE
|
||||
oks_nmsed_kpts = {}
|
||||
for img, items in img_kpts_dict.items():
|
||||
for item in items:
|
||||
kpt_score = 0
|
||||
valid_num = 0
|
||||
for n_jt in range(num_joints):
|
||||
max_jt = item['keypoints'][n_jt][2]
|
||||
if max_jt > in_vis_thre:
|
||||
kpt_score = kpt_score + max_jt
|
||||
valid_num = valid_num + 1
|
||||
if valid_num != 0:
|
||||
kpt_score = kpt_score / valid_num
|
||||
item['score'] = kpt_score * item['score']
|
||||
keep = oks_nms(items, oks_thre)
|
||||
if not keep:
|
||||
oks_nmsed_kpts[img] = items
|
||||
else:
|
||||
oks_nmsed_kpts[img] = [items[kep] for kep in keep]
|
||||
|
||||
# evaluate and save
|
||||
image_set = cfg.DATASET.TEST_SET
|
||||
_write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
|
||||
if 'test' not in image_set and has_coco:
|
||||
ann_path = ann_path if ann_path else os.path.join(cfg.DATASET.ROOT, 'annotations',
|
||||
'person_keypoints_' + image_set + '.json')
|
||||
info_str = _do_python_keypoint_eval(res_file, output_dir, ann_path)
|
||||
name_value = OrderedDict(info_str)
|
||||
return name_value, name_value['AP']
|
||||
return {'Null': 0}, 0
|
||||
@ -0,0 +1,83 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
inference
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from src.utils.transforms import transform_preds
|
||||
|
||||
def get_max_preds(batch_heatmaps):
|
||||
'''
|
||||
get predictions from score maps
|
||||
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
|
||||
'''
|
||||
assert isinstance(batch_heatmaps, np.ndarray), \
|
||||
'batch_heatmaps should be numpy.ndarray'
|
||||
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
|
||||
|
||||
batch_size = batch_heatmaps.shape[0]
|
||||
num_joints = batch_heatmaps.shape[1]
|
||||
width = batch_heatmaps.shape[3]
|
||||
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
|
||||
idx = np.argmax(heatmaps_reshaped, 2)
|
||||
maxvals = np.amax(heatmaps_reshaped, 2)
|
||||
|
||||
maxvals = maxvals.reshape((batch_size, num_joints, 1))
|
||||
idx = idx.reshape((batch_size, num_joints, 1))
|
||||
|
||||
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
||||
|
||||
preds[:, :, 0] = (preds[:, :, 0]) % width
|
||||
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
|
||||
|
||||
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
|
||||
pred_mask = pred_mask.astype(np.float32)
|
||||
|
||||
preds *= pred_mask
|
||||
return preds, maxvals
|
||||
|
||||
|
||||
def get_final_preds(config, batch_heatmaps, center, scale):
|
||||
'''
|
||||
get_final_preds
|
||||
'''
|
||||
coords, maxvals = get_max_preds(batch_heatmaps)
|
||||
|
||||
heatmap_height = batch_heatmaps.shape[2]
|
||||
heatmap_width = batch_heatmaps.shape[3]
|
||||
# post-processing
|
||||
if config.TEST.POST_PROCESS:
|
||||
for n in range(coords.shape[0]):
|
||||
for p in range(coords.shape[1]):
|
||||
hm = batch_heatmaps[n][p]
|
||||
px = int(math.floor(coords[n][p][0] + 0.5))
|
||||
py = int(math.floor(coords[n][p][1] + 0.5))
|
||||
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
|
||||
diff = np.array([hm[py][px+1] - hm[py][px-1],
|
||||
hm[py+1][px]-hm[py-1][px]])
|
||||
coords[n][p] += np.sign(diff) * .25
|
||||
|
||||
preds = coords.copy()
|
||||
|
||||
# Transform back
|
||||
for i in range(coords.shape[0]):
|
||||
preds[i] = transform_preds(coords[i], center[i], scale[i],
|
||||
[heatmap_width, heatmap_height])
|
||||
|
||||
return preds, maxvals
|
||||
@ -0,0 +1,74 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
nms operation
|
||||
'''
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
|
||||
def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
|
||||
'''
|
||||
oks_iou
|
||||
'''
|
||||
if not isinstance(sigmas, np.ndarray):
|
||||
sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72,
|
||||
.62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
var = (sigmas * 2) ** 2
|
||||
xg = g[0::3]
|
||||
yg = g[1::3]
|
||||
vg = g[2::3]
|
||||
ious = np.zeros((d.shape[0]))
|
||||
for n_d in range(0, d.shape[0]):
|
||||
xd = d[n_d, 0::3]
|
||||
yd = d[n_d, 1::3]
|
||||
vd = d[n_d, 2::3]
|
||||
dx = xd - xg
|
||||
dy = yd - yg
|
||||
e = (dx ** 2 + dy ** 2) / var / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
|
||||
if in_vis_thre is not None:
|
||||
ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
|
||||
e = e[ind]
|
||||
ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
|
||||
return ious
|
||||
|
||||
def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
|
||||
"""
|
||||
greedily select boxes with high confidence and overlap with current maximum <= thresh
|
||||
rule out overlap >= thresh, overlap = oks
|
||||
:param kpts_db
|
||||
:param thresh: retain overlap < thresh
|
||||
:return: indexes to keep
|
||||
"""
|
||||
kpts = len(kpts_db)
|
||||
if kpts == 0:
|
||||
return []
|
||||
|
||||
scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
|
||||
kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
|
||||
areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
|
||||
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
|
||||
|
||||
inds = np.where(oks_ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
@ -0,0 +1,137 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
transforms
|
||||
'''
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def flip_back(output_flipped, matched_parts):
|
||||
'''
|
||||
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
||||
'''
|
||||
assert output_flipped.ndim == 4,\
|
||||
'output_flipped should be [batch_size, num_joints, height, width]'
|
||||
|
||||
output_flipped = output_flipped[:, :, :, ::-1]
|
||||
|
||||
for pair in matched_parts:
|
||||
tmp = output_flipped[:, pair[0], :, :].copy()
|
||||
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
||||
output_flipped[:, pair[1], :, :] = tmp
|
||||
|
||||
return output_flipped
|
||||
|
||||
|
||||
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
||||
"""
|
||||
flip coords
|
||||
"""
|
||||
# Flip horizontal
|
||||
joints[:, 0] = width - joints[:, 0] - 1
|
||||
|
||||
# Change left-right parts
|
||||
for pair in matched_parts:
|
||||
joints[pair[0], :], joints[pair[1], :] = \
|
||||
joints[pair[1], :], joints[pair[0], :].copy()
|
||||
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
|
||||
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
|
||||
|
||||
return joints*joints_vis, joints_vis
|
||||
|
||||
|
||||
def transform_preds(coords, center, scale, output_size):
|
||||
'''
|
||||
transform_preds
|
||||
'''
|
||||
target_coords = np.zeros(coords.shape)
|
||||
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
|
||||
for p in range(coords.shape[0]):
|
||||
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||
return target_coords
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
scale,
|
||||
rot,
|
||||
output_size,
|
||||
shift=np.array([0, 0], dtype=np.float32),
|
||||
inv=0):
|
||||
'''
|
||||
get_affine_transform
|
||||
'''
|
||||
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||
print(scale)
|
||||
scale = np.array([scale, scale])
|
||||
|
||||
scale_tmp = scale * 200.0
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
|
||||
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
||||
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
def affine_transform(pt, t):
|
||||
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2]
|
||||
|
||||
|
||||
def get_3rd_point(a, b):
|
||||
direct = a - b
|
||||
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||
|
||||
|
||||
def get_dir(src_point, rot_rad):
|
||||
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
|
||||
src_result = [0, 0]
|
||||
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||
|
||||
return src_result
|
||||
|
||||
|
||||
def crop(img, center, scale, output_size, rot=0):
|
||||
trans = get_affine_transform(center, scale, rot, output_size)
|
||||
|
||||
dst_img = cv2.warpAffine(img,
|
||||
trans,
|
||||
(int(output_size[0]), int(output_size[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
return dst_img
|
||||
@ -0,0 +1,141 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
'''
|
||||
train
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config
|
||||
from src.pose_resnet import GetPoseResNet
|
||||
from src.network_with_loss import JointsMSELoss, PoseResNetWithLoss
|
||||
from src.dataset import CreateDatasetCoco
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
import moxing as mox
|
||||
|
||||
set_seed(config.GENERAL.TRAIN_SEED)
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
|
||||
def get_lr(begin_epoch,
|
||||
total_epochs,
|
||||
steps_per_epoch,
|
||||
lr_init=0.1,
|
||||
factor=0.1,
|
||||
epoch_number_to_drop=(90, 120)
|
||||
):
|
||||
'''
|
||||
get_lr
|
||||
'''
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
step_number_to_drop = [steps_per_epoch * x for x in epoch_number_to_drop]
|
||||
for i in range(int(total_steps)):
|
||||
if i in step_number_to_drop:
|
||||
lr_init = lr_init * factor
|
||||
lr_each_step.append(lr_init)
|
||||
current_step = steps_per_epoch * begin_epoch
|
||||
lr_each_step = np.array(lr_each_step, dtype=np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
return learning_rate
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simpleposenet training")
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main():
|
||||
print("loading parse...")
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
if config.GENERAL.RUN_DISTRIBUTE:
|
||||
init()
|
||||
rank = get_rank()
|
||||
device_num = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=config.MODELARTS.CACHE_INPUT)
|
||||
|
||||
dataset = CreateDatasetCoco(rank=rank,
|
||||
group_size=device_num,
|
||||
train_mode=True,
|
||||
num_parallel_workers=config.TRAIN.NUM_PARALLEL_WORKERS,
|
||||
)
|
||||
net = GetPoseResNet(config)
|
||||
loss = JointsMSELoss(config.LOSS.USE_TARGET_WEIGHT)
|
||||
net_with_loss = PoseResNetWithLoss(net, loss)
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
lr = Tensor(get_lr(config.TRAIN.BEGIN_EPOCH,
|
||||
config.TRAIN.END_EPOCH,
|
||||
dataset_size,
|
||||
lr_init=config.TRAIN.LR,
|
||||
factor=config.TRAIN.LR_FACTOR,
|
||||
epoch_number_to_drop=config.TRAIN.LR_STEP))
|
||||
opt = Adam(net.trainable_params(), learning_rate=lr)
|
||||
time_cb = TimeMonitor(data_size=dataset_size)
|
||||
loss_cb = LossMonitor()
|
||||
cb = [time_cb, loss_cb]
|
||||
if config.TRAIN.SAVE_CKPT:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size, keep_checkpoint_max=20)
|
||||
prefix = ''
|
||||
if config.GENERAL.RUN_DISTRIBUTE:
|
||||
prefix = 'multi_'
|
||||
else:
|
||||
prefix = 'single_'
|
||||
prefix = prefix + 'train_poseresnet_' + config.GENERAL.VERSION + '_' + os.getenv('DEVICE_ID')
|
||||
|
||||
directory = ''
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
directory = config.MODELARTS.CACHE_OUTPUT
|
||||
else:
|
||||
directory = config.TRAIN.CKPT_PATH
|
||||
directory = directory + 'device_'+ os.getenv('DEVICE_ID')
|
||||
|
||||
ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=directory, config=config_ck)
|
||||
cb.append(ckpoint_cb)
|
||||
model = Model(net_with_loss, loss_fn=None, optimizer=opt, amp_level="O2")
|
||||
epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
|
||||
print("************ Start training now ************")
|
||||
print('start training, epoch size = %d' % epoch_size)
|
||||
model.train(epoch_size, dataset, callbacks=cb)
|
||||
|
||||
if config.MODELARTS.IS_MODEL_ARTS:
|
||||
mox.file.copy_parallel(src_url=config.MODELARTS.CACHE_OUTPUT, dst_url=args.train_url)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in new issue