|
|
@ -5,10 +5,10 @@ from data_feeder import DataFeeder
|
|
|
|
import itertools
|
|
|
|
import itertools
|
|
|
|
import numpy
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['InferenceEngine', 'infer']
|
|
|
|
__all__ = ['Inference', 'infer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InferenceEngine(object):
|
|
|
|
class Inference(object):
|
|
|
|
def __init__(self, output, parameters):
|
|
|
|
def __init__(self, output, parameters):
|
|
|
|
topo = topology.Topology(output)
|
|
|
|
topo = topology.Topology(output)
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
@ -55,5 +55,5 @@ class InferenceEngine(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(output, parameters, reader, reader_dict=None, field='value'):
|
|
|
|
def infer(output, parameters, reader, reader_dict=None, field='value'):
|
|
|
|
inferer = InferenceEngine(output=output, parameters=parameters)
|
|
|
|
inferer = Inference(output=output, parameters=parameters)
|
|
|
|
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
|
|
|
|
return inferer.infer(field=field, reader=reader, reader_dict=reader_dict)
|
|
|
|