|
|
|
@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
|
platform::NCCLContextMap *nccl_ctxs, bool distributed)
|
|
|
|
|
platform::NCCLContextMap *nccl_ctxs)
|
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
distributed_(distributed),
|
|
|
|
|
nccl_ctxs_(nccl_ctxs) {
|
|
|
|
|
#else
|
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const std::vector<Scope *> &local_scopes, bool distributed)
|
|
|
|
|
const std::vector<Scope *> &local_scopes)
|
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
|
places_(places),
|
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
distributed_(distributed) {
|
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
|
|
#endif
|
|
|
|
|
for (auto &p : params) {
|
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
@ -99,7 +97,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
|
|
|
|
|
// append send op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
if (is_forwarding && distributed_ && op->Type() == "send") {
|
|
|
|
|
if (!is_forwarding && op->Type() == "send") {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
size_t i = 0;
|
|
|
|
|