|
|
|
@ -17,15 +17,16 @@ from __future__ import print_function
|
|
|
|
import unittest
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import six
|
|
|
|
import six
|
|
|
|
from op_test import OpTest
|
|
|
|
from op_test import OpTest, skip_check_grad_ci
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PReluTest(OpTest):
|
|
|
|
class PReluTest(OpTest):
|
|
|
|
def setUp(self):
|
|
|
|
def setUp(self):
|
|
|
|
|
|
|
|
self.init_input_shape()
|
|
|
|
|
|
|
|
self.init_attr()
|
|
|
|
self.op_type = "prelu"
|
|
|
|
self.op_type = "prelu"
|
|
|
|
self.initTestCase()
|
|
|
|
|
|
|
|
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_np = np.random.uniform(-1, 1, self.x_shape).astype("float32")
|
|
|
|
# Since zero point in prelu is not differentiable, avoid randomize
|
|
|
|
# Since zero point in prelu is not differentiable, avoid randomize
|
|
|
|
# zero.
|
|
|
|
# zero.
|
|
|
|
x_np[np.abs(x_np) < 0.005] = 0.02
|
|
|
|
x_np[np.abs(x_np) < 0.005] = 0.02
|
|
|
|
@ -37,8 +38,8 @@ class PReluTest(OpTest):
|
|
|
|
alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
|
|
|
|
alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
|
|
|
|
self.inputs = {'X': x_np, 'Alpha': alpha_np}
|
|
|
|
self.inputs = {'X': x_np, 'Alpha': alpha_np}
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \
|
|
|
|
alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2],
|
|
|
|
x_np.shape[3]).astype("float32")
|
|
|
|
x_np.shape[3]).astype("float32")
|
|
|
|
self.inputs = {'X': x_np, 'Alpha': alpha_np}
|
|
|
|
self.inputs = {'X': x_np, 'Alpha': alpha_np}
|
|
|
|
|
|
|
|
|
|
|
|
out_np = np.maximum(self.inputs['X'], 0.)
|
|
|
|
out_np = np.maximum(self.inputs['X'], 0.)
|
|
|
|
@ -47,7 +48,10 @@ class PReluTest(OpTest):
|
|
|
|
assert out_np is not self.inputs['X']
|
|
|
|
assert out_np is not self.inputs['X']
|
|
|
|
self.outputs = {'Out': out_np}
|
|
|
|
self.outputs = {'Out': out_np}
|
|
|
|
|
|
|
|
|
|
|
|
def initTestCase(self):
|
|
|
|
def init_input_shape(self):
|
|
|
|
|
|
|
|
self.x_shape = (2, 100, 3, 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_attr(self):
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
def test_check_output(self):
|
|
|
|
@ -66,16 +70,21 @@ class PReluTest(OpTest):
|
|
|
|
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
|
|
|
|
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
|
|
|
|
if six.PY2:
|
|
|
|
if six.PY2:
|
|
|
|
|
|
|
|
|
|
|
|
class TestCase1(PReluTest):
|
|
|
|
@skip_check_grad_ci(
|
|
|
|
def initTestCase(self):
|
|
|
|
reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
class TestModeAll(PReluTest):
|
|
|
|
|
|
|
|
def init_input_shape(self):
|
|
|
|
|
|
|
|
self.x_shape = (2, 3, 4, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_attr(self):
|
|
|
|
self.attrs = {'mode': "all"}
|
|
|
|
self.attrs = {'mode': "all"}
|
|
|
|
|
|
|
|
|
|
|
|
class TestCase2(PReluTest):
|
|
|
|
class TestModeElt(PReluTest):
|
|
|
|
def initTestCase(self):
|
|
|
|
def init_input_shape(self):
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
self.x_shape = (3, 2, 5, 10)
|
|
|
|
|
|
|
|
|
|
|
|
class TestCase3(PReluTest):
|
|
|
|
def init_attr(self):
|
|
|
|
def initTestCase(self):
|
|
|
|
|
|
|
|
self.attrs = {'mode': "element"}
|
|
|
|
self.attrs = {'mode': "element"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|