|
|
|
@ -9,8 +9,8 @@ __all__ = ['infer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Inference(object):
|
|
|
|
|
def __init__(self, output, parameters):
|
|
|
|
|
topo = topology.Topology(output)
|
|
|
|
|
def __init__(self, output_layer, parameters):
|
|
|
|
|
topo = topology.Topology(output_layer)
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
|
|
|
|
|
for param in gm.getParameters():
|
|
|
|
@ -70,13 +70,7 @@ class Inference(object):
|
|
|
|
|
return retv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(output,
|
|
|
|
|
parameters,
|
|
|
|
|
input=None,
|
|
|
|
|
batch_size=None,
|
|
|
|
|
reader=None,
|
|
|
|
|
feeding=None,
|
|
|
|
|
field='value'):
|
|
|
|
|
def infer(output_layer, parameters, input=None, feeding=None, field='value'):
|
|
|
|
|
"""
|
|
|
|
|
Infer a neural network by given neural network output and parameters. The
|
|
|
|
|
user should pass either a batch of input data or reader method.
|
|
|
|
@ -89,19 +83,13 @@ def infer(output,
|
|
|
|
|
batch_size=32)
|
|
|
|
|
print result
|
|
|
|
|
|
|
|
|
|
:param output: output of the neural network that would be inferred
|
|
|
|
|
:type output: paddle.v2.config_base.Layer
|
|
|
|
|
:param output_layer: output of the neural network that would be inferred
|
|
|
|
|
:type output_layer: paddle.v2.config_base.Layer
|
|
|
|
|
:param parameters: parameters of the neural network.
|
|
|
|
|
:type parameters: paddle.v2.parameters.Parameters
|
|
|
|
|
:param input: input data batch. Should be a python iterable object, and each
|
|
|
|
|
element is the data batch.
|
|
|
|
|
:type input: collections.Iterable
|
|
|
|
|
:param batch_size: the batch size when perform inference. Default is the
|
|
|
|
|
length of input.
|
|
|
|
|
:type batch_size: int
|
|
|
|
|
:param reader: input data reader creator in batch. If this field is set, the
|
|
|
|
|
`input` and `batch_size` will be ignored.
|
|
|
|
|
:type reader: callable
|
|
|
|
|
:param feeding: Reader dictionary. Default could generate from input
|
|
|
|
|
value.
|
|
|
|
|
:param field: The prediction field. It should in [`value`, `ids`]. `value`
|
|
|
|
@ -112,10 +100,5 @@ def infer(output,
|
|
|
|
|
:rtype: numpy.ndarray
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
inferer = Inference(output=output, parameters=parameters)
|
|
|
|
|
return inferer.infer(
|
|
|
|
|
field=field,
|
|
|
|
|
input=input,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
reader=reader,
|
|
|
|
|
feeding=feeding)
|
|
|
|
|
inferer = Inference(output_layer=output_layer, parameters=parameters)
|
|
|
|
|
return inferer.infer(field=field, input=input, feeding=feeding)
|
|
|
|
|