Move learning rate and releated op to pserver (#8209)

* dist train support lr decay

* update by comment

* revert elementwise method creator

* delete comment
emailweixu-patch-1
Yancey 8 years ago committed by GitHub
parent 72bcf72c66
commit 279aa626ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -106,6 +106,7 @@ class ListenAndServOp : public framework::OperatorBase {
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(0);
size_t recv_var_cnt = 0; size_t recv_var_cnt = 0;
size_t update_param_cnt = 0;
int batch_barrier = 0; int batch_barrier = 0;
while (batch_barrier != fan_in) { while (batch_barrier != fan_in) {
const detail::MessageWithName &v = rpc_service_->Get(); const detail::MessageWithName &v = rpc_service_->Get();
@ -126,13 +127,14 @@ class ListenAndServOp : public framework::OperatorBase {
std::string param_var_name; std::string param_var_name;
if (it != grad_list.end()) { if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()]; param_var_name = param_list[it - grad_list.begin()];
update_param_cnt++;
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;
} else { } else {
LOG(ERROR) << "grad has no paired param:" << grad_var_name; VLOG(3) << "received variable: " << grad_var_name
<< " no need to update param";
} }
VLOG(3) << "received grad: " << grad_var_name if (fan_in > 1 && !param_var_name.empty()) {
<< " updating param: " << param_var_name;
if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
} }
auto *var = recv_scope.FindVar(grad_var_name); auto *var = recv_scope.FindVar(grad_var_name);
@ -144,11 +146,10 @@ class ListenAndServOp : public framework::OperatorBase {
} }
} }
VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) { if (exit_flag) {
rpc_service_->ShutDown(); rpc_service_->ShutDown();
} }
VLOG(3) << "run optimize graph...";
try { try {
executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
@ -156,7 +157,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
rpc_service_->WaitClientGet(recv_var_cnt); rpc_service_->WaitClientGet(update_param_cnt);
grads_counter_.clear(); grads_counter_.clear();
} // while(true) } // while(true)
} }

File diff suppressed because it is too large Load Diff

@ -117,6 +117,7 @@ def monkey_patch_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype) out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
self.block.append_op( self.block.append_op(
type=op_type, type=op_type,
inputs={'X': [self], inputs={'X': [self],

@ -99,7 +99,7 @@ elif training_role == "TRAINER":
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
avg_cost_np = exe.run(fluid.default_main_program(), avg_cost_np = exe.run(t.get_trainer_program(),
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[avg_cost]) fetch_list=[avg_cost])
print("avg_cost_np", avg_cost_np) print("avg_cost_np", avg_cost_np)

Loading…
Cancel
Save