enhance input type chec for concat (#20584)

test=develop
revert-20712-fix_depthwise_conv
zhupengyang 5 years ago committed by Tao Luo
parent 443f604c3b
commit 5e65c753ea

@ -249,10 +249,15 @@ def concat(input, axis=0, name=None):
# [14 15 16]]
"""
helper = LayerHelper('concat', **locals())
if not isinstance(input, list):
warnings.warn(
"The type of input in concat should be list, but received %s." %
(type(input)))
input = [input]
for x in input:
if not isinstance(x, Variable):
raise TypeError(
"The type of x in 'input' in concat must be Variable, but received %s"
"The type of x in 'input' in concat must be Variable, but received %s."
% (type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(

@ -118,13 +118,22 @@ create_test_fp16(TestConcatOp5)
class TestConcatOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of concat_op must be Variable.
x1 = fluid.create_lod_tensor(
# The input type of concat_op should be list.
x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1')
fluid.layers.concat(x1)
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, x1)
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16(only support on GPU), float32, float64, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.concat, x2)
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
fluid.layers.concat([x6, x7])
if __name__ == '__main__':

Loading…
Cancel
Save