|
|
|
@ -16,18 +16,18 @@ class InferenceEngine(object):
|
|
|
|
|
for param in gm.getParameters():
|
|
|
|
|
val = param.getBuf(api.PARAMETER_VALUE)
|
|
|
|
|
name = param.getName()
|
|
|
|
|
assert isinstance(val, api.Matrix)
|
|
|
|
|
val.copyFromNumpyMat(parameters.get(name))
|
|
|
|
|
assert isinstance(val, api.Vector)
|
|
|
|
|
val.copyFromNumpyArray(parameters.get(name).flatten())
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__data_types__ = topo.data_type()
|
|
|
|
|
|
|
|
|
|
def iter_infer(self, reader, reader_dict=None):
|
|
|
|
|
if reader_dict is None:
|
|
|
|
|
reader_dict = self.default_reader_dict()
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, reader_dict)
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
yield self.__gradient_machine__.forwardTest(
|
|
|
|
|
feeder(data_batch), out_args, api.PASS_TEST)
|
|
|
|
|
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
|
|
|
|
|
|
def iter_infer_field(self, field, **kwargs):
|
|
|
|
@ -35,11 +35,16 @@ class InferenceEngine(object):
|
|
|
|
|
yield [each_result[field] for each_result in result]
|
|
|
|
|
|
|
|
|
|
def infer(self, field='value', **kwargs):
|
|
|
|
|
retv = []
|
|
|
|
|
for result in itertools.izip(
|
|
|
|
|
self.iter_infer_field(
|
|
|
|
|
field=field, **kwargs)):
|
|
|
|
|
retv.append(numpy.concatenate(result))
|
|
|
|
|
retv = None
|
|
|
|
|
for result in self.iter_infer_field(field=field, **kwargs):
|
|
|
|
|
if retv is None:
|
|
|
|
|
retv = [[]] * len(result)
|
|
|
|
|
for i, item in enumerate(result):
|
|
|
|
|
retv[i].append(item)
|
|
|
|
|
retv = [numpy.concatenate(out) for out in retv]
|
|
|
|
|
if len(retv) == 1:
|
|
|
|
|
return retv[0]
|
|
|
|
|
else:
|
|
|
|
|
return retv
|
|
|
|
|
|
|
|
|
|
def default_reader_dict(self):
|
|
|
|
|