!12712 masked face recognition

From: @ffeiding
Reviewed-by: 
Signed-off-by:
pull/12712/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 609a518068

@ -0,0 +1,143 @@
# Masked Face Recognition with Latent Part Detection
# Contents
- [Masked Face Recognition Description](#masked-face-recognition-description)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Training](#training)
- [Evaluation](#evaluation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Masked Face Recognition Description](#contents)
<p align="center">
<img src="./img/overview.png">
</p>
This is a **MindSpore** implementation of [Masked Face Recognition with Latent Part Detection (ACM MM20)](https://dl.acm.org/doi/10.1145/3394171.3413731) by *Feifei Ding, Peixi Peng, Yangru Huang, Mengyue Geng and Yonghong Tian*.
*Masked Face Recognition* aims to match masked faces with common faces and is important especially during the global outbreak of COVID-19. It is challenging to identify masked faces since most facial cues are occluded by mask.
*Latent Part Detection* (LPD) is a differentiable module that can locate the latent facial part which is robust to mask wearing, and the latent part is further used to extract discriminative features. The proposed LPD model is trained in an end-to-end manner and only utilizes the original and synthetic training data.
# [Dataset](#contents)
## Training Dataset
We use [CASIA-WebFace Dataset](http://www.cbsr.ia.ac.cn/english/CASIA-WebFace-Database.html) as the training dataset. After downloading CASIA-WebFace, we first detect faces and facial landmarks using `MTCNN` and align faces to a canonical pose using similarity transformation. (see: [MTCNN - face detection & alignment](https://github.com/kpzhang93/MTCNN_face_detection_alignment)).
Collecting and labeling realistic masked facial data requires a great deal of human labor. To address this issue, we generate masked face images based on CASIA-WebFace. We generate 8 kinds of synthetic masked face images to augment training data based on 8 different styles of masks, such as surgical masks, N95 respirators and activated carbon masks. We mix the original face images with the synthetic masked images as the training data.
<p align="center">
<img src="./img/generated_masked_faces.png" width="600px">
</p>
## Evaluating Dataset
We use [PKU-Masked-Face Dataset](https://pkuml.org/resources/pku-masked-face-dataset.html) as the evaluating dataset. The dataset contains 10,301 face images of 1,018 identities. Each identity has masked and common face images with various orientations, lighting conditions and mask types. Most identities have 5 holistic face images and 5 masked face images with 5 different views: front, left, right, up and down.
The directory structure is as follows:
```python
.
└─ dataset
├─ train dataset
├─ ID1
├─ ID1_0001.jpg
├─ ID1_0002.jpg
...
├─ ID2
...
├─ ID3
...
...
├─ test dataset
├─ ID1
├─ ID1_0001.jpg
├─ ID1_0002.jpg
...
├─ ID2
...
├─ ID3
...
...
```
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to get Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script Description](#contents)
The entire code structure is as following:
```python
└─ face_recognition
├── README.md // descriptions about face_recognition
├── scripts
│ ├── run_train.sh // shell script for training on Ascend
│ ├── run_eval.sh // shell script for evaluation on Ascend
├── src
│ ├── dataset
│ │ ├── Dataset.py // loading evaluating dataset
│ │ ├── MGDataset.py // loading training dataset
│ ├── model
│ │ ├── model.py // lpd model
│ │ ├── stn.py // spatial transformer network module
│ ├── utils
│ │ ├── distance.py // calculate distance of two features
│ │ ├── metric.py // calculate mAP and CMC scores
├─ config.py // hyperparameter setting
├─ train_dataset.py // training data format setting
├─ test_dataset.py // evaluating data format setting
├─ train.py // training scripts
├─ test.py // evaluation scripts
```
# [Training](#contents)
```bash
sh scripts/run_train.sh [USE_DEVICE_ID]
```
You will get the loss value of each epoch as following in "./scripts/data_parallel_log_[DEVICE_ID]/outputs/logs/[TIME].log" or "./scripts/log_parallel_graph/face_recognition_[DEVICE_ID].log":
```python
epoch[0], iter[100], loss:(Tensor(shape=[], dtype=Float32, value= 50.2733), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.000660, mean_fps:743.09 imgs/sec
epoch[0], iter[200], loss:(Tensor(shape=[], dtype=Float32, value= 49.3693), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 32768)), cur_lr:0.001314, mean_fps:4426.42 imgs/sec
epoch[0], iter[300], loss:(Tensor(shape=[], dtype=Float32, value= 48.7081), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.001968, mean_fps:4428.09 imgs/sec
epoch[0], iter[400], loss:(Tensor(shape=[], dtype=Float32, value= 45.7791), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 16384)), cur_lr:0.002622, mean_fps:4428.17 imgs/sec
...
epoch[8], iter[27300], loss:(Tensor(shape=[], dtype=Float32, value= 2.13556), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.38 imgs/sec
epoch[8], iter[27400], loss:(Tensor(shape=[], dtype=Float32, value= 2.36922), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4429.88 imgs/sec
epoch[8], iter[27500], loss:(Tensor(shape=[], dtype=Float32, value= 2.08594), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.59 imgs/sec
epoch[8], iter[27600], loss:(Tensor(shape=[], dtype=Float32, value= 2.38706), Tensor(shape=[], dtype=Bool, value= False), Tensor(shape=[], dtype=Float32, value= 65536)), cur_lr:0.004000, mean_fps:4430.37 imgs/sec
```
# [Evaluation](#contents)
```bash
sh scripts/run_eval.sh [USE_DEVICE_ID]
```
You will get the result as following in "./scripts/log_inference/outputs/models/logs/[TIME].log":
[test_dataset]: zj2jk=0.9495, jk2zj=0.9480, avg=0.9487
| model | mAP | rank1 | rank5 | rank10|
| ---------| ------| ----- | ----- | ----- |
| Baseline | 27.09 | 70.17 | 87.95 | 91.80 |
| MG | 36.55 | 94.12 | 98.01 | 98.66 |
| LPD | 42.14 | 96.22 | 98.11 | 98.75 |
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,40 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"class_num": 10572,
"batch_size": 128,
"learning_rate": 0.01,
"lr_decay_epochs": [40, 80, 100],
"lr_decay_factor": 0.1,
"lr_warmup_epochs": 20,
"p": 16,
"k": 8,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"buffer_size": 10000,
"image_height": 128,
"image_width": 128,
"save_checkpoint": True,
"save_checkpoint_steps": 195,
"keep_checkpoint_max": 2,
"save_checkpoint_path": "./"
})

@ -0,0 +1,202 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""data process"""
import math
import sys
import os
from collections import defaultdict
import numpy as np
from PIL import ImageFile
import cv2
ImageFile.LOAD_TRUNCATED_IMAGES = True
__all__ = ['DistributedPKSampler', 'Dataset']
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
class DistributedPKSampler:
'''DistributedPKSampler'''
def __init__(self, dataset, shuffle=True, p=5, k=2):
assert isinstance(dataset, PKDataset), 'PK Sampler Only Supports PK Dataset!'
self.p = p
self.k = k
self.dataset = dataset
self.epoch = 0
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
self.total_ids = self.step_nums*p
self.batch_size = p*k
self.num_samples = self.total_ids * self.k
self.shuffle = shuffle
self.epoch_gen = 1
def _sample_pk(self, indices):
'''sample pk'''
sampled_pk = []
for indice in indices:
sampled_id = indice
replacement = False
if len(self.dataset.id2range[sampled_id]) < self.k:
replacement = True
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
sampled_pk.extend(index_list.tolist())
return sampled_pk
def __iter__(self):
if self.shuffle:
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
np.random.seed(self.epoch_gen)
indices = np.random.permutation(len(self.dataset.classes))
indices = indices.tolist()
else:
indices = list(range(len(self.dataset.classes)))
indices += indices[:(self.total_ids - len(indices))]
assert len(indices) == self.total_ids
sampled_idxs = self._sample_pk(indices)
return iter(sampled_idxs)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def has_file_allowed_extension(filename, extensions):
""" check if a file has an allowed extensio n.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions allowed (lowercase)
Returns:
bool: True if the file ends with one of the given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
'''make dataset'''
images = []
dir_name = os.path.expanduser(dir_name)
if not (extensions is None) ^ (is_valid_file is None):
raise ValueError("Extensions and is_valid_file should not be the same.")
def is_valid(x):
if extensions is not None:
return has_file_allowed_extension(x, extensions)
return is_valid_file(x)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir_name, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid(path):
item = (path, class_to_idx[target], 0.6)
images.append(item)
return images
class ImageFolderPKDataset:
'''ImageFolderPKDataset'''
def __init__(self, root):
self.classes, self.classes_to_idx = self._find_classes(root)
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
self.id2range = self._build_id2range()
self.all_image_idxs = range(len(self.samples))
self.classes = list(self.id2range.keys())
def _find_classes(self, dir_name):
"""
Finds the class folders in a dataset
Args:
dir (string): root directory path
Returns:
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
Ensures:
No class is a subdirectory of others
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def _build_id2range(self):
'''map id to range'''
id2range = defaultdict(list)
ret_range = defaultdict(list)
for idx, sample in enumerate(self.samples):
label = sample[1]
id2range[label].append((sample, idx))
# print(id2range)
for key in id2range:
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
for item in id2range[key]:
ret_range[key].append(item[1])
return ret_range
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
def pil_loader(path):
'''pil loader'''
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
class Dataset:
'''Dataset'''
def __init__(self, root, loader=pil_loader):
self.dataset = ImageFolderPKDataset(root)
print('Dataset len(dataset):{}'.format(len(self.dataset)))
self.loader = loader
self.classes = self.dataset.classes
self.id2range = self.dataset.id2range
def __getitem__(self, index):
path, target1, target2 = self.dataset[index]
sample = self.loader(path)
return sample, target1, target2
def __len__(self):
return len(self.dataset)

@ -0,0 +1,214 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MGDataset"""
import math
import sys
import os
import os.path as osp
from collections import defaultdict
import random
import numpy as np
from PIL import ImageFile
import cv2
ImageFile.LOAD_TRUNCATED_IMAGES = True
__all__ = ['DistributedPKSampler', 'MGDataset']
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
class DistributedPKSampler:
'''DistributedPKSampler'''
def __init__(self, dataset, shuffle=True, p=5, k=2):
assert isinstance(dataset, MGDataset), 'PK Sampler Only Supports PK Dataset or MG Dataset!'
self.p = p
self.k = k
self.dataset = dataset
self.epoch = 0
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
self.total_ids = self.step_nums*p
self.batch_size = p*k
self.num_samples = self.total_ids * self.k
self.shuffle = shuffle
self.epoch_gen = 1
def _sample_pk(self, indices):
'''sample pk'''
sampled_pk = []
for indice in indices:
sampled_id = indice
replacement = False
if len(self.dataset.id2range[sampled_id]) < self.k:
replacement = True
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
sampled_pk.extend(index_list.tolist())
return sampled_pk
def __iter__(self):
if self.shuffle:
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
np.random.seed(self.epoch_gen)
indices = np.random.permutation(len(self.dataset.classes))
indices = indices.tolist()
else:
indices = list(range(len(self.dataset.classes)))
indices += indices[:(self.total_ids - len(indices))]
assert len(indices) == self.total_ids
sampled_idxs = self._sample_pk(indices)
return iter(sampled_idxs)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
def has_file_allowed_extension(filename, extensions):
""" check if a file has an allowed extensio n.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions allowed (lowercase)
Returns:
bool: True if the file ends with one of the given extensions
"""
return filename.lower().endswith(extensions)
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
'''make dataset'''
images = []
masked_datasets = ["n95", "3m", "new", "mask_1", "mask_2", "mask_3", "mask_4", "mask_5"]
dir_name = os.path.expanduser(dir_name)
if not (extensions is None) ^ (is_valid_file is None):
raise ValueError("Extensions and is_valid_file should not be the same")
def is_valid(x):
if extensions is not None:
return has_file_allowed_extension(x, extensions)
return is_valid_file(x)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir_name, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid(path):
scale = float(osp.splitext(fname)[0].split('_')[1])
item = (path, class_to_idx[target], scale)
images.append(item)
mask_root_path = root.replace("faces_webface_112x112_raw_image", random.choice(masked_datasets))
mask_name = fname.split('_')[0]+".jpg"
mask_path = osp.join(mask_root_path, mask_name)
if os.path.isfile(mask_path) and is_valid(mask_path):
item = (mask_path, class_to_idx[target], scale)
images.append(item)
return images
class ImageFolderPKDataset:
'''Image Folder PKDataset'''
def __init__(self, root):
self.classes, self.classes_to_idx = self._find_classes(root)
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
self.id2range = self._build_id2range()
self.all_image_idxs = range(len(self.samples))
self.classes = list(self.id2range.keys())
def _find_classes(self, dir_name):
"""
Finds the class folders in a dataset
Args:
dir (string): root directory path
Returns:
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
Ensures:
No class is a subdirectory of others
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def _build_id2range(self):
'''id to range'''
id2range = defaultdict(list)
ret_range = defaultdict(list)
for idx, sample in enumerate(self.samples):
label = sample[1]
id2range[label].append((sample, idx))
for key in id2range:
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
for item in id2range[key]:
ret_range[key].append(item[1])
return ret_range
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
def pil_loader(path):
'''load pil'''
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
class MGDataset:
'''MGDataset'''
def __init__(self, root, loader=pil_loader):
self.dataset = ImageFolderPKDataset(root)
print('MGDataset len(dataset):{}'.format(len(self.dataset)))
self.loader = loader
self.classes = self.dataset.classes
self.id2range = self.dataset.id2range
def __getitem__(self, index):
path, target1, target2 = self.dataset[index]
sample = self.loader(path)
return sample, target1, target2
def __len__(self):
return len(self.dataset)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 337 KiB

@ -0,0 +1,30 @@
Dataset len(dataset):5136
Dataset len(dataset):5165
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
0.98500779 0.98617601 0.98676012 0.98753894]
Dataset len(dataset):5136
Dataset len(dataset):5165
Dataset len(dataset):5136
Dataset len(dataset):5165
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
0.98500779 0.98617601 0.98676012 0.98753894]
Dataset len(dataset):5136
Dataset len(dataset):5165
Dataset len(dataset):5136
Dataset len(dataset):5165
660.75
Dataset len(dataset):5136
Dataset len(dataset):5165
Dataset len(dataset):5136
Dataset len(dataset):5165
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
0.98500779 0.98617601 0.98676012 0.98753894]
Dataset len(dataset):5136
Dataset len(dataset):5165
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
0.98500779 0.98617601 0.98676012 0.98753894]
/home/dingfeifei/datasets/faces_webface_112x112_raw_image/0 0_0.5625.jpg
MGDataset len(dataset):884896
epoch: 1 step: 661, loss is 19.227043
epoch: 2 step: 661, loss is 18.528654
epoch: 3 step: 661, loss is 18.451244

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,114 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import sys
import argparse
import random
import math
import numpy as np
from test_dataset import create_dataset
from config import config
from mindspore import context
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
import mindspore.dataset.engine as de
from mindspore.train.serialization import load_checkpoint
from model.model import resnet50, TrainStepWrap, NetWithLossClass
from utils.distance import compute_dist, compute_score
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
local_data_url = 'data'
local_train_url = 'ckpt'
class Logger():
'''Log'''
def __init__(self, logFile="log_max.txt"):
self.terminal = sys.stdout
self.log = open(logFile, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
pass
sys.stdout = Logger("log/log.txt")
if __name__ == '__main__':
query_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
'test/query'), p=config.p, k=config.k)
gallery_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
'test/gallery'), p=config.p, k=config.k)
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num, is_train=False)
loss_net = NetWithLossClass(net, is_train=False)
base_lr = config.learning_rate
warm_up_epochs = config.lr_warmup_epochs
lr_decay_epochs = config.lr_decay_epochs
lr_decay_factor = config.lr_decay_factor
step_size = math.ceil(config.class_num / config.p)
lr_decay_steps = []
lr_decay = []
for i, v in enumerate(lr_decay_epochs):
lr_decay_steps.append(v * step_size)
lr_decay.append(base_lr * lr_decay_factor ** i)
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
lr = lr_1 + lr_2
train_net = TrainStepWrap(loss_net, lr, config.momentum, is_train=False)
load_checkpoint("checkpoints/40.ckpt", net=train_net)
q_feats, q_labels, g_feats, g_labels = [], [], [], []
for data, gt_classes, theta in query_dataset:
output = train_net(data, gt_classes, theta)
output = output.asnumpy()
label = gt_classes.asnumpy()
q_feats.append(output)
q_labels.append(label)
q_feats = np.vstack(q_feats)
q_labels = np.hstack(q_labels)
for data, gt_classes, theta in gallery_dataset:
output = train_net(data, gt_classes, theta)
output = output.asnumpy()
label = gt_classes.asnumpy()
g_feats.append(output)
g_labels.append(label)
g_feats = np.vstack(g_feats)
g_labels = np.hstack(g_labels)
q_g_dist = compute_dist(q_feats, g_feats, dis_type='cosine')
mAP, cmc_scores = compute_score(q_g_dist, q_labels, g_labels)
print(mAP, cmc_scores)

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from config import config
from dataset.Dataset import Dataset
def create_dataset(data_dir, p=16, k=8):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
p(int): randomly choose p classes from all classes.
k(int): randomly choose k images from each of the chosen p classes.
p * k is the batchsize.
Returns:
dataset
"""
dataset = Dataset(data_dir)
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"])
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
resize_op = CV.Resize((resize_height, resize_width))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
change_swap_op = CV.HWC2CHW()
trans = []
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
type_cast_op_label1 = C.TypeCast(mstype.int32)
type_cast_op_label2 = C.TypeCast(mstype.float32)
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
de_dataset = de_dataset.map(input_columns="image", operations=trans)
de_dataset = de_dataset.batch(p*k, drop_remainder=False)
return de_dataset

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import sys
import argparse
import random
import pickle
import numpy as np
from train_dataset import create_dataset
from config import config
from mindspore import context
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
from mindspore.train.model import Model
from mindspore.train.serialization import load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor # TimeMonitor
import mindspore.dataset.engine as de
from mindspore.nn.metrics import Accuracy
from model.model import resnet50, NetWithLossClass, TrainStepWrap, TestStepWrap
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
local_data_url = 'data'
local_train_url = 'ckpt'
class Logger():
'''Logger'''
def __init__(self, logFile="log_max.txt"):
self.terminal = sys.stdout
self.log = open(logFile, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
pass
sys.stdout = Logger("log/log.txt")
if __name__ == '__main__':
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num, is_train=True)
loss_net = NetWithLossClass(net)
dataset = create_dataset("/home/dingfeifei/datasets/faces_webface_112x112_raw_image", \
p=config.p, k=config.k)
step_size = dataset.get_dataset_size()
base_lr = config.learning_rate
warm_up_epochs = config.lr_warmup_epochs
lr_decay_epochs = config.lr_decay_epochs
lr_decay_factor = config.lr_decay_factor
lr_decay_steps = []
lr_decay = []
for i, v in enumerate(lr_decay_epochs):
lr_decay_steps.append(v * step_size)
lr_decay.append(base_lr * lr_decay_factor ** i)
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
lr = lr_1 + lr_2
train_net = TrainStepWrap(loss_net, lr, config.momentum)
test_net = TestStepWrap(net)
f = open("checkpoints/pretrained_resnet50.pkl", "rb")
param_dict = pickle.load(f)
load_param_into_net(net=train_net, parameter_dict=param_dict)
model = Model(train_net, eval_network=test_net, metrics={"Accuracy": Accuracy()})
# time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
#cb = [time_cb, loss_cb]
cb = [loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, \
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory='checkpoints/', \
config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)

@ -0,0 +1,66 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from config import config
from dataset.MGDataset import DistributedPKSampler, MGDataset
def create_dataset(data_dir, p=16, k=8):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
p(int): randomly choose p classes from all classes.
k(int): randomly choose k images from each of the chosen p classes.
p * k is the batchsize.
Returns:
dataset
"""
dataset = MGDataset(data_dir)
sampler = DistributedPKSampler(dataset, p=p, k=k)
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"], sampler=sampler)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
resize_op = CV.Resize((resize_height, resize_width))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
change_swap_op = CV.HWC2CHW()
trans = []
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
type_cast_op_label1 = C.TypeCast(mstype.int32)
type_cast_op_label2 = C.TypeCast(mstype.float32)
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
de_dataset = de_dataset.map(input_columns="image", operations=trans)
de_dataset = de_dataset.batch(p*k, drop_remainder=True)
return de_dataset

@ -0,0 +1,56 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Numpy version of euclidean distance, etc."""
import numpy as np
from utils.metric import cmc, mean_ap
def normalize(nparray, order=2, axis=0):
"""Normalize a N-D numpy array along the specified axis."""
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
return nparray / (norm + np.finfo(np.float32).eps)
def compute_dist(array1, array2, dis_type='euclidean'):
"""Compute the euclidean or cosine distance of all pairs.
Args:
array1: numpy array with shape [m1, n]
array2: numpy array with shape [m2, n]
type:
one of ['cosine', 'euclidean']
Returns:
numpy array with shape [m1, m2]
"""
assert dis_type in ['cosine', 'euclidean']
if dis_type == 'cosine':
array1 = normalize(array1, axis=1)
array2 = normalize(array2, axis=1)
dist = np.matmul(array1, array2.T)
return -1*dist
# shape [m1, 1]
square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis]
# shape [1, m2]
square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...]
squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2
squared_dist[squared_dist < 0] = 0
dist = np.sqrt(squared_dist)
return dist
def compute_score(dist_mat, query_ids, gallery_ids):
mAP = mean_ap(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids)
cmc_scores, _ = cmc(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids, topk=10)
return mAP, cmc_scores

@ -0,0 +1,194 @@
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
reid/evaluation_metrics/ranking.py. Modifications:
1) Only accepts numpy data input, no torch is involved.
1) Here results of each query can be returned.
2) In the single-gallery-shot evaluation case, the time of repeats is changed
from 10 to 100.
"""
from __future__ import absolute_import
from collections import defaultdict
import numpy as np
from sklearn.metrics import average_precision_score
def _unique_sample(ids_dict, num):
mask = np.zeros(num, dtype=np.bool)
for _, indices in ids_dict.items():
i = np.random.choice(indices)
mask[i] = True
return mask
def cmc(
distmat,
query_ids=None,
gallery_ids=None,
query_cams=None,
gallery_cams=None,
topk=100,
separate_camera_set=False,
single_gallery_shot=False,
first_match_break=False,
average=True):
"""
Args:
distmat: numpy array with shape [num_query, num_gallery], the
pairwise distance between query and gallery samples
query_ids: numpy array with shape [num_query]
gallery_ids: numpy array with shape [num_gallery]
query_cams: numpy array with shape [num_query]
gallery_cams: numpy array with shape [num_gallery]
average: whether to average the results across queries
Returns:
If `average` is `False`:
ret: numpy array with shape [num_query, topk]
is_valid_query: numpy array with shape [num_query], containing 0's and
1's, whether each query is valid or not
If `average` is `True`:
numpy array with shape [topk]
"""
# Ensure numpy array
assert isinstance(distmat, np.ndarray)
assert isinstance(query_ids, np.ndarray)
assert isinstance(gallery_ids, np.ndarray)
# assert isinstance(query_cams, np.ndarray)
# assert isinstance(gallery_cams, np.ndarray)
# separate_camera_set=False
first_match_break = True
m, _ = distmat.shape
# Sort and find correct matches
indices = np.argsort(distmat, axis=1)
#print(indices)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute CMC for each query
ret = np.zeros([m, topk])
is_valid_query = np.zeros(m)
num_valid_queries = 0
for i in range(m):
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
if separate_camera_set:
# Filter out samples from same camera
valid = (gallery_cams[indices[i]] != query_cams[i])
if not np.any(matches[i, valid]): continue
is_valid_query[i] = 1
if single_gallery_shot:
repeat = 100
gids = gallery_ids[indices[i][valid]]
inds = np.where(valid)[0]
ids_dict = defaultdict(list)
for j, x in zip(inds, gids):
ids_dict[x].append(j)
else:
repeat = 1
for _ in range(repeat):
if single_gallery_shot:
# Randomly choose one instance for each id
sampled = (valid & _unique_sample(ids_dict, len(valid)))
index = np.nonzero(matches[i, sampled])[0]
else:
index = np.nonzero(matches[i, valid])[0]
delta = 1. / (len(index) * repeat)
for j, k in enumerate(index):
if k - j >= topk: break
if first_match_break:
ret[i, k - j] += 1
break
ret[i, k - j] += delta
num_valid_queries += 1
if num_valid_queries == 0:
raise RuntimeError("No valid query")
ret = ret.cumsum(axis=1)
if average:
return np.sum(ret, axis=0) / num_valid_queries, indices
return ret, is_valid_query, indices
def mean_ap(
distmat,
query_ids=None,
gallery_ids=None,
query_cams=None,
gallery_cams=None,
average=True):
"""
Args:
distmat: numpy array with shape [num_query, num_gallery], the
pairwise distance between query and gallery samples
query_ids: numpy array with shape [num_query]
gallery_ids: numpy array with shape [num_gallery]
query_cams: numpy array with shape [num_query]
gallery_cams: numpy array with shape [num_gallery]
average: whether to average the results across queries
Returns:
If `average` is `False`:
ret: numpy array with shape [num_query]
is_valid_query: numpy array with shape [num_query], containing 0's and
1's, whether each query is valid or not
If `average` is `True`:
a scalar
"""
# -------------------------------------------------------------------------
# The behavior of method `sklearn.average_precision` has changed since version
# 0.19.
# Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
# (https://github.com/zhunzhong07/person-re-ranking/
# blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
# (http://www.liangzheng.org/Project/project_reid.html).
# My current awkward solution is sticking to this older version.
# if cur_version != required_version:
# print('User Warning: Version {} is required for package scikit-learn, '
# 'your current version is {}. '
# 'As a result, the mAP score may not be totally correct. '
# 'You can try `pip uninstall scikit-learn` '
# 'and then `pip install scikit-learn=={}`'.format(
# required_version, cur_version, required_version))
# -------------------------------------------------------------------------
# Ensure numpy array
assert isinstance(distmat, np.ndarray)
assert isinstance(query_ids, np.ndarray)
assert isinstance(gallery_ids, np.ndarray)
# assert isinstance(query_cams, np.ndarray)
# assert isinstance(gallery_cams, np.ndarray)
m, _ = distmat.shape
# Sort and find correct matches
indices = np.argsort(distmat, axis=1)
# print("indices:", indices)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute AP for each query
aps = np.zeros(m)
is_valid_query = np.zeros(m)
for i in range(m):
# Filter out the same id and same camera
# valid = ((gallery_ids[indices[i]] != query_ids[i]) |
# (gallery_cams[indices[i]] != query_cams[i]))
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
# valid = indices[i] != i
# valid = (gallery_cams[indices[i]] != query_cams[i])
y_true = matches[i, valid]
y_score = -distmat[i][indices[i]][valid]
# y_true=y_true[0:100]
# y_score=y_score[0:100]
if not np.any(y_true): continue
is_valid_query[i] = 1
aps[i] = average_precision_score(y_true, y_score)
# if not aps:
# raise RuntimeError("No valid query")
if average:
return float(np.sum(aps)) / np.sum(is_valid_query)
return aps, is_valid_query
Loading…
Cancel
Save