|
|
|
@ -16,9 +16,11 @@ from __future__ import division
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from op_test import OpTest, skip_check_grad_ci
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid.framework import program_guard, Program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def spectral_norm(weight, u, v, dim, power_iters, eps):
|
|
|
|
@ -125,5 +127,46 @@ class TestSpectralNormOp2(TestSpectralNormOp):
|
|
|
|
|
self.eps = 1e-12
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSpectralNormOpError(unittest.TestCase):
|
|
|
|
|
def test_errors(self):
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
|
|
|
|
|
def test_Variable():
|
|
|
|
|
weight_1 = np.random.random((2, 4)).astype("float32")
|
|
|
|
|
fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2)
|
|
|
|
|
|
|
|
|
|
# the weight type of spectral_norm must be Variable
|
|
|
|
|
self.assertRaises(TypeError, test_Variable)
|
|
|
|
|
|
|
|
|
|
def test_weight_dtype():
|
|
|
|
|
weight_2 = np.random.random((2, 4)).astype("int32")
|
|
|
|
|
fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2)
|
|
|
|
|
|
|
|
|
|
# the data type of type must be float32 or float64
|
|
|
|
|
self.assertRaises(TypeError, test_weight_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDygraphSpectralNormOpError(unittest.TestCase):
|
|
|
|
|
def test_errors(self):
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
shape = (2, 4, 3, 3)
|
|
|
|
|
spectralNorm = fluid.dygraph.nn.SpectralNorm(
|
|
|
|
|
shape, dim=1, power_iters=2)
|
|
|
|
|
|
|
|
|
|
def test_Variable():
|
|
|
|
|
weight_1 = np.random.random((2, 4)).astype("float32")
|
|
|
|
|
spectralNorm(weight_1)
|
|
|
|
|
|
|
|
|
|
# the weight type of SpectralNorm must be Variable
|
|
|
|
|
self.assertRaises(TypeError, test_Variable)
|
|
|
|
|
|
|
|
|
|
def test_weight_dtype():
|
|
|
|
|
weight_2 = np.random.random((2, 4)).astype("int32")
|
|
|
|
|
spectralNorm(weight_2)
|
|
|
|
|
|
|
|
|
|
# the data type of type must be float32 or float64
|
|
|
|
|
self.assertRaises(TypeError, test_weight_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|