diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index bf0a2a2836..a0aed3ee1f 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -83,10 +83,14 @@ def select_input(inputs, mask): if isinstance(inputs, list) or isinstance(inputs, tuple): input_dtype = inputs[0].dtype input_shape = inputs[0].shape + input_type = inputs[0].type else: input_dtype = inputs.dtype input_shape = inputs.shape - out = helper.create_variable(dtype=input_dtype, shape=input_shape) + input_type = inputs.type + + out = helper.create_variable( + dtype=input_dtype, shape=input_shape, type=input_type) helper.append_op( type='select_input', inputs={'X': inputs,