|
|
|
@ -63,7 +63,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GetGradVarNameForTrainer(const std::string &varname) const {
|
|
|
|
|
if (grads_counter_.find(varname) != grads_counter_.end()) {
|
|
|
|
|
if (grads_counter_.find(varname) == grads_counter_.end()) {
|
|
|
|
|
grads_counter_[varname] = 0;
|
|
|
|
|
}
|
|
|
|
|
char ret[256];
|
|
|
|
@ -96,11 +96,7 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
VLOG(10) << "recved grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
if (trainer_count > 1) {
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (var != nullptr) {
|
|
|
|
|
// must rename the var to different names to merge gradient.
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *var = recv_scope.Var(grad_var_name);
|
|
|
|
|