update the example with the latest API

trainerSaveLoadParams
Helin Wang 7 years ago
parent 3eef539a42
commit a785a837b9

@ -94,7 +94,7 @@ def train(use_cuda, is_sparse, save_path):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
def event_handler(event):
if isinstance(event, fluid.EndPass):
if isinstance(event, fluid.Event.END_EPOCH):
avg_cost = trainer.test(reader=paddle.dataset.imikolov.test(
word_dict, N))
@ -106,10 +106,9 @@ def train(use_cuda, is_sparse, save_path):
trainer = fluid.Trainer(
partial(inference_network, is_sparse),
optimizer=fluid.optimizer.SGD(learning_rate=0.001),
place=place,
event_handler=event_handler)
trainer.train(train_reader, 100)
fluid.optimizer.SGD(learning_rate=0.001),
place=place)
trainer.train(train_reader, 100, event_handler)
def infer(use_cuda, save_path):

Loading…
Cancel
Save