|
|
|
@ -70,7 +70,7 @@ class Inference(object):
|
|
|
|
|
item = [each_result[each_field] for each_field in field]
|
|
|
|
|
yield item
|
|
|
|
|
|
|
|
|
|
def infer(self, input, field='value', **kwargs):
|
|
|
|
|
def infer(self, input, field='value', flatten_result=True, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Infer a data by model.
|
|
|
|
|
:param input: input data batch. Should be python iterable object.
|
|
|
|
@ -83,7 +83,10 @@ class Inference(object):
|
|
|
|
|
retv = [[] for i in xrange(len(result))]
|
|
|
|
|
for i, item in enumerate(result):
|
|
|
|
|
retv[i].append(item)
|
|
|
|
|
retv = [numpy.concatenate(out) for out in retv]
|
|
|
|
|
|
|
|
|
|
if flatten_result:
|
|
|
|
|
retv = [numpy.concatenate(out) for out in retv]
|
|
|
|
|
|
|
|
|
|
if len(retv) == 1:
|
|
|
|
|
return retv[0]
|
|
|
|
|
else:
|
|
|
|
|