|
|
@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
const std::vector<Scope *> &local_scopes,
|
|
|
|
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
|
|
|
|
platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
|
|
|
|
|
|
|
|
bool balance_parameter_opt_between_cards)
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
places_(places),
|
|
|
|
places_(places),
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
nccl_ctxs_(nccl_ctxs) {
|
|
|
|
nccl_ctxs_(nccl_ctxs),
|
|
|
|
|
|
|
|
balance_parameter_opt_between_cards_(
|
|
|
|
|
|
|
|
balance_parameter_opt_between_cards) {
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::vector<platform::Place> &places,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::string &loss_var_name,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
|
|
|
|
const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
|
|
|
|
|
|
|
|
bool balance_parameter_opt_between_cards)
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
: loss_var_name_(loss_var_name),
|
|
|
|
places_(places),
|
|
|
|
places_(places),
|
|
|
|
local_scopes_(local_scopes) {
|
|
|
|
local_scopes_(local_scopes),
|
|
|
|
|
|
|
|
balance_parameter_opt_between_cards_(
|
|
|
|
|
|
|
|
balance_parameter_opt_between_cards) {
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (auto &p : params) {
|
|
|
|
for (auto &p : params) {
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
|
grad_names_.insert(GradVarName(p));
|
|
|
@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
// Find "send" op first for split is in front of send.
|
|
|
|
// Find "send" op first for split is in front of send.
|
|
|
|
OpDesc *send_op = GetSendOpDesc(program);
|
|
|
|
OpDesc *send_op = GetSendOpDesc(program);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<std::string>> var_name_on_devices;
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
|
|
|
|
|
|
|
|
var_name_on_devices.resize(places_.size());
|
|
|
|
|
|
|
|
bcast_var_name_set.resize(places_.size());
|
|
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
bool is_forwarding = true;
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
if (op->Type() == "send") {
|
|
|
@ -139,17 +151,33 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
is_forwarding = false;
|
|
|
|
is_forwarding = false;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
|
|
|
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
CreateComputationalOp(&result, *op, op_dev_id);
|
|
|
|
|
|
|
|
for (auto &var_name : op->OutputArgumentNames()) {
|
|
|
|
|
|
|
|
var_name_on_devices[op_dev_id].emplace(var_name);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
if (!is_forwarding && places_.size() > 1) {
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
for (auto &og : op->OutputArgumentNames()) {
|
|
|
|
for (auto &og : op->OutputArgumentNames()) {
|
|
|
|
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
|
|
|
|
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
|
|
|
|
if (IsSparseGradient(var_types, og)) {
|
|
|
|
if (balance_parameter_opt_between_cards_) {
|
|
|
|
CreateReduceOp(&result, og, 0);
|
|
|
|
CreateReduceOp(&result, og, cur_device_id);
|
|
|
|
CreateBroadcastOp(&result, og, 0);
|
|
|
|
var_name_on_devices[cur_device_id].emplace(og);
|
|
|
|
|
|
|
|
bcast_var_name_set[cur_device_id].emplace(
|
|
|
|
|
|
|
|
og.substr(0, og.size() - strlen(kGradVarSuffix)));
|
|
|
|
|
|
|
|
cur_device_id = (cur_device_id + 1) % places_.size();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
InsertNCCLAllReduceOp(&result, og);
|
|
|
|
if (IsSparseGradient(var_types, og)) {
|
|
|
|
|
|
|
|
CreateReduceOp(&result, og, 0);
|
|
|
|
|
|
|
|
CreateBroadcastOp(&result, og, 0);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
InsertNCCLAllReduceOp(&result, og);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -157,6 +185,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Insert BCast Ops
|
|
|
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
|
|
|
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set[dev_id];
|
|
|
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
|
|
|
|
CreateBroadcastOp(&result, bcast_name, dev_id);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
Dependency graph has been constructed. However, there are still data
|
|
|
|
Dependency graph has been constructed. However, there are still data
|
|
|
|
harzaeds need to be handled.
|
|
|
|
harzaeds need to be handled.
|
|
|
@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
|
|
|
|
return is_pg_once;
|
|
|
|
return is_pg_once;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int MultiDevSSAGraphBuilder::GetOpDeviceID(
|
|
|
|
|
|
|
|
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
|
|
|
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
|
|
|
if (!balance_parameter_opt_between_cards_) {
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int var_dev_id = -1;
|
|
|
|
|
|
|
|
for (auto &var_name : op.InputArgumentNames()) {
|
|
|
|
|
|
|
|
if (var_dev_id != -1) break;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
|
|
|
|
|
|
|
|
if (var_name_on_devices[i].count(var_name)) {
|
|
|
|
|
|
|
|
var_dev_id = static_cast<int>(i);
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return var_dev_id;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|