|
|
@ -123,6 +123,7 @@ def batch_norm(x,
|
|
|
|
momentum=0.9,
|
|
|
|
momentum=0.9,
|
|
|
|
epsilon=1e-05,
|
|
|
|
epsilon=1e-05,
|
|
|
|
data_format="NCHW",
|
|
|
|
data_format="NCHW",
|
|
|
|
|
|
|
|
use_global_stats=None,
|
|
|
|
name=None):
|
|
|
|
name=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
|
|
|
|
Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
|
|
|
@ -139,6 +140,7 @@ def batch_norm(x,
|
|
|
|
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
|
|
|
|
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
|
|
|
|
training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False.
|
|
|
|
training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False.
|
|
|
|
data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW".
|
|
|
|
data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW".
|
|
|
|
|
|
|
|
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
|
|
|
|
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
|
|
|
|
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
@ -167,8 +169,6 @@ def batch_norm(x,
|
|
|
|
|
|
|
|
|
|
|
|
assert len(x.shape) >= 2, "input dim must be larger than 1"
|
|
|
|
assert len(x.shape) >= 2, "input dim must be larger than 1"
|
|
|
|
|
|
|
|
|
|
|
|
# we use not training means use_global_status, more details see nn._BatchNormBase
|
|
|
|
|
|
|
|
use_global_stats = not training
|
|
|
|
|
|
|
|
# input ad out must share the memory
|
|
|
|
# input ad out must share the memory
|
|
|
|
mean_out = running_mean
|
|
|
|
mean_out = running_mean
|
|
|
|
variance_out = running_var
|
|
|
|
variance_out = running_var
|
|
|
@ -181,11 +181,18 @@ def batch_norm(x,
|
|
|
|
|
|
|
|
|
|
|
|
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
|
|
|
|
data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_global_stats == None:
|
|
|
|
|
|
|
|
use_global_stats = not training
|
|
|
|
|
|
|
|
trainable_statistics = False
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
trainable_statistics = not use_global_stats
|
|
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
if in_dygraph_mode():
|
|
|
|
# for dygraph need tuple
|
|
|
|
# for dygraph need tuple
|
|
|
|
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
|
|
|
|
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
|
|
|
|
data_format, "use_mkldnn", False, "fuse_with_relu", False,
|
|
|
|
data_format, "use_mkldnn", False, "fuse_with_relu", False,
|
|
|
|
"use_global_stats", use_global_stats)
|
|
|
|
"use_global_stats", use_global_stats, "trainable_statistics",
|
|
|
|
|
|
|
|
trainable_statistics)
|
|
|
|
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
|
|
|
|
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
|
|
|
|
x, weight, bias, running_mean, running_var, mean_out, variance_out,
|
|
|
|
x, weight, bias, running_mean, running_var, mean_out, variance_out,
|
|
|
|
*attrs)
|
|
|
|
*attrs)
|
|
|
@ -204,6 +211,7 @@ def batch_norm(x,
|
|
|
|
"use_mkldnn": False,
|
|
|
|
"use_mkldnn": False,
|
|
|
|
"fuse_with_relu": False,
|
|
|
|
"fuse_with_relu": False,
|
|
|
|
"use_global_stats": use_global_stats,
|
|
|
|
"use_global_stats": use_global_stats,
|
|
|
|
|
|
|
|
"trainable_statistics": trainable_statistics,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputs = {
|
|
|
|
inputs = {
|
|
|
|