parent
6cf308076d
commit
2705bb2ca5
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,189 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute eval."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.dataset_eval import data_generator_eval
|
||||
from src.config import config
|
||||
from src.FaceAttribute.resnet18 import get_resnet18
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def softmax(x, axis=0):
|
||||
return np.exp(x) / np.sum(np.exp(x), axis=axis)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.model_path
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed-----------------------')
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
de_dataloader, steps_per_epoch, _ = data_generator_eval(args)
|
||||
|
||||
total_data_num_age = 0
|
||||
total_data_num_gen = 0
|
||||
total_data_num_mask = 0
|
||||
age_num = 0
|
||||
gen_num = 0
|
||||
mask_num = 0
|
||||
gen_tp_num = 0
|
||||
mask_tp_num = 0
|
||||
gen_fp_num = 0
|
||||
mask_fp_num = 0
|
||||
gen_fn_num = 0
|
||||
mask_fn_num = 0
|
||||
for step_i, (data, gt_classes) in enumerate(de_dataloader):
|
||||
|
||||
print('evaluating {}/{} ...'.format(step_i + 1, steps_per_epoch))
|
||||
|
||||
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||
fea = network(data_tensor)
|
||||
|
||||
gt_age, gt_gen, gt_mask = gt_classes[0]
|
||||
|
||||
age_result, gen_result, mask_result = fea
|
||||
|
||||
age_result_np = age_result.asnumpy()
|
||||
gen_result_np = gen_result.asnumpy()
|
||||
mask_result_np = mask_result.asnumpy()
|
||||
|
||||
age_prob = softmax(age_result_np[0].astype(np.float32)).tolist()
|
||||
gen_prob = softmax(gen_result_np[0].astype(np.float32)).tolist()
|
||||
mask_prob = softmax(mask_result_np[0].astype(np.float32)).tolist()
|
||||
|
||||
age = age_prob.index(max(age_prob))
|
||||
gen = gen_prob.index(max(gen_prob))
|
||||
mask = mask_prob.index(max(mask_prob))
|
||||
|
||||
if gt_age == age:
|
||||
age_num += 1
|
||||
if gt_gen == gen:
|
||||
gen_num += 1
|
||||
if gt_mask == mask:
|
||||
mask_num += 1
|
||||
|
||||
if gt_gen == 1 and gen == 1:
|
||||
gen_tp_num += 1
|
||||
if gt_gen == 0 and gen == 1:
|
||||
gen_fp_num += 1
|
||||
if gt_gen == 1 and gen == 0:
|
||||
gen_fn_num += 1
|
||||
|
||||
if gt_mask == 1 and mask == 1:
|
||||
mask_tp_num += 1
|
||||
if gt_mask == 0 and mask == 1:
|
||||
mask_fp_num += 1
|
||||
if gt_mask == 1 and mask == 0:
|
||||
mask_fn_num += 1
|
||||
|
||||
if gt_age != -1:
|
||||
total_data_num_age += 1
|
||||
if gt_gen != -1:
|
||||
total_data_num_gen += 1
|
||||
if gt_mask != -1:
|
||||
total_data_num_mask += 1
|
||||
|
||||
age_accuracy = float(age_num) / float(total_data_num_age)
|
||||
|
||||
gen_precision = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fp_num))
|
||||
gen_recall = float(gen_tp_num) / (float(gen_tp_num) + float(gen_fn_num))
|
||||
gen_accuracy = float(gen_num) / float(total_data_num_gen)
|
||||
gen_f1 = 2. * gen_precision * gen_recall / (gen_precision + gen_recall)
|
||||
|
||||
mask_precision = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fp_num))
|
||||
mask_recall = float(mask_tp_num) / (float(mask_tp_num) + float(mask_fn_num))
|
||||
mask_accuracy = float(mask_num) / float(total_data_num_mask)
|
||||
mask_f1 = 2. * mask_precision * mask_recall / (mask_precision + mask_recall)
|
||||
|
||||
print('model: ', ckpt_path)
|
||||
print('total age num: ', total_data_num_age)
|
||||
print('total gen num: ', total_data_num_gen)
|
||||
print('total mask num: ', total_data_num_mask)
|
||||
print('age accuracy: ', age_accuracy)
|
||||
print('gen accuracy: ', gen_accuracy)
|
||||
print('mask accuracy: ', mask_accuracy)
|
||||
print('gen precision: ', gen_precision)
|
||||
print('gen recall: ', gen_recall)
|
||||
print('gen f1: ', gen_f1)
|
||||
print('mask precision: ', mask_precision)
|
||||
print('mask recall: ', mask_recall)
|
||||
print('mask f1: ', mask_f1)
|
||||
|
||||
model_name = os.path.basename(ckpt_path).split('.')[0]
|
||||
model_dir = os.path.dirname(ckpt_path)
|
||||
result_txt = os.path.join(model_dir, model_name + '.txt')
|
||||
if os.path.exists(result_txt):
|
||||
os.remove(result_txt)
|
||||
with open(result_txt, 'a') as ft:
|
||||
ft.write('model: {}\n'.format(ckpt_path))
|
||||
ft.write('total age num: {}\n'.format(total_data_num_age))
|
||||
ft.write('total gen num: {}\n'.format(total_data_num_gen))
|
||||
ft.write('total mask num: {}\n'.format(total_data_num_mask))
|
||||
ft.write('age accuracy: {}\n'.format(age_accuracy))
|
||||
ft.write('gen accuracy: {}\n'.format(gen_accuracy))
|
||||
ft.write('mask accuracy: {}\n'.format(mask_accuracy))
|
||||
ft.write('gen precision: {}\n'.format(gen_precision))
|
||||
ft.write('gen recall: {}\n'.format(gen_recall))
|
||||
ft.write('gen f1: {}\n'.format(gen_f1))
|
||||
ft.write('mask precision: {}\n'.format(mask_precision))
|
||||
ft.write('mask recall: {}\n'.format(mask_recall))
|
||||
ft.write('mask f1: {}\n'.format(mask_f1))
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='face attributes eval')
|
||||
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
args_1.workers = config.workers
|
||||
|
||||
main(args_1)
|
@ -0,0 +1,75 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.FaceAttribute.resnet18_softmax import get_resnet18
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def main(args):
|
||||
network = get_resnet18(args)
|
||||
ckpt_path = args.model_path
|
||||
if os.path.isfile(ckpt_path):
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('-----------------------load model success-----------------------')
|
||||
else:
|
||||
print('-----------------------load model failed -----------------------')
|
||||
|
||||
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32)
|
||||
tensor_input_data = Tensor(input_data)
|
||||
|
||||
export(network, tensor_input_data, file_name=ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'),
|
||||
file_format='AIR')
|
||||
print('-----------------------export model success-----------------------')
|
||||
|
||||
def parse_args():
|
||||
"""parse_args"""
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||
parser.add_argument('--model_path', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
if __name__ == "__main__":
|
||||
args_1 = parse_args()
|
||||
|
||||
args_1.dst_h = config.dst_h
|
||||
args_1.dst_w = config.dst_w
|
||||
args_1.attri_num = config.attri_num
|
||||
args_1.classes = config.classes
|
||||
args_1.flat_dim = config.flat_dim
|
||||
args_1.fc_dim = config.fc_dim
|
||||
|
||||
main(args_1)
|
@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
rm -rf ${current_exec_path}/device*
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
RANK_TABLE=$(get_real_path $2)
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $RANK_TABLE
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
export RANK_SIZE=8
|
||||
|
||||
echo 'start training'
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
echo 'start rank '$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
export RANK_ID=$i
|
||||
dev=`expr $i + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
done
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='eval.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start evaluating'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--model_path=$PRETRAINED_BACKBONE > eval.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='export.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
BATCH_SIZE=$1
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $BATCH_SIZE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start converting'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--model_path=$PRETRAINED_BACKBONE > convert.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname "$(pwd)")
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start training'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID || exit
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--world_size=1 \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,56 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute cross entropy."""
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class CrossEntropyWithIgnoreIndex(nn.Cell):
|
||||
'''Cross Entropy With Ignore Index Loss.'''
|
||||
def __init__(self):
|
||||
super(CrossEntropyWithIgnoreIndex, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, dtype=mstype.float32)
|
||||
self.off_value = Tensor(0.0, dtype=mstype.float32)
|
||||
self.cast = P.Cast()
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.greater = P.Greater()
|
||||
self.maximum = P.Maximum()
|
||||
self.fill = P.Fill()
|
||||
self.sum = P.ReduceSum(keep_dims=False)
|
||||
self.dtype = P.DType()
|
||||
self.relu = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.const_one = Tensor(np.ones([1]), dtype=mstype.float32)
|
||||
self.const_eps = Tensor(0.00001, dtype=mstype.float32)
|
||||
|
||||
def construct(self, x, label):
|
||||
'''Construct function.'''
|
||||
mask = self.reshape(label, (F.shape(label)[0], 1))
|
||||
mask = self.cast(mask, mstype.float32)
|
||||
mask = mask + self.const_eps
|
||||
mask = self.relu(mask)/mask
|
||||
x = x * mask
|
||||
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(x)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(x, one_hot_label)
|
||||
positive = self.sum(self.cast(self.greater(loss, self.fill(self.dtype(loss), F.shape(loss), 0.0)),
|
||||
mstype.float32), 0)
|
||||
positive = self.maximum(positive, self.const_one)
|
||||
loss = self.sum(loss, 0) / positive
|
||||
return loss
|
@ -0,0 +1,43 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute network unit."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn import Dense
|
||||
|
||||
|
||||
class Cut(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def bn_with_initialize(out_channels):
|
||||
bn = nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5)
|
||||
return bn
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
return Dense(input_channels, out_channels)
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1, pad_mode="pad", padding=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
pad_mode=pad_mode, group=groups, has_bias=False, dilation=dilation, padding=padding)
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, pad_mode="pad", stride=1, padding=0):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_channels, out_channels, pad_mode=pad_mode, kernel_size=1, stride=stride, has_bias=False,
|
||||
padding=padding)
|
@ -0,0 +1,78 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute head."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||
|
||||
__all__ = ['get_attri_head']
|
||||
|
||||
|
||||
class AttriHead(Cell):
|
||||
'''Attribute Head.'''
|
||||
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriHead, self).__init__()
|
||||
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc1_relu = P.ReLU()
|
||||
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||
self.attri_fc1_relu = P.ReLU()
|
||||
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||
|
||||
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc2_relu = P.ReLU()
|
||||
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||
self.attri_fc2_relu = P.ReLU()
|
||||
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||
|
||||
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc3_relu = P.ReLU()
|
||||
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||
self.attri_fc3_relu = P.ReLU()
|
||||
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
output0 = self.fc1(x)
|
||||
output0 = self.fc1_relu(output0)
|
||||
output0 = self.fc1_bn(output0)
|
||||
output0 = self.attri_fc1(output0)
|
||||
output0 = self.attri_fc1_relu(output0)
|
||||
output0 = self.attri_bn1(output0)
|
||||
|
||||
output1 = self.fc2(x)
|
||||
output1 = self.fc2_relu(output1)
|
||||
output1 = self.fc2_bn(output1)
|
||||
output1 = self.attri_fc2(output1)
|
||||
output1 = self.attri_fc2_relu(output1)
|
||||
output1 = self.attri_bn2(output1)
|
||||
|
||||
output2 = self.fc3(x)
|
||||
output2 = self.fc3_relu(output2)
|
||||
output2 = self.fc3_bn(output2)
|
||||
output2 = self.attri_fc3(output2)
|
||||
output2 = self.attri_fc3_relu(output2)
|
||||
output2 = self.attri_bn3(output2)
|
||||
|
||||
return output0, output1, output2
|
||||
|
||||
|
||||
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||
return attri_head
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute head."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import fc_with_initialize
|
||||
|
||||
__all__ = ['get_attri_head']
|
||||
|
||||
|
||||
class AttriHead(Cell):
|
||||
'''Attribute Head.'''
|
||||
def __init__(self, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriHead, self).__init__()
|
||||
self.fc1 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc1_relu = P.ReLU()
|
||||
self.fc1_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc1 = fc_with_initialize(fc_dim, attri_num_list[0])
|
||||
self.attri_fc1_relu = P.ReLU()
|
||||
self.attri_bn1 = nn.BatchNorm1d(attri_num_list[0], affine=False)
|
||||
self.softmax1 = P.Softmax()
|
||||
|
||||
self.fc2 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc2_relu = P.ReLU()
|
||||
self.fc2_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc2 = fc_with_initialize(fc_dim, attri_num_list[1])
|
||||
self.attri_fc2_relu = P.ReLU()
|
||||
self.attri_bn2 = nn.BatchNorm1d(attri_num_list[1], affine=False)
|
||||
self.softmax2 = P.Softmax()
|
||||
|
||||
self.fc3 = fc_with_initialize(flat_dim, fc_dim)
|
||||
self.fc3_relu = P.ReLU()
|
||||
self.fc3_bn = nn.BatchNorm1d(fc_dim, affine=False)
|
||||
self.attri_fc3 = fc_with_initialize(fc_dim, attri_num_list[2])
|
||||
self.attri_fc3_relu = P.ReLU()
|
||||
self.attri_bn3 = nn.BatchNorm1d(attri_num_list[2], affine=False)
|
||||
self.softmax3 = P.Softmax()
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
output0 = self.fc1(x)
|
||||
output0 = self.fc1_relu(output0)
|
||||
output0 = self.fc1_bn(output0)
|
||||
output0 = self.attri_fc1(output0)
|
||||
output0 = self.attri_fc1_relu(output0)
|
||||
output0 = self.attri_bn1(output0)
|
||||
output0 = self.softmax1(output0)
|
||||
|
||||
output1 = self.fc2(x)
|
||||
output1 = self.fc2_relu(output1)
|
||||
output1 = self.fc2_bn(output1)
|
||||
output1 = self.attri_fc2(output1)
|
||||
output1 = self.attri_fc2_relu(output1)
|
||||
output1 = self.attri_bn2(output1)
|
||||
output1 = self.softmax2(output1)
|
||||
|
||||
output2 = self.fc3(x)
|
||||
output2 = self.fc3_relu(output2)
|
||||
output2 = self.fc3_bn(output2)
|
||||
output2 = self.attri_fc3(output2)
|
||||
output2 = self.attri_fc3_relu(output2)
|
||||
output2 = self.attri_bn3(output2)
|
||||
output2 = self.softmax3(output2)
|
||||
|
||||
return output0, output1, output2
|
||||
|
||||
|
||||
def get_attri_head(flat_dim, fc_dim, attri_num_list):
|
||||
attri_head = AttriHead(flat_dim, fc_dim, attri_num_list)
|
||||
return attri_head
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute loss."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.FaceAttribute.cross_entropy import CrossEntropyWithIgnoreIndex
|
||||
|
||||
|
||||
__all__ = ['get_loss']
|
||||
|
||||
|
||||
class CriterionsFaceAttri(nn.Cell):
|
||||
'''Criterions Face Attribute.'''
|
||||
def __init__(self):
|
||||
super(CriterionsFaceAttri, self).__init__()
|
||||
|
||||
# label
|
||||
self.gatherv2 = P.GatherV2()
|
||||
self.squeeze = P.Squeeze(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.mean = P.ReduceMean()
|
||||
|
||||
self.label0_param = Tensor([0], dtype=mstype.int32)
|
||||
self.label1_param = Tensor([1], dtype=mstype.int32)
|
||||
self.label2_param = Tensor([2], dtype=mstype.int32)
|
||||
|
||||
# loss
|
||||
self.ce_ignore_loss = CrossEntropyWithIgnoreIndex()
|
||||
self.printn = P.Print()
|
||||
|
||||
def construct(self, x0, x1, x2, label):
|
||||
'''Construct function.'''
|
||||
# each sub attribute loss
|
||||
label0 = self.squeeze(self.gatherv2(label, self.label0_param, 1))
|
||||
loss0 = self.ce_ignore_loss(x0, label0)
|
||||
|
||||
label1 = self.squeeze(self.gatherv2(label, self.label1_param, 1))
|
||||
loss1 = self.ce_ignore_loss(x1, label1)
|
||||
|
||||
label2 = self.squeeze(self.gatherv2(label, self.label2_param, 1))
|
||||
loss2 = self.ce_ignore_loss(x2, label2)
|
||||
|
||||
loss = loss0 + loss1 + loss2
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_loss():
|
||||
return CriterionsFaceAttri()
|
@ -0,0 +1,145 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute resnet18 backbone."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||
from src.FaceAttribute.head_factory import get_attri_head
|
||||
|
||||
__all__ = ['get_resnet18']
|
||||
|
||||
|
||||
class IRBlock(Cell):
|
||||
'''IRBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(IRBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||
self.bn1 = bn_with_initialize(planes)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||
self.bn2 = bn_with_initialize(planes)
|
||||
|
||||
if downsample is None:
|
||||
self.downsample = Cut()
|
||||
else:
|
||||
self.downsample = downsample
|
||||
|
||||
self.add = TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
identity = self.downsample(x)
|
||||
out = self.add(out, identity)
|
||||
out = self.relu2(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownSample(Cell):
|
||||
def __init__(self, inplanes, planes, expansion, stride):
|
||||
super(DownSample, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||
self.bn1 = bn_with_initialize(planes * expansion)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''Make layer.'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||
super(MakeLayer, self).__init__()
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||
|
||||
self.layers = []
|
||||
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
self.layers.append(block(self.inplanes, planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttriResNet(Cell):
|
||||
'''Resnet for attribute.'''
|
||||
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriResNet, self).__init__()
|
||||
|
||||
# resnet18
|
||||
self.inplanes = 32
|
||||
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||
self.bn1 = bn_with_initialize(self.inplanes)
|
||||
self.relu = P.ReLU()
|
||||
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||
|
||||
# avg global pooling
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
b, c, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, c))
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def get_resnet18(args):
|
||||
'''Build resnet18 for attribute.'''
|
||||
flat_dim = args.flat_dim
|
||||
fc_dim = args.fc_dim
|
||||
str_classes = args.classes.strip().split(',')
|
||||
if args.attri_num != len(str_classes):
|
||||
print('args warning: attri_num != classes num')
|
||||
return None
|
||||
attri_num_list = []
|
||||
for i, _ in enumerate(str_classes):
|
||||
attri_num_list.append(int(str_classes[i]))
|
||||
|
||||
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
return attri_resnet18
|
@ -0,0 +1,145 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute resnet18 backbone."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import TensorAdd
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from src.FaceAttribute.custom_net import Cut, bn_with_initialize, conv1x1, conv3x3
|
||||
from src.FaceAttribute.head_factory_softmax import get_attri_head
|
||||
|
||||
__all__ = ['get_resnet18']
|
||||
|
||||
|
||||
class IRBlock(Cell):
|
||||
'''IRBlock.'''
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(IRBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride=stride)
|
||||
self.bn1 = bn_with_initialize(planes)
|
||||
self.relu1 = P.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride=1)
|
||||
self.bn2 = bn_with_initialize(planes)
|
||||
|
||||
if downsample is None:
|
||||
self.downsample = Cut()
|
||||
else:
|
||||
self.downsample = downsample
|
||||
|
||||
self.add = TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.relu2 = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
identity = self.downsample(x)
|
||||
out = self.add(out, identity)
|
||||
out = self.relu2(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownSample(Cell):
|
||||
def __init__(self, inplanes, planes, expansion, stride):
|
||||
super(DownSample, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes * expansion, stride=stride, pad_mode="valid")
|
||||
self.bn1 = bn_with_initialize(planes * expansion)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
return out
|
||||
|
||||
|
||||
class MakeLayer(Cell):
|
||||
'''Make layer function.'''
|
||||
def __init__(self, block, inplanes, planes, blocks, stride=1):
|
||||
super(MakeLayer, self).__init__()
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
self.downsample = DownSample(self.inplanes, planes, block.expansion, stride)
|
||||
|
||||
self.layers = []
|
||||
self.layers.append(block(self.inplanes, planes, stride, self.downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
self.layers.append(block(self.inplanes, planes))
|
||||
self.layers = nn.CellList(self.layers)
|
||||
|
||||
def construct(self, x):
|
||||
for block in self.layers:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttriResNet(Cell):
|
||||
'''Resnet for attribute.'''
|
||||
def __init__(self, block, layers, flat_dim, fc_dim, attri_num_list):
|
||||
super(AttriResNet, self).__init__()
|
||||
|
||||
# resnet18
|
||||
self.inplanes = 32
|
||||
self.conv1 = conv3x3(3, self.inplanes, stride=1)
|
||||
self.bn1 = bn_with_initialize(self.inplanes)
|
||||
self.relu = P.ReLU()
|
||||
self.layer1 = MakeLayer(block, inplanes=32, planes=64, blocks=layers[0], stride=2)
|
||||
self.layer2 = MakeLayer(block, inplanes=64, planes=128, blocks=layers[1], stride=2)
|
||||
self.layer3 = MakeLayer(block, inplanes=128, planes=256, blocks=layers[2], stride=2)
|
||||
self.layer4 = MakeLayer(block, inplanes=256, planes=512, blocks=layers[3], stride=2)
|
||||
|
||||
# avg global pooling
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.head = get_attri_head(flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
def construct(self, x):
|
||||
'''Construct function.'''
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.mean(x, (2, 3))
|
||||
b, c, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, c))
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def get_resnet18(args):
|
||||
'''Build resnet18 for attribute.'''
|
||||
flat_dim = args.flat_dim
|
||||
fc_dim = args.fc_dim
|
||||
str_classes = args.classes.strip().split(',')
|
||||
if args.attri_num != len(str_classes):
|
||||
print('args warning: attri_num != classes num')
|
||||
return None
|
||||
attri_num_list = []
|
||||
for i, _ in enumerate(str_classes):
|
||||
attri_num_list.append(int(str_classes[i]))
|
||||
|
||||
attri_resnet18 = AttriResNet(IRBlock, (2, 2, 2, 2), flat_dim, fc_dim, attri_num_list)
|
||||
|
||||
return attri_resnet18
|
@ -0,0 +1,46 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Network config setting, will be used in train.py and eval.py"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
'per_batch_size': 128,
|
||||
'dst_h': 112,
|
||||
'dst_w': 112,
|
||||
'workers': 8,
|
||||
'attri_num': 3,
|
||||
'classes': '9,2,2',
|
||||
'backbone': 'resnet18',
|
||||
'loss_scale': 1024,
|
||||
'flat_dim': 512,
|
||||
'fc_dim': 256,
|
||||
'lr': 0.009,
|
||||
'lr_scale': 1,
|
||||
'lr_epochs': [20, 30, 50],
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.9,
|
||||
'max_epoch': 70,
|
||||
'warmup_epochs': 0,
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
|
||||
# data_to_mindrecord parameter
|
||||
'eval_dataset_txt_file': 'Your_label_txt_file',
|
||||
'eval_mindrecord_file_name': 'Your_output_path/data_test.mindrecord',
|
||||
'train_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_mindrecord_file_name': 'Your_output_path/data_train.mindrecord',
|
||||
'train_append_dataset_txt_file': 'Your_label_txt_file',
|
||||
'train_append_mindrecord_file_name': 'Your_previous_output_path/data_train.mindrecord0'
|
||||
})
|
@ -0,0 +1,66 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert dataset to mindrecord for evaluating Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.eval_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.eval_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Convert data to mindrecord.'''
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
attri_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
print('Loading eval data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing eval data to mindrecord...')
|
||||
writer.add_schema(attri_json, "attri_json")
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
@ -0,0 +1,66 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert dataset to mindrecord for training Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.train_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.train_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Covert data to mindrecord.'''
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
attri_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
print('Loading train data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
writer.add_schema(attri_json, "attri_json")
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Add dataset to an existed mindrecord for training Face attribute."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from config import config
|
||||
|
||||
dataset_txt_file = config.train_append_dataset_txt_file
|
||||
|
||||
mindrecord_file_name = config.train_append_mindrecord_file_name
|
||||
|
||||
mindrecord_num = 8
|
||||
|
||||
|
||||
def convert_data_to_mindrecord():
|
||||
'''Covert data to mindrecord.'''
|
||||
print('Loading mindrecord...')
|
||||
writer = FileWriter.open_for_append(mindrecord_file_name)
|
||||
|
||||
print('Loading train data...')
|
||||
total_data = []
|
||||
with open(dataset_txt_file, 'r') as ft:
|
||||
lines = ft.readlines()
|
||||
for line in lines:
|
||||
sline = line.strip().split(" ")
|
||||
image_file = sline[0]
|
||||
labels = []
|
||||
for item in sline[1:]:
|
||||
labels.append(int(item))
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"label": np.array(labels, dtype='int32')
|
||||
}
|
||||
|
||||
total_data.append(data)
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
if total_data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(total_data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_data_to_mindrecord()
|
@ -0,0 +1,45 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute dataset for eval"""
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
__all__ = ['data_generator_eval']
|
||||
|
||||
|
||||
def data_generator_eval(args):
|
||||
'''Build eval dataloader.'''
|
||||
mindrecord_path = args.mindrecord_path
|
||||
dst_w = args.dst_w
|
||||
dst_h = args.dst_h
|
||||
batch_size = 1
|
||||
attri_num = args.attri_num
|
||||
transform_img = F2.Compose([F.Decode(),
|
||||
F.Resize((dst_w, dst_h)),
|
||||
F.ToTensor(),
|
||||
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"])
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||
python_multiprocessing=True)
|
||||
de_dataset = de_dataset.batch(batch_size)
|
||||
|
||||
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||
steps_per_epoch = de_dataset.get_dataset_size()
|
||||
print("image number:{0}".format(steps_per_epoch))
|
||||
num_classes = attri_num
|
||||
|
||||
return de_dataloader, steps_per_epoch, num_classes
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute dataset for train"""
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.py_transforms as F
|
||||
import mindspore.dataset.transforms.py_transforms as F2
|
||||
|
||||
__all__ = ['data_generator']
|
||||
|
||||
|
||||
def data_generator(args):
|
||||
'''Build train dataloader.'''
|
||||
mindrecord_path = args.mindrecord_path
|
||||
dst_w = args.dst_w
|
||||
dst_h = args.dst_h
|
||||
batch_size = args.per_batch_size
|
||||
attri_num = args.attri_num
|
||||
max_epoch = args.max_epoch
|
||||
transform_img = F2.Compose([F.Decode(),
|
||||
F.Resize((dst_w, dst_h)),
|
||||
F.RandomHorizontalFlip(prob=0.5),
|
||||
F.ToTensor(),
|
||||
F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
de_dataset = de.MindDataset(mindrecord_path + "0", columns_list=["image", "label"], num_shards=args.world_size,
|
||||
shard_id=args.local_rank)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=transform_img, num_parallel_workers=args.workers,
|
||||
python_multiprocessing=True)
|
||||
de_dataset = de_dataset.batch(batch_size, drop_remainder=True)
|
||||
steps_per_epoch = de_dataset.get_dataset_size()
|
||||
de_dataset = de_dataset.repeat(max_epoch)
|
||||
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
||||
|
||||
num_classes = attri_num
|
||||
|
||||
return de_dataloader, steps_per_epoch, num_classes
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom logger."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
logger_name_1 = 'face_attributes'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''Logger.'''
|
||||
def __init__(self, logger_name):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
self.local_rank = 0
|
||||
|
||||
def setup_logging_file(self, log_dir, local_rank=0):
|
||||
'''Setup logging file.'''
|
||||
self.local_rank = local_rank
|
||||
if self.local_rank == 0:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER(logger_name_1)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute learning rate scheduler."""
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
|
||||
def warmup_step(args, gamma=0.1):
|
||||
'''Warmup step.'''
|
||||
base_lr = args.lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
|
||||
milestones = args.lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * args.steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_learning_rate(i, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
yield lr
|
@ -0,0 +1,232 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face attribute train."""
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.communication.management import get_group_size, init, get_rank
|
||||
from mindspore.nn import TrainOneStepCell
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.FaceAttribute.resnet18 import get_resnet18
|
||||
from src.FaceAttribute.loss_factory import get_loss
|
||||
from src.dataset_train import data_generator
|
||||
from src.lrsche_factory import warmup_step
|
||||
from src.logging import get_logger, AverageMeter
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
'''Build train network.'''
|
||||
def __init__(self, network, criterion):
|
||||
super(BuildTrainNetwork, self).__init__()
|
||||
self.network = network
|
||||
self.criterion = criterion
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, input_data, label):
|
||||
logit0, logit1, logit2 = self.network(input_data)
|
||||
loss = self.criterion(logit0, logit1, logit2, label)
|
||||
return loss
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''Argument for Face Attributes.'''
|
||||
parser = argparse.ArgumentParser('Face Attributes')
|
||||
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
||||
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
'''train function.'''
|
||||
# logger
|
||||
args = parse_args()
|
||||
|
||||
# init distributed
|
||||
if args.world_size != 1:
|
||||
init()
|
||||
args.local_rank = get_rank()
|
||||
args.world_size = get_group_size()
|
||||
|
||||
args.per_batch_size = config.per_batch_size
|
||||
args.dst_h = config.dst_h
|
||||
args.dst_w = config.dst_w
|
||||
args.workers = config.workers
|
||||
args.attri_num = config.attri_num
|
||||
args.classes = config.classes
|
||||
args.backbone = config.backbone
|
||||
args.loss_scale = config.loss_scale
|
||||
args.flat_dim = config.flat_dim
|
||||
args.fc_dim = config.fc_dim
|
||||
args.lr = config.lr
|
||||
args.lr_scale = config.lr_scale
|
||||
args.lr_epochs = config.lr_epochs
|
||||
args.weight_decay = config.weight_decay
|
||||
args.momentum = config.momentum
|
||||
args.max_epoch = config.max_epoch
|
||||
args.warmup_epochs = config.warmup_epochs
|
||||
args.log_interval = config.log_interval
|
||||
args.ckpt_path = config.ckpt_path
|
||||
|
||||
if args.world_size == 1:
|
||||
args.per_batch_size = 256
|
||||
else:
|
||||
args.lr = args.lr * 4.
|
||||
|
||||
if args.world_size != 1:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
else:
|
||||
parallel_mode = ParallelMode.STAND_ALONE
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
|
||||
|
||||
# model and log save path
|
||||
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
args.logger = get_logger(args.outputs_dir, args.local_rank)
|
||||
loss_meter = AverageMeter('loss')
|
||||
|
||||
# dataloader
|
||||
args.logger.info('start create dataloader')
|
||||
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
|
||||
args.steps_per_epoch = steps_per_epoch
|
||||
args.num_classes = num_classes
|
||||
args.logger.info('end create dataloader')
|
||||
args.logger.save_args(args)
|
||||
|
||||
# backbone and loss
|
||||
args.logger.important_info('start create network')
|
||||
create_network_start = time.time()
|
||||
network = get_resnet18(args)
|
||||
|
||||
criterion = get_loss()
|
||||
|
||||
# load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(args.pretrained))
|
||||
|
||||
# optimizer and lr scheduler
|
||||
lr = warmup_step(args, gamma=0.1)
|
||||
opt = Momentum(params=network.trainable_params(),
|
||||
learning_rate=lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
loss_scale=args.loss_scale)
|
||||
|
||||
train_net = BuildTrainNetwork(network, criterion)
|
||||
|
||||
# mixed precision training
|
||||
criterion.add_flags_recursive(fp32=True)
|
||||
|
||||
# package training process
|
||||
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
# checkpoint
|
||||
if args.local_rank == 0:
|
||||
ckpt_max_num = args.max_epoch
|
||||
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
|
||||
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = train_net
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
ckpt_cb.begin(run_context)
|
||||
|
||||
train_net.set_train()
|
||||
t_end = time.time()
|
||||
t_epoch = time.time()
|
||||
old_progress = -1
|
||||
|
||||
i = 0
|
||||
for _, (data, gt_classes) in enumerate(de_dataloader):
|
||||
|
||||
data_tensor = Tensor(data, dtype=mstype.float32)
|
||||
gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
|
||||
|
||||
loss = train_net(data_tensor, gt_tensor)
|
||||
loss_meter.update(loss.asnumpy()[0])
|
||||
|
||||
# save ckpt
|
||||
if args.local_rank == 0:
|
||||
cb_params.cur_step_num = i + 1
|
||||
cb_params.batch_num = i + 2
|
||||
ckpt_cb.step_end(run_context)
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
cb_params.cur_epoch_num += 1
|
||||
|
||||
# save Log
|
||||
if i == 0:
|
||||
time_for_graph_compile = time.time() - create_network_start
|
||||
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
|
||||
|
||||
if i % args.log_interval == 0 and args.local_rank == 0:
|
||||
time_used = time.time() - t_end
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
|
||||
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
||||
|
||||
t_end = time.time()
|
||||
loss_meter.reset()
|
||||
old_progress = i
|
||||
|
||||
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
|
||||
epoch_time_used = time.time() - t_epoch
|
||||
epoch = int(i / args.steps_per_epoch)
|
||||
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
|
||||
args.logger.info('=================================================')
|
||||
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
||||
args.logger.info('=================================================')
|
||||
t_epoch = time.time()
|
||||
|
||||
i += 1
|
||||
|
||||
args.logger.info('--------- trains out ---------')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in new issue