|
|
|
@ -16,6 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import six
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -62,17 +63,20 @@ class PReluTest(OpTest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
|
|
|
|
|
# class TestCase1(PReluTest):
|
|
|
|
|
# def initTestCase(self):
|
|
|
|
|
# self.attrs = {'mode': "all"}
|
|
|
|
|
if six.PY2:
|
|
|
|
|
|
|
|
|
|
# class TestCase2(PReluTest):
|
|
|
|
|
# def initTestCase(self):
|
|
|
|
|
# self.attrs = {'mode': "channel"}
|
|
|
|
|
class TestCase1(PReluTest):
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "all"}
|
|
|
|
|
|
|
|
|
|
class TestCase2(PReluTest):
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "channel"}
|
|
|
|
|
|
|
|
|
|
class TestCase3(PReluTest):
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
self.attrs = {'mode': "element"}
|
|
|
|
|
|
|
|
|
|
# class TestCase3(PReluTest):
|
|
|
|
|
# def initTestCase(self):
|
|
|
|
|
# self.attrs = {'mode': "element"}
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|