|
|
|
@ -78,7 +78,7 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("Weight",
|
|
|
|
|
"The input weight tensor of spectral_norm operator, "
|
|
|
|
|
"This can be a 2-D, 3-D, 4-D, 5-D tensor which is the"
|
|
|
|
|
"This can be a 2-D, 3-D, 4-D, 5-D tensor which is the "
|
|
|
|
|
"weights of fc, conv1d, conv2d, conv3d layer.");
|
|
|
|
|
AddInput("U",
|
|
|
|
|
"The weight_u tensor of spectral_norm operator, "
|
|
|
|
@ -90,29 +90,29 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"be in shape [C, 1].");
|
|
|
|
|
AddInput("V",
|
|
|
|
|
"The weight_v tensor of spectral_norm operator, "
|
|
|
|
|
"This can be a 1-D tensor in shape [W, 1],"
|
|
|
|
|
"W is the 2nd dimentions of Weight after reshape"
|
|
|
|
|
"corresponding by Attr(dim). As for Attr(dim) = 1"
|
|
|
|
|
"in conv2d layer with weight shape [M, C, K1, K2]"
|
|
|
|
|
"Weight will be reshape to [C, M*K1*K2], V will"
|
|
|
|
|
"This can be a 1-D tensor in shape [W, 1], "
|
|
|
|
|
"W is the 2nd dimentions of Weight after reshape "
|
|
|
|
|
"corresponding by Attr(dim). As for Attr(dim) = 1 "
|
|
|
|
|
"in conv2d layer with weight shape [M, C, K1, K2] "
|
|
|
|
|
"Weight will be reshape to [C, M*K1*K2], V will "
|
|
|
|
|
"be in shape [M*K1*K2, 1].");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"The output weight tensor of spectral_norm operator, "
|
|
|
|
|
"This tensor is in same shape with Input(Weight).");
|
|
|
|
|
|
|
|
|
|
AddAttr<int>("dim",
|
|
|
|
|
"dimension corresponding to number of outputs,"
|
|
|
|
|
"it should be set as 0 if Input(Weight) is the"
|
|
|
|
|
"weight of fc layer, and should be set as 1 if"
|
|
|
|
|
"Input(Weight) is the weight of conv layer,"
|
|
|
|
|
"default is 0.")
|
|
|
|
|
"dimension corresponding to number of outputs, "
|
|
|
|
|
"it should be set as 0 if Input(Weight) is the "
|
|
|
|
|
"weight of fc layer, and should be set as 1 if "
|
|
|
|
|
"Input(Weight) is the weight of conv layer, "
|
|
|
|
|
"default 0.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<int>("power_iters",
|
|
|
|
|
"number of power iterations to calculate"
|
|
|
|
|
"spectral norm, default is 1.")
|
|
|
|
|
"number of power iterations to calculate "
|
|
|
|
|
"spectral norm, default 1.")
|
|
|
|
|
.SetDefault(1);
|
|
|
|
|
AddAttr<float>("eps",
|
|
|
|
|
"epsilob for numerical stability in"
|
|
|
|
|
"epsilob for numerical stability in "
|
|
|
|
|
"calculating norms")
|
|
|
|
|
.SetDefault(1e-12);
|
|
|
|
|
|
|
|
|
@ -126,20 +126,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
with spectral normalize value.
|
|
|
|
|
|
|
|
|
|
For spectral normalization calculations, we rescaling weight
|
|
|
|
|
tensor with \sigma, while \sigma{\mathbf{W}} is
|
|
|
|
|
tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is
|
|
|
|
|
|
|
|
|
|
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
|
|
|
|
$$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$
|
|
|
|
|
|
|
|
|
|
We calculate \sigma{\mathbf{W}} through power iterations as
|
|
|
|
|
We calculate :math:`\sigma{\mathbf{W}}` through power iterations as
|
|
|
|
|
|
|
|
|
|
$$
|
|
|
|
|
\mathbf{v} = \mathbf{W}^{T} \mathbf{u}
|
|
|
|
|
\mathbf{v} = \frac{\mathbf{v}}{\|\mathbf{v}\|_2}
|
|
|
|
|
$$
|
|
|
|
|
$$
|
|
|
|
|
\mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2}
|
|
|
|
|
$$
|
|
|
|
|
$$
|
|
|
|
|
\mathbf{u} = \mathbf{W}^{T} \mathbf{v}
|
|
|
|
|
\mathbf{u} = \frac{\mathbf{u}}{\|\mathbf{u}\|_2}
|
|
|
|
|
$$
|
|
|
|
|
$$
|
|
|
|
|
\mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2}
|
|
|
|
|
$$
|
|
|
|
|
|
|
|
|
|
And \sigma should be
|
|
|
|
|
And :math:`\sigma` should be
|
|
|
|
|
|
|
|
|
|
\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
|
|
|
|
|
$$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$
|
|
|
|
|
|
|
|
|
|
For details of spectral normalization, please refer to paper:
|
|
|
|
|
`Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
|
|
|
|
|