|
|
|
@ -21,6 +21,9 @@ from op_test import OpTest
|
|
|
|
|
|
|
|
|
|
class PReluTest(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
print('setUp')
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.op_type = "prelu"
|
|
|
|
|
self.initTestCase()
|
|
|
|
|
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
|
|
|
|
@ -39,32 +42,45 @@ class PReluTest(OpTest):
|
|
|
|
|
alpha_np = np.random.rand(*x_np.shape).astype("float32")
|
|
|
|
|
self.inputs = {'X': x_np, 'Alpha': alpha_np}
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
print('self.inputs', self.inputs)
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
out_np = np.maximum(self.inputs['X'], 0.)
|
|
|
|
|
out_np = out_np + np.minimum(self.inputs['X'],
|
|
|
|
|
0.) * self.inputs['Alpha']
|
|
|
|
|
assert out_np is not self.inputs['X']
|
|
|
|
|
self.outputs = {'Out': out_np}
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
|
print('tearDown')
|
|
|
|
|
import sys
|
|
|
|
|
print('self.outputs', self.outputs)
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.outputs = {'Out': out_np}
|
|
|
|
|
del self.outputs
|
|
|
|
|
del self.inputs
|
|
|
|
|
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
def test_check_4_output(self):
|
|
|
|
|
print('test_check_0_output')
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
|
self.check_grad(['X', 'Alpha'], 'Out')
|
|
|
|
|
|
|
|
|
|
def test_check_grad_ignore_x(self):
|
|
|
|
|
def test_check_0_grad_2_ignore_x(self):
|
|
|
|
|
print('test_check_2_grad_2_ignore_x')
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
|
|
|
|
|
|
|
|
|
|
def test_check_grad_ignore_alpha(self):
|
|
|
|
|
# TODO(minqiyang): remove the order of tests
|
|
|
|
|
def test_check_1_grad_1(self):
|
|
|
|
|
print('test_check_1_grad_1')
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.check_grad(['X', 'Alpha'], 'Out')
|
|
|
|
|
|
|
|
|
|
def test_check_3_grad_3_ignore_alpha(self):
|
|
|
|
|
print('test_check_3_grad_3_ignore_alpha')
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
self.check_grad(['X'], 'Out', no_grad_set=set('Alpha'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -73,15 +89,14 @@ class TestCase1(PReluTest):
|
|
|
|
|
self.attrs = {'mode': "all"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCase2(PReluTest):
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCase3(PReluTest):
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "element"}
|
|
|
|
|
|
|
|
|
|
#class TestCase2(PReluTest):
|
|
|
|
|
# def initTestCase(self):
|
|
|
|
|
# self.attrs = {'mode': "channel"}
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
#class TestCase3(PReluTest):
|
|
|
|
|
# def initTestCase(self):
|
|
|
|
|
# self.attrs = {'mode': "element"}
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|