add checkouts for input parameters of InstanceNorm2d

pull/13178/head
zhouyuanshen 4 years ago
parent 4eba6e0b9b
commit f3c257fd97

@ -38,7 +38,6 @@ class InstanceNormGradGpuKernel : public GpuKernel {
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
is_training_(true),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
@ -79,7 +78,7 @@ class InstanceNormGradGpuKernel : public GpuKernel {
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 3);
workspace_addr = GetDeviceAddress<T>(workspace, 1);
}
size_t N = input_shape_[0];
@ -93,18 +92,14 @@ class InstanceNormGradGpuKernel : public GpuKernel {
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
float *reserve_addr = nullptr;
if (is_training_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
MS_LOG(EXCEPTION) << "The backward of InstanceNorm operator in evaluation mode is not implemented yet.";
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -115,8 +110,7 @@ class InstanceNormGradGpuKernel : public GpuKernel {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
InitResource();
is_training_ = GetAttr<bool>(kernel_node, "is_training");
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
@ -214,7 +208,6 @@ class InstanceNormGradGpuKernel : public GpuKernel {
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
double epsilon_;
bool is_training_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;

@ -14,12 +14,14 @@
# ============================================================================
"""normalization"""
import itertools
import numbers
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor
from mindspore.common._decorator import deprecated
from mindspore.ops.primitive import constexpr
import mindspore.context as context
@ -868,6 +870,18 @@ class InstanceNorm2d(Cell):
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
is calculated via the biased estimator.
By default, this layer uses instance statistics computed from input data in both training and evaluation modes.
If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during evaluation. The running estimates are
kept with a default momentum of 0.1.
InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
Note:
Note that the formula for updating the running_mean and running_var is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
@ -880,17 +894,13 @@ class InstanceNorm2d(Cell):
running_mean and running_var computation. Default: 0.1.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. Default: True.
@ -905,7 +915,16 @@ class InstanceNorm2d(Cell):
``GPU``
Raise:
ValueError: If num_features is less than 1 or momentum not in (0, 1).
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If `momentum` is not a float.
TypeError: If `affine` is not a bool.
TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if
the initialized element type is not float32.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous
class inheriting from `Initializer` not exists.
Examples:
>>> net = nn.InstanceNorm2d(3)
@ -926,9 +945,15 @@ class InstanceNorm2d(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
input_dims='2d'):
use_batch_statistics=True):
super(InstanceNorm2d, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
validator.check_value_type('eps', eps, [float], self.cls_name)
validator.check_value_type('momentum', momentum, [float], self.cls_name)
validator.check_value_type('affine', affine, [bool], self.cls_name)
args_input = {"gamma_init": gamma_init, "beta_init": beta_init,
"moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init}
self.check_types_valid(args_input, 'InstanceNorm2d')
if num_features < 1:
raise ValueError("num_features must be at least 1")
@ -937,7 +962,7 @@ class InstanceNorm2d(Cell):
self.use_batch_statistics = use_batch_statistics
self.num_features = num_features
self.eps = eps
self.input_dims = input_dims
self.input_dims = '2d'
self.moving_mean = Parameter(initializer(
moving_mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer(
@ -968,6 +993,15 @@ class InstanceNorm2d(Cell):
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
def check_types_valid(self, args_dict, name):
for key, _ in args_dict.items():
val = args_dict[key]
if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
raise TypeError(f"[{name}]Supported type for arg {key} is [Tensor, numbers.Number, str, Initializer],"
f"but got {type(val)}")
if isinstance(val, Tensor) and val.dtype is not float:
raise TypeError(f"[{name}]The type of arg {key} should be float32, but got {val.dtype}")
class GroupNorm(Cell):
r"""

@ -693,8 +693,7 @@ def get_bprop_fused_batch_norm_ex(self):
@bprop_getters.register(P.InstanceNorm)
def get_bprop_instance_norm(self):
"""Grad definition for `InstanceNorm` operation."""
is_training = self.is_training
input_grad = G.InstanceNormGrad(is_training, self.epsilon, self.momentum)
input_grad = G.InstanceNormGrad(self.epsilon, self.momentum)
def bprop(x, gamma, beta, mean, variance, out, dout):
saved_mean = out[1]

@ -759,7 +759,7 @@ class InstanceNormGrad(PrimitiveWithInfer):
"""Gradients of InstanceNorm operation."""
@prim_attr_register
def __init__(self, is_training=True, epsilon=0.0, momentum=0.1):
def __init__(self, epsilon=0.0, momentum=0.1):
self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
outputs=['dx', 'bn_gamma', 'bn_beta'])

Loading…
Cancel
Save