|
|
@ -19,6 +19,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/details/build_strategy.h"
|
|
|
|
#include "paddle/fluid/framework/details/build_strategy.h"
|
|
|
|
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
|
|
|
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace platform {
|
|
|
|
namespace platform {
|
|
|
@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
|
|
|
int GetVarDeviceID(const std::string &varname) const override;
|
|
|
|
int GetVarDeviceID(const std::string &varname) const override;
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
|
|
|
|
void CreateOpHandleIOs(Graph *result, const OpDesc &op,
|
|
|
|
size_t device_id) const;
|
|
|
|
size_t device_id) const;
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
|
|
|
|
|
|
|
|
|
|
|
bool IsScaleLossOp(const OpDesc &op) const;
|
|
|
|
bool IsScaleLossOp(const OpDesc &op) const;
|
|
|
|
|
|
|
|
|
|
|
|
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
|
|
|
|
void CreateRPCOp(Graph *result, const OpDesc &op) const;
|
|
|
|
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
|
|
|
|
void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
/**
|
|
|
|
* Is this operator as the end-point operator before/after send operator.
|
|
|
|
* Is this operator as the end-point operator before/after send operator.
|
|
|
@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
|
|
|
std::vector<std::string> FindDistTrainRecvVars(
|
|
|
|
std::vector<std::string> FindDistTrainRecvVars(
|
|
|
|
const ProgramDesc &program) const;
|
|
|
|
const ProgramDesc &program) const;
|
|
|
|
|
|
|
|
|
|
|
|
void ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
void ConnectOp(Graph *result, OpHandleBase *op,
|
|
|
|
const std::string &prev_op_name) const;
|
|
|
|
const std::string &prev_op_name) const;
|
|
|
|
|
|
|
|
|
|
|
|
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
|
|
|
|
void CreateComputationalOps(Graph *result, const OpDesc &op,
|
|
|
|
size_t num_places) const;
|
|
|
|
size_t num_places) const;
|
|
|
|
|
|
|
|
|
|
|
|
void CreateScaleLossGradOp(SSAGraph *result) const;
|
|
|
|
void CreateScaleLossGradOp(Graph *result) const;
|
|
|
|
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
|
|
|
|
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
|
|
|
|
int dst_dev_id) const;
|
|
|
|
int dst_dev_id) const;
|
|
|
|
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
|
|
|
|
void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
|
|
|
|
int dev_id) const;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool IsParameterGradientOnce(
|
|
|
|
bool IsParameterGradientOnce(
|
|
|
|
const std::string &og,
|
|
|
|
const std::string &og,
|
|
|
@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
|
|
|
|
|
|
|
|
|
|
|
int GetOpDeviceID(const OpDesc &op) const;
|
|
|
|
int GetOpDeviceID(const OpDesc &op) const;
|
|
|
|
|
|
|
|
|
|
|
|
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
|
|
|
|
void InsertAllReduceOp(Graph *result, const std::string &og) const;
|
|
|
|
|
|
|
|
|
|
|
|
void InsertDataBalanceOp(SSAGraph *result,
|
|
|
|
void InsertDataBalanceOp(Graph *result,
|
|
|
|
const std::vector<std::string> &datas) const;
|
|
|
|
const std::vector<std::string> &datas) const;
|
|
|
|
|
|
|
|
|
|
|
|
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
|
|
|
|
void CreateBroadcastOp(Graph *result, const std::string &p_name,
|
|
|
|
size_t src_dev_id) const;
|
|
|
|
size_t src_dev_id) const;
|
|
|
|
|
|
|
|
|
|
|
|
bool IsSparseGradient(const std::string &og) const;
|
|
|
|
bool IsSparseGradient(const std::string &og) const;
|
|
|
|