|
|
|
@ -163,7 +163,13 @@ void MultiDevSSAGraphBuilderBase::Init() const {
|
|
|
|
|
nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
places_.size(), local_scopes_.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Places size and LocalScopes not equal "
|
|
|
|
|
"Places size(%d), LocalScopes size(%d) "
|
|
|
|
|
"If use multi devices, Places size must equas to LocalScopes size.",
|
|
|
|
|
places_.size(), local_scopes_.size()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
|
|
|
|
@ -500,7 +506,11 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
SetCommunicationContext(op_handle, places_[i]);
|
|
|
|
|
auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
PADDLE_ENFORCE_EQ(vars.empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Can not find Var(%s) in Place[%d] "
|
|
|
|
|
"Paddle Can not add AllReduce OP for Var(%s).",
|
|
|
|
|
og, i, og));
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad);
|
|
|
|
|
VLOG(10) << "all_reduce_op_handle add input " << prev_grad->DebugString();
|
|
|
|
@ -566,7 +576,11 @@ details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->Get<details::GraphVars>(details::kGraphVars)[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
PADDLE_ENFORCE_EQ(vars.empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Can not find Var(%s) in Place[%d] "
|
|
|
|
|
"Paddle Can not add Reduce OP for Var(%s).",
|
|
|
|
|
og, i, og));
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad);
|
|
|
|
|
}
|
|
|
|
@ -590,7 +604,11 @@ bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const {
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
PADDLE_ENFORCE(all_vars_.count(og) != 0);
|
|
|
|
|
PADDLE_ENFORCE_NE(all_vars_.count(og), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Can not find Var(%s) in VarDescs "
|
|
|
|
|
"Paddle Can not add Collective OP for Var(%s).",
|
|
|
|
|
og, og));
|
|
|
|
|
return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -641,10 +659,20 @@ int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
|
|
|
|
|
std::vector<std::string>,
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_grad.size(), 2U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"In Node %s, the size of attribute %s must be 2, include Parameter "
|
|
|
|
|
"and Parameter@Grad.",
|
|
|
|
|
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
int dev_id = GetVarDeviceID(param_grad[1]);
|
|
|
|
|
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
|
|
|
|
|
node->Op()->Type(), param_grad[0], param_grad[1]);
|
|
|
|
|
PADDLE_ENFORCE_NE(dev_id, -1, platform::errors::NotFound(
|
|
|
|
|
"Can not find Device ID, for NodeName:%s, "
|
|
|
|
|
"NodeType:%s, Param:%s, Param@Grad:%s"
|
|
|
|
|
"For this fault, you can consult the "
|
|
|
|
|
"Paddle technical personnel for answer ",
|
|
|
|
|
node->Name(), node->Op()->Type(),
|
|
|
|
|
param_grad[0], param_grad[1]));
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -654,10 +682,16 @@ size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID(
|
|
|
|
|
for (auto var_name : var_names) {
|
|
|
|
|
if (all_vars_.find(var_name) == all_vars_.end()) continue;
|
|
|
|
|
auto var_desc = all_vars_.at(var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_desc,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Can not find Var(%s) in Var Desc.", var_name));
|
|
|
|
|
auto dim = framework::make_ddim(var_desc->GetShape());
|
|
|
|
|
int64_t numel = framework::product(dim);
|
|
|
|
|
PADDLE_ENFORCE_GT(numel, 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(numel, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The numel of Var(%s) must greater than 0"
|
|
|
|
|
"Please check your code,about Var(%s) Shape.",
|
|
|
|
|
var_name, var_name));
|
|
|
|
|
numel_sum += numel;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -736,7 +770,12 @@ int ReduceSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|
std::vector<std::string>,
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_grad.size(), 2U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"In Node %s, The size of attribute %s must be 2, include Parameter "
|
|
|
|
|
"and Parameter@Grad.",
|
|
|
|
|
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
int dev_id = GetVarDeviceID(param_grad[1]);
|
|
|
|
|
|
|
|
|
|
if (dev_id == -1) {
|
|
|
|
@ -798,7 +837,12 @@ std::vector<ir::Node *> ReduceSSAGraphBuilder::SortForReduceMode(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Sorted ops calc error!"
|
|
|
|
|
"The result for sorted ops size(%d) must be "
|
|
|
|
|
"equal to topo ops size(%d).",
|
|
|
|
|
sorted_ops.size(), topo_ops.size()));
|
|
|
|
|
|
|
|
|
|
ResetState();
|
|
|
|
|
return sorted_ops;
|
|
|
|
@ -820,14 +864,23 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
bool insert_op = false;
|
|
|
|
|
if (OpHaveRole(*node, OpRole::kRPC)) {
|
|
|
|
|
int op_dev_id = CreateRPCOp(result, node);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"Can not schedule the RPC operator to the right place.");
|
|
|
|
|
PADDLE_ENFORCE_NE(op_dev_id, -1, platform::errors::InvalidArgument(
|
|
|
|
|
"Can not schedule the RPC operator to "
|
|
|
|
|
"the right place. NodeName:%s.",
|
|
|
|
|
node->Name()));
|
|
|
|
|
if (node->Op()->Type() == "recv") {
|
|
|
|
|
auto recv_vars_attr =
|
|
|
|
|
BOOST_GET_CONST(std::vector<std::string>,
|
|
|
|
|
node->Op()->GetNullableAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient]
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
recv_vars_attr.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"In Node %s, the size of attribute %s must be 2, include "
|
|
|
|
|
"Parameter and Parameter@Grad.",
|
|
|
|
|
node->Name(),
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleVarAttrName())); // [parameter,
|
|
|
|
|
// gradient]
|
|
|
|
|
if (recv_vars_attr[0].find(".block") == std::string::npos) {
|
|
|
|
|
bcast_var_name_set_[op_dev_id].emplace(recv_vars_attr[0]);
|
|
|
|
|
}
|
|
|
|
@ -879,8 +932,9 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
|
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
|
|
|
|
|
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
|
|
|
|
|
"This hack no longer holds, please fix.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ir::IsControlDepVar(*node->inputs[0]), false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"This hack no longer holds, please fix."));
|
|
|
|
|
// the variable name which contains .block means it was split by
|
|
|
|
|
// split_byref op
|
|
|
|
|
if (strategy_.reduce_ ==
|
|
|
|
@ -893,7 +947,12 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
auto send_param_grad = BOOST_GET_CONST(
|
|
|
|
|
std::vector<std::string>,
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
send_param_grad.size(), 2U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"In Node %s, the size of attribute %s must be 2, include "
|
|
|
|
|
"Parameter and Parameter@Grad.",
|
|
|
|
|
node->Name(), OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
|
|
|
|
|
VLOG(10) << "send grad " << input_var_names[0] << " origin "
|
|
|
|
|
<< send_param_grad[1] << " place: " << op_dev_id;
|
|
|
|
@ -926,9 +985,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
op_dev_id = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
|
|
|
|
|
node->Op()->Type());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
op_dev_id, -1,
|
|
|
|
|
platform::errors::NotFound("Can not find the right place for rpc op: %s.",
|
|
|
|
|
node->Op()->Type()));
|
|
|
|
|
// Create fetch_barrier op handle to enable output on all devices.
|
|
|
|
|
// **NOTE** fetch_barrier should output variables list same as recv op does.
|
|
|
|
|
if (node->Op()->Type() == "fetch_barrier") {
|
|
|
|
@ -956,7 +1016,10 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
int outvar_dev_id = op_dev_id;
|
|
|
|
|
if (node->Op()->Type() == "fetch_barrier") {
|
|
|
|
|
outvar_dev_id = GetVarDeviceID(output->Name());
|
|
|
|
|
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
|
|
|
|
|
PADDLE_ENFORCE_NE(outvar_dev_id, -1,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Can not find the right place for the var: %s.",
|
|
|
|
|
output->Name()));
|
|
|
|
|
}
|
|
|
|
|
p = places_[outvar_dev_id];
|
|
|
|
|
ir::Node *new_node = nullptr;
|
|
|
|
@ -1007,13 +1070,14 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"the distribute training related op should be in [split_byref, "
|
|
|
|
|
"concat].");
|
|
|
|
|
platform::errors::Unimplemented("The distribute training related op "
|
|
|
|
|
"should be in [split_byref, concat]."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"can not find right place for distributed op: %s",
|
|
|
|
|
node->Op()->Type());
|
|
|
|
|
PADDLE_ENFORCE_NE(op_dev_id, -1,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Can not find right place for distributed op: %s.",
|
|
|
|
|
node->Op()->Type()));
|
|
|
|
|
|
|
|
|
|
CreateComputationalOp(result, node, op_dev_id);
|
|
|
|
|
return op_dev_id;
|
|
|
|
|