|
|
|
@ -32,30 +32,6 @@ class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
|
|
|
|
|
return [2, 3, 4, 5]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
|
|
|
|
|
def get_x_shape(self):
|
|
|
|
|
return [2, 3, 4, 5]
|
|
|
|
|
|
|
|
|
|
def get_axis(self):
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
|
|
|
|
|
def get_x_shape(self):
|
|
|
|
|
return [2, 3, 4, 5]
|
|
|
|
|
|
|
|
|
|
def get_axis(self):
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp):
|
|
|
|
|
def get_x_shape(self):
|
|
|
|
|
return [2, 3, 4, 5]
|
|
|
|
|
|
|
|
|
|
def get_axis(self):
|
|
|
|
|
return 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check if primitives already exist in backward
|
|
|
|
|
class TestSoftmaxMKLDNNPrimitivesAlreadyExist(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|