|
|
|
@ -34,6 +34,7 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
|
|
|
|
|
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
|
|
|
|
@ -303,7 +304,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
result.Set(kGraphVars, new GraphVars(places_.size()));
|
|
|
|
|
result.Set(kGraphDepVars, new GraphDepVars);
|
|
|
|
|
result.Set(kGraphOps, new GraphOps);
|
|
|
|
|
result.Set(kShardedVarDevice, new ShardedVarDevice);
|
|
|
|
|
|
|
|
|
|
// find send/recv vars so that we can place the distributed training
|
|
|
|
|
// related op in the place 0
|
|
|
|
@ -317,11 +317,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
bool is_dist_train = false;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, int> sharded_var_device;
|
|
|
|
|
|
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
|
int op_dev_id = CreateRPCOp(&result, node);
|
|
|
|
|
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"Can not schedule the RPC operator to the right place.");
|
|
|
|
|
if (node->Op()->Type() == "recv") {
|
|
|
|
@ -337,7 +339,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
} else if (boost::get<int>(node->Op()->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kDist)) {
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(&result, node);
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
|
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
|
|
|
|
|
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
|
|
|
|
@ -356,12 +358,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
// the block.
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(result, node);
|
|
|
|
|
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
|
CreateComputationalOp(&result, node, op_dev_id);
|
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|
graph->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(n->Name(), op_dev_id);
|
|
|
|
|
sharded_var_device.emplace(n->Name(), op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// This op runs on all devices, and its output may have parameter's
|
|
|
|
@ -398,8 +399,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|
cur_device_id = GetAppropriateDeviceID({g_name});
|
|
|
|
|
CreateReduceOp(&result, g_name, cur_device_id);
|
|
|
|
|
graph->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(g_name, cur_device_id);
|
|
|
|
|
sharded_var_device.emplace(g_name, cur_device_id);
|
|
|
|
|
if (!is_dist_train) {
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(p_name);
|
|
|
|
|
}
|
|
|
|
@ -617,8 +617,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|
const ir::Graph &graph, ir::Node *node,
|
|
|
|
|
const std::unordered_map<std::string, int> &sharded_var_device) const {
|
|
|
|
|
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
@ -631,15 +632,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
|
|
|
|
|
int dev_id = GetVarDeviceID(graph, param_grad[1]);
|
|
|
|
|
int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
|
|
|
|
|
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
|
|
|
|
|
node->Op()->Type(), param_grad[0], param_grad[1]);
|
|
|
|
|
return dev_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
|
|
|
|
|
const std::string &varname) const {
|
|
|
|
|
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetVarDeviceID(
|
|
|
|
|
const ir::Graph &graph, const std::string &varname,
|
|
|
|
|
const std::unordered_map<std::string, int> &sharded_var_device) const {
|
|
|
|
|
auto got = sharded_var_device.find(varname);
|
|
|
|
|
return got == sharded_var_device.end() ? -1 : got->second;
|
|
|
|
|
}
|
|
|
|
@ -709,8 +710,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateDistTrainOp(
|
|
|
|
|
ir::Graph *result, ir::Node *node,
|
|
|
|
|
std::unordered_map<std::string, int> *sharded_var_device) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
@ -725,23 +727,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
node->Op()->Type() == "split_selected_rows" ||
|
|
|
|
|
node->Op()->Type() == "split_ids") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
|
|
|
|
|
op_dev_id =
|
|
|
|
|
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(input_var_names);
|
|
|
|
|
for (auto &varname : input_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : output_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else if (node->Op()->Type() == "concat") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
|
|
|
|
|
op_dev_id =
|
|
|
|
|
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
|
|
|
|
|
for (auto &varname : output_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
|
|
|
|
@ -774,12 +775,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create RPC related op handles that connects its in ops and out ops.
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int MultiDevSSAGraphBuilder::CreateRPCOp(
|
|
|
|
|
ir::Graph *result, ir::Node *node,
|
|
|
|
|
std::unordered_map<std::string, int> *sharded_var_device) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
|
|
|
|
|
op_dev_id =
|
|
|
|
|
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
|
|
|
|
|
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
|
|
|
|
|
"This hack no longer holds, please fix.");
|
|
|
|
|
// the variable name which contains .block means it was splited by
|
|
|
|
@ -797,11 +800,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
VLOG(10) << "send grad " << input_var_names[0] << " origin "
|
|
|
|
|
<< send_param_grad[1] << " place: " << op_dev_id;
|
|
|
|
|
for (auto &varname : input_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(send_param_grad[1], op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(send_param_grad[1], op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else if (node->Op()->Type() == "recv") {
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
@ -811,7 +812,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
auto recv_param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
if (recv_param_grad.size() == 2U) {
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
|
|
|
|
|
op_dev_id =
|
|
|
|
|
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
|
|
|
|
|
VLOG(10) << "recv param " << recv_param_grad[0]
|
|
|
|
|
<< " get grad place: " << recv_param_grad[1]
|
|
|
|
|
<< " place: " << op_dev_id;
|
|
|
|
@ -819,8 +821,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(output_var_names);
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : output_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
sharded_var_device->emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// send_barrier, fetch_barrier will run on place 0;
|
|
|
|
@ -847,7 +848,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
for (ir::Node *output : node->outputs) {
|
|
|
|
|
int outvar_dev_id = op_dev_id;
|
|
|
|
|
if (node->Op()->Type() == "fetch_barrier") {
|
|
|
|
|
outvar_dev_id = GetVarDeviceID(*result, output->Name());
|
|
|
|
|
outvar_dev_id =
|
|
|
|
|
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
|
|
|
|
|
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
|
|
|
|
|
}
|
|
|
|
|
p = places_[outvar_dev_id];
|
|
|
|
|