|
|
|
@ -27,13 +27,7 @@ __all__ = ['DataFeeder']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_dtype(dtype):
|
|
|
|
|
if isinstance(dtype, str):
|
|
|
|
|
if dtype in [
|
|
|
|
|
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
|
|
|
|
|
'int32', 'int64', 'uint8'
|
|
|
|
|
]:
|
|
|
|
|
return dtype
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(dtype, core.VarDesc.VarType):
|
|
|
|
|
if dtype == core.VarDesc.VarType.BOOL:
|
|
|
|
|
return 'bool'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.FP16:
|
|
|
|
@ -52,6 +46,19 @@ def convert_dtype(dtype):
|
|
|
|
|
return 'int64'
|
|
|
|
|
elif dtype == core.VarDesc.VarType.UINT8:
|
|
|
|
|
return 'uint8'
|
|
|
|
|
else:
|
|
|
|
|
if dtype in [
|
|
|
|
|
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
|
|
|
|
|
'int32', 'int64', 'uint8', u'bool', u'float16', u'float32',
|
|
|
|
|
u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8'
|
|
|
|
|
]:
|
|
|
|
|
# this code is a little bit dangerous, since error could happen
|
|
|
|
|
# when casting no-asci code to str in python2.
|
|
|
|
|
# but since the set itself is limited, so currently, it is good.
|
|
|
|
|
# however, jointly supporting python2 and python3, (as well as python4 maybe)
|
|
|
|
|
# may still be a long-lasting problem.
|
|
|
|
|
return str(dtype)
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"dtype must be any of [bool, float16, float32, float64, int8, int16, "
|
|
|
|
|
"int32, int64, uint8]")
|
|
|
|
|