commit
609a518068
@ -0,0 +1,40 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"class_num": 10572,
|
||||
"batch_size": 128,
|
||||
"learning_rate": 0.01,
|
||||
"lr_decay_epochs": [40, 80, 100],
|
||||
"lr_decay_factor": 0.1,
|
||||
"lr_warmup_epochs": 20,
|
||||
"p": 16,
|
||||
"k": 8,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 120,
|
||||
"buffer_size": 10000,
|
||||
"image_height": 128,
|
||||
"image_width": 128,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 195,
|
||||
"keep_checkpoint_max": 2,
|
||||
"save_checkpoint_path": "./"
|
||||
})
|
@ -0,0 +1,202 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""data process"""
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
from PIL import ImageFile
|
||||
import cv2
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
__all__ = ['DistributedPKSampler', 'Dataset']
|
||||
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
|
||||
|
||||
|
||||
class DistributedPKSampler:
|
||||
'''DistributedPKSampler'''
|
||||
def __init__(self, dataset, shuffle=True, p=5, k=2):
|
||||
assert isinstance(dataset, PKDataset), 'PK Sampler Only Supports PK Dataset!'
|
||||
self.p = p
|
||||
self.k = k
|
||||
self.dataset = dataset
|
||||
self.epoch = 0
|
||||
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
|
||||
self.total_ids = self.step_nums*p
|
||||
self.batch_size = p*k
|
||||
self.num_samples = self.total_ids * self.k
|
||||
self.shuffle = shuffle
|
||||
self.epoch_gen = 1
|
||||
|
||||
|
||||
def _sample_pk(self, indices):
|
||||
'''sample pk'''
|
||||
sampled_pk = []
|
||||
for indice in indices:
|
||||
sampled_id = indice
|
||||
replacement = False
|
||||
if len(self.dataset.id2range[sampled_id]) < self.k:
|
||||
replacement = True
|
||||
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
|
||||
sampled_pk.extend(index_list.tolist())
|
||||
|
||||
return sampled_pk
|
||||
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
|
||||
np.random.seed(self.epoch_gen)
|
||||
indices = np.random.permutation(len(self.dataset.classes))
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset.classes)))
|
||||
indices += indices[:(self.total_ids - len(indices))]
|
||||
assert len(indices) == self.total_ids
|
||||
|
||||
sampled_idxs = self._sample_pk(indices)
|
||||
|
||||
return iter(sampled_idxs)
|
||||
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def has_file_allowed_extension(filename, extensions):
|
||||
""" check if a file has an allowed extensio n.
|
||||
|
||||
Args:
|
||||
filename (string): path to a file
|
||||
extensions (tuple of strings): extensions allowed (lowercase)
|
||||
|
||||
Returns:
|
||||
bool: True if the file ends with one of the given extensions
|
||||
"""
|
||||
return filename.lower().endswith(extensions)
|
||||
|
||||
|
||||
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
|
||||
'''make dataset'''
|
||||
images = []
|
||||
dir_name = os.path.expanduser(dir_name)
|
||||
if not (extensions is None) ^ (is_valid_file is None):
|
||||
raise ValueError("Extensions and is_valid_file should not be the same.")
|
||||
def is_valid(x):
|
||||
if extensions is not None:
|
||||
return has_file_allowed_extension(x, extensions)
|
||||
return is_valid_file(x)
|
||||
for target in sorted(class_to_idx.keys()):
|
||||
d = os.path.join(dir_name, target)
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
for root, _, fnames in sorted(os.walk(d)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if is_valid(path):
|
||||
item = (path, class_to_idx[target], 0.6)
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
class ImageFolderPKDataset:
|
||||
'''ImageFolderPKDataset'''
|
||||
def __init__(self, root):
|
||||
self.classes, self.classes_to_idx = self._find_classes(root)
|
||||
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
|
||||
self.id2range = self._build_id2range()
|
||||
self.all_image_idxs = range(len(self.samples))
|
||||
self.classes = list(self.id2range.keys())
|
||||
|
||||
def _find_classes(self, dir_name):
|
||||
"""
|
||||
Finds the class folders in a dataset
|
||||
|
||||
Args:
|
||||
dir (string): root directory path
|
||||
|
||||
Returns:
|
||||
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of others
|
||||
"""
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
# Faster and available in Python 3.5 and above
|
||||
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
|
||||
else:
|
||||
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
def _build_id2range(self):
|
||||
'''map id to range'''
|
||||
id2range = defaultdict(list)
|
||||
ret_range = defaultdict(list)
|
||||
for idx, sample in enumerate(self.samples):
|
||||
label = sample[1]
|
||||
id2range[label].append((sample, idx))
|
||||
# print(id2range)
|
||||
for key in id2range:
|
||||
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
|
||||
for item in id2range[key]:
|
||||
ret_range[key].append(item[1])
|
||||
return ret_range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.samples[index]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
def pil_loader(path):
|
||||
'''pil loader'''
|
||||
img = cv2.imread(path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
class Dataset:
|
||||
'''Dataset'''
|
||||
def __init__(self, root, loader=pil_loader):
|
||||
self.dataset = ImageFolderPKDataset(root)
|
||||
print('Dataset len(dataset):{}'.format(len(self.dataset)))
|
||||
self.loader = loader
|
||||
self.classes = self.dataset.classes
|
||||
self.id2range = self.dataset.id2range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target1, target2 = self.dataset[index]
|
||||
sample = self.loader(path)
|
||||
return sample, target1, target2
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
@ -0,0 +1,214 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MGDataset"""
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import ImageFile
|
||||
import cv2
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
__all__ = ['DistributedPKSampler', 'MGDataset']
|
||||
IMG_EXTENSIONS = ('.jpg', 'jpeg', '.png', '.ppm', '.bmp', 'pgm', '.tif', '.tiff', 'webp')
|
||||
|
||||
|
||||
class DistributedPKSampler:
|
||||
'''DistributedPKSampler'''
|
||||
def __init__(self, dataset, shuffle=True, p=5, k=2):
|
||||
assert isinstance(dataset, MGDataset), 'PK Sampler Only Supports PK Dataset or MG Dataset!'
|
||||
self.p = p
|
||||
self.k = k
|
||||
self.dataset = dataset
|
||||
self.epoch = 0
|
||||
self.step_nums = int(math.ceil(len(self.dataset.classes)*1.0/p))
|
||||
self.total_ids = self.step_nums*p
|
||||
self.batch_size = p*k
|
||||
self.num_samples = self.total_ids * self.k
|
||||
self.shuffle = shuffle
|
||||
self.epoch_gen = 1
|
||||
|
||||
def _sample_pk(self, indices):
|
||||
'''sample pk'''
|
||||
sampled_pk = []
|
||||
for indice in indices:
|
||||
sampled_id = indice
|
||||
replacement = False
|
||||
if len(self.dataset.id2range[sampled_id]) < self.k:
|
||||
replacement = True
|
||||
index_list = np.random.choice(self.dataset.id2range[sampled_id][0:], self.k, replace=replacement)
|
||||
sampled_pk.extend(index_list.tolist())
|
||||
|
||||
return sampled_pk
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.epoch_gen = (self.epoch_gen + 1) & 0xffffffff
|
||||
np.random.seed(self.epoch_gen)
|
||||
indices = np.random.permutation(len(self.dataset.classes))
|
||||
indices = indices.tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset.classes)))
|
||||
|
||||
indices += indices[:(self.total_ids - len(indices))]
|
||||
assert len(indices) == self.total_ids
|
||||
|
||||
sampled_idxs = self._sample_pk(indices)
|
||||
|
||||
return iter(sampled_idxs)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def has_file_allowed_extension(filename, extensions):
|
||||
""" check if a file has an allowed extensio n.
|
||||
|
||||
Args:
|
||||
filename (string): path to a file
|
||||
extensions (tuple of strings): extensions allowed (lowercase)
|
||||
|
||||
Returns:
|
||||
bool: True if the file ends with one of the given extensions
|
||||
"""
|
||||
return filename.lower().endswith(extensions)
|
||||
|
||||
|
||||
def make_dataset(dir_name, class_to_idx, extensions=None, is_valid_file=None):
|
||||
'''make dataset'''
|
||||
images = []
|
||||
masked_datasets = ["n95", "3m", "new", "mask_1", "mask_2", "mask_3", "mask_4", "mask_5"]
|
||||
dir_name = os.path.expanduser(dir_name)
|
||||
if not (extensions is None) ^ (is_valid_file is None):
|
||||
raise ValueError("Extensions and is_valid_file should not be the same")
|
||||
|
||||
def is_valid(x):
|
||||
if extensions is not None:
|
||||
return has_file_allowed_extension(x, extensions)
|
||||
return is_valid_file(x)
|
||||
|
||||
|
||||
for target in sorted(class_to_idx.keys()):
|
||||
d = os.path.join(dir_name, target)
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
for root, _, fnames in sorted(os.walk(d)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if is_valid(path):
|
||||
scale = float(osp.splitext(fname)[0].split('_')[1])
|
||||
item = (path, class_to_idx[target], scale)
|
||||
images.append(item)
|
||||
mask_root_path = root.replace("faces_webface_112x112_raw_image", random.choice(masked_datasets))
|
||||
mask_name = fname.split('_')[0]+".jpg"
|
||||
mask_path = osp.join(mask_root_path, mask_name)
|
||||
if os.path.isfile(mask_path) and is_valid(mask_path):
|
||||
item = (mask_path, class_to_idx[target], scale)
|
||||
images.append(item)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class ImageFolderPKDataset:
|
||||
'''Image Folder PKDataset'''
|
||||
def __init__(self, root):
|
||||
self.classes, self.classes_to_idx = self._find_classes(root)
|
||||
self.samples = make_dataset(root, self.classes_to_idx, IMG_EXTENSIONS, None)
|
||||
self.id2range = self._build_id2range()
|
||||
self.all_image_idxs = range(len(self.samples))
|
||||
self.classes = list(self.id2range.keys())
|
||||
|
||||
def _find_classes(self, dir_name):
|
||||
"""
|
||||
Finds the class folders in a dataset
|
||||
|
||||
Args:
|
||||
dir (string): root directory path
|
||||
|
||||
Returns:
|
||||
tuple (class, class_to_idx): where classes are relative to dir, and class_to_idx is a directionaty
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of others
|
||||
"""
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
# Faster and available in Python 3.5 and above
|
||||
classes = [d.name for d in os.scandir(dir_name) if d.is_dir()]
|
||||
else:
|
||||
classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))]
|
||||
classes.sort()
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
def _build_id2range(self):
|
||||
'''id to range'''
|
||||
id2range = defaultdict(list)
|
||||
ret_range = defaultdict(list)
|
||||
for idx, sample in enumerate(self.samples):
|
||||
label = sample[1]
|
||||
id2range[label].append((sample, idx))
|
||||
for key in id2range:
|
||||
id2range[key].sort(key=lambda x: int(os.path.basename(x[0][0]).split(".")[0]))
|
||||
for item in id2range[key]:
|
||||
ret_range[key].append(item[1])
|
||||
|
||||
return ret_range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.samples[index]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
def pil_loader(path):
|
||||
'''load pil'''
|
||||
img = cv2.imread(path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
class MGDataset:
|
||||
'''MGDataset'''
|
||||
def __init__(self, root, loader=pil_loader):
|
||||
self.dataset = ImageFolderPKDataset(root)
|
||||
print('MGDataset len(dataset):{}'.format(len(self.dataset)))
|
||||
self.loader = loader
|
||||
self.classes = self.dataset.classes
|
||||
self.id2range = self.dataset.id2range
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target1, target2 = self.dataset[index]
|
||||
sample = self.loader(path)
|
||||
return sample, target1, target2
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
After Width: | Height: | Size: 1.7 MiB |
After Width: | Height: | Size: 337 KiB |
@ -0,0 +1,30 @@
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
660.75
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
Dataset len(dataset):5136
|
||||
Dataset len(dataset):5165
|
||||
0.4214082362915882 [0.96222741 0.97507788 0.97858255 0.98111371 0.98286604 0.98345016
|
||||
0.98500779 0.98617601 0.98676012 0.98753894]
|
||||
/home/dingfeifei/datasets/faces_webface_112x112_raw_image/0 0_0.5625.jpg
|
||||
MGDataset len(dataset):884896
|
||||
epoch: 1 step: 661, loss is 19.227043
|
||||
epoch: 2 step: 661, loss is 18.528654
|
||||
epoch: 3 step: 661, loss is 18.451244
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
python3 ./test.py
|
@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
python3 ./train.py
|
@ -0,0 +1,114 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
from test_dataset import create_dataset
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from model.model import resnet50, TrainStepWrap, NetWithLossClass
|
||||
from utils.distance import compute_dist, compute_score
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
local_data_url = 'data'
|
||||
local_train_url = 'ckpt'
|
||||
|
||||
|
||||
class Logger():
|
||||
'''Log'''
|
||||
def __init__(self, logFile="log_max.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(logFile, 'a')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
sys.stdout = Logger("log/log.txt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
|
||||
'test/query'), p=config.p, k=config.k)
|
||||
gallery_dataset = create_dataset(data_dir=os.path.join('/home/dingfeifei/datasets', \
|
||||
'test/gallery'), p=config.p, k=config.k)
|
||||
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num, is_train=False)
|
||||
loss_net = NetWithLossClass(net, is_train=False)
|
||||
|
||||
base_lr = config.learning_rate
|
||||
warm_up_epochs = config.lr_warmup_epochs
|
||||
lr_decay_epochs = config.lr_decay_epochs
|
||||
lr_decay_factor = config.lr_decay_factor
|
||||
step_size = math.ceil(config.class_num / config.p)
|
||||
lr_decay_steps = []
|
||||
lr_decay = []
|
||||
for i, v in enumerate(lr_decay_epochs):
|
||||
lr_decay_steps.append(v * step_size)
|
||||
lr_decay.append(base_lr * lr_decay_factor ** i)
|
||||
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
|
||||
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
|
||||
lr = lr_1 + lr_2
|
||||
|
||||
train_net = TrainStepWrap(loss_net, lr, config.momentum, is_train=False)
|
||||
|
||||
load_checkpoint("checkpoints/40.ckpt", net=train_net)
|
||||
|
||||
q_feats, q_labels, g_feats, g_labels = [], [], [], []
|
||||
for data, gt_classes, theta in query_dataset:
|
||||
output = train_net(data, gt_classes, theta)
|
||||
output = output.asnumpy()
|
||||
label = gt_classes.asnumpy()
|
||||
q_feats.append(output)
|
||||
q_labels.append(label)
|
||||
q_feats = np.vstack(q_feats)
|
||||
q_labels = np.hstack(q_labels)
|
||||
|
||||
for data, gt_classes, theta in gallery_dataset:
|
||||
output = train_net(data, gt_classes, theta)
|
||||
output = output.asnumpy()
|
||||
label = gt_classes.asnumpy()
|
||||
g_feats.append(output)
|
||||
g_labels.append(label)
|
||||
g_feats = np.vstack(g_feats)
|
||||
g_labels = np.hstack(g_labels)
|
||||
|
||||
q_g_dist = compute_dist(q_feats, g_feats, dis_type='cosine')
|
||||
mAP, cmc_scores = compute_score(q_g_dist, q_labels, g_labels)
|
||||
|
||||
print(mAP, cmc_scores)
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from config import config
|
||||
from dataset.Dataset import Dataset
|
||||
|
||||
|
||||
def create_dataset(data_dir, p=16, k=8):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
p(int): randomly choose p classes from all classes.
|
||||
k(int): randomly choose k images from each of the chosen p classes.
|
||||
p * k is the batchsize.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset = Dataset(data_dir)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"])
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
resize_op = CV.Resize((resize_height, resize_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
|
||||
|
||||
change_swap_op = CV.HWC2CHW()
|
||||
|
||||
trans = []
|
||||
|
||||
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
|
||||
|
||||
type_cast_op_label1 = C.TypeCast(mstype.int32)
|
||||
type_cast_op_label2 = C.TypeCast(mstype.float32)
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
|
||||
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=trans)
|
||||
de_dataset = de_dataset.batch(p*k, drop_remainder=False)
|
||||
|
||||
return de_dataset
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
import pickle
|
||||
import numpy as np
|
||||
from train_dataset import create_dataset
|
||||
from config import config
|
||||
from mindspore import context
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr, warmup_lr
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor # TimeMonitor
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from model.model import resnet50, NetWithLossClass, TrainStepWrap, TestStepWrap
|
||||
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
local_data_url = 'data'
|
||||
local_train_url = 'ckpt'
|
||||
|
||||
class Logger():
|
||||
'''Logger'''
|
||||
def __init__(self, logFile="log_max.txt"):
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(logFile, 'a')
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
sys.stdout = Logger("log/log.txt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
epoch_size = config.epoch_size
|
||||
net = resnet50(class_num=config.class_num, is_train=True)
|
||||
loss_net = NetWithLossClass(net)
|
||||
|
||||
dataset = create_dataset("/home/dingfeifei/datasets/faces_webface_112x112_raw_image", \
|
||||
p=config.p, k=config.k)
|
||||
|
||||
step_size = dataset.get_dataset_size()
|
||||
base_lr = config.learning_rate
|
||||
warm_up_epochs = config.lr_warmup_epochs
|
||||
lr_decay_epochs = config.lr_decay_epochs
|
||||
lr_decay_factor = config.lr_decay_factor
|
||||
lr_decay_steps = []
|
||||
lr_decay = []
|
||||
for i, v in enumerate(lr_decay_epochs):
|
||||
lr_decay_steps.append(v * step_size)
|
||||
lr_decay.append(base_lr * lr_decay_factor ** i)
|
||||
lr_1 = warmup_lr(base_lr, step_size*warm_up_epochs, step_size, warm_up_epochs)
|
||||
lr_2 = piecewise_constant_lr(lr_decay_steps, lr_decay)
|
||||
lr = lr_1 + lr_2
|
||||
|
||||
train_net = TrainStepWrap(loss_net, lr, config.momentum)
|
||||
test_net = TestStepWrap(net)
|
||||
|
||||
f = open("checkpoints/pretrained_resnet50.pkl", "rb")
|
||||
param_dict = pickle.load(f)
|
||||
load_param_into_net(net=train_net, parameter_dict=param_dict)
|
||||
|
||||
model = Model(train_net, eval_network=test_net, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# time_cb = TimeMonitor(data_size=step_size)
|
||||
loss_cb = LossMonitor()
|
||||
#cb = [time_cb, loss_cb]
|
||||
cb = [loss_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, \
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="resnet", directory='checkpoints/', \
|
||||
config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
model.train(epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
@ -0,0 +1,66 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from config import config
|
||||
from dataset.MGDataset import DistributedPKSampler, MGDataset
|
||||
|
||||
|
||||
def create_dataset(data_dir, p=16, k=8):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
p(int): randomly choose p classes from all classes.
|
||||
k(int): randomly choose k images from each of the chosen p classes.
|
||||
p * k is the batchsize.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset = MGDataset(data_dir)
|
||||
sampler = DistributedPKSampler(dataset, p=p, k=k)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label1", "label2"], sampler=sampler)
|
||||
|
||||
resize_height = config.image_height
|
||||
resize_width = config.image_width
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
resize_op = CV.Resize((resize_height, resize_width))
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
normalize_op = CV.Normalize([0.486, 0.459, 0.408], [0.229, 0.224, 0.225])
|
||||
|
||||
change_swap_op = CV.HWC2CHW()
|
||||
|
||||
trans = []
|
||||
|
||||
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
|
||||
|
||||
type_cast_op_label1 = C.TypeCast(mstype.int32)
|
||||
type_cast_op_label2 = C.TypeCast(mstype.float32)
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="label1", operations=type_cast_op_label1)
|
||||
de_dataset = de_dataset.map(input_columns="label2", operations=type_cast_op_label2)
|
||||
de_dataset = de_dataset.map(input_columns="image", operations=trans)
|
||||
de_dataset = de_dataset.batch(p*k, drop_remainder=True)
|
||||
|
||||
return de_dataset
|
@ -0,0 +1,56 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Numpy version of euclidean distance, etc."""
|
||||
import numpy as np
|
||||
from utils.metric import cmc, mean_ap
|
||||
|
||||
|
||||
def normalize(nparray, order=2, axis=0):
|
||||
"""Normalize a N-D numpy array along the specified axis."""
|
||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
|
||||
|
||||
def compute_dist(array1, array2, dis_type='euclidean'):
|
||||
"""Compute the euclidean or cosine distance of all pairs.
|
||||
Args:
|
||||
array1: numpy array with shape [m1, n]
|
||||
array2: numpy array with shape [m2, n]
|
||||
type:
|
||||
one of ['cosine', 'euclidean']
|
||||
Returns:
|
||||
numpy array with shape [m1, m2]
|
||||
"""
|
||||
assert dis_type in ['cosine', 'euclidean']
|
||||
if dis_type == 'cosine':
|
||||
array1 = normalize(array1, axis=1)
|
||||
array2 = normalize(array2, axis=1)
|
||||
dist = np.matmul(array1, array2.T)
|
||||
return -1*dist
|
||||
|
||||
# shape [m1, 1]
|
||||
square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis]
|
||||
# shape [1, m2]
|
||||
square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...]
|
||||
squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2
|
||||
squared_dist[squared_dist < 0] = 0
|
||||
dist = np.sqrt(squared_dist)
|
||||
return dist
|
||||
|
||||
|
||||
def compute_score(dist_mat, query_ids, gallery_ids):
|
||||
mAP = mean_ap(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids)
|
||||
cmc_scores, _ = cmc(distmat=dist_mat, query_ids=query_ids, gallery_ids=gallery_ids, topk=10)
|
||||
return mAP, cmc_scores
|
@ -0,0 +1,194 @@
|
||||
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid)
|
||||
reid/evaluation_metrics/ranking.py. Modifications:
|
||||
1) Only accepts numpy data input, no torch is involved.
|
||||
1) Here results of each query can be returned.
|
||||
2) In the single-gallery-shot evaluation case, the time of repeats is changed
|
||||
from 10 to 100.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import average_precision_score
|
||||
|
||||
|
||||
def _unique_sample(ids_dict, num):
|
||||
mask = np.zeros(num, dtype=np.bool)
|
||||
for _, indices in ids_dict.items():
|
||||
i = np.random.choice(indices)
|
||||
mask[i] = True
|
||||
return mask
|
||||
|
||||
|
||||
def cmc(
|
||||
distmat,
|
||||
query_ids=None,
|
||||
gallery_ids=None,
|
||||
query_cams=None,
|
||||
gallery_cams=None,
|
||||
topk=100,
|
||||
separate_camera_set=False,
|
||||
single_gallery_shot=False,
|
||||
first_match_break=False,
|
||||
average=True):
|
||||
"""
|
||||
Args:
|
||||
distmat: numpy array with shape [num_query, num_gallery], the
|
||||
pairwise distance between query and gallery samples
|
||||
query_ids: numpy array with shape [num_query]
|
||||
gallery_ids: numpy array with shape [num_gallery]
|
||||
query_cams: numpy array with shape [num_query]
|
||||
gallery_cams: numpy array with shape [num_gallery]
|
||||
average: whether to average the results across queries
|
||||
Returns:
|
||||
If `average` is `False`:
|
||||
ret: numpy array with shape [num_query, topk]
|
||||
is_valid_query: numpy array with shape [num_query], containing 0's and
|
||||
1's, whether each query is valid or not
|
||||
If `average` is `True`:
|
||||
numpy array with shape [topk]
|
||||
"""
|
||||
# Ensure numpy array
|
||||
assert isinstance(distmat, np.ndarray)
|
||||
assert isinstance(query_ids, np.ndarray)
|
||||
assert isinstance(gallery_ids, np.ndarray)
|
||||
# assert isinstance(query_cams, np.ndarray)
|
||||
# assert isinstance(gallery_cams, np.ndarray)
|
||||
# separate_camera_set=False
|
||||
first_match_break = True
|
||||
m, _ = distmat.shape
|
||||
# Sort and find correct matches
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
#print(indices)
|
||||
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
|
||||
# Compute CMC for each query
|
||||
ret = np.zeros([m, topk])
|
||||
is_valid_query = np.zeros(m)
|
||||
num_valid_queries = 0
|
||||
|
||||
for i in range(m):
|
||||
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
|
||||
|
||||
if separate_camera_set:
|
||||
# Filter out samples from same camera
|
||||
valid = (gallery_cams[indices[i]] != query_cams[i])
|
||||
|
||||
if not np.any(matches[i, valid]): continue
|
||||
|
||||
is_valid_query[i] = 1
|
||||
|
||||
if single_gallery_shot:
|
||||
repeat = 100
|
||||
gids = gallery_ids[indices[i][valid]]
|
||||
inds = np.where(valid)[0]
|
||||
ids_dict = defaultdict(list)
|
||||
for j, x in zip(inds, gids):
|
||||
ids_dict[x].append(j)
|
||||
else:
|
||||
repeat = 1
|
||||
|
||||
for _ in range(repeat):
|
||||
if single_gallery_shot:
|
||||
# Randomly choose one instance for each id
|
||||
sampled = (valid & _unique_sample(ids_dict, len(valid)))
|
||||
index = np.nonzero(matches[i, sampled])[0]
|
||||
else:
|
||||
index = np.nonzero(matches[i, valid])[0]
|
||||
|
||||
delta = 1. / (len(index) * repeat)
|
||||
for j, k in enumerate(index):
|
||||
if k - j >= topk: break
|
||||
if first_match_break:
|
||||
ret[i, k - j] += 1
|
||||
break
|
||||
ret[i, k - j] += delta
|
||||
num_valid_queries += 1
|
||||
|
||||
if num_valid_queries == 0:
|
||||
raise RuntimeError("No valid query")
|
||||
ret = ret.cumsum(axis=1)
|
||||
|
||||
if average:
|
||||
return np.sum(ret, axis=0) / num_valid_queries, indices
|
||||
|
||||
return ret, is_valid_query, indices
|
||||
|
||||
|
||||
def mean_ap(
|
||||
distmat,
|
||||
query_ids=None,
|
||||
gallery_ids=None,
|
||||
query_cams=None,
|
||||
gallery_cams=None,
|
||||
average=True):
|
||||
"""
|
||||
Args:
|
||||
distmat: numpy array with shape [num_query, num_gallery], the
|
||||
pairwise distance between query and gallery samples
|
||||
query_ids: numpy array with shape [num_query]
|
||||
gallery_ids: numpy array with shape [num_gallery]
|
||||
query_cams: numpy array with shape [num_query]
|
||||
gallery_cams: numpy array with shape [num_gallery]
|
||||
average: whether to average the results across queries
|
||||
Returns:
|
||||
If `average` is `False`:
|
||||
ret: numpy array with shape [num_query]
|
||||
is_valid_query: numpy array with shape [num_query], containing 0's and
|
||||
1's, whether each query is valid or not
|
||||
If `average` is `True`:
|
||||
a scalar
|
||||
"""
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# The behavior of method `sklearn.average_precision` has changed since version
|
||||
# 0.19.
|
||||
# Version 0.18.1 has same results as Matlab evaluation code by Zhun Zhong
|
||||
# (https://github.com/zhunzhong07/person-re-ranking/
|
||||
# blob/master/evaluation/utils/evaluation.m) and by Liang Zheng
|
||||
# (http://www.liangzheng.org/Project/project_reid.html).
|
||||
# My current awkward solution is sticking to this older version.
|
||||
# if cur_version != required_version:
|
||||
# print('User Warning: Version {} is required for package scikit-learn, '
|
||||
# 'your current version is {}. '
|
||||
# 'As a result, the mAP score may not be totally correct. '
|
||||
# 'You can try `pip uninstall scikit-learn` '
|
||||
# 'and then `pip install scikit-learn=={}`'.format(
|
||||
# required_version, cur_version, required_version))
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
# Ensure numpy array
|
||||
assert isinstance(distmat, np.ndarray)
|
||||
assert isinstance(query_ids, np.ndarray)
|
||||
assert isinstance(gallery_ids, np.ndarray)
|
||||
# assert isinstance(query_cams, np.ndarray)
|
||||
# assert isinstance(gallery_cams, np.ndarray)
|
||||
|
||||
m, _ = distmat.shape
|
||||
|
||||
# Sort and find correct matches
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
# print("indices:", indices)
|
||||
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
|
||||
# Compute AP for each query
|
||||
aps = np.zeros(m)
|
||||
is_valid_query = np.zeros(m)
|
||||
for i in range(m):
|
||||
# Filter out the same id and same camera
|
||||
# valid = ((gallery_ids[indices[i]] != query_ids[i]) |
|
||||
# (gallery_cams[indices[i]] != query_cams[i]))
|
||||
valid = (gallery_ids[indices[i]] != query_ids[i]) | (gallery_ids[indices[i]] == query_ids[i])
|
||||
# valid = indices[i] != i
|
||||
# valid = (gallery_cams[indices[i]] != query_cams[i])
|
||||
y_true = matches[i, valid]
|
||||
y_score = -distmat[i][indices[i]][valid]
|
||||
|
||||
# y_true=y_true[0:100]
|
||||
# y_score=y_score[0:100]
|
||||
if not np.any(y_true): continue
|
||||
is_valid_query[i] = 1
|
||||
aps[i] = average_precision_score(y_true, y_score)
|
||||
# if not aps:
|
||||
# raise RuntimeError("No valid query")
|
||||
if average:
|
||||
return float(np.sum(aps)) / np.sum(is_valid_query)
|
||||
return aps, is_valid_query
|
Loading…
Reference in new issue