Polish code

revert-12646-feature/jit/xbyak
minqiyang 7 years ago
parent bc12c2c616
commit d4b10eef5f

@ -533,6 +533,10 @@ class Operator(object):
in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
self.desc.set_input(in_proto.name, in_arg_names)
else:
self.desc.set_input(in_proto.name, [])
@ -566,7 +570,9 @@ class Operator(object):
elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode())
else:
out_arg_names.append(six.u(arg.name))
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)

@ -401,6 +401,8 @@ class LayerHelper(object):
return input_var
if isinstance(act, six.string_types):
act = {'type': act}
else:
raise TypeError(str(act) + " should be unicode or str")
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn')

@ -70,6 +70,10 @@ def switch(new_generator=None):
def guard(new_generator=None):
if isinstance(new_generator, six.string_types):
new_generator = UniqueNameGenerator(new_generator)
elif isinstance(new_generator, six.binary_type):
new_generator = UniqueNameGenerator(new_generator.decode())
else:
raise TypeError(str(new_generator) + " should be unicode or str")
old = switch(new_generator)
yield
switch(old)

@ -73,6 +73,8 @@ def recordio(paths, buf_size=100):
def reader():
if isinstance(paths, six.string_types):
path = paths
elif isinstance(paths, six.binary_type):
path = paths.decode()
else:
path = ",".join(paths)
f = rec.reader(path)

Loading…
Cancel
Save