|
|
|
@ -75,11 +75,15 @@ class Trainer(object):
|
|
|
|
|
self.train_program = framework.Program()
|
|
|
|
|
|
|
|
|
|
with framework.program_guard(self.train_program, self.startup_program):
|
|
|
|
|
loss = program_func()
|
|
|
|
|
program_func_outs = program_func()
|
|
|
|
|
self.test_outputs = program_func_outs if isinstance(
|
|
|
|
|
program_func_outs, list) else [program_func_outs]
|
|
|
|
|
self.test_program = self.train_program.clone()
|
|
|
|
|
if not isinstance(optimizer, opt_module.Optimizer):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The optimizer should be an instance of Optimizer")
|
|
|
|
|
|
|
|
|
|
# The fisrt element of program_func_outs is loss.
|
|
|
|
|
loss = self.test_outputs[0]
|
|
|
|
|
optimize_ops, params_grads = optimizer.minimize(loss)
|
|
|
|
|
|
|
|
|
|
self.place = Trainer._check_and_get_place(place)
|
|
|
|
@ -168,8 +172,17 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
|
|
|
|
|
|
|
|
|
|
def test(self, reader):
|
|
|
|
|
pass
|
|
|
|
|
def test(self, reader, feed_order=None):
|
|
|
|
|
"""
|
|
|
|
|
Test the model on given test data
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
reader: The reader that yields test data.
|
|
|
|
|
feed_order: Feeding order of reader. None will following the defining
|
|
|
|
|
order in program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self._test_by_executor(reader, feed_order, self.test_outputs)
|
|
|
|
|
|
|
|
|
|
def save_params(self, param_path):
|
|
|
|
|
# reference: save_persistables in io.py
|
|
|
|
@ -225,22 +238,10 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
if feed_order is None:
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
var
|
|
|
|
|
for var in self.train_program.global_block(
|
|
|
|
|
).vars.itervalues()
|
|
|
|
|
if hasattr(var, 'is_data') and var.is_data
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
self.train_program.global_block().var(var_name)
|
|
|
|
|
for var_name in feed_order
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
feed_var_list = build_feed_var_list(self.train_program, feed_order)
|
|
|
|
|
feeder = data_feeder.DataFeeder(
|
|
|
|
|
feed_list=feed_var_list, place=self.place)
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
for epoch_id in range(num_epochs):
|
|
|
|
|
event_handler(BeginEpochEvent(epoch_id))
|
|
|
|
|
for step_id, data in enumerate(reader()):
|
|
|
|
@ -248,3 +249,48 @@ class Trainer(object):
|
|
|
|
|
exe.run(feed=feeder.feed(data), fetch_list=[])
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id))
|
|
|
|
|
event_handler(EndEpochEvent(epoch_id))
|
|
|
|
|
|
|
|
|
|
def _test_by_executor(self, reader, feed_order, fetch_list):
|
|
|
|
|
with executor.scope_guard(self.scope):
|
|
|
|
|
feed_var_list = build_feed_var_list(self.test_program, feed_order)
|
|
|
|
|
feeder = data_feeder.DataFeeder(
|
|
|
|
|
feed_list=feed_var_list, place=self.place)
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
accumulated = len(fetch_list) * [0]
|
|
|
|
|
count = 0
|
|
|
|
|
for data in reader():
|
|
|
|
|
outs = exe.run(program=self.test_program,
|
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
|
fetch_list=fetch_list)
|
|
|
|
|
accumulated = [x[0] + x[1][0] for x in zip(accumulated, outs)]
|
|
|
|
|
count += 1
|
|
|
|
|
|
|
|
|
|
return [x / count for x in accumulated]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_feed_var_list(program, feed_order):
|
|
|
|
|
if not isinstance(program, framework.Program):
|
|
|
|
|
raise TypeError("The 'program' should be an object of Program")
|
|
|
|
|
|
|
|
|
|
if feed_order is None:
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
var for var in program.global_block().vars.itervalues()
|
|
|
|
|
if var.is_data
|
|
|
|
|
]
|
|
|
|
|
elif isinstance(feed_order, list):
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
program.global_block().var(var_name) for var_name in feed_order
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(feed_order, dict):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The 'feed_order' should be either None, list or dict.")
|
|
|
|
|
if not sorted(feed_order.values()) == range(len(feed_order)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
|
|
|
|
|
)
|
|
|
|
|
sorted_pair_list = sorted(feed_order.items(), key=lambda item: item[1])
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
program.global_block().var(pair[0]) for pair in sorted_pair_list
|
|
|
|
|
]
|
|
|
|
|
return feed_var_list
|
|
|
|
|