|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
from op_test import OpTest, skip_check_grad_ci
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid import compiler, Program, program_guard, core
|
|
|
|
|
|
|
|
|
@ -65,6 +65,8 @@ class TestConcatOp2(TestConcatOp):
|
|
|
|
|
self.axis = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@skip_check_grad_ci(
|
|
|
|
|
reason="The function 'check_grad' for large inputs is too slow.")
|
|
|
|
|
class TestConcatOp3(TestConcatOp):
|
|
|
|
|
def init_test_data(self):
|
|
|
|
|
self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype)
|
|
|
|
@ -76,6 +78,9 @@ class TestConcatOp3(TestConcatOp):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@skip_check_grad_ci(
|
|
|
|
|
reason="This test will meet fetch error when there is a null grad. The detailed information is in PR#17015."
|
|
|
|
|
)
|
|
|
|
|
class TestConcatOp4(TestConcatOp):
|
|
|
|
|
def init_test_data(self):
|
|
|
|
|
self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
|
|
|
|
|