Graph in ParallelExecutor Builder

guochaorong-patch-1
Xin Pan 7 years ago
parent 7781297c70
commit 2eeaa8d5cf

@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetVarDeviceID(const std::string &varname) const override; int GetVarDeviceID(const std::string &varname) const override;
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t device_id) const; size_t device_id) const;
private: private:
@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(Graph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const; const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op, void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(Graph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const; void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op, void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
int dev_id) const;
bool IsParameterGradientOnce( bool IsParameterGradientOnce(
const std::string &og, const std::string &og,
@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(const OpDesc &op) const; int GetOpDeviceID(const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result, void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const; const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;

@ -17,8 +17,8 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : graph->vars_) { for (auto &var_map : *boost::any_cast<GraphVars *>(graph->attrs["vars"])) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var); boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dep_var);
} }
} }
} }
@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
} }
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name, Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) { const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset]; auto &var_holders =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset];
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var; return var;
} }
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset]
[each_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place); auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
for (auto &op : graph->ops_) { GraphOps &all_ops = *boost::any_cast<GraphOps *>(graph->attrs["ops"]);
for (auto &op : all_ops) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(); auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf); boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }

@ -16,15 +16,24 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder { class SSAGraphBuilder {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
@ -42,20 +51,20 @@ class SSAGraphBuilder {
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
static void PolishGraphToSupportDataHazards(SSAGraph *graph); static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, static VarHandle *CreateOrGetLatestVarHandle(Graph *graph,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle, // Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph); static void AddOutputToLeafOps(Graph *graph);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework

@ -27,7 +27,7 @@ namespace framework {
class Graph { class Graph {
public: public:
std::map<std::string, std::vector<boost::any>> attrs; std::map<std::string, boost::any> attrs;
std::vector<Node *> inputs; std::vector<Node *> inputs;
std::vector<Node *> outputs; std::vector<Node *> outputs;

@ -14,6 +14,27 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework
} // namespace paddle } // namespace paddle

Loading…
Cancel
Save