|
|
|
@ -202,6 +202,7 @@ class CheckWrapper(object):
|
|
|
|
|
for each in item:
|
|
|
|
|
callback(each)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckInputTypeWrapper(object):
|
|
|
|
|
def __init__(self, generator, input_types, logger):
|
|
|
|
|
self.generator = generator
|
|
|
|
@ -209,17 +210,18 @@ class CheckInputTypeWrapper(object):
|
|
|
|
|
self.logger = logger
|
|
|
|
|
|
|
|
|
|
def __call__(self, obj, filename):
|
|
|
|
|
for items in self.generator(obj, filename):
|
|
|
|
|
try:
|
|
|
|
|
# dict type is required for input_types when item is dict type
|
|
|
|
|
assert (isinstance(items, dict) and \
|
|
|
|
|
not isinstance(self.input_types, dict))==False
|
|
|
|
|
yield items
|
|
|
|
|
except AssertionError as e:
|
|
|
|
|
self.logger.error(
|
|
|
|
|
for items in self.generator(obj, filename):
|
|
|
|
|
try:
|
|
|
|
|
# dict type is required for input_types when item is dict type
|
|
|
|
|
assert (isinstance(items, dict) and \
|
|
|
|
|
not isinstance(self.input_types, dict))==False
|
|
|
|
|
yield items
|
|
|
|
|
except AssertionError as e:
|
|
|
|
|
self.logger.error(
|
|
|
|
|
"%s type is required for input type but got %s" %
|
|
|
|
|
(repr(type(items)), repr(type(self.input_types))))
|
|
|
|
|
raise
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def provider(input_types=None,
|
|
|
|
|
should_shuffle=None,
|
|
|
|
@ -374,8 +376,8 @@ def provider(input_types=None,
|
|
|
|
|
self.generator = InputOrderWrapper(self.generator,
|
|
|
|
|
self.input_order)
|
|
|
|
|
else:
|
|
|
|
|
self.generator = CheckInputTypeWrapper(self.generator, self.slots,
|
|
|
|
|
self.logger)
|
|
|
|
|
self.generator = CheckInputTypeWrapper(
|
|
|
|
|
self.generator, self.slots, self.logger)
|
|
|
|
|
if self.check:
|
|
|
|
|
self.generator = CheckWrapper(self.generator, self.slots,
|
|
|
|
|
check_fail_continue,
|
|
|
|
|