|
|
|
@ -21,30 +21,16 @@ class Inference(object):
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__data_types__ = topo.data_type()
|
|
|
|
|
|
|
|
|
|
def iter_infer(self, input=None, batch_size=None, reader=None,
|
|
|
|
|
feeding=None):
|
|
|
|
|
def iter_infer(self, input, feeding=None):
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, feeding)
|
|
|
|
|
if reader is None:
|
|
|
|
|
assert input is not None and isinstance(input, collections.Iterable)
|
|
|
|
|
if not isinstance(input, collections.Iterable):
|
|
|
|
|
raise TypeError("When reader is None, input should be whole "
|
|
|
|
|
"inference data and should be iterable")
|
|
|
|
|
|
|
|
|
|
if batch_size is None:
|
|
|
|
|
if not hasattr(input, '__len__'):
|
|
|
|
|
raise ValueError("Should set batch size when input data "
|
|
|
|
|
"don't contain length.")
|
|
|
|
|
batch_size = len(input)
|
|
|
|
|
|
|
|
|
|
def __reader_impl__():
|
|
|
|
|
for each_sample in input:
|
|
|
|
|
yield each_sample
|
|
|
|
|
|
|
|
|
|
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
|
|
|
|
|
else:
|
|
|
|
|
if input is not None:
|
|
|
|
|
raise ValueError("User should set either input or reader, "
|
|
|
|
|
"should not set them both.")
|
|
|
|
|
batch_size = len(input)
|
|
|
|
|
|
|
|
|
|
def __reader_impl__():
|
|
|
|
|
for each_sample in input:
|
|
|
|
|
yield each_sample
|
|
|
|
|
|
|
|
|
|
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
|
|
|
|
|