parent
c84523a437
commit
25579c4523
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,244 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
##############test densenet example#################
|
||||
python eval.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/CHECKPOINT
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
import glob
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size, release
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from src.utils.logging import get_logger
|
||||
from src.datasets import classification_dataset
|
||||
from src.network import DenseNet121
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
|
||||
save_graphs=True, device_id=devid)
|
||||
|
||||
|
||||
class ParameterReduce(nn.Cell):
|
||||
"""
|
||||
reduce parameter
|
||||
"""
|
||||
def __init__(self):
|
||||
super(ParameterReduce, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
self.reduce = P.AllReduce()
|
||||
|
||||
def construct(self, x):
|
||||
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
|
||||
out = x * one
|
||||
ret = self.reduce(out)
|
||||
return ret
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
"""
|
||||
parse args
|
||||
"""
|
||||
parser = argparse.ArgumentParser('mindspore classification test')
|
||||
|
||||
# dataset related
|
||||
parser.add_argument('--data_dir', type=str, default='', help='eval data dir')
|
||||
parser.add_argument('--num_classes', type=int, default=1000, help='num of classes in dataset')
|
||||
parser.add_argument('--image_size', type=str, default='224,224', help='image size of the dataset')
|
||||
# network related
|
||||
parser.add_argument('--backbone', default='resnet50', help='backbone')
|
||||
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load.'
|
||||
'If it is a direction, it will test all ckpt')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
|
||||
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
||||
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
|
||||
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
|
||||
|
||||
# roma obs
|
||||
parser.add_argument('--train_url', type=str, default="", help='train url')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
args = merge_args(args, cloud_args)
|
||||
|
||||
args.per_batch_size = config.per_batch_size
|
||||
args.image_size = list(map(int, args.image_size.split(',')))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_top5_acc(top5_arg, gt_class):
|
||||
sub_count = 0
|
||||
for top5, gt in zip(top5_arg, gt_class):
|
||||
if gt in top5:
|
||||
sub_count += 1
|
||||
return sub_count
|
||||
|
||||
def merge_args(args, cloud_args):
|
||||
"""
|
||||
merge args and cloud_args
|
||||
"""
|
||||
args_dict = vars(args)
|
||||
if isinstance(cloud_args, dict):
|
||||
for key in cloud_args.keys():
|
||||
val = cloud_args[key]
|
||||
if key in args_dict and val:
|
||||
arg_type = type(args_dict[key])
|
||||
if arg_type is not type(None):
|
||||
val = arg_type(val)
|
||||
args_dict[key] = val
|
||||
return args
|
||||
|
||||
def test(cloud_args=None):
|
||||
"""
|
||||
network eval function. Get top1 and top5 ACC from classification.
|
||||
The result will be save at [./outputs] by default.
|
||||
"""
|
||||
args = parse_args(cloud_args)
|
||||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
||||
args.outputs_dir = os.path.join(args.log_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
args.logger = get_logger(args.outputs_dir, args.rank)
|
||||
args.logger.save_args(args)
|
||||
|
||||
# network
|
||||
args.logger.important_info('start create network')
|
||||
if os.path.isdir(args.pretrained):
|
||||
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
|
||||
|
||||
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0])
|
||||
|
||||
args.models = sorted(models, key=f)
|
||||
else:
|
||||
args.models = [args.pretrained,]
|
||||
|
||||
for model in args.models:
|
||||
de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
|
||||
per_batch_size=args.per_batch_size,
|
||||
max_epoch=1, rank=args.rank, group_size=args.group_size,
|
||||
mode='eval')
|
||||
eval_dataloader = de_dataset.create_tuple_iterator()
|
||||
network = DenseNet121(args.num_classes)
|
||||
|
||||
param_dict = load_checkpoint(model)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load model {} success'.format(model))
|
||||
|
||||
network.add_flags_recursive(fp16=True)
|
||||
|
||||
img_tot = 0
|
||||
top1_correct = 0
|
||||
top5_correct = 0
|
||||
network.set_train(False)
|
||||
for data, gt_classes in eval_dataloader:
|
||||
output = network(Tensor(data, mstype.float32))
|
||||
output = output.asnumpy()
|
||||
gt_classes = gt_classes.asnumpy()
|
||||
|
||||
top1_output = np.argmax(output, (-1))
|
||||
top5_output = np.argsort(output)[:, -5:]
|
||||
|
||||
t1_correct = np.equal(top1_output, gt_classes).sum()
|
||||
top1_correct += t1_correct
|
||||
top5_correct += get_top5_acc(top5_output, gt_classes)
|
||||
img_tot += args.per_batch_size
|
||||
|
||||
results = [[top1_correct], [top5_correct], [img_tot]]
|
||||
args.logger.info('before results={}'.format(results))
|
||||
if args.is_distributed:
|
||||
model_md5 = model.replace('/', '')
|
||||
tmp_dir = '../cache'
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
|
||||
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
|
||||
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
|
||||
np.save(top1_correct_npy, top1_correct)
|
||||
np.save(top5_correct_npy, top5_correct)
|
||||
np.save(img_tot_npy, img_tot)
|
||||
while True:
|
||||
rank_ok = True
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
|
||||
or not os.path.exists(img_tot_npy):
|
||||
rank_ok = False
|
||||
if rank_ok:
|
||||
break
|
||||
|
||||
top1_correct_all = 0
|
||||
top5_correct_all = 0
|
||||
img_tot_all = 0
|
||||
for other_rank in range(args.group_size):
|
||||
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
|
||||
top1_correct_all += np.load(top1_correct_npy)
|
||||
top5_correct_all += np.load(top5_correct_npy)
|
||||
img_tot_all += np.load(img_tot_npy)
|
||||
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
|
||||
results = np.array(results)
|
||||
|
||||
else:
|
||||
results = np.array(results)
|
||||
|
||||
args.logger.info('after results={}'.format(results))
|
||||
top1_correct = results[0, 0]
|
||||
top5_correct = results[1, 0]
|
||||
img_tot = results[2, 0]
|
||||
acc1 = 100.0 * top1_correct / img_tot
|
||||
acc5 = 100.0 * top5_correct / img_tot
|
||||
args.logger.info('after allreduce eval: top1_correct={}, tot={}, acc={:.2f}%'.format(top1_correct,
|
||||
img_tot,
|
||||
acc1))
|
||||
args.logger.info('after allreduce eval: top5_correct={}, tot={}, acc={:.2f}%'.format(top5_correct,
|
||||
img_tot,
|
||||
acc5))
|
||||
if args.is_distributed:
|
||||
release()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_eval.sh DEVICE_NUM RANK_TABLE_FILE DATASET CKPT_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 /data/hccl.json /path/to/dataset /path/to/ckpt"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in eval_x/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
export RANK_TABLE_FILE=$2
|
||||
DATASET=$3
|
||||
CKPT_PATH=$4
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf eval_$i
|
||||
mkdir ./eval_$i
|
||||
cp ./*.py ./eval_$i
|
||||
cp -r ./src ./eval_$i
|
||||
cd ./eval_$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py \
|
||||
--data_dir=$DATASET \
|
||||
--pretrained=$CKPT_PATH > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
||||
|
@ -0,0 +1,45 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh scipts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE DATASET"
|
||||
echo "for example: sh scipts/run_distribute_train.sh 8 /data/hccl.json /path/to/dataset"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in train_x/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
export RANK_TABLE_FILE=$2
|
||||
DATASET=$3
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf train_$i
|
||||
mkdir ./train_$i
|
||||
cp ./*.py ./train_$i
|
||||
cp -r ./src ./train_$i
|
||||
cd ./train_$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--data_dir=$DATASET > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,46 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""config"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"image_size": '224,224',
|
||||
"num_classes": 1000,
|
||||
|
||||
"lr": 0.1,
|
||||
"lr_scheduler": 'cosine_annealing',
|
||||
"lr_epochs": '30,60,90,120',
|
||||
"lr_gamma": 0.1,
|
||||
"eta_min": 0,
|
||||
"T_max": 120,
|
||||
"max_epoch": 120,
|
||||
"per_batch_size": 32,
|
||||
"warmup_epochs": 0,
|
||||
|
||||
"weight_decay": 0.0001,
|
||||
"momentum": 0.9,
|
||||
"is_dynamic_loss_scale": 0,
|
||||
"loss_scale": 1024,
|
||||
"label_smooth": 0,
|
||||
"label_smooth_factor": 0.1,
|
||||
|
||||
"log_interval": 100,
|
||||
"ckpt_interval": 2000,
|
||||
"ckpt_path": 'outputs/',
|
||||
"is_save_on_master": 1,
|
||||
|
||||
"rank": 0,
|
||||
"group_size": 1
|
||||
})
|
@ -0,0 +1,22 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
read dataset for classification
|
||||
"""
|
||||
|
||||
from .classification import classification_dataset
|
||||
|
||||
__all__ = ["classification_dataset"]
|
@ -0,0 +1,155 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
A function that returns a dataset for classification.
|
||||
"""
|
||||
|
||||
import os
|
||||
from PIL import Image, ImageFile
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as vision_C
|
||||
import mindspore.dataset.transforms.c_transforms as normal_C
|
||||
from src.datasets.sampler import DistributedSampler
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class TxtDataset():
|
||||
"""
|
||||
read dataset from txt
|
||||
"""
|
||||
def __init__(self, root, txt_name):
|
||||
super(TxtDataset, self).__init__()
|
||||
self.imgs = []
|
||||
self.labels = []
|
||||
fin = open(txt_name, "r")
|
||||
for line in fin:
|
||||
img_name, label = line.strip().split(' ')
|
||||
self.imgs.append(os.path.join(root, img_name))
|
||||
self.labels.append(int(label))
|
||||
fin.close()
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.open(self.imgs[index]).convert('RGB')
|
||||
return img, self.labels[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
|
||||
mode='train',
|
||||
input_mode='folder',
|
||||
root='',
|
||||
num_parallel_workers=None,
|
||||
shuffle=None,
|
||||
sampler=None,
|
||||
class_indexing=None,
|
||||
drop_remainder=True,
|
||||
transform=None,
|
||||
target_transform=None):
|
||||
"""
|
||||
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
|
||||
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
|
||||
are written into a textfile.
|
||||
|
||||
Args:
|
||||
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
|
||||
Or path of the textfile that contains every image's path of the dataset.
|
||||
image_size (str): Size of the input images.
|
||||
per_batch_size (int): the batch size of evey step during training.
|
||||
max_epoch (int): the number of epochs.
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided
|
||||
into (default=None).
|
||||
mode (str): "train" or others. Default: " train".
|
||||
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
|
||||
root (str): the images path for "input_mode="txt"". Default: " ".
|
||||
num_parallel_workers (int): Number of workers to read the data. Default: None.
|
||||
shuffle (bool): Whether or not to perform shuffle on the dataset
|
||||
(default=None, performs shuffle).
|
||||
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
|
||||
class_indexing (dict): A str-to-int mapping from folder name to index
|
||||
(default=None, the folder names will be sorted
|
||||
alphabetically and each class will be given a
|
||||
unique index starting from 0).
|
||||
|
||||
Examples:
|
||||
>>> from src.datasets.classification import classification_dataset
|
||||
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
|
||||
>>> dataset_dir = "/path/to/imagefolder_directory"
|
||||
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
|
||||
>>> per_batch_size=64, max_epoch=100,
|
||||
>>> rank=0, group_size=4)
|
||||
>>> # Path of the textfile that contains every image's path of the dataset.
|
||||
>>> dataset_dir = "/path/to/dataset/images/train.txt"
|
||||
>>> images_dir = "/path/to/dataset/images"
|
||||
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
|
||||
>>> per_batch_size=64, max_epoch=100,
|
||||
>>> rank=0, group_size=4,
|
||||
>>> input_mode="txt", root=images_dir)
|
||||
"""
|
||||
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
if transform is None:
|
||||
if mode == 'train':
|
||||
transform_img = [
|
||||
vision_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
vision_C.RandomHorizontalFlip(prob=0.5),
|
||||
vision_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
vision_C.Normalize(mean=mean, std=std),
|
||||
vision_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
vision_C.Decode(),
|
||||
vision_C.Resize((256, 256)),
|
||||
vision_C.CenterCrop(image_size),
|
||||
vision_C.Normalize(mean=mean, std=std),
|
||||
vision_C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = transform
|
||||
|
||||
if target_transform is None:
|
||||
transform_label = [
|
||||
normal_C.TypeCast(mstype.int32)
|
||||
]
|
||||
else:
|
||||
transform_label = target_transform
|
||||
|
||||
if input_mode == 'folder':
|
||||
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
|
||||
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
else:
|
||||
dataset = TxtDataset(root, data_dir)
|
||||
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
|
||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
||||
de_dataset.set_dataset_size(len(sampler))
|
||||
|
||||
de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
|
||||
de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
|
||||
|
||||
columns_to_project = ["image", "label"]
|
||||
de_dataset = de_dataset.project(columns=columns_to_project)
|
||||
|
||||
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
|
||||
de_dataset = de_dataset.repeat(1)
|
||||
|
||||
return de_dataset
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
shuffle and distribute sample
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
function to distribute and shuffle sample
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samples * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_length).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset_length)))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
@ -0,0 +1,19 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
loss function
|
||||
"""
|
||||
|
||||
from .crossentropy import *
|
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
loss function CrossEntropy
|
||||
"""
|
||||
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""
|
||||
loss function CrossEntropy
|
||||
"""
|
||||
def __init__(self, smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label,
|
||||
F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
@ -0,0 +1,19 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
learning rate scheduler
|
||||
"""
|
||||
from .lr_scheduler import *
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
densenet network
|
||||
"""
|
||||
from .densenet import DenseNet121
|
@ -0,0 +1,230 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
model architecture of densenet
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import initializer as init
|
||||
from src.utils.var_init import default_recurisive_init, KaimingNormal
|
||||
|
||||
__all__ = ["DenseNet121"]
|
||||
|
||||
class GlobalAvgPooling(nn.Cell):
|
||||
"""
|
||||
GlobalAvgPooling function.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalAvgPooling, self).__init__()
|
||||
self.mean = P.ReduceMean(True)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.mean(x, (2, 3))
|
||||
b, c, _, _ = self.shape(x)
|
||||
x = self.reshape(x, (b, c))
|
||||
return x
|
||||
|
||||
class CommonHead(nn.Cell):
|
||||
def __init__(self, num_classes, out_channels):
|
||||
super(CommonHead, self).__init__()
|
||||
self.avgpool = GlobalAvgPooling()
|
||||
self.fc = nn.Dense(out_channels, num_classes, has_bias=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.avgpool(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False):
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
|
||||
padding=padding, pad_mode="pad")
|
||||
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False):
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
|
||||
padding=padding, pad_mode="pad")
|
||||
|
||||
|
||||
def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False):
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
|
||||
padding=padding, pad_mode="pad")
|
||||
|
||||
|
||||
class _DenseLayer(nn.Cell):
|
||||
"""
|
||||
the dense layer, include 2 conv layer
|
||||
"""
|
||||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
||||
super(_DenseLayer, self).__init__()
|
||||
self.norm1 = nn.BatchNorm2d(num_input_features)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.conv1 = conv1x1(num_input_features, bn_size*growth_rate)
|
||||
|
||||
self.norm2 = nn.BatchNorm2d(bn_size*growth_rate)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv2 = conv3x3(bn_size*growth_rate, growth_rate)
|
||||
|
||||
# nn.Dropout in MindSpore use keep_prob, diff from Pytorch
|
||||
self.keep_prob = 1.0 - drop_rate
|
||||
self.dropout = nn.Dropout(keep_prob=self.keep_prob)
|
||||
|
||||
def construct(self, features):
|
||||
bottleneck = self.conv1(self.relu1(self.norm1(features)))
|
||||
new_features = self.conv2(self.relu2(self.norm2(bottleneck)))
|
||||
if self.keep_prob < 1:
|
||||
new_features = self.dropout(new_features)
|
||||
return new_features
|
||||
|
||||
class _DenseBlock(nn.Cell):
|
||||
"""
|
||||
the dense block
|
||||
"""
|
||||
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
|
||||
super(_DenseBlock, self).__init__()
|
||||
self.cell_list = nn.CellList()
|
||||
for i in range(num_layers):
|
||||
layer = _DenseLayer(
|
||||
num_input_features + i * growth_rate,
|
||||
growth_rate=growth_rate,
|
||||
bn_size=bn_size,
|
||||
drop_rate=drop_rate
|
||||
)
|
||||
self.cell_list.append(layer)
|
||||
|
||||
self.concate = P.Concat(axis=1)
|
||||
|
||||
def construct(self, init_features):
|
||||
features = init_features
|
||||
for layer in self.cell_list:
|
||||
new_features = layer(features)
|
||||
features = self.concate((features, new_features))
|
||||
return features
|
||||
|
||||
class _Transition(nn.Cell):
|
||||
"""
|
||||
the transiton layer
|
||||
"""
|
||||
def __init__(self, num_input_features, num_output_features):
|
||||
super(_Transition, self).__init__()
|
||||
self.features = nn.SequentialCell(OrderedDict([
|
||||
('norm', nn.BatchNorm2d(num_input_features)),
|
||||
('relu', nn.ReLU()),
|
||||
('conv', conv1x1(num_input_features, num_output_features)),
|
||||
('pool', nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
]))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
class Densenet(nn.Cell):
|
||||
"""
|
||||
the densenet architecture
|
||||
"""
|
||||
__constants__ = ['features']
|
||||
|
||||
def __init__(self, growth_rate, block_config, num_init_features, bn_size=4, drop_rate=0):
|
||||
super(Densenet, self).__init__()
|
||||
|
||||
layers = OrderedDict()
|
||||
layers['conv0'] = conv7x7(3, num_init_features, stride=2, padding=3)
|
||||
layers['norm0'] = nn.BatchNorm2d(num_init_features)
|
||||
layers['relu0'] = nn.ReLU()
|
||||
layers['pool0'] = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
|
||||
# Each denseblock
|
||||
num_features = num_init_features
|
||||
for i, num_layers in enumerate(block_config):
|
||||
block = _DenseBlock(
|
||||
num_layers=num_layers,
|
||||
num_input_features=num_features,
|
||||
bn_size=bn_size,
|
||||
growth_rate=growth_rate,
|
||||
drop_rate=drop_rate
|
||||
)
|
||||
layers['denseblock%d'%(i+1)] = block
|
||||
num_features = num_features + num_layers*growth_rate
|
||||
|
||||
if i != len(block_config)-1:
|
||||
trans = _Transition(num_input_features=num_features,
|
||||
num_output_features=num_features // 2)
|
||||
layers['transition%d'%(i+1)] = trans
|
||||
num_features = num_features // 2
|
||||
|
||||
# Final batch norm
|
||||
layers['norm5'] = nn.BatchNorm2d(num_features)
|
||||
layers['relu5'] = nn.ReLU()
|
||||
|
||||
self.features = nn.SequentialCell(layers)
|
||||
self.out_channels = num_features
|
||||
|
||||
def construct(self, x):
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
def get_out_channels(self):
|
||||
return self.out_channels
|
||||
|
||||
def _densenet121(**kwargs):
|
||||
return Densenet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, **kwargs)
|
||||
|
||||
|
||||
def _densenet161(**kwargs):
|
||||
return Densenet(growth_rate=48, block_config=(6, 12, 36, 24), num_init_features=96, **kwargs)
|
||||
|
||||
|
||||
def _densenet169(**kwargs):
|
||||
return Densenet(growth_rate=32, block_config=(6, 12, 32, 32), num_init_features=64, **kwargs)
|
||||
|
||||
|
||||
def _densenet201(**kwargs):
|
||||
return Densenet(growth_rate=32, block_config=(6, 12, 48, 32), num_init_features=64, **kwargs)
|
||||
|
||||
|
||||
|
||||
class DenseNet121(nn.Cell):
|
||||
"""
|
||||
the densenet121 architectur
|
||||
"""
|
||||
def __init__(self, num_classes):
|
||||
super(DenseNet121, self).__init__()
|
||||
self.backbone = _densenet121()
|
||||
out_channels = self.backbone.get_out_channels()
|
||||
self.head = CommonHead(num_classes, out_channels)
|
||||
|
||||
default_recurisive_init(self)
|
||||
for _, cell in self.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(KaimingNormal(a=math.sqrt(5), mode='fan_out',
|
||||
nonlinearity='relu'),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
elif isinstance(cell, nn.BatchNorm2d):
|
||||
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
|
||||
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.bias.set_data(init.initializer('zeros', cell.bias.shape))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.backbone(x)
|
||||
x = self.head(x)
|
||||
return x
|
@ -0,0 +1,41 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
get parameter function
|
||||
"""
|
||||
def get_param_groups(network):
|
||||
"""
|
||||
get parameter groups
|
||||
"""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
@ -0,0 +1,14 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
@ -0,0 +1,82 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
get logger.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
set up logging file.
|
||||
|
||||
Args:
|
||||
logger_name (string): logger name.
|
||||
log_dir (string): path of logger.
|
||||
|
||||
Returns:
|
||||
string, logger path
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""set up log file"""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER("mindversion", rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,204 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Initialize.
|
||||
"""
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
r"""
|
||||
Return the recommended gain value for the given nonlinearity function.
|
||||
|
||||
The values are as follows:
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of `num` to `arr`."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
def _calculate_in_and_out(arr):
|
||||
"""
|
||||
Calculate n_in and n_out.
|
||||
|
||||
Args:
|
||||
arr (Array): Input array.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dim = len(arr.shape)
|
||||
if dim < 2:
|
||||
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
|
||||
|
||||
n_in = arr.shape[1]
|
||||
n_out = arr.shape[0]
|
||||
|
||||
if dim > 2:
|
||||
counter = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||
n_in *= counter
|
||||
n_out *= counter
|
||||
return n_in, n_out
|
||||
|
||||
def _select_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_in_and_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
class KaimingInit(init.Initializer):
|
||||
r"""
|
||||
Base Class. Initialize the array with He kaiming algorithm.
|
||||
|
||||
Args:
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function, recommended to use only with
|
||||
``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingInit, self).__init__()
|
||||
self.mode = mode
|
||||
self.gain = _calculate_gain(nonlinearity, a)
|
||||
|
||||
def _initialize(self, arr):
|
||||
pass
|
||||
|
||||
|
||||
class KaimingUniform(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
class KaimingNormal(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming normal algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
|
||||
|
||||
.. math::
|
||||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""default_recurisive_init"""
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.shape, cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.shape), cell.bias.dtype))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.shape, cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(Tensor(np.random.uniform(-bound, bound, cell.bias.shape), cell.bias.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue