!3508 modelzoo: support VGG network in GPU
Merge pull request !3508 from ms_yan/vgg_gpupull/3508/MERGE
commit
f15c6e2b77
@ -0,0 +1,39 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
|
||||
def __init__(self, smooth_factor=0., num_classes=1001):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
@ -0,0 +1,82 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
get logger.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
set up logging file.
|
||||
|
||||
Args:
|
||||
logger_name (string): logger name.
|
||||
log_dir (string): path of logger.
|
||||
|
||||
Returns:
|
||||
string, logger path
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""set up log file"""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER("mindversion", rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,53 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
choose samples from the dataset
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class DistributedSampler():
|
||||
"""
|
||||
sampling the dataset.
|
||||
|
||||
Args:
|
||||
Returns:
|
||||
num_samples, number of samples.
|
||||
"""
|
||||
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.group_size = group_size
|
||||
self.dataset_length = len(self.dataset)
|
||||
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
|
||||
self.total_size = self.num_samples * self.group_size
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
self.seed = (self.seed + 1) & 0xffffffff
|
||||
np.random.seed(self.seed)
|
||||
indices = np.random.permutation(self.dataset_length).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset_length)))
|
||||
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = indices[self.rank::self.group_size]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
@ -0,0 +1,36 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
"""Param groups for optimizer."""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
@ -0,0 +1,213 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Initialize.
|
||||
"""
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import initializer as init
|
||||
|
||||
def _calculate_gain(nonlinearity, param=None):
|
||||
r"""
|
||||
Return the recommended gain value for the given nonlinearity function.
|
||||
|
||||
The values are as follows:
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of `num` to `arr`."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
def _calculate_in_and_out(arr):
|
||||
"""
|
||||
Calculate n_in and n_out.
|
||||
|
||||
Args:
|
||||
arr (Array): Input array.
|
||||
|
||||
Returns:
|
||||
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
|
||||
"""
|
||||
dim = len(arr.shape)
|
||||
if dim < 2:
|
||||
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
|
||||
|
||||
n_in = arr.shape[1]
|
||||
n_out = arr.shape[0]
|
||||
|
||||
if dim > 2:
|
||||
counter = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||
n_in *= counter
|
||||
n_out *= counter
|
||||
return n_in, n_out
|
||||
|
||||
def _select_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_in_and_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
class KaimingInit(init.Initializer):
|
||||
r"""
|
||||
Base Class. Initialize the array with He kaiming algorithm.
|
||||
|
||||
Args:
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function, recommended to use only with
|
||||
``'relu'`` or ``'leaky_relu'`` (default).
|
||||
"""
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingInit, self).__init__()
|
||||
self.mode = mode
|
||||
self.gain = _calculate_gain(nonlinearity, a)
|
||||
def _initialize(self, arr):
|
||||
pass
|
||||
|
||||
|
||||
class KaimingUniform(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
class KaimingNormal(KaimingInit):
|
||||
r"""
|
||||
Initialize the array with He kaiming normal algorithm. The resulting tensor will
|
||||
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
|
||||
|
||||
.. math::
|
||||
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
|
||||
|
||||
Input:
|
||||
arr (Array): The array to be assigned.
|
||||
|
||||
Returns:
|
||||
Array, assigned array.
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
|
||||
"""
|
||||
|
||||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""default_recurisive_init"""
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||
cell.bias.default_input.dtype)
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||
cell.bias.default_input.dtype)
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
@ -0,0 +1,40 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
warm up cosine annealing learning rate.
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""warm up cosine annealing learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
warm up step learning rate.
|
||||
"""
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_value = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr_value = float(lr_max) * base * base
|
||||
if lr_value < 0.0:
|
||||
lr_value = 0.0
|
||||
lr_each_step.append(lr_value)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""warmup_step_lr"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue