|
|
|
@ -109,5 +109,29 @@ class TestExpandOpRank4(OpTest):
|
|
|
|
|
self.check_grad(['X'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestExpandOpInteger(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "expand"
|
|
|
|
|
self.inputs = {'X': np.random.random((2, 4, 5)).astype("int32")}
|
|
|
|
|
self.attrs = {'expand_times': [2, 1, 4]}
|
|
|
|
|
output = np.tile(self.inputs['X'], (2, 1, 4))
|
|
|
|
|
self.outputs = {'Out': output}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestExpandOpBoolean(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "expand"
|
|
|
|
|
self.inputs = {'X': np.random.random((2, 4, 5)).astype("bool")}
|
|
|
|
|
self.attrs = {'expand_times': [2, 1, 4]}
|
|
|
|
|
output = np.tile(self.inputs['X'], (2, 1, 4))
|
|
|
|
|
self.outputs = {'Out': output}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|