|
|
|
@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
bool insert_collection_ops = NeedCollectiveOps();
|
|
|
|
|
if (strategy_.async_mode_) {
|
|
|
|
|
// async mode did not need to merge gradient
|
|
|
|
|
insert_collection_ops = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
|
if (DealWithSpecialOp(&result, node)) {
|
|
|
|
@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
|
|
|
|
|
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
|
|
|
|
|
OpProtoAndCheckerMaker::OpRoleAttrName())) &
|
|
|
|
|
static_cast<int>(OpRole::kBackward));
|
|
|
|
|
// optimize op is already processed in DealWithSpecialOp,
|
|
|
|
|
// here we only consider backward op
|
|
|
|
|
if (!is_bk_op) continue;
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
* the op that will generate the gradient of on parameter will have
|
|
|
|
|
one attr op_role_var
|
|
|
|
|
* to record the parameter and gradient, like:
|
|
|
|
|
attrs {
|
|
|
|
|
name: "op_role_var"
|
|
|
|
|
type: STRINGS
|
|
|
|
|
strings: "fc_1.b_0"
|
|
|
|
|
strings: "fc_1.b_0@GRAD"
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
// Currently, we assume that once gradient is generated, it can be
|
|
|
|
|
// broadcast, and each gradient is only broadcast once.
|
|
|
|
|
auto backward_vars =
|
|
|
|
@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
|
|
|
|
|
for (size_t i = 0; i < backward_vars.size(); i += 2) {
|
|
|
|
|
auto &p_name = backward_vars[i];
|
|
|
|
|
auto &g_name = backward_vars[i + 1];
|
|
|
|
|
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
|
|
|
|
|
VLOG(3) << "Bcast " << g_name << " for parameter " << p_name;
|
|
|
|
|
|
|
|
|
|
InsertCollectiveOp(&result, p_name, g_name);
|
|
|
|
|
}
|
|
|
|
@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node,
|
|
|
|
|
int dev_id) const {
|
|
|
|
|
size_t dev_id) const {
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(
|
|
|
|
|
new ComputationOpHandle(result->CreateOpNode(node->Op()),
|
|
|
|
|
local_scopes_[dev_id], places_[dev_id], dev_id));
|
|
|
|
@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
|
|
|
|
|
const std::string &og,
|
|
|
|
|
int dst_dev_id) const {
|
|
|
|
|
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
|
|
|
|
|
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
|
|
|
|
|
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
|
|
|
|
@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
bool insert_op = false;
|
|
|
|
|
if (OpHaveRole(*node, OpRole::kRPC)) {
|
|
|
|
|
// in async_mode, each graph will send it's own gradient.
|
|
|
|
|
if (strategy_.async_mode_ && node->Op()->Type() == "send") {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
int op_dev_id = CreateRPCOp(result, node);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"Can not schedule the RPC operator to the right place.");
|
|
|
|
@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
} else if (OpHaveRole(*node, OpRole::kDist)) {
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(result, node);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
|
// the input(block of parameter) of concat is on different device,
|
|
|
|
|
// the output(parameter) will on one device.
|
|
|
|
|
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
|
|
|
|
|
bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
|
|
|
|
|
}
|
|
|
|
@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(node);
|
|
|
|
|
if (op_dev_id != -1) { // This op only runs on one specific device.
|
|
|
|
|
// optimize op will be processed here.
|
|
|
|
|
CreateComputationalOp(result, node, op_dev_id);
|
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|
sharded_var_device_.emplace(n->Name(), op_dev_id);
|
|
|
|
@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
|
|
|
|
|
const std::string &p_name,
|
|
|
|
|
const std::string &g_name) const {
|
|
|
|
|
// collective gradient to each device
|
|
|
|
|
size_t cur_device_id = 0;
|
|
|
|
|
switch (strategy_.reduce_) {
|
|
|
|
|
case BuildStrategy::ReduceStrategy::kReduce:
|
|
|
|
|