|
|
|
@ -26,31 +26,25 @@ from .framework import Variable, default_main_program, _current_expected_place,
|
|
|
|
|
from .framework import _cpu_num, _cuda_ids
|
|
|
|
|
__all__ = ['DataFeeder']
|
|
|
|
|
|
|
|
|
|
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
|
|
|
|
|
core.VarDesc.VarType.BOOL: 'bool',
|
|
|
|
|
core.VarDesc.VarType.FP16: 'float16',
|
|
|
|
|
core.VarDesc.VarType.FP32: 'float32',
|
|
|
|
|
core.VarDesc.VarType.FP64: 'float64',
|
|
|
|
|
core.VarDesc.VarType.INT8: 'int8',
|
|
|
|
|
core.VarDesc.VarType.INT16: 'int16',
|
|
|
|
|
core.VarDesc.VarType.INT32: 'int32',
|
|
|
|
|
core.VarDesc.VarType.INT64: 'int64',
|
|
|
|
|
core.VarDesc.VarType.UINT8: 'uint8',
|
|
|
|
|
core.VarDesc.VarType.COMPLEX64: 'complex64',
|
|
|
|
|
core.VarDesc.VarType.COMPLEX128: 'complex128',
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_dtype(dtype):
|
|
|
|
|
if isinstance(dtype, core.VarDesc.VarType):
|
|
|
|
|
if dtype == core.VarDesc.VarType.BOOL:
|
|
|
|
|
return 'bool'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP16:
|
|
|
|
|
return 'float16'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
return 'float32'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP64:
|
|
|
|
|
return 'float64'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT8:
|
|
|
|
|
return 'int8'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT16:
|
|
|
|
|
return 'int16'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT32:
|
|
|
|
|
return 'int32'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.INT64:
|
|
|
|
|
return 'int64'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.UINT8:
|
|
|
|
|
return 'uint8'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.COMPLEX64:
|
|
|
|
|
return 'complex64'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.COMPLEX128:
|
|
|
|
|
return 'complex128'
|
|
|
|
|
if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
|
|
|
|
|
return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
|
|
|
|
|
elif isinstance(dtype, type):
|
|
|
|
|
if dtype in [
|
|
|
|
|
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
|
|
|
|
|