|
|
|
@ -191,15 +191,54 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
std::unordered_map<std::string, int> rpc_var_device_mapping;
|
|
|
|
|
int rpc_op_device_id = 0;
|
|
|
|
|
auto schedule_rpc_op = [&]() -> void {
|
|
|
|
|
rpc_op_device_id++;
|
|
|
|
|
if (rpc_op_device_id >= static_cast<int>(places_.size())) {
|
|
|
|
|
rpc_op_device_id = 0;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
if (boost::get<int>(
|
|
|
|
|
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
|
|
|
|
|
static_cast<int>(OpRole::kRPC)) {
|
|
|
|
|
// append rpc op if program is distributed trainer main program.
|
|
|
|
|
// always use the first device
|
|
|
|
|
CreateRPCOp(&result, *op);
|
|
|
|
|
if (op->Type() == "send_vars") {
|
|
|
|
|
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
|
|
|
|
|
if (got == remote_vars_devices_.end()) {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
} else {
|
|
|
|
|
rpc_op_device_id = got->second;
|
|
|
|
|
}
|
|
|
|
|
CreateRPCOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
} else if (op->Type() == "recv") {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
for (auto &varname : op->OutputArgumentNames()) {
|
|
|
|
|
remote_vars_devices_.insert({varname, rpc_op_device_id});
|
|
|
|
|
}
|
|
|
|
|
CreateRPCOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
} else {
|
|
|
|
|
CreateRPCOp(&result, *op, 0);
|
|
|
|
|
}
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
|
|
|
|
|
CreateDistTrainOp(&result, *op);
|
|
|
|
|
if (op->Type() == "split_byref") {
|
|
|
|
|
schedule_rpc_op();
|
|
|
|
|
for (auto &varname : op->OutputArgumentNames()) {
|
|
|
|
|
remote_vars_devices_.insert({varname, rpc_op_device_id});
|
|
|
|
|
}
|
|
|
|
|
CreateDistTrainOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
}
|
|
|
|
|
if (op->Type() == "oncat") {
|
|
|
|
|
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
|
|
|
|
|
PADDLE_ENFORCE_NE(got != remote_vars_devices_.end(),
|
|
|
|
|
"can not find right place to concat received var.");
|
|
|
|
|
CreateDistTrainOp(&result, *op, got->second);
|
|
|
|
|
} else {
|
|
|
|
|
CreateDistTrainOp(&result, *op, 0);
|
|
|
|
|
}
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
|
if (strategy_.gradient_scale_ !=
|
|
|
|
@ -464,17 +503,18 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
CreateComputationalOp(result, op, 0);
|
|
|
|
|
const OpDesc &op,
|
|
|
|
|
int place_id) const {
|
|
|
|
|
CreateComputationalOp(result, op, place_id);
|
|
|
|
|
if (op.Type() == "concat") {
|
|
|
|
|
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
auto &p = places_[0];
|
|
|
|
|
auto *s = local_scopes_[0];
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, const OpDesc &op,
|
|
|
|
|
int place_id) const {
|
|
|
|
|
auto &p = places_[place_id];
|
|
|
|
|
auto *s = local_scopes_[place_id];
|
|
|
|
|
result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
|
|
|
|
|
|
|
|
|
|
if (op.Type() == "send_barrier") {
|
|
|
|
@ -493,7 +533,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
|
|
|
|
|
// TODO(Yancey1989): schedule rpc op on different place may
|
|
|
|
|
// increate throughput
|
|
|
|
|
CreateOpHandleIOs(result, op, 0);
|
|
|
|
|
CreateOpHandleIOs(result, op, place_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
|
|
|
|
|