|
|
|
|
@ -156,9 +156,14 @@ class TestMin8DOp(OpTest):
|
|
|
|
|
class TestProdOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "reduce_prod"
|
|
|
|
|
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
|
|
|
|
|
self.init_data_type()
|
|
|
|
|
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.data_type)}
|
|
|
|
|
self.outputs = {'Out': self.inputs['X'].prod(axis=0)}
|
|
|
|
|
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
self.data_type = "float32" if core.is_compiled_with_rocm(
|
|
|
|
|
) else "float64"
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
@ -169,14 +174,19 @@ class TestProdOp(OpTest):
|
|
|
|
|
class TestProd6DOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "reduce_prod"
|
|
|
|
|
self.init_data_type()
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype("float64")
|
|
|
|
|
'X': np.random.random((5, 6, 2, 3, 4, 2)).astype(self.data_type)
|
|
|
|
|
}
|
|
|
|
|
self.attrs = {'dim': [2, 3, 4]}
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim']))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
self.data_type = "float32" if core.is_compiled_with_rocm(
|
|
|
|
|
) else "float64"
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
@ -187,14 +197,20 @@ class TestProd6DOp(OpTest):
|
|
|
|
|
class TestProd8DOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "reduce_prod"
|
|
|
|
|
self.init_data_type()
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
|
|
|
|
|
'X': np.random.random(
|
|
|
|
|
(2, 5, 3, 2, 2, 3, 4, 2)).astype(self.data_type)
|
|
|
|
|
}
|
|
|
|
|
self.attrs = {'dim': [2, 3, 4]}
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'Out': self.inputs['X'].prod(axis=tuple(self.attrs['dim']))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def init_data_type(self):
|
|
|
|
|
self.data_type = "float32" if core.is_compiled_with_rocm(
|
|
|
|
|
) else "float64"
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|