36 lines
942 B
36 lines
942 B
import unittest
|
|
import numpy as np
|
|
from op_test import OpTest
|
|
|
|
|
|
class TestReshapeOp(OpTest):
|
|
def setUp(self):
|
|
self.op_type = "reshape"
|
|
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
|
|
self.attrs = {'shape': [10 * 20]}
|
|
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
|
|
|
|
def test_check_output(self):
|
|
self.check_output()
|
|
|
|
def test_check_grad(self):
|
|
self.check_grad(["X"], "Out")
|
|
|
|
|
|
class TestReshapeOpDimInfer(OpTest):
|
|
def setUp(self):
|
|
self.op_type = "reshape"
|
|
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
|
|
self.attrs = {'shape': [4, -1, 5]}
|
|
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
|
|
|
|
def test_check_output(self):
|
|
self.check_output()
|
|
|
|
def test_check_grad(self):
|
|
self.check_grad(["X"], "Out")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|