|
|
|
@ -2,6 +2,7 @@ import numpy
|
|
|
|
|
import collections
|
|
|
|
|
import topology
|
|
|
|
|
import minibatch
|
|
|
|
|
import cPickle
|
|
|
|
|
|
|
|
|
|
__all__ = ['infer', 'Inference']
|
|
|
|
|
|
|
|
|
@ -25,11 +26,23 @@ class Inference(object):
|
|
|
|
|
:type parameters: paddle.v2.parameters.Parameters
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, output_layer, parameters):
|
|
|
|
|
def __init__(self, parameters, output_layer=None, fileobj=None):
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
topo = topology.Topology(output_layer)
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
|
|
|
|
|
|
|
|
|
|
if output_layer is not None:
|
|
|
|
|
topo = topology.Topology(output_layer)
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
|
|
|
|
|
self.__data_types__ = topo.data_type()
|
|
|
|
|
elif fileobj is not None:
|
|
|
|
|
tmp = cPickle.load(fileobj)
|
|
|
|
|
gm = api.GradientMachine.createByConfigProtoStr(
|
|
|
|
|
tmp['protobin'], api.CREATE_MODE_TESTING,
|
|
|
|
|
[api.PARAMETER_VALUE])
|
|
|
|
|
self.__data_types__ = tmp['data_type']
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Either output_layer or fileobj must be set")
|
|
|
|
|
|
|
|
|
|
for param in gm.getParameters():
|
|
|
|
|
val = param.getBuf(api.PARAMETER_VALUE)
|
|
|
|
|
name = param.getName()
|
|
|
|
@ -43,7 +56,6 @@ class Inference(object):
|
|
|
|
|
# called here, but it's better to call this function in one place.
|
|
|
|
|
param.setValueUpdated()
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__data_types__ = topo.data_type()
|
|
|
|
|
|
|
|
|
|
def iter_infer(self, input, feeding=None):
|
|
|
|
|
from data_feeder import DataFeeder
|
|
|
|
|