Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

0.10.0rc
Yancey1989 8 years ago
commit fc6f2032e2

1
.gitmodules vendored

@ -1,3 +1,4 @@
[submodule "book"]
path = book
url = https://github.com/PaddlePaddle/book.git
branch = develop

@ -1 +1 @@
Subproject commit 22ed2a01aee872f055b5f5f212428f481cefc10d
Subproject commit 6e3875eb62533de1f2c1088a477719eb57b9732c

@ -122,13 +122,14 @@ def main():
test_creator = paddle.dataset.mnist.test()
test_data = []
for item in test_creator():
test_data.append(item[0])
test_data.append((item[0], ))
if len(test_data) == 100:
break
# output is a softmax layer. It returns probabilities.
# Shape should be (100, 10)
probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
probs = paddle.infer(
output_layer=predict, parameters=parameters, input=test_data)
print probs.shape

@ -13,8 +13,6 @@
# limitations under the License.
import sys
import paddle.trainer_config_helpers.attrs as attrs
from paddle.trainer_config_helpers.poolings import MaxPooling
import paddle.v2 as paddle
@ -51,16 +49,14 @@ def stacked_lstm_net(input_dim,
emb_dim: dimension of word embedding.
hid_dim: dimension of hidden layer.
stacked_num: number of stacked lstm-hidden layer.
is_predict: is predicting or not.
Some layers is not needed in network when predicting.
"""
assert stacked_num % 2 == 1
layer_attr = attrs.ExtraLayerAttribute(drop_rate=0.5)
fc_para_attr = attrs.ParameterAttribute(learning_rate=1e-3)
lstm_para_attr = attrs.ParameterAttribute(initial_std=0., learning_rate=1.)
layer_attr = paddle.attr.Extra(drop_rate=0.5)
fc_para_attr = paddle.attr.Param(learning_rate=1e-3)
lstm_para_attr = paddle.attr.Param(initial_std=0., learning_rate=1.)
para_attr = [fc_para_attr, lstm_para_attr]
bias_attr = attrs.ParameterAttribute(initial_std=0., l2_rate=0.)
bias_attr = paddle.attr.Param(initial_std=0., l2_rate=0.)
relu = paddle.activation.Relu()
linear = paddle.activation.Linear()
@ -90,8 +86,10 @@ def stacked_lstm_net(input_dim,
layer_attr=layer_attr)
inputs = [fc, lstm]
fc_last = paddle.layer.pooling(input=inputs[0], pooling_type=MaxPooling())
lstm_last = paddle.layer.pooling(input=inputs[1], pooling_type=MaxPooling())
fc_last = paddle.layer.pooling(
input=inputs[0], pooling_type=paddle.pooling.Max())
lstm_last = paddle.layer.pooling(
input=inputs[1], pooling_type=paddle.pooling.Max())
output = paddle.layer.fc(input=[fc_last, lstm_last],
size=class_dim,
act=paddle.activation.Softmax(),
@ -105,14 +103,23 @@ def stacked_lstm_net(input_dim,
if __name__ == '__main__':
# init
paddle.init(use_gpu=False, trainer_count=4)
paddle.init(use_gpu=False)
# network config
#data
print 'load dictionary...'
word_dict = paddle.dataset.imdb.word_dict()
dict_dim = len(word_dict)
class_dim = 2
train_reader = paddle.batch(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100)
test_reader = paddle.batch(
lambda: paddle.dataset.imdb.test(word_dict), batch_size=100)
feeding = {'word': 0, 'label': 1}
# network config
# Please choose the way to build the network
# by uncommenting the corresponding line.
cost = convolution_net(dict_dim, class_dim=class_dim)
@ -137,12 +144,7 @@ if __name__ == '__main__':
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.batch(
lambda: paddle.dataset.imdb.test(word_dict),
batch_size=128),
feeding={'word': 0,
'label': 1})
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# create trainer
@ -151,11 +153,7 @@ if __name__ == '__main__':
update_equation=adam_optimizer)
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=100),
reader=train_reader,
event_handler=event_handler,
feeding={'word': 0,
'label': 1},
num_passes=10)
feeding=feeding,
num_passes=2)

@ -1,3 +1,4 @@
import sys
import paddle.v2 as paddle
@ -104,7 +105,9 @@ def main():
parameters = paddle.parameters.create(cost)
# define optimize method and trainer
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
optimizer = paddle.optimizer.Adam(
learning_rate=5e-5,
regularization=paddle.optimizer.L2Regularization(rate=1e-3))
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
@ -125,8 +128,11 @@ def main():
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
# start to train
trainer.train(

@ -1,26 +1,9 @@
API
===
模型配置 API
------------
.. toctree::
:maxdepth: 1
v2/model_configs.rst
数据 API
--------
.. toctree::
:maxdepth: 1
v2/data.rst
训练 API
--------
.. toctree::
:maxdepth: 1
v2/run_logic.rst
模型配置 <v2/model_configs.rst>
数据访问 <v2/data.rst>
训练与应用 <v2/run_logic.rst>

@ -1,26 +1,9 @@
API
===
Model Config API
----------------
.. toctree::
:maxdepth: 1
v2/model_configs.rst
Data API
--------
.. toctree::
:maxdepth: 1
v2/data.rst
Train API
---------
.. toctree::
:maxdepth: 1
v2/run_logic.rst
v2/run_logic.rst

@ -0,0 +1,101 @@
===========
Activation
===========
Abs
===
.. automodule:: paddle.v2.activation
:members: Abs
:noindex:
Exp
===
.. automodule:: paddle.v2.activation
:members: Exp
:noindex:
Identity
========
.. automodule:: paddle.v2.activation
:members: Identity
:noindex:
Linear
======
.. automodule:: paddle.v2.activation
:members: Linear
:noindex:
Log
===
.. automodule:: paddle.v2.activation
:members: Log
:noindex:
Square
======
.. automodule:: paddle.v2.activation
:members: Square
:noindex:
Sigmoid
=======
.. automodule:: paddle.v2.activation
:members: Sigmoid
:noindex:
Softmax
=======
.. automodule:: paddle.v2.activation
:members: Softmax
:noindex:
SequenceSoftmax
===============
.. automodule:: paddle.v2.activation
:members: SequenceSoftmax
:noindex:
Relu
====
.. automodule:: paddle.v2.activation
:members: Relu
:noindex:
BRelu
=====
.. automodule:: paddle.v2.activation
:members: BRelu
:noindex:
SoftRelu
========
.. automodule:: paddle.v2.activation
:members: SoftRelu
:noindex:
Tanh
====
.. automodule:: paddle.v2.activation
:members: Tanh
:noindex:
STanh
=====
.. automodule:: paddle.v2.activation
:members: STanh
:noindex:

@ -0,0 +1,6 @@
Parameter Attribute
===================
.. automodule:: paddle.v2.attr
:members:
:noindex:

File diff suppressed because it is too large Load Diff

@ -0,0 +1,117 @@
========
Networks
========
The v2.networks module contains pieces of neural network that combine multiple layers.
NLP
===
sequence_conv_pool
------------------
.. automodule:: paddle.v2.networks
:members: sequence_conv_pool
:noindex:
.. _api_trainer_config_helpers_network_text_conv_pool:
text_conv_pool
--------------
.. automodule:: paddle.v2.networks
:members: text_conv_pool
:noindex:
Images
======
img_conv_bn_pool
----------------
.. automodule:: paddle.v2.networks
:members: img_conv_bn_pool
:noindex:
img_conv_group
--------------
.. automodule:: paddle.v2.networks
:members: img_conv_group
:noindex:
.. _api_trainer_config_helpers_network_simple_img_conv_pool:
simple_img_conv_pool
--------------------
.. automodule:: paddle.v2.networks
:members: simple_img_conv_pool
:noindex:
vgg_16_network
---------------
.. automodule:: paddle.v2.networks
:members: vgg_16_network
:noindex:
Recurrent
=========
LSTM
----
lstmemory_unit
``````````````
.. automodule:: paddle.v2.networks
:members: lstmemory_unit
:noindex:
lstmemory_group
```````````````
.. automodule:: paddle.v2.networks
:members: lstmemory_group
:noindex:
simple_lstm
```````````
.. automodule:: paddle.v2.networks
:members: simple_lstm
:noindex:
bidirectional_lstm
``````````````````
.. automodule:: paddle.v2.networks
:members: bidirectional_lstm
:noindex:
GRU
---
gru_unit
````````
.. automodule:: paddle.v2.networks
:members: gru_unit
:noindex:
gru_group
`````````
.. automodule:: paddle.v2.networks
:members: gru_group
:noindex:
simple_gru
``````````
.. automodule:: paddle.v2.networks
:members: simple_gru
:noindex:
simple_attention
----------------
.. automodule:: paddle.v2.networks
:members: simple_attention
:noindex:
Miscs
=====
dropout_layer
--------------
.. automodule:: paddle.v2.networks
:members: dropout_layer
:noindex:

@ -0,0 +1,47 @@
.. _api_v2.optimizer:
==========
Optimizer
==========
Momentum
========
.. automodule:: paddle.v2.optimizer
:members: Momentum
:noindex:
Adam
====
.. automodule:: paddle.v2.optimizer
:members: Adam
:noindex:
Adamax
======
.. automodule:: paddle.v2.optimizer
:members: Adamax
:noindex:
AdaGrad
=======
.. automodule:: paddle.v2.optimizer
:members: AdaGrad
:noindex:
DecayedAdaGrad
==============
.. automodule:: paddle.v2.optimizer
:members: DecayedAdaGrad
:noindex:
AdaDelta
========
.. automodule:: paddle.v2.optimizer
:members: AdaDelta
:noindex:
RMSProp
=======
.. automodule:: paddle.v2.optimizer
:members: RMSProp
:noindex:

@ -0,0 +1,46 @@
=======
Pooling
=======
BasePool
========
.. automodule:: paddle.v2.pooling
:members: BasePool
:noindex:
Avg
===
.. automodule:: paddle.v2.pooling
:members: Avg
:noindex:
Max
===
.. automodule:: paddle.v2.pooling
:members: Max
:noindex:
Sum
===
.. automodule:: paddle.v2.pooling
:members: Sum
:noindex:
SquareRootN
===========
.. automodule:: paddle.v2.pooling
:members: SquareRootN
:noindex:
CudnnAvg
========
.. automodule:: paddle.v2.pooling
:members: CudnnAvg
:noindex:
CudnnMax
========
.. automodule:: paddle.v2.pooling
:members: CudnnMax
:noindex:

@ -1,52 +1,53 @@
================
Data Related API
================
========
Datasets
========
#########
DataTypes
#########
=========
.. automodule:: paddle.v2.data_type
:members:
:noindex:
##########
DataFeeder
##########
==========
.. automodule:: paddle.v2.data_feeder
:members:
:noindex:
######
Reader
######
======
.. automodule:: paddle.v2.reader
:members:
:noindex:
.. automodule:: paddle.v2.reader.creator
:members:
:noindex:
#########
minibatch
#########
=========
.. automodule:: paddle.v2.minibatch
:members:
:noindex:
#######
Dataset
#######
=======
.. automodule:: paddle.v2.dataset
:members:
:noindex:
mnist
+++++
.. automodule:: paddle.v2.dataset.mnist
:members:
:noindex:
cifar
@ -54,40 +55,54 @@ cifar
.. automodule:: paddle.v2.dataset.cifar
:members:
:noindex:
conll05
+++++++
.. automodule:: paddle.v2.dataset.conll05
:members:
:noindex:
imdb
++++
.. automodule:: paddle.v2.dataset.imdb
:members:
:noindex:
imikolov
++++++++
.. automodule:: paddle.v2.dataset.imikolov
:members:
:noindex:
movielens
+++++++++
.. automodule:: paddle.v2.dataset.movielens
:members:
:noindex:
sentiment
+++++++++
.. automodule:: paddle.v2.dataset.sentiment
:members:
:noindex:
uci_housing
+++++++++++
.. automodule:: paddle.v2.dataset.uci_housing
:members:
:noindex:
wmt14
+++++
.. automodule:: paddle.v2.dataset.uci_housing
:members:
:noindex:

@ -1,46 +1,12 @@
#########################
Configuration Related API
#########################
======
Layers
======
.. automodule:: paddle.v2.layer
:members:
==========
Attributes
==========
.. automodule:: paddle.v2.attr
:members:
===========
Activations
===========
.. automodule:: paddle.v2.activation
:members:
========
Poolings
========
.. automodule:: paddle.v2.pooling
:members:
========
Networks
========
.. automodule:: paddle.v2.networks
:members:
==========
Optimizers
==========
.. automodule:: paddle.v2.optimizer
:members:
Model Configuration
===================
.. toctree::
:maxdepth: 1
config/activation.rst
config/layer.rst
config/optimizer.rst
config/pooling.rst
config/networks.rst
config/attr.rst

@ -1,34 +1,27 @@
###########
Trainer API
###########
======================
Training and Inference
======================
==========
Parameters
==========
.. automodule:: paddle.v2.parameters
:members:
:noindex:
=======
Trainer
=======
.. automodule:: paddle.v2.trainer
:members:
.. automodule:: paddle.v2.trainer
:noindex:
=====
Event
=====
.. automodule:: paddle.v2.event
:members:
.. automodule:: paddle.v2.event
:noindex:
=========
Inference
=========
.. autofunction:: paddle.v2.infer
.. autofunction:: paddle.v2.infer
:noindex:

File diff suppressed because it is too large Load Diff

@ -42,7 +42,7 @@ Windows -- in a consistent way.
.. code-block:: bash
docker run -d -p 2202:22 -v $PWD:/paddle paddle:dev
docker run -d -p 2202:22 -p 8888:8888 -v $PWD:/paddle paddle:dev
This runs a container of the development environment Docker image
with the local source tree mounted to :code:`/paddle` of the
@ -82,6 +82,29 @@ Windows -- in a consistent way.
cd /paddle/build
ctest
4. Run PaddlePaddle Book under Docker Container
The Jupyter Notebook is an open-source web application that allows
you to create and share documents that contain live code, equations,
visualizations and explanatory text in a single browser.
PaddlePaddle Book is an interactive Jupyter Notebook for users and developers.
We already exposed port 8888 for this book. If you want to
dig deeper into deep learning, PaddlePaddle Book definitely is your best choice.
Once you are inside the container, simply issue the command:
.. code-block:: bash
jupyter notebook
Then, you would back and paste the address into the local browser:
.. code-block:: text
http://localhost:8888/
That's all. Enjoy your journey!
CPU-only and GPU Images
-----------------------
@ -93,21 +116,21 @@ automatically runs the following commands:
.. code-block:: bash
docker build -t paddle:cpu -f paddle/scripts/docker/Dockerfile .
docker build -t paddle:gpu -f paddle/scripts/docker/Dockerfile.gpu .
docker build -t paddle:cpu -f paddle/scripts/docker/Dockerfile --build-arg BUILD_AND_INSTALL=ON .
docker build -t paddle:gpu -f paddle/scripts/docker/Dockerfile.gpu --build-arg BUILD_AND_INSTALL=ON .
To run the CPU-only image as an interactive container:
.. code-block:: bash
docker run -it --rm paddledev/paddle:cpu-latest /bin/bash
docker run -it --rm paddledev/paddle:0.10.0rc1-cpu /bin/bash
or, we can run it as a daemon container
.. code-block:: bash
docker run -d -p 2202:22 paddledev/paddle:cpu-latest
docker run -d -p 2202:22 paddledev/paddle:0.10.0rc1-cpu
and SSH to this container using password :code:`root`:
@ -129,7 +152,7 @@ to install CUDA driver and let Docker knows about it:
export CUDA_SO="$(\ls /usr/lib64/libcuda* | xargs -I{} echo '-v {}:{}') $(\ls /usr/lib64/libnvidia* | xargs -I{} echo '-v {}:{}')"
export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:gpu-latest
docker run ${CUDA_SO} ${DEVICES} -it paddledev/paddle:0.10.0rc1-gpu
Non-AVX Images
@ -171,7 +194,7 @@ container:
.. code-block:: bash
docker run -d --name paddle-cpu-doc paddle:cpu
docker run -d --name paddle-cpu-doc paddle:0.10.0rc1-cpu
docker run -d --volumes-from paddle-cpu-doc -p 8088:80 nginx

@ -1,3 +1,6 @@
* {
font-family:"Roboto","Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
}
body {
padding-top: 80px;
background-image: none !important;

@ -18,6 +18,7 @@ ENV WITH_GPU=OFF
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
ENV DOCKER_BUILD=TRUE
ENV HOME /root
@ -50,7 +51,9 @@ RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd .. && rm -rf cmake-3.4.1
COPY . /paddle/
RUN cd /paddle/ && git submodule update --init --recursive
RUN /paddle/paddle/scripts/docker/build.sh
VOLUME ["/usr/share/nginx/html/data", "/usr/share/nginx/html/paddle"]
# Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service
@ -60,9 +63,7 @@ RUN sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config
RUN sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config
EXPOSE 22
# Jupyter Notebook directory.
RUN mkdir /notes/
WORKDIR "/notes"
# Jupyter Notebook: Paddle book
EXPOSE 8888
COPY ./paddle/scripts/docker/entrypoint /opt/bin/

@ -18,6 +18,7 @@ ENV WITH_GPU=ON
ENV WITH_AVX=${WITH_AVX:-ON}
ENV WITH_DOC=${WITH_DOC:-OFF}
ENV WITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
ENV DOCKER_BUILD=TRUE
ENV HOME /root
@ -50,7 +51,9 @@ RUN curl -sSL https://cmake.org/files/v3.4/cmake-3.4.1.tar.gz | tar -xz && \
cd .. && rm -rf cmake-3.4.1
COPY . /paddle/
RUN cd /paddle/ && git submodule update --init --recursive
RUN /paddle/paddle/scripts/docker/build.sh
VOLUME ["/usr/share/nginx/html/data", "/usr/share/nginx/html/paddle"]
# Configure OpenSSH server. c.f. https://docs.docker.com/engine/examples/running_ssh_service
@ -60,9 +63,7 @@ RUN sed -ri 's/^PermitRootLogin\s+.*/PermitRootLogin yes/' /etc/ssh/sshd_config
RUN sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config
EXPOSE 22
# Jupyter Notebook directory.
RUN mkdir /notes/
WORKDIR "/notes"
# Jupyter Notebook: Paddle book
EXPOSE 8888
COPY ./paddle/scripts/docker/entrypoint /opt/bin/

@ -17,7 +17,8 @@ if [[ ${BUILD_AND_INSTALL:-OFF} == 'ON' ]]; then
fi
mkdir -p /paddle/build # -p means no error if exists
cd /paddle/build
# clean local cmake and third_party cache
cd /paddle/build && rm -rf * && rm -rf ../third_party
cmake .. \
-DWITH_DOC=${WITH_DOC:-OFF} \
-DWITH_GPU=${WITH_GPU:-OFF} \
@ -56,6 +57,12 @@ if [[ ${BUILD_AND_INSTALL:-OFF} == 'ON' ]]; then
pip install /usr/local/opt/paddle/share/wheels/py_paddle*linux*.whl
pip install /usr/local/opt/paddle/share/wheels/paddle*.whl
paddle version
if [[ ${DOCKER_BUILD:-FALSE} == 'TRUE' ]]; then
# reduce docker image size
rm -rf /paddle/build
rm -rf /usr/local/opt/paddle/share/wheels/
fi
fi
trap : 0

@ -1,8 +1,4 @@
#!/bin/bash
LOG=/var/log/all
touch $LOG
/usr/sbin/sshd -D >> $LOG &
jupyter notebook --ip=0.0.0.0 /notes/ >> $LOG &
tail -f $LOG
/usr/sbin/sshd -D &
jupyter notebook --ip=0.0.0.0 /paddle/book/

@ -9,8 +9,8 @@ __all__ = ['infer']
class Inference(object):
def __init__(self, output, parameters):
topo = topology.Topology(output)
def __init__(self, output_layer, parameters):
topo = topology.Topology(output_layer)
gm = api.GradientMachine.createFromConfigProto(
topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE])
for param in gm.getParameters():
@ -21,33 +21,16 @@ class Inference(object):
self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type()
def iter_infer(self, input=None, batch_size=None, reader=None,
feeding=None):
def iter_infer(self, input, feeding=None):
feeder = DataFeeder(self.__data_types__, feeding)
if reader is None:
assert input is not None and isinstance(input, collections.Iterable)
if not isinstance(input, collections.Iterable):
raise TypeError("When reader is None, input should be whole "
"inference data and should be iterable")
if batch_size is None:
if not hasattr(input, '__len__'):
raise ValueError("Should set batch size when input data "
"don't contain length.")
batch_size = len(input)
def __reader_impl__():
for each_sample in input:
if len(feeder) == 1:
yield [each_sample]
else:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
else:
if input is not None:
raise ValueError("User should set either input or reader, "
"should not set them both.")
batch_size = len(input)
def __reader_impl__():
for each_sample in input:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
self.__gradient_machine__.start()
for data_batch in reader():
yield self.__gradient_machine__.forwardTest(feeder(data_batch))
@ -71,13 +54,7 @@ class Inference(object):
return retv
def infer(output,
parameters,
input=None,
batch_size=None,
reader=None,
feeding=None,
field='value'):
def infer(output_layer, parameters, input, feeding=None, field='value'):
"""
Infer a neural network by given neural network output and parameters. The
user should pass either a batch of input data or reader method.
@ -90,19 +67,13 @@ def infer(output,
batch_size=32)
print result
:param output: output of the neural network that would be inferred
:type output: paddle.v2.config_base.Layer
:param output_layer: output of the neural network that would be inferred
:type output_layer: paddle.v2.config_base.Layer
:param parameters: parameters of the neural network.
:type parameters: paddle.v2.parameters.Parameters
:param input: input data batch. Should be a python iterable object, and each
element is the data batch.
:type input: collections.Iterable
:param batch_size: the batch size when perform inference. Default is the
length of input.
:type batch_size: int
:param reader: input data reader creator in batch. If this field is set, the
`input` and `batch_size` will be ignored.
:type reader: callable
:param feeding: Reader dictionary. Default could generate from input
value.
:param field: The prediction field. It should in [`value`, `ids`]. `value`
@ -113,10 +84,5 @@ def infer(output,
:rtype: numpy.ndarray
"""
inferer = Inference(output=output, parameters=parameters)
return inferer.infer(
field=field,
input=input,
batch_size=batch_size,
reader=reader,
feeding=feeding)
inferer = Inference(output_layer=output_layer, parameters=parameters)
return inferer.infer(field=field, input=input, feeding=feeding)

Loading…
Cancel
Save