|
|
|
@ -68,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
size_t place_id) const {
|
|
|
|
|
auto p = places_[place_id];
|
|
|
|
|
auto *op_handle = result->Get<GraphOps>("ops").back().get();
|
|
|
|
@ -192,8 +193,9 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
|
|
|
|
|
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
|
|
|
|
|
// some optimizer ops might not depend on any nodes), we manually move all
|
|
|
|
|
// optimizer nodes after last backward nodes.
|
|
|
|
|
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
|
|
|
|
|
std::vector<ir::Node *> ret = ir::TopologySort(graph);
|
|
|
|
|
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
|
|
|
|
|
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
|
|
|
|
|
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
|
|
|
|
|
size_t last_backward = 0;
|
|
|
|
|
std::vector<ir::Node *> optimize_ops;
|
|
|
|
|
std::vector<ir::Node *> sorted_ret;
|
|
|
|
@ -232,8 +234,8 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const Graph &graph) {
|
|
|
|
|
return sorted_ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
std::unique_ptr<Graph> graph) const {
|
|
|
|
|
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) const {
|
|
|
|
|
// Rebuild the graph structure.
|
|
|
|
|
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
|
|
|
|
|
auto nodes = std::move(graph->nodes);
|
|
|
|
@ -245,7 +247,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Graph &result = *graph;
|
|
|
|
|
ir::Graph &result = *graph;
|
|
|
|
|
std::unordered_set<std::string> og_has_been_broadcast;
|
|
|
|
|
|
|
|
|
|
// We cannot invoke resize. It is a bug of GCC 4.8
|
|
|
|
@ -397,7 +399,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
size_t src_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -427,7 +429,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
int dev_id) const {
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(
|
|
|
|
@ -436,7 +438,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
|
|
|
|
|
CreateOpHandleIOs(result, node, dev_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
|
|
|
|
|
const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new AllReduceOpHandle(
|
|
|
|
@ -466,7 +468,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
|
|
|
|
|
Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
ir::Graph *result, const std::vector<std::string> &datas) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
result->Get<GraphOps>("ops").emplace_back(new DataBalanceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
|
|
|
|
@ -529,7 +531,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
|
|
|
|
|
return got == var_name_on_devices_.end() ? -1 : got->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -559,7 +561,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
size_t num_places) const {
|
|
|
|
|
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
|
|
|
|
@ -571,7 +573,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
const std::string &og,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -604,7 +606,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
|
|
|
|
|
|
|
|
|
|
// Find the first occurence of `prev_op_name` and make current `op` depend
|
|
|
|
|
// on it.
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->Get<GraphOps>("ops")) {
|
|
|
|
|
if (prev_op->Name() == prev_op_name) {
|
|
|
|
@ -617,7 +619,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
@ -664,7 +666,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create RPC related op handles that connects its in ops and out ops.
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const {
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
|
|
|
|
|