|
|
@ -345,7 +345,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
|
|
auto var = new VarHandle(vars.size() - 1, i, og, p);
|
|
|
|
auto var = new VarHandle(vars.size(), i, og, p);
|
|
|
|
vars.emplace_back(var);
|
|
|
|
vars.emplace_back(var);
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -442,8 +442,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto &vars = result->vars_[dst_dev_id][og];
|
|
|
|
auto &vars = result->vars_[dst_dev_id][og];
|
|
|
|
auto var =
|
|
|
|
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
|
|
|
|
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
|
|
|
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
vars.emplace_back(var);
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
return var;
|
|
|
|
return var;
|
|
|
|