Merge pull request #15864 from heavengate/spectral_norm

Add spectral norm op
align_pyramid
Kaipeng Deng 7 years ago committed by GitHub
commit 6d8771b55c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -128,6 +128,7 @@ paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'par
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '013795af319e2e86d3506741941078ee'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '6148b6a555cbfb62fdcd030d8982c18c'))

@ -0,0 +1,197 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class SpectralNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of SpectralNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("U"),
"Input(U) of SpectralNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("V"),
"Input(V) of SpectralNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SpectralNormOp should not be null.");
auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size();
PADDLE_ENFORCE(rank_weight >= 2 && rank_weight <= 5,
"The rank of Input(Weights) can only be 2, 3,"
"4, 5 for fc, conv1d, conv2d, conv3d layers.");
int dim = ctx->Attrs().Get<int>("dim");
int power_iters = ctx->Attrs().Get<int>("power_iters");
PADDLE_ENFORCE(dim == 0 || dim == 1, "Attr(dim) can only be 0 or 1");
PADDLE_ENFORCE(power_iters >= 0,
"Attr(power_iters) should be larger equal then 0");
int h = dim_weight[dim];
int w = 1;
for (int i = 0; i < rank_weight; i++) {
if (i != dim) {
w *= dim_weight[i];
}
}
auto dim_u = ctx->GetInputDim("U");
auto dim_v = ctx->GetInputDim("V");
PADDLE_ENFORCE_EQ(dim_u[0], h,
"Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]");
PADDLE_ENFORCE_EQ(
dim_v[0], w,
"Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]");
ctx->SetOutputDim("Out", dim_weight);
ctx->ShareLoD("Weight", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(),
ctx.GetPlace());
}
};
class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Weight",
"The input weight tensor of spectral_norm operator, "
"This can be a 2-D, 3-D, 4-D, 5-D tensor which is the "
"weights of fc, conv1d, conv2d, conv3d layer.");
AddInput("U",
"The weight_u tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [H, 1],"
"H is the 1st dimentions of Weight after reshape"
"corresponding by Attr(dim). As for Attr(dim) = 1"
"in conv2d layer with weight shape [M, C, K1, K2]"
"Weight will be reshape to [C, M*K1*K2], U will"
"be in shape [C, 1].");
AddInput("V",
"The weight_v tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [W, 1], "
"W is the 2nd dimentions of Weight after reshape "
"corresponding by Attr(dim). As for Attr(dim) = 1 "
"in conv2d layer with weight shape [M, C, K1, K2] "
"Weight will be reshape to [C, M*K1*K2], V will "
"be in shape [M*K1*K2, 1].");
AddOutput("Out",
"The output weight tensor of spectral_norm operator, "
"This tensor is in same shape with Input(Weight).");
AddAttr<int>("dim",
"The index of dimension which should be permuted "
"to the first before reshaping Input(Weight) to "
"matrix, it should be set as 0 if Input(Weight) is "
"the weight of fc layer, and should be set as 1 if "
"Input(Weight) is the weight of conv layer, "
"default 0.")
.SetDefault(0);
AddAttr<int>("power_iters",
"number of power iterations to calculate "
"spectral norm, default 1.")
.SetDefault(1);
AddAttr<float>("eps",
"epsilon for numerical stability in "
"calculating norms")
.SetDefault(1e-12);
AddComment(R"DOC(
This layer calculates the spectral normalization value of weight of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
tensor.
Spectral normalization stabilizes the training of critic in GANs
(Generative Adversarial Networks). This layer rescaling weight tensor
with spectral normalize value.
For spectral normalization calculations, we rescaling weight
tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is
$$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$
We calculate :math:`\sigma{\mathbf{W}}` through power iterations as
$$
\mathbf{v} = \mathbf{W}^{T} \mathbf{u}
$$
$$
\mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2}
$$
$$
\mathbf{u} = \mathbf{W}^{T} \mathbf{v}
$$
$$
\mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2}
$$
And :math:`\sigma` should be
$$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$
For details of spectral normalization, please refer to paper:
`Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
)DOC");
}
};
class SpectralNormOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Weight"), "Input(Weight) should not be null");
PADDLE_ENFORCE(ctx->HasInput("U"), "Input(U) should not be null");
PADDLE_ENFORCE(ctx->HasInput("V"), "Input(V) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("Weight");
if (ctx->HasOutput(framework::GradVarName("Weight"))) {
ctx->SetOutputDim(framework::GradVarName("Weight"), dim_x);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(spectral_norm, ops::SpectralNormOp, ops::SpectralNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(spectral_norm_grad, ops::SpectralNormOpGrad);
REGISTER_OP_CPU_KERNEL(
spectral_norm,
ops::SpectralNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpectralNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
spectral_norm_grad,
ops::SpectralNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpectralNormGradKernel<paddle::platform::CPUDeviceContext, double>);

@ -0,0 +1,22 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
spectral_norm,
ops::SpectralNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpectralNormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
spectral_norm_grad,
ops::SpectralNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpectralNormGradKernel<paddle::platform::CUDADeviceContext, double>);

File diff suppressed because it is too large Load Diff

@ -94,6 +94,7 @@ __all__ = [
'multiplex',
'layer_norm',
'group_norm',
'spectral_norm',
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
@ -3346,6 +3347,98 @@ def group_norm(input,
return helper.append_activation(group_norm_out)
@templatedoc()
def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
"""
**Spectral Normalization Layer**
This layer calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` shoule be a positive interger, do following
calculations with U and V for :attr:`power_iters` rounds.
.. math::
\mathbf{v} := \\frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} := \\frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} = \\frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(${weight_type}): ${weight_comment}
dim(int): ${dim_comment}
power_iters(int): ${power_iters_comment}
eps(float): ${eps_comment}
name (str): The name of this layer. It is optional.
Returns:
Variable: A tensor variable of weight parameters after spectral normalization.
Examples:
>>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2)
"""
helper = LayerHelper('spectral_norm', **locals())
dtype = weight.dtype
# create intput and parameters
inputs = {'Weight': weight}
input_shape = weight.shape
h = input_shape[dim]
w = np.prod(input_shape) // h
u = helper.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=dtype,
default_initializer=Normal(0., 1.))
u.stop_gradient = True
inputs['U'] = u
v = helper.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=dtype,
default_initializer=Normal(0., 1.))
inputs['V'] = v
v.stop_gradient = True
# create output
out = helper.create_variable(dtype=dtype)
helper.append_op(
type="spectral_norm",
inputs=inputs,
outputs={"Out": out, },
attrs={
"dim": dim,
"power_iters": power_iters,
"eps": eps,
})
return out
def conv2d_transpose(input,
num_filters,
output_size=None,

@ -1035,6 +1035,19 @@ class TestBook(unittest.TestCase):
print(str(program))
def test_spectral_norm(self):
program = Program()
with program_guard(program):
weight = layers.data(
name='weight',
shape=[2, 3, 32, 32],
dtype="float32",
append_batch_size=False)
out = layers.spectral_norm(weight, dim=1, power_iters=1)
self.assertIsNotNone(out)
print(str(program))
def test_shuffle_channel(self):
program = Program()
with program_guard(program):

@ -0,0 +1,122 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def spectral_norm(weight, u, v, dim, power_iters, eps):
shape = weight.shape
weight_mat = weight.copy()
h = shape[dim]
w = np.prod(shape) // h
if dim != 0:
perm = [dim] + [d for d in range(len(shape)) if d != dim]
weight_mat = weight_mat.transpose(perm)
weight_mat = weight_mat.reshape((h, w))
u = u.reshape((h, 1))
v = v.reshape((w, 1))
for i in range(power_iters):
v = np.matmul(weight_mat.T, u)
v_norm = np.sqrt((v * v).sum())
v = v / (v_norm + eps)
u = np.matmul(weight_mat, v)
u_norm = np.sqrt((u * u).sum())
u = u / (u_norm + eps)
sigma = (u * np.matmul(weight_mat, v)).sum()
return weight / sigma
class TestSpectralNormOpNoGrad(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'spectral_norm'
weight = np.random.random(self.weight_shape).astype('float32')
u = np.random.normal(0., 1., self.u_shape).astype('float32')
v = np.random.normal(0., 1., self.v_shape).astype('float32')
self.attrs = {
"dim": self.dim,
"power_iters": self.power_iters,
"eps": self.eps,
}
self.inputs = {
"Weight": weight,
"U": u,
"V": v,
}
output = spectral_norm(weight, u, v, self.dim, self.power_iters,
self.eps)
self.outputs = {"Out": output}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.weight_shape = (2, 3)
self.u_shape = (2, )
self.v_shape = (3, )
self.dim = 0
self.power_iters = 5
self.eps = 1e-12
class TestSpectralNormOpNoGrad2(TestSpectralNormOpNoGrad):
def initTestCase(self):
self.weight_shape = (2, 3, 3, 3)
self.u_shape = (3, )
self.v_shape = (18, )
self.dim = 1
self.power_iters = 10
self.eps = 1e-12
class TestSpectralNormOp(TestSpectralNormOpNoGrad):
def test_check_grad_ignore_uv(self):
self.check_grad(
['Weight'],
'Out',
no_grad_set=set(["U", "V"]),
max_relative_error=0.1)
def initTestCase(self):
self.weight_shape = (2, 3)
self.u_shape = (2, )
self.v_shape = (3, )
self.dim = 0
self.power_iters = 0
self.eps = 1e-12
class TestSpectralNormOp2(TestSpectralNormOp):
def initTestCase(self):
self.weight_shape = (2, 3, 3, 3)
self.u_shape = (3, )
self.v_shape = (18, )
self.dim = 1
self.power_iters = 0
self.eps = 1e-12
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save