|
|
|
|
@ -82,6 +82,7 @@ current_endpoint = os.getenv("SERVER_ENDPOINT")
|
|
|
|
|
# run as trainer or parameter server
|
|
|
|
|
training_role = os.getenv("TRAINING_ROLE",
|
|
|
|
|
"TRAINER") # get the training role: trainer/pserver
|
|
|
|
|
|
|
|
|
|
t.transpile(
|
|
|
|
|
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS)
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
@ -97,9 +98,10 @@ elif training_role == "TRAINER":
|
|
|
|
|
feed_list=[first_word, second_word, third_word, forth_word, next_word],
|
|
|
|
|
place=place)
|
|
|
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
|
trainer_prog = t.get_trainer_program()
|
|
|
|
|
for pass_id in range(PASS_NUM):
|
|
|
|
|
for data in train_reader():
|
|
|
|
|
avg_cost_np = exe.run(t.get_trainer_program(),
|
|
|
|
|
avg_cost_np = exe.run(trainer_prog,
|
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
|
fetch_list=[avg_cost])
|
|
|
|
|
print("avg_cost_np", avg_cost_np)
|
|
|
|
|
|