parent
6cf308076d
commit
2705bb2ca5
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute eval."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
from src.dataset_eval import data_generator_eval
|
||||||
|
from src.config import config
|
||||||
|
from src.FaceAttribute.resnet18 import get_resnet18
|
||||||
|
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x, axis=0):
|
||||||
|
return np.exp(x) / np.sum(np.exp(x), axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
network = get_resnet18(args)
|
||||||
|
ckpt_path = args.model_path
|
||||||
|
if os.path.isfile(ckpt_path):
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
param_dict_new = {}
|
||||||
|
for key, values in param_dict.items():
|
||||||
|
if key.startswith('moments.'):
|
||||||
|
continue
|
||||||
|
elif key.startswith('network.'):
|
||||||
|
param_dict_new[key[8:]] = values
|
||||||
|
else:
|
||||||
|
param_dict_new[key] = values
|
||||||
|
load_param_into_net(network, param_dict_new)
|
||||||
|
print('-----------------------load model success-----------------------')
|
||||||
|
else:
|
||||||
|
print('-----------------------load model failed-----------------------')
|
||||||
|
|
||||||
|
network.set_train(False)
|
||||||
|
|
||||||
|
de_dataloader, steps_per_epoch, _ = data_generator_eval(args)
|
||||||
|
|
||||||
|
total_data_num_age = 0
|
||||||
|
total_data_num_gen = 0
|
||||||
|
total_data_num_mask = 0
|
||||||
|
age_num = 0
|
||||||
|
gen_num = 0
|
||||||
|
mask_num = 0
|
||||||
|
gen_tp_num = 0
|
||||||
|
mask_tp_num = 0
|
||||||
|
gen_fp_num = 0
|
||||||
|
mask_fp_num = 0
|
||||||
|
gen_fn_num = 0
|
||||||
|
mask_fn_num = 0
|
||||||
|
for step_i, (data, gt_classes) in enumerate(de_dataloader):
|
||||||
|
|
||||||
|
print('evaluating {}/{} ...'.format(step_i + 1, steps_per_epoch))
|
||||||
|
|
||||||
|
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||||
|
fea = network(data_tensor)
|
||||||
|
|
||||||
|
gt_age, gt_gen, gt_mask = gt_classes[0]
|
||||||
|
|
||||||
|
age_result, gen_result, mask_result = fea
|
||||||
|
|
||||||
|
age_result_np = age_result.asnumpy()
|
||||||
|
gen_result_np = gen_result.asnumpy()
|
||||||
|
mask_result_np = mask_result.asnumpy()
|
||||||
|
|
||||||
|
age_prob = softmax(age_result_np[0].astype(np.float32)).tolist()
|
||||||
|
gen_prob = softmax(gen_result_np[0].astype(np.float32)).tolist()
|
||||||
|
mask_prob = softmax(mask_result_np[0].astype(np.float32)).tolist()
|
||||||
|
|
||||||
|
age = age_prob.index(max(age_prob))
|
||||||
|
gen = gen_prob.index(max(gen_prob))
|
||||||
|
mask = mask_prob.index(max(mask_prob))
|
||||||
|
|
||||||
|
if gt_age == age:
|
||||||
|
age_num += 1
|
||||||
|
if gt_gen == gen:
|
||||||
|
gen_num += 1
|
||||||
|
if gt_mask == mask:
|
||||||
|
mask_num += 1
|
||||||
|
|
||||||
|
if gt_gen == 1 and gen == 1:
|
||||||
|
gen_tp_num += 1
|
||||||
|
if gt_gen == 0 and gen == 1:
|
||||||
|
gen_fp_num += 1
|
||||||
|
if gt_gen == 1 and gen == 0:
|
||||||
|
gen_fn_num += 1
|
||||||
|
|
||||||
|
if gt_mask == 1 and mask == 1:
|
||||||
|
mask_tp_num += 1
|
||||||
|
if gt_mask == 0 and mask == 1:
|
||||||
|
mask_fp_num += 1
|
||||||
|
if gt_mask == 1 and mask == 0:
|
||||||
|
mask_fn_num += 1
|
||||||
|
|
||||||
|
if gt_age != -1:
|
||||||
|
total_data_num_age += 1
|
||||||
|
if gt_gen != -1:
|
||||||
|
total_data_num_gen += 1
|
||||||
|
if gt_mask != -1:
|
||||||
|
total_data_num_mask += 1
|
||||||
|
|
||||||
|
age_accuracy = float(age_num) / float(total_data_num_age)
|
||||||
|
|
||||||
|
gen_precision = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fp_num))
|
||||||
|
gen_recall = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fn_num))
|
||||||
|
gen_accuracy = float(gen_num) / float(total_data_num_gen)
|
||||||
|
gen_f1 = 2. * gen_precision * gen_recall / (gen_precision + gen_recall)
|
||||||
|
|
||||||
|
mask_precision = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fp_num))
|
||||||
|
mask_recall = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fn_num))
|
||||||
|
mask_accuracy = float(mask_num) / float(total_data_num_mask)
|
||||||
|
mask_f1 = 2. * mask_precision * mask_recall / (mask_precision + mask_recall)
|
||||||
|
|
||||||
|
print('model: ', ckpt_path)
|
||||||
|
print('total age num: ', total_data_num_age)
|
||||||
|
print('total gen num: ', total_data_num_gen)
|
||||||
|
print('total mask num: ', total_data_num_mask)
|
||||||
|
print('age accuracy: ', age_accuracy)
|
||||||
|
print('gen accuracy: ', gen_accuracy)
|
||||||
|
print('mask accuracy: ', mask_accuracy)
|
||||||
|
print('gen precision: ', gen_precision)
|
||||||
|
print('gen recall: ', gen_recall)
|
||||||
|
print('gen f1: ', gen_f1)
|
||||||
|
print('mask precision: ', mask_precision)
|
||||||
|
print('mask recall: ', mask_recall)
|
||||||
|
print('mask f1: ', mask_f1)
|
||||||
|
|
||||||
|
model_name = os.path.basename(ckpt_path).split('.')[0]
|
||||||
|
model_dir = os.path.dirname(ckpt_path)
|
||||||
|
result_txt = os.path.join(model_dir, model_name + '.txt')
|
||||||
|
if os.path.exists(result_txt):
|
||||||
|
os.remove(result_txt)
|
||||||
|
with open(result_txt, 'a') as ft:
|
||||||
|
ft.write('model: {}\n'.format(ckpt_path))
|
||||||
|
ft.write('total age num: {}\n'.format(total_data_num_age))
|
||||||
|
ft.write('total gen num: {}\n'.format(total_data_num_gen))
|
||||||
|
ft.write('total mask num: {}\n'.format(total_data_num_mask))
|
||||||
|
ft.write('age accuracy: {}\n'.format(age_accuracy))
|
||||||
|
ft.write('gen accuracy: {}\n'.format(gen_accuracy))
|
||||||
|
ft.write('mask accuracy: {}\n'.format(mask_accuracy))
|
||||||
|
ft.write('gen precision: {}\n'.format(gen_precision))
|
||||||
|
ft.write('gen recall: {}\n'.format(gen_recall))
|
||||||
|
ft.write('gen f1: {}\n'.format(gen_f1))
|
||||||
|
ft.write('mask precision: {}\n'.format(mask_precision))
|
||||||
|
ft.write('mask recall: {}\n'.format(mask_recall))
|
||||||
|
ft.write('mask f1: {}\n'.format(mask_f1))
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""parse_args"""
|
||||||
|
parser = argparse.ArgumentParser(description='face attributes eval')
|
||||||
|
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||||
|
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||||
|
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args_1 = parse_args()
|
||||||
|
|
||||||
|
args_1.dst_h = config.dst_h
|
||||||
|
args_1.dst_w = config.dst_w
|
||||||
|
args_1.attri_num = config.attri_num
|
||||||
|
args_1.classes = config.classes
|
||||||
|
args_1.flat_dim = config.flat_dim
|
||||||
|
args_1.fc_dim = config.fc_dim
|
||||||
|
args_1.workers = config.workers
|
||||||
|
|
||||||
|
main(args_1)
|
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Convert ckpt to air."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.FaceAttribute.resnet18_softmax import get_resnet18
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
network = get_resnet18(args)
|
||||||
|
ckpt_path = args.model_path
|
||||||
|
if os.path.isfile(ckpt_path):
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
param_dict_new = {}
|
||||||
|
for key, values in param_dict.items():
|
||||||
|
if key.startswith('moments.'):
|
||||||
|
continue
|
||||||
|
elif key.startswith('network.'):
|
||||||
|
param_dict_new[key[8:]] = values
|
||||||
|
else:
|
||||||
|
param_dict_new[key] = values
|
||||||
|
load_param_into_net(network, param_dict_new)
|
||||||
|
print('-----------------------load model success-----------------------')
|
||||||
|
else:
|
||||||
|
print('-----------------------load model failed -----------------------')
|
||||||
|
|
||||||
|
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32)
|
||||||
|
tensor_input_data = Tensor(input_data)
|
||||||
|
|
||||||
|
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
|
||||||
|
file_format='AIR')
|
||||||
|
print('-----------------------export model success-----------------------')
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""parse_args"""
|
||||||
|
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||||
|
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||||
|
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args_1 = parse_args()
|
||||||
|
|
||||||
|
args_1.dst_h = config.dst_h
|
||||||
|
args_1.dst_w = config.dst_w
|
||||||
|
args_1.attri_num = config.attri_num
|
||||||
|
args_1.classes = config.classes
|
||||||
|
args_1.flat_dim = config.flat_dim
|
||||||
|
args_1.fc_dim = config.fc_dim
|
||||||
|
|
||||||
|
main(args_1)
|
@ -0,0 +1,81 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ] && [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
|
||||||
|
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
SCRIPT_NAME='train.py'
|
||||||
|
|
||||||
|
rm -rf ${current_exec_path}/device*
|
||||||
|
|
||||||
|
ulimit -c unlimited
|
||||||
|
|
||||||
|
MINDRECORD_FILE=$(get_real_path $1)
|
||||||
|
RANK_TABLE=$(get_real_path $2)
|
||||||
|
PRETRAINED_BACKBONE=''
|
||||||
|
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo $MINDRECORD_FILE
|
||||||
|
echo $RANK_TABLE
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=$RANK_TABLE
|
||||||
|
export RANK_SIZE=8
|
||||||
|
|
||||||
|
echo 'start training'
|
||||||
|
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||||
|
do
|
||||||
|
echo 'start rank '$i
|
||||||
|
mkdir ${current_exec_path}/device$i
|
||||||
|
cd ${current_exec_path}/device$i || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
dev=`expr $i + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
python ${dirname_path}/${SCRIPT_NAME} \
|
||||||
|
--mindrecord_path=$MINDRECORD_FILE \
|
||||||
|
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'running'
|
@ -0,0 +1,71 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
SCRIPT_NAME='eval.py'
|
||||||
|
|
||||||
|
ulimit -c unlimited
|
||||||
|
|
||||||
|
MINDRECORD_FILE=$(get_real_path $1)
|
||||||
|
USE_DEVICE_ID=$2
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo $MINDRECORD_FILE
|
||||||
|
echo $USE_DEVICE_ID
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
echo 'start evaluating'
|
||||||
|
export RANK_ID=0
|
||||||
|
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
echo 'start device '$USE_DEVICE_ID
|
||||||
|
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
python ${dirname_path}/${SCRIPT_NAME} \
|
||||||
|
--mindrecord_path=$MINDRECORD_FILE \
|
||||||
|
--model_path=$PRETRAINED_BACKBONE > eval.log 2>&1 &
|
||||||
|
|
||||||
|
echo 'running'
|
@ -0,0 +1,71 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
SCRIPT_NAME='export.py'
|
||||||
|
|
||||||
|
ulimit -c unlimited
|
||||||
|
|
||||||
|
BATCH_SIZE=$1
|
||||||
|
USE_DEVICE_ID=$2
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo $BATCH_SIZE
|
||||||
|
echo $USE_DEVICE_ID
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
echo 'start converting'
|
||||||
|
export RANK_ID=0
|
||||||
|
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
echo 'start device '$USE_DEVICE_ID
|
||||||
|
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
python ${dirname_path}/${SCRIPT_NAME} \
|
||||||
|
--batch_size=$BATCH_SIZE \
|
||||||
|
--model_path=$PRETRAINED_BACKBONE > convert.log 2>&1 &
|
||||||
|
|
||||||
|
echo 'running'
|
@ -0,0 +1,77 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ] && [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||||
|
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
SCRIPT_NAME='train.py'
|
||||||
|
|
||||||
|
ulimit -c unlimited
|
||||||
|
|
||||||
|
MINDRECORD_FILE=$(get_real_path $1)
|
||||||
|
USE_DEVICE_ID=$2
|
||||||
|
PRETRAINED_BACKBONE=''
|
||||||
|
|
||||||
|
if [ $# == 3 ]
|
||||||
|
then
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo $MINDRECORD_FILE
|
||||||
|
echo $USE_DEVICE_ID
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
echo 'start training'
|
||||||
|
export RANK_ID=0
|
||||||
|
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
echo 'start device '$USE_DEVICE_ID
|
||||||
|
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||||
|
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
python ${dirname_path}/${SCRIPT_NAME} \
|
||||||
|
--world_size=1 \
|
||||||
|
--mindrecord_path=$MINDRECORD_FILE \
|
||||||
|
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||||
|
|
||||||
|
echo 'running'
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute cross entropy."""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropyWithIgnoreIndex(nn.Cell):
|
||||||
|
'''Cross Entropy With Ignore Index Loss.'''
|
||||||
|
def __init__(self):
|
||||||
|
super(CrossEntropyWithIgnoreIndex, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||||
|
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
self.greater = P.Greater()
|
||||||
|
self.maximum = P.Maximum()
|
||||||
|
self.fill = P.Fill()
|
||||||
|
self.sum = P.ReduceSum(keep_dims=False)
|
||||||
|
self.dtype = P.DType()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.const_one = Tensor(np.ones([1]), dtype=mstype.float32)
|
||||||
|
self.const_eps = Tensor(0.00001, dtype=mstype.float32)
|
||||||
|
|
||||||
|
def construct(self, x, label):
|
||||||
|
'''Construct function.'''
|
||||||
|
mask = self.reshape(label, (F.shape(label)[0], 1))
|
||||||
|
mask = self.cast(mask, mstype.float32)
|
||||||
|
mask = mask + self.const_eps
|
||||||
|
mask = self.relu(mask)/mask
|
||||||
|
x = x * mask
|
||||||
|
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(x)[1], self.on_value, self.off_value)
|
||||||
|
loss = self.ce(x, one_hot_label)
|
||||||
|
positive = self.sum(self.cast(self.greater(loss, self.fill(self.dtype(loss), F.shape(loss), 0.0)),
|
||||||
|
mstype.float32), 0)
|
||||||
|
positive = self.maximum(positive, self.const_one)
|
||||||
|
loss = self.sum(loss, 0) / positive
|
||||||
|
return loss
|
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute network unit."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.nn import Dense
|
||||||
|
|
||||||
|
|
||||||
|
class Cut(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def bn_with_initialize(out_channels):
|
||||||
|
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5)
|
||||||
|
return bn
|
||||||
|
|
||||||
|
|
||||||
|
def fc_with_initialize(input_channels, out_channels):
|
||||||
|
return Dense(input_channels, out_channels)
|
||||||
|
|
||||||
|
|
||||||
|
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||||
|
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=False,
|
||||||
|
padding=padding)
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute head."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||||
|
|
||||||
|
__all__ = ['get_attri_head']
|
||||||
|
|
||||||
|
|
||||||
|
class AttriHead(Cell):
|
||||||
|
'''Attribute Head.'''
|
||||||
|
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||||
|
super(AttriHead, self).__init__()
|
||||||
|
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc1_relu = P.ReLU()
|
||||||
|
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||||
|
self.attri_fc1_relu = P.ReLU()
|
||||||
|
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||||
|
|
||||||
|
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc2_relu = P.ReLU()
|
||||||
|
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||||
|
self.attri_fc2_relu = P.ReLU()
|
||||||
|
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||||
|
|
||||||
|
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc3_relu = P.ReLU()
|
||||||
|
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||||
|
self.attri_fc3_relu = P.ReLU()
|
||||||
|
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''Construct function.'''
|
||||||
|
output0 = self.fc1(x)
|
||||||
|
output0 = self.fc1_relu(output0)
|
||||||
|
output0 = self.fc1_bn(output0)
|
||||||
|
output0 = self.attri_fc1(output0)
|
||||||
|
output0 = self.attri_fc1_relu(output0)
|
||||||
|
output0 = self.attri_bn1(output0)
|
||||||
|
|
||||||
|
output1 = self.fc2(x)
|
||||||
|
output1 = self.fc2_relu(output1)
|
||||||
|
output1 = self.fc2_bn(output1)
|
||||||
|
output1 = self.attri_fc2(output1)
|
||||||
|
output1 = self.attri_fc2_relu(output1)
|
||||||
|
output1 = self.attri_bn2(output1)
|
||||||
|
|
||||||
|
output2 = self.fc3(x)
|
||||||
|
output2 = self.fc3_relu(output2)
|
||||||
|
output2 = self.fc3_bn(output2)
|
||||||
|
output2 = self.attri_fc3(output2)
|
||||||
|
output2 = self.attri_fc3_relu(output2)
|
||||||
|
output2 = self.attri_bn3(output2)
|
||||||
|
|
||||||
|
return output0, output1, output2
|
||||||
|
|
||||||
|
|
||||||
|
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||||
|
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||||
|
return attri_head
|
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute head."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||||
|
|
||||||
|
__all__ = ['get_attri_head']
|
||||||
|
|
||||||
|
|
||||||
|
class AttriHead(Cell):
|
||||||
|
'''Attribute Head.'''
|
||||||
|
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||||
|
super(AttriHead, self).__init__()
|
||||||
|
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc1_relu = P.ReLU()
|
||||||
|
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||||
|
self.attri_fc1_relu = P.ReLU()
|
||||||
|
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||||
|
self.softmax1 = P.Softmax()
|
||||||
|
|
||||||
|
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc2_relu = P.ReLU()
|
||||||
|
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||||
|
self.attri_fc2_relu = P.ReLU()
|
||||||
|
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||||
|
self.softmax2 = P.Softmax()
|
||||||
|
|
||||||
|
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||||
|
self.fc3_relu = P.ReLU()
|
||||||
|
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||||
|
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||||
|
self.attri_fc3_relu = P.ReLU()
|
||||||
|
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||||
|
self.softmax3 = P.Softmax()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''Construct function.'''
|
||||||
|
output0 = self.fc1(x)
|
||||||
|
output0 = self.fc1_relu(output0)
|
||||||
|
output0 = self.fc1_bn(output0)
|
||||||
|
output0 = self.attri_fc1(output0)
|
||||||
|
output0 = self.attri_fc1_relu(output0)
|
||||||
|
output0 = self.attri_bn1(output0)
|
||||||
|
output0 = self.softmax1(output0)
|
||||||
|
|
||||||
|
output1 = self.fc2(x)
|
||||||
|
output1 = self.fc2_relu(output1)
|
||||||
|
output1 = self.fc2_bn(output1)
|
||||||
|
output1 = self.attri_fc2(output1)
|
||||||
|
output1 = self.attri_fc2_relu(output1)
|
||||||
|
output1 = self.attri_bn2(output1)
|
||||||
|
output1 = self.softmax2(output1)
|
||||||
|
|
||||||
|
output2 = self.fc3(x)
|
||||||
|
output2 = self.fc3_relu(output2)
|
||||||
|
output2 = self.fc3_bn(output2)
|
||||||
|
output2 = self.attri_fc3(output2)
|
||||||
|
output2 = self.attri_fc3_relu(output2)
|
||||||
|
output2 = self.attri_bn3(output2)
|
||||||
|
output2 = self.softmax3(output2)
|
||||||
|
|
||||||
|
return output0, output1, output2
|
||||||
|
|
||||||
|
|
||||||
|
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||||
|
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||||
|
return attri_head
|
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute loss."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
from src.FaceAttribute.cross_entropy import CrossEntropyWithIgnoreIndex
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['get_loss']
|
||||||
|
|
||||||
|
|
||||||
|
class CriterionsFaceAttri(nn.Cell):
|
||||||
|
'''Criterions Face Attribute.'''
|
||||||
|
def __init__(self):
|
||||||
|
super(CriterionsFaceAttri, self).__init__()
|
||||||
|
|
||||||
|
# label
|
||||||
|
self.gatherv2 = P.GatherV2()
|
||||||
|
self.squeeze = P.Squeeze(axis=1)
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.mean = P.ReduceMean()
|
||||||
|
|
||||||
|
self.label0_param = Tensor([0], dtype=mstype.int32)
|
||||||
|
self.label1_param = Tensor([1], dtype=mstype.int32)
|
||||||
|
self.label2_param = Tensor([2], dtype=mstype.int32)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
self.ce_ignore_loss = CrossEntropyWithIgnoreIndex()
|
||||||
|
self.printn = P.Print()
|
||||||
|
|
||||||
|
def construct(self, x0, x1, x2, label):
|
||||||
|
'''Construct function.'''
|
||||||
|
# each sub attribute loss
|
||||||
|
label0 = self.squeeze(self.gatherv2(label, self.label0_param, 1))
|
||||||
|
loss0 = self.ce_ignore_loss(x0, label0)
|
||||||
|
|
||||||
|
label1 = self.squeeze(self.gatherv2(label, self.label1_param, 1))
|
||||||
|
loss1 = self.ce_ignore_loss(x1, label1)
|
||||||
|
|
||||||
|
label2 = self.squeeze(self.gatherv2(label, self.label2_param, 1))
|
||||||
|
loss2 = self.ce_ignore_loss(x2, label2)
|
||||||
|
|
||||||
|
loss = loss0 + loss1 + loss2
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss():
|
||||||
|
return CriterionsFaceAttri()
|
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute resnet18 backbone."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops.operations import TensorAdd
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||||
|
from src.FaceAttribute.head_factory import get_attri_head
|
||||||
|
|
||||||
|
__all__ = ['get_resnet18']
|
||||||
|
|
||||||
|
|
||||||
|
class IRBlock(Cell):
|
||||||
|
'''IRBlock.'''
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(IRBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||||
|
self.bn1 = bn_with_initialize(planes)
|
||||||
|
self.relu1 = P.ReLU()
|
||||||
|
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||||
|
self.bn2 = bn_with_initialize(planes)
|
||||||
|
|
||||||
|
if downsample is None:
|
||||||
|
self.downsample = Cut()
|
||||||
|
else:
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
self.add = TensorAdd()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.relu2 = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu1(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
identity = self.downsample(x)
|
||||||
|
out = self.add(out, identity)
|
||||||
|
out = self.relu2(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample(Cell):
|
||||||
|
def __init__(self, inplanes, planes, expansion, stride):
|
||||||
|
super(DownSample, self).__init__()
|
||||||
|
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||||
|
self.bn1 = bn_with_initialize(planes * expansion)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MakeLayer(Cell):
|
||||||
|
'''Make layer.'''
|
||||||
|
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||||
|
super(MakeLayer, self).__init__()
|
||||||
|
|
||||||
|
self.inplanes = inplanes
|
||||||
|
self.downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||||
|
self.inplanes = planes
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
self.layers.append(block(self.inplanes, planes))
|
||||||
|
self.layers = nn.CellList(self.layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
for block in self.layers:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttriResNet(Cell):
|
||||||
|
'''Resnet for attribute.'''
|
||||||
|
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||||
|
super(AttriResNet, self).__init__()
|
||||||
|
|
||||||
|
# resnet18
|
||||||
|
self.inplanes = 32
|
||||||
|
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||||
|
self.bn1 = bn_with_initialize(self.inplanes)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||||
|
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||||
|
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||||
|
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||||
|
|
||||||
|
# avg global pooling
|
||||||
|
self.mean = P.ReduceMean(keep_dims=True)
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''Construct function.'''
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.mean(x, (2, 3))
|
||||||
|
b, c, _, _ = self.shape(x)
|
||||||
|
x = self.reshape(x, (b, c))
|
||||||
|
return self.head(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet18(args):
|
||||||
|
'''Build resnet18 for attribute.'''
|
||||||
|
flat_dim = args.flat_dim
|
||||||
|
fc_dim = args.fc_dim
|
||||||
|
str_classes = args.classes.strip().split(',')
|
||||||
|
if args.attri_num != len(str_classes):
|
||||||
|
print('args warning: attri_num != classes num')
|
||||||
|
return None
|
||||||
|
attri_num_list = []
|
||||||
|
for i, _ in enumerate(str_classes):
|
||||||
|
attri_num_list.append(int(str_classes[i]))
|
||||||
|
|
||||||
|
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||||
|
|
||||||
|
return attri_resnet18
|
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute resnet18 backbone."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops.operations import TensorAdd
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn import Cell
|
||||||
|
|
||||||
|
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||||
|
from src.FaceAttribute.head_factory_softmax import get_attri_head
|
||||||
|
|
||||||
|
__all__ = ['get_resnet18']
|
||||||
|
|
||||||
|
|
||||||
|
class IRBlock(Cell):
|
||||||
|
'''IRBlock.'''
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(IRBlock, self).__init__()
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||||
|
self.bn1 = bn_with_initialize(planes)
|
||||||
|
self.relu1 = P.ReLU()
|
||||||
|
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||||
|
self.bn2 = bn_with_initialize(planes)
|
||||||
|
|
||||||
|
if downsample is None:
|
||||||
|
self.downsample = Cut()
|
||||||
|
else:
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
self.add = TensorAdd()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.relu2 = P.ReLU()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu1(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
identity = self.downsample(x)
|
||||||
|
out = self.add(out, identity)
|
||||||
|
out = self.relu2(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample(Cell):
|
||||||
|
def __init__(self, inplanes, planes, expansion, stride):
|
||||||
|
super(DownSample, self).__init__()
|
||||||
|
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||||
|
self.bn1 = bn_with_initialize(planes * expansion)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MakeLayer(Cell):
|
||||||
|
'''Make layer function.'''
|
||||||
|
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||||
|
super(MakeLayer, self).__init__()
|
||||||
|
|
||||||
|
self.inplanes = inplanes
|
||||||
|
self.downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||||
|
self.inplanes = planes
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
self.layers.append(block(self.inplanes, planes))
|
||||||
|
self.layers = nn.CellList(self.layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
for block in self.layers:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttriResNet(Cell):
|
||||||
|
'''Resnet for attribute.'''
|
||||||
|
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||||
|
super(AttriResNet, self).__init__()
|
||||||
|
|
||||||
|
# resnet18
|
||||||
|
self.inplanes = 32
|
||||||
|
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||||
|
self.bn1 = bn_with_initialize(self.inplanes)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||||
|
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||||
|
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||||
|
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||||
|
|
||||||
|
# avg global pooling
|
||||||
|
self.mean = P.ReduceMean(keep_dims=True)
|
||||||
|
self.shape = P.Shape()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
'''Construct function.'''
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.mean(x, (2, 3))
|
||||||
|
b, c, _, _ = self.shape(x)
|
||||||
|
x = self.reshape(x, (b, c))
|
||||||
|
return self.head(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet18(args):
|
||||||
|
'''Build resnet18 for attribute.'''
|
||||||
|
flat_dim = args.flat_dim
|
||||||
|
fc_dim = args.fc_dim
|
||||||
|
str_classes = args.classes.strip().split(',')
|
||||||
|
if args.attri_num != len(str_classes):
|
||||||
|
print('args warning: attri_num != classes num')
|
||||||
|
return None
|
||||||
|
attri_num_list = []
|
||||||
|
for i, _ in enumerate(str_classes):
|
||||||
|
attri_num_list.append(int(str_classes[i]))
|
||||||
|
|
||||||
|
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||||
|
|
||||||
|
return attri_resnet18
|
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""Network config setting, will be used in train.py and eval.py"""
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
config = ed({
|
||||||
|
'per_batch_size': 128,
|
||||||
|
'dst_h': 112,
|
||||||
|
'dst_w': 112,
|
||||||
|
'workers': 8,
|
||||||
|
'attri_num': 3,
|
||||||
|
'classes': '9,2,2',
|
||||||
|
'backbone': 'resnet18',
|
||||||
|
'loss_scale': 1024,
|
||||||
|
'flat_dim': 512,
|
||||||
|
'fc_dim': 256,
|
||||||
|
'lr': 0.009,
|
||||||
|
'lr_scale': 1,
|
||||||
|
'lr_epochs': [20, 30, 50],
|
||||||
|
'weight_decay': 0.0005,
|
||||||
|
'momentum': 0.9,
|
||||||
|
'max_epoch': 70,
|
||||||
|
'warmup_epochs': 0,
|
||||||
|
'log_interval': 10,
|
||||||
|
'ckpt_path': '../../output',
|
||||||
|
|
||||||
|
# data_to_mindrecord parameter
|
||||||
|
'eval_dataset_txt_file': 'Your_label_txt_file',
|
||||||
|
'eval_mindrecord_file_name': 'Your_output_path/data_test.mindrecord',
|
||||||
|
'train_dataset_txt_file': 'Your_label_txt_file',
|
||||||
|
'train_mindrecord_file_name': 'Your_output_path/data_train.mindrecord',
|
||||||
|
'train_append_dataset_txt_file': 'Your_label_txt_file',
|
||||||
|
'train_append_mindrecord_file_name': 'Your_previous_output_path/data_train.mindrecord0'
|
||||||
|
})
|
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Convert dataset to mindrecord for evaluating Face attribute."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
dataset_txt_file = config.eval_dataset_txt_file
|
||||||
|
|
||||||
|
mindrecord_file_name = config.eval_mindrecord_file_name
|
||||||
|
|
||||||
|
mindrecord_num = 8
|
||||||
|
|
||||||
|
|
||||||
|
def convert_data_to_mindrecord():
|
||||||
|
'''Convert data to mindrecord.'''
|
||||||
|
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||||
|
attri_json = {
|
||||||
|
"image": {"type": "bytes"},
|
||||||
|
"label": {"type": "int32", "shape": [-1]}
|
||||||
|
}
|
||||||
|
|
||||||
|
print('Loading eval data...')
|
||||||
|
total_data = []
|
||||||
|
with open(dataset_txt_file, 'r') as ft:
|
||||||
|
lines = ft.readlines()
|
||||||
|
for line in lines:
|
||||||
|
sline = line.strip().split(" ")
|
||||||
|
image_file = sline[0]
|
||||||
|
labels = []
|
||||||
|
for item in sline[1:]:
|
||||||
|
labels.append(int(item))
|
||||||
|
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"image": img,
|
||||||
|
"label": np.array(labels, dtype='int32')
|
||||||
|
}
|
||||||
|
|
||||||
|
total_data.append(data)
|
||||||
|
|
||||||
|
print('Writing eval data to mindrecord...')
|
||||||
|
writer.add_schema(attri_json, "attri_json")
|
||||||
|
if total_data is None:
|
||||||
|
raise ValueError("None needs writing to mindrecord.")
|
||||||
|
writer.write_raw_data(total_data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
convert_data_to_mindrecord()
|
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Convert dataset to mindrecord for training Face attribute."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
dataset_txt_file = config.train_dataset_txt_file
|
||||||
|
|
||||||
|
mindrecord_file_name = config.train_mindrecord_file_name
|
||||||
|
|
||||||
|
mindrecord_num = 8
|
||||||
|
|
||||||
|
|
||||||
|
def convert_data_to_mindrecord():
|
||||||
|
'''Covert data to mindrecord.'''
|
||||||
|
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||||
|
attri_json = {
|
||||||
|
"image": {"type": "bytes"},
|
||||||
|
"label": {"type": "int32", "shape": [-1]}
|
||||||
|
}
|
||||||
|
|
||||||
|
print('Loading train data...')
|
||||||
|
total_data = []
|
||||||
|
with open(dataset_txt_file, 'r') as ft:
|
||||||
|
lines = ft.readlines()
|
||||||
|
for line in lines:
|
||||||
|
sline = line.strip().split(" ")
|
||||||
|
image_file = sline[0]
|
||||||
|
labels = []
|
||||||
|
for item in sline[1:]:
|
||||||
|
labels.append(int(item))
|
||||||
|
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"image": img,
|
||||||
|
"label": np.array(labels, dtype='int32')
|
||||||
|
}
|
||||||
|
|
||||||
|
total_data.append(data)
|
||||||
|
|
||||||
|
print('Writing train data to mindrecord...')
|
||||||
|
writer.add_schema(attri_json, "attri_json")
|
||||||
|
if total_data is None:
|
||||||
|
raise ValueError("None needs writing to mindrecord.")
|
||||||
|
writer.write_raw_data(total_data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
convert_data_to_mindrecord()
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Add dataset to an existed mindrecord for training Face attribute."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
|
||||||
|
dataset_txt_file = config.train_append_dataset_txt_file
|
||||||
|
|
||||||
|
mindrecord_file_name = config.train_append_mindrecord_file_name
|
||||||
|
|
||||||
|
mindrecord_num = 8
|
||||||
|
|
||||||
|
|
||||||
|
def convert_data_to_mindrecord():
|
||||||
|
'''Covert data to mindrecord.'''
|
||||||
|
print('Loading mindrecord...')
|
||||||
|
writer = FileWriter.open_for_append(mindrecord_file_name)
|
||||||
|
|
||||||
|
print('Loading train data...')
|
||||||
|
total_data = []
|
||||||
|
with open(dataset_txt_file, 'r') as ft:
|
||||||
|
lines = ft.readlines()
|
||||||
|
for line in lines:
|
||||||
|
sline = line.strip().split(" ")
|
||||||
|
image_file = sline[0]
|
||||||
|
labels = []
|
||||||
|
for item in sline[1:]:
|
||||||
|
labels.append(int(item))
|
||||||
|
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"image": img,
|
||||||
|
"label": np.array(labels, dtype='int32')
|
||||||
|
}
|
||||||
|
|
||||||
|
total_data.append(data)
|
||||||
|
|
||||||
|
print('Writing train data to mindrecord...')
|
||||||
|
if total_data is None:
|
||||||
|
raise ValueError("None needs writing to mindrecord.")
|
||||||
|
writer.write_raw_data(total_data)
|
||||||
|
writer.commit()
|
||||||
|
|
||||||
|
|
||||||
|
convert_data_to_mindrecord()
|
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute dataset for eval"""
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.vision.py_transforms as F
|
||||||
|
import mindspore.dataset.transforms.py_transforms as F2
|
||||||
|
|
||||||
|
__all__ = ['data_generator_eval']
|
||||||
|
|
||||||
|
|
||||||
|
def data_generator_eval(args):
|
||||||
|
'''Build eval dataloader.'''
|
||||||
|
mindrecord_path = args.mindrecord_path
|
||||||
|
dst_w = args.dst_w
|
||||||
|
dst_h = args.dst_h
|
||||||
|
batch_size = 1
|
||||||
|
attri_num = args.attri_num
|
||||||
|
transform_img = F2.Compose([F.Decode(),
|
||||||
|
F.Resize((dst_w, dst_h)),
|
||||||
|
F.ToTensor(),
|
||||||
|
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"])
|
||||||
|
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||||
|
python_multiprocessing=True)
|
||||||
|
de_dataset = de_dataset.batch(batch_size)
|
||||||
|
|
||||||
|
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||||
|
steps_per_epoch = de_dataset.get_dataset_size()
|
||||||
|
print("image number:{0}".format(steps_per_epoch))
|
||||||
|
num_classes = attri_num
|
||||||
|
|
||||||
|
return de_dataloader, steps_per_epoch, num_classes
|
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute dataset for train"""
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.vision.py_transforms as F
|
||||||
|
import mindspore.dataset.transforms.py_transforms as F2
|
||||||
|
|
||||||
|
__all__ = ['data_generator']
|
||||||
|
|
||||||
|
|
||||||
|
def data_generator(args):
|
||||||
|
'''Build train dataloader.'''
|
||||||
|
mindrecord_path = args.mindrecord_path
|
||||||
|
dst_w = args.dst_w
|
||||||
|
dst_h = args.dst_h
|
||||||
|
batch_size = args.per_batch_size
|
||||||
|
attri_num = args.attri_num
|
||||||
|
max_epoch = args.max_epoch
|
||||||
|
transform_img = F2.Compose([F.Decode(),
|
||||||
|
F.Resize((dst_w, dst_h)),
|
||||||
|
F.RandomHorizontalFlip(prob=0.5),
|
||||||
|
F.ToTensor(),
|
||||||
|
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||||
|
|
||||||
|
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"], num_shards=args.world_size,
|
||||||
|
shard_id=args.local_rank)
|
||||||
|
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||||
|
python_multiprocessing=True)
|
||||||
|
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
|
||||||
|
steps_per_epoch = de_dataset.get_dataset_size()
|
||||||
|
de_dataset = de_dataset.repeat(max_epoch)
|
||||||
|
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||||
|
|
||||||
|
num_classes = attri_num
|
||||||
|
|
||||||
|
return de_dataloader, steps_per_epoch, num_classes
|
@ -0,0 +1,105 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Custom logger."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logger_name_1 = 'face_attributes'
|
||||||
|
|
||||||
|
|
||||||
|
class LOGGER(logging.Logger):
|
||||||
|
'''Logger.'''
|
||||||
|
def __init__(self, logger_name):
|
||||||
|
super(LOGGER, self).__init__(logger_name)
|
||||||
|
console = logging.StreamHandler(sys.stdout)
|
||||||
|
console.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
console.setFormatter(formatter)
|
||||||
|
self.addHandler(console)
|
||||||
|
self.local_rank = 0
|
||||||
|
|
||||||
|
def setup_logging_file(self, log_dir, local_rank=0):
|
||||||
|
'''Setup logging file.'''
|
||||||
|
self.local_rank = local_rank
|
||||||
|
if self.local_rank == 0:
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||||
|
self.log_fn = os.path.join(log_dir, log_name)
|
||||||
|
fh = logging.FileHandler(self.log_fn)
|
||||||
|
fh.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
self.addHandler(fh)
|
||||||
|
|
||||||
|
def info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||||
|
self._log(logging.INFO, msg, args, **kwargs)
|
||||||
|
|
||||||
|
def save_args(self, args):
|
||||||
|
self.info('Args:')
|
||||||
|
args_dict = vars(args)
|
||||||
|
for key in args_dict.keys():
|
||||||
|
self.info('--> %s: %s', key, args_dict[key])
|
||||||
|
self.info('')
|
||||||
|
|
||||||
|
def important_info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||||
|
line_width = 2
|
||||||
|
important_msg = '\n'
|
||||||
|
important_msg += ('*'*70 + '\n')*line_width
|
||||||
|
important_msg += ('*'*line_width + '\n')*2
|
||||||
|
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||||
|
important_msg += ('*'*line_width + '\n')*2
|
||||||
|
important_msg += ('*'*70 + '\n')*line_width
|
||||||
|
self.info(important_msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(path, rank):
|
||||||
|
logger = LOGGER(logger_name_1)
|
||||||
|
logger.setup_logging_file(path, rank)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter():
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||||
|
self.name = name
|
||||||
|
self.fmt = fmt
|
||||||
|
self.reset()
|
||||||
|
self.tb_writer = tb_writer
|
||||||
|
self.cur_step = 1
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
if self.tb_writer is not None:
|
||||||
|
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||||
|
self.cur_step += 1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||||
|
return fmtstr.format(**self.__dict__)
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute learning rate scheduler."""
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||||
|
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||||
|
learning_rate = float(init_lr) + lr_inc * current_step
|
||||||
|
return learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_step(args, gamma=0.1):
|
||||||
|
'''Warmup step.'''
|
||||||
|
base_lr = args.lr
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||||
|
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
|
||||||
|
milestones = args.lr_epochs
|
||||||
|
milestones_steps = []
|
||||||
|
for milestone in milestones:
|
||||||
|
milestones_step = milestone * args.steps_per_epoch
|
||||||
|
milestones_steps.append(milestones_step)
|
||||||
|
|
||||||
|
lr = base_lr
|
||||||
|
milestones_steps_counter = Counter(milestones_steps)
|
||||||
|
for i in range(total_steps):
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
lr = lr * gamma**milestones_steps_counter[i]
|
||||||
|
yield lr
|
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Face attribute train."""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.nn.optim import Momentum
|
||||||
|
from mindspore.communication.management import get_group_size, init, get_rank
|
||||||
|
from mindspore.nn import TrainOneStepCell
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
from src.FaceAttribute.resnet18 import get_resnet18
|
||||||
|
from src.FaceAttribute.loss_factory import get_loss
|
||||||
|
from src.dataset_train import data_generator
|
||||||
|
from src.lrsche_factory import warmup_step
|
||||||
|
from src.logging import get_logger, AverageMeter
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||||
|
|
||||||
|
|
||||||
|
class BuildTrainNetwork(nn.Cell):
|
||||||
|
'''Build train network.'''
|
||||||
|
def __init__(self, network, criterion):
|
||||||
|
super(BuildTrainNetwork, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.criterion = criterion
|
||||||
|
self.print = P.Print()
|
||||||
|
|
||||||
|
def construct(self, input_data, label):
|
||||||
|
logit0, logit1, logit2 = self.network(input_data)
|
||||||
|
loss = self.criterion(logit0, logit1, logit2, label)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
'''Argument for Face Attributes.'''
|
||||||
|
parser = argparse.ArgumentParser('Face Attributes')
|
||||||
|
|
||||||
|
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||||
|
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||||
|
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
||||||
|
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
|
||||||
|
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
'''train function.'''
|
||||||
|
# logger
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# init distributed
|
||||||
|
if args.world_size != 1:
|
||||||
|
init()
|
||||||
|
args.local_rank = get_rank()
|
||||||
|
args.world_size = get_group_size()
|
||||||
|
|
||||||
|
args.per_batch_size = config.per_batch_size
|
||||||
|
args.dst_h = config.dst_h
|
||||||
|
args.dst_w = config.dst_w
|
||||||
|
args.workers = config.workers
|
||||||
|
args.attri_num = config.attri_num
|
||||||
|
args.classes = config.classes
|
||||||
|
args.backbone = config.backbone
|
||||||
|
args.loss_scale = config.loss_scale
|
||||||
|
args.flat_dim = config.flat_dim
|
||||||
|
args.fc_dim = config.fc_dim
|
||||||
|
args.lr = config.lr
|
||||||
|
args.lr_scale = config.lr_scale
|
||||||
|
args.lr_epochs = config.lr_epochs
|
||||||
|
args.weight_decay = config.weight_decay
|
||||||
|
args.momentum = config.momentum
|
||||||
|
args.max_epoch = config.max_epoch
|
||||||
|
args.warmup_epochs = config.warmup_epochs
|
||||||
|
args.log_interval = config.log_interval
|
||||||
|
args.ckpt_path = config.ckpt_path
|
||||||
|
|
||||||
|
if args.world_size == 1:
|
||||||
|
args.per_batch_size = 256
|
||||||
|
else:
|
||||||
|
args.lr = args.lr * 4.
|
||||||
|
|
||||||
|
if args.world_size != 1:
|
||||||
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
|
else:
|
||||||
|
parallel_mode = ParallelMode.STAND_ALONE
|
||||||
|
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
|
||||||
|
|
||||||
|
# model and log save path
|
||||||
|
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||||
|
args.logger = get_logger(args.outputs_dir, args.local_rank)
|
||||||
|
loss_meter = AverageMeter('loss')
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
args.logger.info('start create dataloader')
|
||||||
|
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
|
||||||
|
args.steps_per_epoch = steps_per_epoch
|
||||||
|
args.num_classes = num_classes
|
||||||
|
args.logger.info('end create dataloader')
|
||||||
|
args.logger.save_args(args)
|
||||||
|
|
||||||
|
# backbone and loss
|
||||||
|
args.logger.important_info('start create network')
|
||||||
|
create_network_start = time.time()
|
||||||
|
network = get_resnet18(args)
|
||||||
|
|
||||||
|
criterion = get_loss()
|
||||||
|
|
||||||
|
# load pretrain model
|
||||||
|
if os.path.isfile(args.pretrained):
|
||||||
|
param_dict = load_checkpoint(args.pretrained)
|
||||||
|
param_dict_new = {}
|
||||||
|
for key, values in param_dict.items():
|
||||||
|
if key.startswith('moments.'):
|
||||||
|
continue
|
||||||
|
elif key.startswith('network.'):
|
||||||
|
param_dict_new[key[8:]] = values
|
||||||
|
else:
|
||||||
|
param_dict_new[key] = values
|
||||||
|
load_param_into_net(network, param_dict_new)
|
||||||
|
args.logger.info('load model {} success'.format(args.pretrained))
|
||||||
|
|
||||||
|
# optimizer and lr scheduler
|
||||||
|
lr = warmup_step(args, gamma=0.1)
|
||||||
|
opt = Momentum(params=network.trainable_params(),
|
||||||
|
learning_rate=lr,
|
||||||
|
momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
loss_scale=args.loss_scale)
|
||||||
|
|
||||||
|
train_net = BuildTrainNetwork(network, criterion)
|
||||||
|
|
||||||
|
# mixed precision training
|
||||||
|
criterion.add_flags_recursive(fp32=True)
|
||||||
|
|
||||||
|
# package training process
|
||||||
|
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
# checkpoint
|
||||||
|
if args.local_rank == 0:
|
||||||
|
ckpt_max_num = args.max_epoch
|
||||||
|
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
|
||||||
|
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
|
||||||
|
cb_params = _InternalCallbackParam()
|
||||||
|
cb_params.train_network = train_net
|
||||||
|
cb_params.epoch_num = ckpt_max_num
|
||||||
|
cb_params.cur_epoch_num = 0
|
||||||
|
run_context = RunContext(cb_params)
|
||||||
|
ckpt_cb.begin(run_context)
|
||||||
|
|
||||||
|
train_net.set_train()
|
||||||
|
t_end = time.time()
|
||||||
|
t_epoch = time.time()
|
||||||
|
old_progress = -1
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for _, (data, gt_classes) in enumerate(de_dataloader):
|
||||||
|
|
||||||
|
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||||
|
gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
|
||||||
|
|
||||||
|
loss = train_net(data_tensor, gt_tensor)
|
||||||
|
loss_meter.update(loss.asnumpy()[0])
|
||||||
|
|
||||||
|
# save ckpt
|
||||||
|
if args.local_rank == 0:
|
||||||
|
cb_params.cur_step_num = i + 1
|
||||||
|
cb_params.batch_num = i + 2
|
||||||
|
ckpt_cb.step_end(run_context)
|
||||||
|
|
||||||
|
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||||
|
cb_params.cur_epoch_num += 1
|
||||||
|
|
||||||
|
# save Log
|
||||||
|
if i == 0:
|
||||||
|
time_for_graph_compile = time.time() - create_network_start
|
||||||
|
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
|
||||||
|
|
||||||
|
if i % args.log_interval == 0 and args.local_rank == 0:
|
||||||
|
time_used = time.time() - t_end
|
||||||
|
epoch = int(i / args.steps_per_epoch)
|
||||||
|
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
|
||||||
|
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
||||||
|
|
||||||
|
t_end = time.time()
|
||||||
|
loss_meter.reset()
|
||||||
|
old_progress = i
|
||||||
|
|
||||||
|
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||||
|
epoch_time_used = time.time() - t_epoch
|
||||||
|
epoch = int(i / args.steps_per_epoch)
|
||||||
|
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
|
||||||
|
args.logger.info('=================================================')
|
||||||
|
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||||
|
args.logger.info('=================================================')
|
||||||
|
t_epoch = time.time()
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
args.logger.info('--------- trains out ---------')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
Loading…
Reference in new issue