Move learning rate and releated op to pserver (#8209)

* dist train support lr decay

* update by comment

* revert elementwise method creator

* delete comment
emailweixu-patch-1
Yancey 7 years ago committed by GitHub
parent 72bcf72c66
commit 279aa626ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -106,6 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get();
@ -126,13 +127,14 @@ class ListenAndServOp : public framework::OperatorBase {
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
if (fan_in > 1) {
if (fan_in > 1 && !param_var_name.empty()) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
auto *var = recv_scope.FindVar(grad_var_name);
@ -144,11 +146,10 @@ class ListenAndServOp : public framework::OperatorBase {
}
}
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) {
rpc_service_->ShutDown();
}
VLOG(3) << "run optimize graph...";
try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
@ -156,7 +157,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(recv_var_cnt);
rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear();
} // while(true)
}

File diff suppressed because it is too large Load Diff

@ -117,6 +117,7 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op(
type=op_type,
inputs={'X': [self],

@ -99,7 +99,7 @@ elif training_role == "TRAINER":
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_cost_np = exe.run(fluid.default_main_program(),
avg_cost_np = exe.run(t.get_trainer_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np)

Loading…
Cancel
Save