add weight_norm & remove_weight_norm (#26131)
* add weight_norm, test=developrevert-24895-update_cub
parent
facc0a10c9
commit
fd66d76231
@ -0,0 +1,183 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy
|
||||||
|
import collections
|
||||||
|
from functools import reduce
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
class TestDygraphWeightNorm(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.init_test_case()
|
||||||
|
self.set_data()
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.batch_size = 3
|
||||||
|
self.data_desc = (['x', [2, 3, 3]], )
|
||||||
|
self.dim = None
|
||||||
|
|
||||||
|
def set_data(self):
|
||||||
|
self.data = collections.OrderedDict()
|
||||||
|
for desc in self.data_desc:
|
||||||
|
data_name = desc[0]
|
||||||
|
data_shape = desc[1]
|
||||||
|
data_value = numpy.random.random(
|
||||||
|
size=[self.batch_size] + data_shape).astype('float32')
|
||||||
|
self.data[data_name] = data_value
|
||||||
|
|
||||||
|
def norm_except_dim(self, w, dim=None):
|
||||||
|
shape = w.shape
|
||||||
|
ndims = len(shape)
|
||||||
|
shape_numel = reduce(lambda x, y: x * y, shape)
|
||||||
|
if dim == -1:
|
||||||
|
return numpy.linalg.norm(w, axis=None, keepdims=True)
|
||||||
|
elif dim == 0:
|
||||||
|
tile_shape = list(w.shape)
|
||||||
|
tile_shape[0] = 1
|
||||||
|
w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0]))
|
||||||
|
return numpy.linalg.norm(w_matrix, axis=1, keepdims=True)
|
||||||
|
elif dim == (ndims - 1):
|
||||||
|
w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1]))
|
||||||
|
return numpy.linalg.norm(w_matrix, axis=0, keepdims=True)
|
||||||
|
else:
|
||||||
|
perm = list(range(ndims))
|
||||||
|
perm_ori = list(range(ndims))
|
||||||
|
perm[0] = dim
|
||||||
|
perm[dim] = 0
|
||||||
|
p_transposed = numpy.transpose(w, perm)
|
||||||
|
return self.norm_except_dim(p_transposed, 0)
|
||||||
|
|
||||||
|
def weight_normalize(self, w, dim=None):
|
||||||
|
shape = w.shape
|
||||||
|
ndims = len(shape)
|
||||||
|
shape_numel = reduce(lambda x, y: x * y, shape)
|
||||||
|
v = w
|
||||||
|
g = self.norm_except_dim(w, dim)
|
||||||
|
g_mul = g
|
||||||
|
|
||||||
|
if dim == -1:
|
||||||
|
v_norm = v / (numpy.linalg.norm(v, axis=None, keepdims=True))
|
||||||
|
elif dim == 0:
|
||||||
|
w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0]))
|
||||||
|
v_norm = v / numpy.linalg.norm(w_matrix, axis=1)
|
||||||
|
v_norm = numpy.reshape(v_norm, shape)
|
||||||
|
g = numpy.squeeze(g, axis=1)
|
||||||
|
elif dim == (ndims - 1):
|
||||||
|
w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1]))
|
||||||
|
v_norm = v / numpy.linalg.norm(w_matrix, axis=0, keepdims=True)
|
||||||
|
v_norm = numpy.reshape(v_norm, shape)
|
||||||
|
else:
|
||||||
|
perm = list(range(ndims))
|
||||||
|
perm[0] = dim
|
||||||
|
perm[dim] = 0
|
||||||
|
p_transposed = numpy.transpose(v, perm)
|
||||||
|
transposed_shape = p_transposed.shape
|
||||||
|
transposed_shape_numel = reduce(lambda x, y: x * y,
|
||||||
|
transposed_shape)
|
||||||
|
p_matrix = numpy.reshape(
|
||||||
|
p_transposed, (p_transposed.shape[0],
|
||||||
|
transposed_shape_numel // p_transposed.shape[0]))
|
||||||
|
v_norm = v / numpy.expand_dims(
|
||||||
|
numpy.expand_dims(
|
||||||
|
numpy.linalg.norm(
|
||||||
|
p_matrix, axis=1, keepdims=True), axis=0),
|
||||||
|
axis=(ndims - 1))
|
||||||
|
v_norm = numpy.reshape(v_norm, transposed_shape)
|
||||||
|
v_norm = numpy.transpose(v_norm, perm)
|
||||||
|
g = numpy.squeeze(g, axis=1)
|
||||||
|
if dim == 1:
|
||||||
|
eaxis = 2
|
||||||
|
elif dim == 2:
|
||||||
|
eaxis = 1
|
||||||
|
g_mul = numpy.expand_dims(
|
||||||
|
numpy.expand_dims(
|
||||||
|
numpy.expand_dims(
|
||||||
|
g, axis=0), axis=eaxis),
|
||||||
|
axis=(ndims - 1))
|
||||||
|
w = g_mul * v_norm
|
||||||
|
return g, v
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
fluid.enable_imperative()
|
||||||
|
linear = paddle.nn.Conv2D(2, 3, 3)
|
||||||
|
before_weight = linear.weight.numpy()
|
||||||
|
if self.dim == None:
|
||||||
|
self.dim = -1
|
||||||
|
wn = weight_norm(linear, dim=self.dim)
|
||||||
|
outputs = []
|
||||||
|
for name, data in self.data.items():
|
||||||
|
output = linear(fluid.dygraph.to_variable(data))
|
||||||
|
outputs.append(output.numpy())
|
||||||
|
after_weight = linear.weight
|
||||||
|
self.actual_outputs = [linear.weight_g.numpy(), linear.weight_v.numpy()]
|
||||||
|
|
||||||
|
expect_output = self.weight_normalize(before_weight, self.dim)
|
||||||
|
|
||||||
|
for expect, actual in zip(expect_output, self.actual_outputs):
|
||||||
|
self.assertTrue(
|
||||||
|
numpy.allclose(
|
||||||
|
numpy.array(actual), expect, atol=0.001))
|
||||||
|
|
||||||
|
|
||||||
|
class TestDygraphWeightNormCase1(TestDygraphWeightNorm):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.batch_size = 3
|
||||||
|
self.data_desc = (['x', [2, 3, 3]], )
|
||||||
|
self.dim = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDygraphWeightNormCase2(TestDygraphWeightNorm):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.batch_size = 3
|
||||||
|
self.data_desc = (['x', [2, 3, 3]], )
|
||||||
|
self.dim = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestDygraphWeightNormCase3(TestDygraphWeightNorm):
|
||||||
|
def init_test_case(self):
|
||||||
|
self.batch_size = 3
|
||||||
|
self.data_desc = (['x', [2, 3, 3]], )
|
||||||
|
self.dim = 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestDygraphRemoveWeightNorm(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.init_test_case()
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.batch_size = 3
|
||||||
|
self.data_desc = (['x', [2, 3, 3]], )
|
||||||
|
self.dim = None
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
fluid.enable_imperative()
|
||||||
|
linear = paddle.nn.Conv2D(2, 3, 3)
|
||||||
|
before_weight = linear.weight
|
||||||
|
wn = weight_norm(linear, dim=self.dim)
|
||||||
|
rwn = remove_weight_norm(linear)
|
||||||
|
after_weight = linear.weight
|
||||||
|
self.assertTrue(
|
||||||
|
numpy.allclose(
|
||||||
|
before_weight.numpy(), after_weight.numpy(), atol=0.001))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from . import weight_norm_hook
|
||||||
|
from .weight_norm_hook import weight_norm, remove_weight_norm
|
||||||
@ -0,0 +1,225 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from ... import fluid
|
||||||
|
from ...fluid import dygraph
|
||||||
|
from ...fluid import layers as F
|
||||||
|
from ...fluid.layer_helper import LayerHelper
|
||||||
|
from ...fluid.data_feeder import check_variable_and_dtype
|
||||||
|
from ...tensor.math import multiply
|
||||||
|
|
||||||
|
__all__ = ['weight_norm', 'remove_weight_norm']
|
||||||
|
|
||||||
|
|
||||||
|
def l2_norm(x, axis, epsilon=1e-12, name=None):
|
||||||
|
if len(x.shape) == 1:
|
||||||
|
axis = 0
|
||||||
|
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
|
||||||
|
|
||||||
|
helper = LayerHelper("l2_normalize", **locals())
|
||||||
|
out = helper.create_variable_for_type_inference(dtype=x.dtype)
|
||||||
|
norm = helper.create_variable_for_type_inference(dtype=x.dtype)
|
||||||
|
helper.append_op(
|
||||||
|
type="norm",
|
||||||
|
inputs={"X": x},
|
||||||
|
outputs={"Out": out,
|
||||||
|
"Norm": norm},
|
||||||
|
attrs={
|
||||||
|
"axis": 1 if axis is None else axis,
|
||||||
|
"epsilon": epsilon,
|
||||||
|
})
|
||||||
|
return F.squeeze(norm, axes=[axis])
|
||||||
|
|
||||||
|
|
||||||
|
def norm_except_dim(p, dim):
|
||||||
|
shape = p.shape
|
||||||
|
ndims = len(shape)
|
||||||
|
if dim == -1:
|
||||||
|
return F.sqrt(F.reduce_sum(F.square(p)) + 1e-12)
|
||||||
|
elif dim == 0:
|
||||||
|
p_matrix = F.reshape(p, (shape[0], -1))
|
||||||
|
return l2_norm(p_matrix, axis=1)
|
||||||
|
elif dim == ndims - 1:
|
||||||
|
p_matrix = F.reshape(p, (-1, shape[-1]))
|
||||||
|
return l2_norm(p_matrix, axis=0)
|
||||||
|
else:
|
||||||
|
perm = list(range(ndims))
|
||||||
|
perm[0] = dim
|
||||||
|
perm[dim] = 0
|
||||||
|
p_transposed = F.transpose(p, perm)
|
||||||
|
return norm_except_dim(p_transposed, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_norm(v, g, dim):
|
||||||
|
shape = v.shape
|
||||||
|
ndims = len(shape)
|
||||||
|
|
||||||
|
if dim == -1:
|
||||||
|
v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
|
||||||
|
elif dim == 0:
|
||||||
|
p_matrix = F.reshape(v, (shape[0], -1))
|
||||||
|
v_normalized = F.l2_normalize(p_matrix, axis=1)
|
||||||
|
v_normalized = F.reshape(v_normalized, shape)
|
||||||
|
elif dim == ndims - 1:
|
||||||
|
p_matrix = F.reshape(v, (-1, shape[-1]))
|
||||||
|
v_normalized = F.l2_normalize(p_matrix, axis=0)
|
||||||
|
v_normalized = F.reshape(v_normalized, shape)
|
||||||
|
else:
|
||||||
|
perm = list(range(ndims))
|
||||||
|
perm[0] = dim
|
||||||
|
perm[dim] = 0
|
||||||
|
p_transposed = F.transpose(v, perm)
|
||||||
|
transposed_shape = p_transposed.shape
|
||||||
|
p_matrix = F.reshape(p_transposed, (p_transposed.shape[0], -1))
|
||||||
|
v_normalized = F.l2_normalize(p_matrix, axis=1)
|
||||||
|
v_normalized = F.reshape(v_normalized, transposed_shape)
|
||||||
|
v_normalized = F.transpose(v_normalized, perm)
|
||||||
|
weight = multiply(v_normalized, g, axis=dim if dim is not None else -1)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
class WeightNorm(object):
|
||||||
|
def __init__(self, name, dim):
|
||||||
|
if dim is None:
|
||||||
|
dim = -1
|
||||||
|
self.name = name
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def compute_weight(self, layer):
|
||||||
|
g = getattr(layer, self.name + '_g')
|
||||||
|
v = getattr(layer, self.name + '_v')
|
||||||
|
return _weight_norm(v, g, self.dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(layer, name, dim):
|
||||||
|
for k, hook in layer._forward_pre_hooks.items():
|
||||||
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
||||||
|
raise RuntimeError("Cannot register two weight_norm hooks on "
|
||||||
|
"the same parameter {}".format(name))
|
||||||
|
|
||||||
|
if dim is None:
|
||||||
|
dim = -1
|
||||||
|
|
||||||
|
fn = WeightNorm(name, dim)
|
||||||
|
|
||||||
|
w = getattr(layer, name)
|
||||||
|
del layer._parameters[name]
|
||||||
|
|
||||||
|
g_var = norm_except_dim(w, dim)
|
||||||
|
v = layer.create_parameter(w.shape, dtype=w.dtype)
|
||||||
|
layer.add_parameter(name + "_v", v)
|
||||||
|
g = layer.create_parameter(g_var.shape, dtype=g_var.dtype)
|
||||||
|
layer.add_parameter(name + '_g', g)
|
||||||
|
with dygraph.no_grad():
|
||||||
|
F.assign(w, v)
|
||||||
|
F.assign(g_var, g)
|
||||||
|
setattr(layer, name, fn.compute_weight(layer))
|
||||||
|
|
||||||
|
layer.register_forward_pre_hook(fn)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def remove(self, layer):
|
||||||
|
w_var = self.compute_weight(layer)
|
||||||
|
delattr(layer, self.name)
|
||||||
|
del layer._parameters[self.name + '_g']
|
||||||
|
del layer._parameters[self.name + '_v']
|
||||||
|
w = layer.create_parameter(w_var.shape, dtype=w_var.dtype)
|
||||||
|
layer.add_parameter(self.name, w)
|
||||||
|
with dygraph.no_grad():
|
||||||
|
F.assign(w_var, w)
|
||||||
|
|
||||||
|
def __call__(self, layer, inputs):
|
||||||
|
setattr(layer, self.name, self.compute_weight(layer))
|
||||||
|
|
||||||
|
|
||||||
|
def weight_norm(layer, name='weight', dim=0):
|
||||||
|
"""
|
||||||
|
This weight_norm layer applies weight normalization to a parameter according to the
|
||||||
|
following formula:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\mathbf{w} = g \dfrac{v}{\|v\|}
|
||||||
|
|
||||||
|
Weight normalization is a reparameterization of the weight vectors in a neural network that
|
||||||
|
decouples the magnitude of those weight vectors from their direction. Weight normalization
|
||||||
|
replaces the parameter specified by `name`(eg: 'weight') with two parameters: one parameter
|
||||||
|
specifying the magnitude (eg: 'weight_g') and one parameter specifying the direction
|
||||||
|
(eg: 'weight_v'). Weight normalization has been implemented as discussed in this paper:
|
||||||
|
`Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks
|
||||||
|
<https://arxiv.org/pdf/1602.07868.pdf>`_.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
layer(Layer): Layer of paddle, which has weight.
|
||||||
|
name(str, optional): Name of the weight parameter. Default: 'weight'.
|
||||||
|
dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative number
|
||||||
|
which is less than the rank of weight Tensor. For Example, dim can be chosen from 0,
|
||||||
|
1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw] and rank is 4.
|
||||||
|
If dim is set to None, meaning that all elements will be normalized. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Origin layer with weight norm hook.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from paddle.nn import Conv2D
|
||||||
|
from paddle.nn.utils import weight_norm
|
||||||
|
|
||||||
|
x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
|
||||||
|
paddle.disable_static()
|
||||||
|
conv = Conv2D(3, 5, 3)
|
||||||
|
wn = weight_norm(conv)
|
||||||
|
print(conv.weight_g.shape)
|
||||||
|
# [5]
|
||||||
|
print(conv.weight_v.shape)
|
||||||
|
# [5, 3, 3, 3]
|
||||||
|
"""
|
||||||
|
WeightNorm.apply(layer, name, dim)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def remove_weight_norm(layer, name='weight'):
|
||||||
|
"""
|
||||||
|
remove weight normalization from layer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
layer(Layer): Layer of paddle, which has weight.
|
||||||
|
name(str, optional): Name of the weight parameter. Default: 'weight'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Origin layer without weight norm
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
import paddle
|
||||||
|
from paddle.nn import Conv2D
|
||||||
|
from paddle.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
|
||||||
|
paddle.disable_static()
|
||||||
|
conv = Conv2D(3, 5, 3)
|
||||||
|
wn = weight_norm(conv)
|
||||||
|
remove_weight_norm(conv)
|
||||||
|
print(conv.weight_g)
|
||||||
|
# AttributeError: 'Conv2D' object has no attribute 'weight_g'
|
||||||
|
"""
|
||||||
|
for k, hook in layer._forward_pre_hooks.items():
|
||||||
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
||||||
|
hook.remove(layer)
|
||||||
|
del layer._forward_pre_hooks[k]
|
||||||
|
return layer
|
||||||
|
|
||||||
|
raise ValueError("weight_norm of '{}' not found in {}".format(name, layer))
|
||||||
Loading…
Reference in new issue