|
|
|
@ -225,25 +225,6 @@ class TestDistCTR2x2(FleetDistRunnerBase):
|
|
|
|
|
debug=False)
|
|
|
|
|
pass_time = time.time() - pass_start
|
|
|
|
|
|
|
|
|
|
res_dict = dict()
|
|
|
|
|
res_dict['loss'] = self.avg_cost
|
|
|
|
|
|
|
|
|
|
class FH(fluid.executor.FetchHandler):
|
|
|
|
|
def handle(self, res_dict):
|
|
|
|
|
for key in res_dict:
|
|
|
|
|
v = res_dict[key]
|
|
|
|
|
print("{}: \n {}\n".format(key, v))
|
|
|
|
|
|
|
|
|
|
for epoch_id in range(1):
|
|
|
|
|
pass_start = time.time()
|
|
|
|
|
dataset.set_filelist(filelist)
|
|
|
|
|
exe.train_from_dataset(
|
|
|
|
|
program=fleet.main_program,
|
|
|
|
|
dataset=dataset,
|
|
|
|
|
fetch_handler=FH(var_dict=res_dict, period_secs=2),
|
|
|
|
|
debug=False)
|
|
|
|
|
pass_time = time.time() - pass_start
|
|
|
|
|
|
|
|
|
|
if os.getenv("SAVE_MODEL") == "1":
|
|
|
|
|
model_dir = tempfile.mkdtemp()
|
|
|
|
|
fleet.save_inference_model(exe, model_dir,
|
|
|
|
@ -251,6 +232,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
|
|
|
|
|
self.avg_cost)
|
|
|
|
|
self.check_model_right(model_dir)
|
|
|
|
|
shutil.rmtree(model_dir)
|
|
|
|
|
|
|
|
|
|
fleet.stop_worker()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|