|
|
|
@ -172,9 +172,9 @@ class Trainer(object):
|
|
|
|
|
def train(self,
|
|
|
|
|
num_epochs,
|
|
|
|
|
event_handler,
|
|
|
|
|
reader=None,
|
|
|
|
|
parallel=False,
|
|
|
|
|
feed_order=None):
|
|
|
|
|
reader,
|
|
|
|
|
feed_order,
|
|
|
|
|
parallel=False):
|
|
|
|
|
"""
|
|
|
|
|
Train the model.
|
|
|
|
|
|
|
|
|
@ -202,7 +202,7 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
|
|
|
|
|
|
|
|
|
|
def test(self, reader, feed_order=None):
|
|
|
|
|
def test(self, reader, feed_order):
|
|
|
|
|
"""
|
|
|
|
|
Test the model on given test data
|
|
|
|
|
|
|
|
|
@ -276,12 +276,7 @@ def build_feed_var_list(program, feed_order):
|
|
|
|
|
if not isinstance(program, framework.Program):
|
|
|
|
|
raise TypeError("The 'program' should be an object of Program")
|
|
|
|
|
|
|
|
|
|
if feed_order is None:
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
var for var in program.global_block().vars.itervalues()
|
|
|
|
|
if var.is_data
|
|
|
|
|
]
|
|
|
|
|
elif isinstance(feed_order, list):
|
|
|
|
|
if isinstance(feed_order, list):
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
program.global_block().var(var_name) for var_name in feed_order
|
|
|
|
|
]
|
|
|
|
|