parent
ac5371b38f
commit
cd31275061
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Inference Interface"""
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import nn
|
||||||
|
|
||||||
|
from src.dataset import create_dataset_cifar10
|
||||||
|
from src.utils import count_params
|
||||||
|
from src.hournasnet import hournasnet
|
||||||
|
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluation')
|
||||||
|
parser.add_argument('--data_path', type=str, default='/home/workspace/mindspore_dataset/',
|
||||||
|
metavar='DIR', help='path to dataset')
|
||||||
|
parser.add_argument('--model', default='hournas_f_c10', type=str, metavar='MODEL',
|
||||||
|
help='Name of model to train (default: "tinynet_c"')
|
||||||
|
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
|
||||||
|
help='number of label classes (default: 10)')
|
||||||
|
parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',
|
||||||
|
help='input batch size for training (default: 256)')
|
||||||
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||||
|
help='how many training processes to use (default: 4)')
|
||||||
|
parser.add_argument('--ckpt', type=str, default='./ms_hournas_f_c10.ckpt',
|
||||||
|
help='model checkpoint to load')
|
||||||
|
parser.add_argument('--GPU', action='store_true', default=True,
|
||||||
|
help='Use GPU for training (default: True)')
|
||||||
|
parser.add_argument('--dataset_sink', action='store_true', default=True)
|
||||||
|
parser.add_argument('--image-size', type=int, default=32, metavar='N',
|
||||||
|
help='input image size (default: 32)')
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entrance for training"""
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(sys.argv)
|
||||||
|
|
||||||
|
#context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE)
|
||||||
|
|
||||||
|
if args.GPU:
|
||||||
|
context.set_context(device_target='GPU')
|
||||||
|
|
||||||
|
# parse model argument
|
||||||
|
assert args.model.startswith(
|
||||||
|
"hournas"), "Only Tinynet models are supported."
|
||||||
|
#_, sub_name = args.model.split("_")
|
||||||
|
net = hournasnet(args.model,
|
||||||
|
num_classes=args.num_classes,
|
||||||
|
drop_rate=0.0,
|
||||||
|
drop_connect_rate=0.0,
|
||||||
|
global_pool="avg",
|
||||||
|
bn_tf=False,
|
||||||
|
bn_momentum=None,
|
||||||
|
bn_eps=None)
|
||||||
|
print(net)
|
||||||
|
print("Total number of parameters:", count_params(net))
|
||||||
|
cfg = edict({'image_height': args.image_size, 'image_width': args.image_size,})
|
||||||
|
cfg.batch_size = args.batch_size
|
||||||
|
print(cfg)
|
||||||
|
|
||||||
|
#input_size = net.default_cfg['input_size'][1]
|
||||||
|
val_data_url = args.data_path #os.path.join(args.data_path, 'val')
|
||||||
|
val_dataset = create_dataset_cifar10(val_data_url, repeat_num=1, training=False, cifar_cfg=cfg)
|
||||||
|
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
|
|
||||||
|
eval_metrics = {'Validation-Loss': Loss(),
|
||||||
|
'Top1-Acc': Top1CategoricalAccuracy(),
|
||||||
|
'Top5-Acc': Top5CategoricalAccuracy()}
|
||||||
|
|
||||||
|
ckpt = load_checkpoint(args.ckpt)
|
||||||
|
load_param_into_net(net, ckpt)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
model = Model(net, loss, metrics=eval_metrics)
|
||||||
|
|
||||||
|
metrics = model.eval(val_dataset, dataset_sink_mode=False)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""hub config."""
|
||||||
|
from src.hournasnet import hournasnet
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(name, *args, **kwargs):
|
||||||
|
if name == 'HourNAS':
|
||||||
|
return hournasnet(*args, **kwargs)
|
||||||
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,55 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Architecture of HourNAS"""
|
||||||
|
predefine_archs = {
|
||||||
|
'hournas_f_c10': {
|
||||||
|
'genotypes': [
|
||||||
|
#'conv3bnrelu',
|
||||||
|
'ir_k3_e1_se',
|
||||||
|
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k5_e1_se', 'ir_k3_e1_se',
|
||||||
|
'ir_k5_e6_se', 'ir_k5_e1_se', 'ir_k3_e1_se', 'ir_k5_e1_se',
|
||||||
|
'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se', 'ir_k3_e6_se',
|
||||||
|
'ir_k5_e6_se', 'ir_k5_e3_se', 'ir_k5_e3_se', 'ir_k5_e3_se',
|
||||||
|
'ir_k5_e6_se', 'ir_k5_e6_se', 'ir_k3_e6_se', 'ir_k5_e6_se',
|
||||||
|
'ir_k5_e6_se',
|
||||||
|
#'conv1', 'adaavgpool'
|
||||||
|
],
|
||||||
|
'strides': [
|
||||||
|
#1,
|
||||||
|
1,
|
||||||
|
1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1,
|
||||||
|
2, 1, 1, 1,
|
||||||
|
1, 1, 1, 1,
|
||||||
|
2, 1, 1, 1,
|
||||||
|
1,
|
||||||
|
#1, 1
|
||||||
|
],
|
||||||
|
'out_channels': [
|
||||||
|
#32,
|
||||||
|
16,
|
||||||
|
24, 24, 24, 24,
|
||||||
|
40, 40, 40, 40,
|
||||||
|
80, 80, 80, 80,
|
||||||
|
112, 112, 112, 112,
|
||||||
|
192, 192, 192, 192,
|
||||||
|
320,
|
||||||
|
#1280, 1280,
|
||||||
|
],
|
||||||
|
'dropout_ratio': 0.2,
|
||||||
|
'default_init': 'True',
|
||||||
|
'se_ratio': '0.05'
|
||||||
|
},
|
||||||
|
}
|
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Data operations, will be used in train.py and eval.py"""
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset.vision.py_transforms as py_vision
|
||||||
|
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||||
|
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
from mindspore.dataset.vision import Inter
|
||||||
|
import mindspore.dataset.vision.c_transforms as vision
|
||||||
|
|
||||||
|
# values that should remain constant
|
||||||
|
DEFAULT_CROP_PCT = 0.875
|
||||||
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||||
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||||
|
|
||||||
|
# data preprocess configs
|
||||||
|
SCALE = (0.08, 1.0)
|
||||||
|
RATIO = (3./4., 4./3.)
|
||||||
|
|
||||||
|
ds.config.set_seed(1)
|
||||||
|
|
||||||
|
|
||||||
|
def split_imgs_and_labels(imgs, labels, batchInfo):
|
||||||
|
"""split data into labels and images"""
|
||||||
|
ret_imgs = []
|
||||||
|
ret_labels = []
|
||||||
|
|
||||||
|
for i, image in enumerate(imgs):
|
||||||
|
ret_imgs.append(image)
|
||||||
|
ret_labels.append(labels[i])
|
||||||
|
return np.array(ret_imgs), np.array(ret_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
|
||||||
|
input_size=224, color_jitter=0.4):
|
||||||
|
"""Create ImageNet training dataset"""
|
||||||
|
if not os.path.exists(train_data_url):
|
||||||
|
raise ValueError('Path not exists')
|
||||||
|
decode_op = py_vision.Decode()
|
||||||
|
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
|
||||||
|
scale=SCALE, ratio=RATIO,
|
||||||
|
interpolation=Inter.BICUBIC)
|
||||||
|
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
|
||||||
|
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
|
||||||
|
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
|
||||||
|
contrast=adjust_range,
|
||||||
|
saturation=adjust_range)
|
||||||
|
to_tensor = py_vision.ToTensor()
|
||||||
|
normalize_op = py_vision.Normalize(
|
||||||
|
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||||
|
|
||||||
|
# assemble all the transforms
|
||||||
|
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
|
||||||
|
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
|
||||||
|
|
||||||
|
rank_id = get_rank() if distributed else 0
|
||||||
|
rank_size = get_group_size() if distributed else 1
|
||||||
|
|
||||||
|
dataset_train = ds.ImageFolderDataset(train_data_url,
|
||||||
|
num_parallel_workers=workers,
|
||||||
|
shuffle=True,
|
||||||
|
num_shards=rank_size,
|
||||||
|
shard_id=rank_id)
|
||||||
|
|
||||||
|
dataset_train = dataset_train.map(input_columns=["image"],
|
||||||
|
operations=image_ops,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
|
||||||
|
dataset_train = dataset_train.map(input_columns=["label"],
|
||||||
|
operations=type_cast_op,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
|
||||||
|
# batch dealing
|
||||||
|
ds_train = dataset_train.batch(batch_size,
|
||||||
|
per_batch_map=split_imgs_and_labels,
|
||||||
|
input_columns=["image", "label"],
|
||||||
|
num_parallel_workers=2,
|
||||||
|
drop_remainder=True)
|
||||||
|
|
||||||
|
ds_train = ds_train.repeat(1)
|
||||||
|
return ds_train
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
|
||||||
|
input_size=224):
|
||||||
|
"""Create ImageNet validation dataset"""
|
||||||
|
if not os.path.exists(val_data_url):
|
||||||
|
raise ValueError('Path not exists')
|
||||||
|
rank_id = get_rank() if distributed else 0
|
||||||
|
rank_size = get_group_size() if distributed else 1
|
||||||
|
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
|
||||||
|
num_shards=rank_size, shard_id=rank_id)
|
||||||
|
scale_size = None
|
||||||
|
|
||||||
|
if isinstance(input_size, tuple):
|
||||||
|
assert len(input_size) == 2
|
||||||
|
if input_size[-1] == input_size[-2]:
|
||||||
|
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
|
||||||
|
else:
|
||||||
|
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
|
||||||
|
else:
|
||||||
|
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
|
||||||
|
|
||||||
|
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||||
|
decode_op = py_vision.Decode()
|
||||||
|
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
|
||||||
|
center_crop = py_vision.CenterCrop(size=input_size)
|
||||||
|
to_tensor = py_vision.ToTensor()
|
||||||
|
normalize_op = py_vision.Normalize(
|
||||||
|
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||||
|
|
||||||
|
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
|
||||||
|
to_tensor, normalize_op])
|
||||||
|
|
||||||
|
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
dataset = dataset.map(input_columns=["image"], operations=image_ops,
|
||||||
|
num_parallel_workers=workers)
|
||||||
|
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
|
||||||
|
input_columns=["image", "label"],
|
||||||
|
num_parallel_workers=2,
|
||||||
|
drop_remainder=True)
|
||||||
|
dataset = dataset.repeat(1)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def _get_rank_info():
|
||||||
|
"""
|
||||||
|
get rank size and rank id
|
||||||
|
"""
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
|
if rank_size > 1:
|
||||||
|
rank_size = get_group_size()
|
||||||
|
rank_id = get_rank()
|
||||||
|
else:
|
||||||
|
rank_size = rank_id = None
|
||||||
|
|
||||||
|
return rank_size, rank_id
|
||||||
|
|
||||||
|
def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None):
|
||||||
|
"""Data operations."""
|
||||||
|
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||||
|
if not training:
|
||||||
|
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
|
||||||
|
|
||||||
|
rank_size, rank_id = _get_rank_info()
|
||||||
|
if training:
|
||||||
|
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
|
||||||
|
else:
|
||||||
|
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||||
|
|
||||||
|
resize_height = cifar_cfg.image_height
|
||||||
|
resize_width = cifar_cfg.image_width
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||||
|
random_horizontal_op = vision.RandomHorizontalFlip()
|
||||||
|
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||||
|
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
|
||||||
|
#normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||||
|
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.24703233, 0.24348505, 0.26158768))
|
||||||
|
changeswap_op = vision.HWC2CHW()
|
||||||
|
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
c_trans = []
|
||||||
|
if training:
|
||||||
|
c_trans = [random_crop_op, random_horizontal_op]
|
||||||
|
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
|
||||||
|
|
||||||
|
# apply map operations on images
|
||||||
|
data_set = data_set.map(operations=type_cast_op, input_columns="label")
|
||||||
|
data_set = data_set.map(operations=c_trans, input_columns="image")
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply repeat operations
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
|
||||||
|
return data_set
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""model utils"""
|
||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(value):
|
||||||
|
"""Convert string arguments to bool type"""
|
||||||
|
if value.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
|
return True
|
||||||
|
if value.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||||
|
return False
|
||||||
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
|
||||||
|
warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
|
||||||
|
"""Get scheduled learning rate"""
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
global_steps = steps_per_epoch * global_epoch
|
||||||
|
self_warmup_delta = ((base_lr - warmup_lr_init) / \
|
||||||
|
warmup_epochs) if warmup_epochs > 0 else 0
|
||||||
|
self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
|
||||||
|
for i in range(total_steps):
|
||||||
|
epochs = math.floor(i/steps_per_epoch)
|
||||||
|
cond = 1 if (epochs < warmup_epochs) else 0
|
||||||
|
warmup_lr = warmup_lr_init + epochs * self_warmup_delta
|
||||||
|
decay_nums = math.floor(epochs / decay_epochs)
|
||||||
|
decay_rate = math.pow(self_decay_rate, decay_nums)
|
||||||
|
decay_lr = base_lr * decay_rate
|
||||||
|
lr = cond * warmup_lr + (1 - cond) * decay_lr
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
lr_each_step = lr_each_step[global_steps:]
|
||||||
|
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||||
|
return lr_each_step
|
||||||
|
|
||||||
|
|
||||||
|
def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
|
||||||
|
"""Apply weight decay to only conv and dense layers (len(shape) > =2)
|
||||||
|
Args:
|
||||||
|
net (mindspore.nn.Cell): Mindspore network instance
|
||||||
|
weight_decay (float): weight decay tobe used.
|
||||||
|
skip_list (tuple): list of parameter names without weight decay
|
||||||
|
Returns:
|
||||||
|
A list of group of parameters, separated by different weight decay.
|
||||||
|
"""
|
||||||
|
decay = []
|
||||||
|
no_decay = []
|
||||||
|
if not skip_list:
|
||||||
|
skip_list = ()
|
||||||
|
for param in net.trainable_params():
|
||||||
|
if len(param.shape) == 1 or \
|
||||||
|
param.name.endswith(".bias") or \
|
||||||
|
param.name in skip_list:
|
||||||
|
no_decay.append(param)
|
||||||
|
else:
|
||||||
|
decay.append(param)
|
||||||
|
return [
|
||||||
|
{'params': no_decay, 'weight_decay': 0.},
|
||||||
|
{'params': decay, 'weight_decay': weight_decay}]
|
||||||
|
|
||||||
|
|
||||||
|
def count_params(net):
|
||||||
|
"""Count number of parameters in the network
|
||||||
|
Args:
|
||||||
|
net (mindspore.nn.Cell): Mindspore network instance
|
||||||
|
Returns:
|
||||||
|
total_params (int): Total number of trainable params
|
||||||
|
"""
|
||||||
|
total_params = 0
|
||||||
|
for param in net.trainable_params():
|
||||||
|
total_params += np.prod(param.shape)
|
||||||
|
return total_params
|
Loading…
Reference in new issue