Graph in ParallelExecutor Builder

guochaorong-patch-1
Xin Pan 7 years ago
parent 7781297c70
commit 2eeaa8d5cf

@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace platform {
@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
void CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t device_id) const;
private:
@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(Graph *result, const OpDesc &op) const;
void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
void CreateComputationalOps(Graph *result, const OpDesc &op,
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
bool IsParameterGradientOnce(
const std::string &og,
@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result,
void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const;

@ -17,8 +17,8 @@
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : *boost::any_cast<GraphVars *>(graph->attrs["vars"])) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dep_var);
}
}
}
@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name,
Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holders =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var;
}
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
auto &vars =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset]
[each_var_name];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
GraphOps &all_ops = *boost::any_cast<GraphOps *>(graph->attrs["ops"]);
for (auto &op : all_ops) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}

@ -16,15 +16,24 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
@ -42,20 +51,20 @@ class SSAGraphBuilder {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void AddOutputToLeafOps(Graph *graph);
};
} // namespace details
} // namespace framework

@ -27,7 +27,7 @@ namespace framework {
class Graph {
public:
std::map<std::string, std::vector<boost::any>> attrs;
std::map<std::string, boost::any> attrs;
std::vector<Node *> inputs;
std::vector<Node *> outputs;

@ -14,6 +14,27 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {} // namespace framework
namespace framework {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework
} // namespace paddle

Loading…
Cancel
Save