|
|
@ -34,7 +34,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
const std::vector<Scope *> &local_scopes, bool skip_scale_loss,
|
|
|
|
platform::NCCLContextMap *nccl_ctxs)
|
|
|
|
platform::NCCLContextMap *nccl_ctxs)
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
places_(places),
|
|
|
|
places_(places),
|
|
|
@ -45,7 +45,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::vector<Scope *> &local_scopes)
|
|
|
|
const std::vector<Scope *> &local_scopes, bool skip_scale_loss)
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
places_(places),
|
|
|
|
places_(places),
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
@ -53,6 +53,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
for (auto &p : params) {
|
|
|
|
for (auto &p : params) {
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
skip_scale_loss_ = skip_scale_loss;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
|
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
|
|
|
@ -133,7 +134,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
if (!skip_scale_loss_) {
|
|
|
|
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
|
|
|
|
}
|
|
|
|
is_forwarding = false;
|
|
|
|
is_forwarding = false;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|