|
|
|
@ -49,7 +49,7 @@ static void CreateTensorFromMessageType(framework::Variable *var,
|
|
|
|
|
var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"VraibleMessage type %d is not in "
|
|
|
|
|
"VariableMessage type %d is not in "
|
|
|
|
|
"[LoDTensor, SelectedRows]",
|
|
|
|
|
var_type);
|
|
|
|
|
}
|
|
|
|
@ -121,17 +121,17 @@ class RecvOp : public framework::OperatorBase {
|
|
|
|
|
if (it != grad_list.end()) {
|
|
|
|
|
param_var_name = param_list[it - grad_list.begin()];
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "grad have no paired param:" << grad_var_name;
|
|
|
|
|
LOG(ERROR) << "grad has no paired param:" << grad_var_name;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "recved grad: " << grad_var_name
|
|
|
|
|
VLOG(3) << "received grad: " << grad_var_name
|
|
|
|
|
<< " updating param: " << param_var_name;
|
|
|
|
|
if (fan_in > 1) {
|
|
|
|
|
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
|
|
|
|
|
}
|
|
|
|
|
auto *var = recv_scope.FindVar(grad_var_name);
|
|
|
|
|
if (var == nullptr) {
|
|
|
|
|
LOG(ERROR) << "can not find server side var: " << grad_var_name;
|
|
|
|
|
PADDLE_THROW("can not find server side var");
|
|
|
|
|
LOG(ERROR) << "Can not find server side var: " << grad_var_name;
|
|
|
|
|
PADDLE_THROW("Can not find server side var");
|
|
|
|
|
}
|
|
|
|
|
detail::DeserializeFromMessage(v.second, dev_ctx, var);
|
|
|
|
|
}
|
|
|
|
@ -165,7 +165,7 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Recv operator
|
|
|
|
|
|
|
|
|
|
This operator will recv tensor from send_op
|
|
|
|
|
This operator will recieve tensor from send_op
|
|
|
|
|
)DOC");
|
|
|
|
|
AddAttr<std::string>("endpoint",
|
|
|
|
|
"(string, default 127.0.0.1:6164)"
|
|
|
|
@ -176,11 +176,11 @@ This operator will recv tensor from send_op
|
|
|
|
|
kOptimizeBlock, "Serialized ProgramDesc string for recv to run.");
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"ParamList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.")
|
|
|
|
|
"grad->param name mapping to find which parameters to optimize.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<std::vector<std::string>>(
|
|
|
|
|
"GradList", "type list of string",
|
|
|
|
|
"grad->param name mapping to find which param to optimize.")
|
|
|
|
|
"grad->param name mapping to find which parameters to optimize.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddAttr<int>("Fanin", "type int",
|
|
|
|
|
"Number of trainers in the current cluster job")
|
|
|
|
|