|
|
|
@ -27,6 +27,7 @@
|
|
|
|
|
|
|
|
|
|
# TODO: define normalization api
|
|
|
|
|
|
|
|
|
|
import six
|
|
|
|
|
from ...fluid.dygraph.nn import InstanceNorm
|
|
|
|
|
|
|
|
|
|
from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS
|
|
|
|
@ -36,7 +37,6 @@ from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS
|
|
|
|
|
from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS
|
|
|
|
|
|
|
|
|
|
from ...fluid.dygraph import layers
|
|
|
|
|
|
|
|
|
|
from ...framework import get_default_dtype, set_default_dtype
|
|
|
|
|
from ...fluid.framework import in_dygraph_mode
|
|
|
|
|
|
|
|
|
@ -50,6 +50,7 @@ from ..functional import batch_norm, layer_norm, instance_norm
|
|
|
|
|
import numpy as np
|
|
|
|
|
import numbers
|
|
|
|
|
import warnings
|
|
|
|
|
from ...fluid.dygraph.base import no_grad
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm',
|
|
|
|
@ -566,17 +567,28 @@ class _BatchNormBase(layers.Layer):
|
|
|
|
|
param_shape = [num_features]
|
|
|
|
|
|
|
|
|
|
# create parameter
|
|
|
|
|
if weight_attr == False:
|
|
|
|
|
self.weight = self.create_parameter(
|
|
|
|
|
attr=None, shape=param_shape, default_initializer=Constant(1.0))
|
|
|
|
|
self.weight.stop_gradient = True
|
|
|
|
|
else:
|
|
|
|
|
self.weight = self.create_parameter(
|
|
|
|
|
attr=self._weight_attr,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
default_initializer=Constant(1.0))
|
|
|
|
|
self.weight.stop_gradient = (self._weight_attr is False) or (
|
|
|
|
|
self._weight_attr and self._weight_attr.learning_rate == 0.)
|
|
|
|
|
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
|
|
|
|
|
|
|
|
|
|
if bias_attr == False:
|
|
|
|
|
self.bias = self.create_parameter(
|
|
|
|
|
attr=None,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
default_initializer=Constant(0.0),
|
|
|
|
|
is_bias=True)
|
|
|
|
|
self.bias.stop_gradient = True
|
|
|
|
|
else:
|
|
|
|
|
self.bias = self.create_parameter(
|
|
|
|
|
attr=self._bias_attr, shape=param_shape, is_bias=True)
|
|
|
|
|
self.bias.stop_gradient = (self._bias_attr is False) or (
|
|
|
|
|
self._bias_attr and self._bias_attr.learning_rate == 0.)
|
|
|
|
|
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
|
|
|
|
|
|
|
|
|
|
moving_mean_name = None
|
|
|
|
|
moving_variance_name = None
|
|
|
|
@ -611,6 +623,7 @@ class _BatchNormBase(layers.Layer):
|
|
|
|
|
self._epsilon = epsilon
|
|
|
|
|
self._fuse_with_relu = False
|
|
|
|
|
self._track_running_stats = track_running_stats
|
|
|
|
|
self._name = name
|
|
|
|
|
|
|
|
|
|
def _check_input_dim(self, input):
|
|
|
|
|
raise NotImplementedError("BatchNorm Base error")
|
|
|
|
@ -898,7 +911,7 @@ class BatchNorm3d(_BatchNormBase):
|
|
|
|
|
len(input.shape)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SyncBatchNorm(layers.Layer):
|
|
|
|
|
class SyncBatchNorm(_BatchNormBase):
|
|
|
|
|
"""
|
|
|
|
|
This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
|
|
|
|
|
It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can
|
|
|
|
@ -984,72 +997,16 @@ class SyncBatchNorm(layers.Layer):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
epsilon=1e-05,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
track_running_stats=True,
|
|
|
|
|
epsilon=1e-05,
|
|
|
|
|
weight_attr=None,
|
|
|
|
|
bias_attr=None,
|
|
|
|
|
data_format='NCHW',
|
|
|
|
|
track_running_stats=True,
|
|
|
|
|
name=None):
|
|
|
|
|
super(SyncBatchNorm, self).__init__()
|
|
|
|
|
self._weight_attr = weight_attr
|
|
|
|
|
self._bias_attr = bias_attr
|
|
|
|
|
self._num_features = num_features
|
|
|
|
|
self._data_layout = data_format
|
|
|
|
|
self._momentum = momentum
|
|
|
|
|
self._epsilon = epsilon
|
|
|
|
|
self._track_running_stats = track_running_stats
|
|
|
|
|
|
|
|
|
|
if self._track_running_stats == False:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"moving mean and moving variance will be calculated whether `track_running_stats` is set to `True` or `False`, we will fix it in the next version."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
param_shape = [self._num_features]
|
|
|
|
|
|
|
|
|
|
# create parameter
|
|
|
|
|
if weight_attr == False:
|
|
|
|
|
self.weight = self.create_parameter(
|
|
|
|
|
attr=None, shape=param_shape, default_initializer=Constant(1.0))
|
|
|
|
|
self.weight.stop_gradient = True
|
|
|
|
|
else:
|
|
|
|
|
self.weight = self.create_parameter(
|
|
|
|
|
attr=self._weight_attr,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
default_initializer=Constant(1.0))
|
|
|
|
|
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
|
|
|
|
|
|
|
|
|
|
if bias_attr == False:
|
|
|
|
|
self.bias = self.create_parameter(
|
|
|
|
|
attr=None,
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
default_initializer=Constant(0.0),
|
|
|
|
|
is_bias=True)
|
|
|
|
|
self.bias.stop_gradient = True
|
|
|
|
|
else:
|
|
|
|
|
self.bias = self.create_parameter(
|
|
|
|
|
attr=self._bias_attr, shape=param_shape, is_bias=True)
|
|
|
|
|
self.bias.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
|
|
|
|
|
|
|
|
|
|
self._mean = self.create_parameter(
|
|
|
|
|
attr=ParamAttr(
|
|
|
|
|
name=None,
|
|
|
|
|
initializer=Constant(0.0),
|
|
|
|
|
trainable=False,
|
|
|
|
|
do_model_average=True),
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
dtype=self._dtype)
|
|
|
|
|
self._mean.stop_gradient = True
|
|
|
|
|
|
|
|
|
|
self._variance = self.create_parameter(
|
|
|
|
|
attr=ParamAttr(
|
|
|
|
|
name=None,
|
|
|
|
|
initializer=Constant(1.0),
|
|
|
|
|
trainable=False,
|
|
|
|
|
do_model_average=True),
|
|
|
|
|
shape=param_shape,
|
|
|
|
|
dtype=self._dtype)
|
|
|
|
|
self._variance.stop_gradient = True
|
|
|
|
|
super(SyncBatchNorm,
|
|
|
|
|
self).__init__(num_features, momentum, epsilon, weight_attr,
|
|
|
|
|
bias_attr, data_format, track_running_stats, name)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
# create output
|
|
|
|
@ -1063,7 +1020,7 @@ class SyncBatchNorm(layers.Layer):
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
|
|
|
|
|
"is_test", not self.training, "data_layout",
|
|
|
|
|
self._data_layout, "use_mkldnn", False, "fuse_with_relu",
|
|
|
|
|
self._data_format, "use_mkldnn", False, "fuse_with_relu",
|
|
|
|
|
False, "use_global_stats", False, 'trainable_statistics',
|
|
|
|
|
False)
|
|
|
|
|
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
|
|
|
|
@ -1073,13 +1030,13 @@ class SyncBatchNorm(layers.Layer):
|
|
|
|
|
return sync_batch_norm_out
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
|
|
|
|
|
'BatchNorm')
|
|
|
|
|
'SyncBatchNorm')
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
"momentum": self._momentum,
|
|
|
|
|
"epsilon": self._epsilon,
|
|
|
|
|
"is_test": not self.training,
|
|
|
|
|
"data_layout": self._data_layout,
|
|
|
|
|
"data_layout": self._data_format,
|
|
|
|
|
"use_mkldnn": False,
|
|
|
|
|
"fuse_with_relu": False,
|
|
|
|
|
"use_global_stats": False,
|
|
|
|
@ -1112,3 +1069,45 @@ class SyncBatchNorm(layers.Layer):
|
|
|
|
|
self._helper.append_op(
|
|
|
|
|
type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
|
|
|
|
|
return sync_batch_norm_out
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def convert_sync_batchnorm(cls, layer):
|
|
|
|
|
"""
|
|
|
|
|
Helper function to convert :class: `paddle.nn.BatchNorm*d` layers in the model to :class: `paddle.nn.SyncBatchNorm` layers.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
layer(paddle.nn.Layer): model containing one or more `BatchNorm*d` layers.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The original model with converted SyncBatchNorm layers. If BatchNorm*d layer in the model, use SyncBatchNorm layer instead.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.nn as nn
|
|
|
|
|
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
model = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5))
|
|
|
|
|
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
layer_output = layer
|
|
|
|
|
if isinstance(layer, _BatchNormBase):
|
|
|
|
|
layer_output = SyncBatchNorm(layer._num_features, layer._epsilon,
|
|
|
|
|
layer._momentum, layer._weight_attr,
|
|
|
|
|
layer._bias_attr, layer._data_format,
|
|
|
|
|
layer._name)
|
|
|
|
|
|
|
|
|
|
if layer._weight_attr != False and layer._bias_attr != False:
|
|
|
|
|
with no_grad():
|
|
|
|
|
layer_output.weight = layer.weight
|
|
|
|
|
layer_output.bias = layer.bias
|
|
|
|
|
layer_output._mean = layer._mean
|
|
|
|
|
layer_output._variance = layer._variance
|
|
|
|
|
|
|
|
|
|
for name, sublayer in layer.named_sublayers():
|
|
|
|
|
layer_output.add_sublayer(name,
|
|
|
|
|
cls.convert_sync_batchnorm(sublayer))
|
|
|
|
|
del layer
|
|
|
|
|
return layer_output
|
|
|
|
|