|
|
@ -40,7 +40,7 @@ class TestTemporalShift(OpTest):
|
|
|
|
def setUp(self):
|
|
|
|
def setUp(self):
|
|
|
|
self.initTestCase()
|
|
|
|
self.initTestCase()
|
|
|
|
self.op_type = 'temporal_shift'
|
|
|
|
self.op_type = 'temporal_shift'
|
|
|
|
x = np.random.random(self.x_shape).astype('float64')
|
|
|
|
x = np.random.random(self.x_shape).astype(self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
self.attrs = {
|
|
|
|
self.attrs = {
|
|
|
|
"seg_num": self.seg_num,
|
|
|
|
"seg_num": self.seg_num,
|
|
|
@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
|
|
|
|
self.x_shape = (6, 4, 4, 4)
|
|
|
|
self.x_shape = (6, 4, 4, 4)
|
|
|
|
self.seg_num = 3
|
|
|
|
self.seg_num = 3
|
|
|
|
self.shift_ratio = 0.25
|
|
|
|
self.shift_ratio = 0.25
|
|
|
|
|
|
|
|
self.dtype = 'float64'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTemporalShift2(TestTemporalShift):
|
|
|
|
class TestTemporalShift2(TestTemporalShift):
|
|
|
@ -78,6 +79,26 @@ class TestTemporalShift3(TestTemporalShift):
|
|
|
|
self.shift_ratio = 0.3
|
|
|
|
self.shift_ratio = 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
|
|
|
|
"core is not compiled with CUDA")
|
|
|
|
|
|
|
|
class TestTemporalShiftFP16(TestTemporalShift):
|
|
|
|
|
|
|
|
def initTestCase(self):
|
|
|
|
|
|
|
|
self.x_shape = (3, 10, 5, 5)
|
|
|
|
|
|
|
|
self.seg_num = 1
|
|
|
|
|
|
|
|
self.shift_ratio = 0.3
|
|
|
|
|
|
|
|
self.dtype = 'float16'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
|
|
|
if core.is_float16_supported(place):
|
|
|
|
|
|
|
|
self.check_output_with_place(place)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_grad_ignore_uv(self):
|
|
|
|
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
|
|
|
|
if core.is_float16_supported(place):
|
|
|
|
|
|
|
|
self.check_grad_with_place(place, ['X'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTemporalShiftAPI(unittest.TestCase):
|
|
|
|
class TestTemporalShiftAPI(unittest.TestCase):
|
|
|
|
def test_api(self):
|
|
|
|
def test_api(self):
|
|
|
|
input = paddle.randn([6, 4, 2, 2])
|
|
|
|
input = paddle.randn([6, 4, 2, 2])
|
|
|
|