error message of SpectralNorm OP enhancement (#23516)

revert-23830-2.0-beta
silingtong123 5 years ago committed by GitHub
parent 076dcdfde9
commit f9e2a27963
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2989,6 +2989,8 @@ class SpectralNorm(layers.Layer):
self.weight_v.stop_gradient = True
def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(

@ -16,9 +16,11 @@ from __future__ import division
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid import core
from paddle.fluid.framework import program_guard, Program
def spectral_norm(weight, u, v, dim, power_iters, eps):
@ -125,5 +127,46 @@ class TestSpectralNormOp2(TestSpectralNormOp):
self.eps = 1e-12
class TestSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
fluid.layers.spectral_norm(weight_1, dim=1, power_iters=2)
# the weight type of spectral_norm must be Variable
self.assertRaises(TypeError, test_Variable)
def test_weight_dtype():
weight_2 = np.random.random((2, 4)).astype("int32")
fluid.layers.spectral_norm(weight_2, dim=1, power_iters=2)
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = (2, 4, 3, 3)
spectralNorm = fluid.dygraph.nn.SpectralNorm(
shape, dim=1, power_iters=2)
def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
spectralNorm(weight_1)
# the weight type of SpectralNorm must be Variable
self.assertRaises(TypeError, test_Variable)
def test_weight_dtype():
weight_2 = np.random.random((2, 4)).astype("int32")
spectralNorm(weight_2)
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save