|
|
|
@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
const platform::Place &p,
|
|
|
|
|
const size_t &i) const {
|
|
|
|
|
size_t place_id) const {
|
|
|
|
|
auto p = places_[place_id];
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
|
|
|
|
|
auto var_names = op.InputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
|
|
|
|
|
for (auto &each_var_name : op.InputArgumentNames()) {
|
|
|
|
|
VarHandle *var =
|
|
|
|
|
CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
|
|
|
|
|
op_handle->AddInput(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var_names = op.OutputArgumentNames();
|
|
|
|
|
|
|
|
|
|
for (auto &each_var_name : var_names) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, i);
|
|
|
|
|
for (auto &each_var_name : op.OutputArgumentNames()) {
|
|
|
|
|
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto checker = [&](const std::vector<std::string> opvars,
|
|
|
|
|
const std::vector<std::string> sendvars) -> bool {
|
|
|
|
|
bool is_dist_train_op = false;
|
|
|
|
|
/**
|
|
|
|
|
* Check any of opvars contains `.block` and in sendvars
|
|
|
|
|
*/
|
|
|
|
|
auto checker = [](const std::vector<std::string> &opvars,
|
|
|
|
|
const std::vector<std::string> &sendvars) -> bool {
|
|
|
|
|
for (auto &var : opvars) {
|
|
|
|
|
if (var.find(".block") != std::string::npos &&
|
|
|
|
|
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
|
|
|
|
|
is_dist_train_op = true;
|
|
|
|
|
break;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return is_dist_train_op;
|
|
|
|
|
return false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "split") {
|
|
|
|
@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
places_.size());
|
|
|
|
|
|
|
|
|
|
// Find "send" op first for split is in front of send.
|
|
|
|
|
OpDesc *send_op = nullptr;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
send_op = op;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
OpDesc *send_op = GetSendOpDesc(program);
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
// user can customize loss@grad if skip_scale_loss_
|
|
|
|
|
if (!skip_scale_loss_) {
|
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
|
}
|
|
|
|
@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
if (!is_forwarding) {
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once. But there are no
|
|
|
|
|
// other cases, for example, we need to adjust the gradient according to
|
|
|
|
|
// the input when we get the gradient, which is not considered at
|
|
|
|
|
// present.
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
|
for (auto &og : op->OutputArgumentNames()) {
|
|
|
|
|
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
|
|
|
|
|
InsertNCCLAllReduceOp(&result, og);
|
|
|
|
@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
return std::unique_ptr<SSAGraph>(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
|
|
|
|
|
const ProgramDesc &program) const {
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
SSAGraph *result, const std::string &og) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
|
|
|
|
|
auto p = places_[scope_idx];
|
|
|
|
|
auto s = local_scopes_[scope_idx];
|
|
|
|
|
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
|
|
|
|
|
CreateOpHandleIOs(result, op, p, scope_idx);
|
|
|
|
|
CreateOpHandleIOs(result, op, scope_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
|
|
|
|
|
result->ops_.emplace_back(new SendOpHandle(op, s, p));
|
|
|
|
|
// Create inputs for output on original place and no ssa output
|
|
|
|
|
// is created for send op.
|
|
|
|
|
CreateOpHandleIOs(result, op, p, 0);
|
|
|
|
|
CreateOpHandleIOs(result, op, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
|
|
|
|
|