|
|
@ -23,6 +23,8 @@ from ..core import VarDesc
|
|
|
|
from .layer_function_generator import templatedoc
|
|
|
|
from .layer_function_generator import templatedoc
|
|
|
|
from ..data_feeder import convert_dtype
|
|
|
|
from ..data_feeder import convert_dtype
|
|
|
|
import numpy
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
from ..data_feeder import convert_dtype
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
__all__ = [
|
|
|
|
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
|
|
|
|
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
|
|
|
@ -247,6 +249,21 @@ def concat(input, axis=0, name=None):
|
|
|
|
# [14 15 16]]
|
|
|
|
# [14 15 16]]
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
helper = LayerHelper('concat', **locals())
|
|
|
|
helper = LayerHelper('concat', **locals())
|
|
|
|
|
|
|
|
for x in input:
|
|
|
|
|
|
|
|
if not isinstance(x, Variable):
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"The type of x in 'input' in concat must be Variable, but received %s"
|
|
|
|
|
|
|
|
% (type(x)))
|
|
|
|
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
|
|
|
"The data type of x in 'input' in concat only support float16 on GPU now."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if convert_dtype(x.dtype) not in [
|
|
|
|
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
|
|
|
|
]:
|
|
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"The data type of x in 'input' in concat must be float16(only support on GPU), float32, float64, int32, int64, but received %s."
|
|
|
|
|
|
|
|
% (convert_dtype(x.dtype)))
|
|
|
|
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
|
|
|
|
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
|
|
|
|
helper.append_op(
|
|
|
|
helper.append_op(
|
|
|
|
type='concat',
|
|
|
|
type='concat',
|
|
|
|