|
|
|
@ -167,10 +167,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
|
|
|
|
|
|
|
|
|
|
bool is_forwarding = true;
|
|
|
|
|
bool insert_collection_ops = NeedCollectiveOps();
|
|
|
|
|
if (strategy_.async_mode_) {
|
|
|
|
|
// async mode did not need to merge gradient
|
|
|
|
|
insert_collection_ops = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (ir::Node *node : sorted_ops) {
|
|
|
|
|
if (DealWithSpecialOp(&result, node)) {
|
|
|
|
@ -749,10 +745,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
bool insert_op = false;
|
|
|
|
|
if (OpHaveRole(*node, OpRole::kRPC)) {
|
|
|
|
|
// in async_mode, each graph will send it's own gradient.
|
|
|
|
|
if (strategy_.async_mode_ && node->Op()->Type() == "send") {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
int op_dev_id = CreateRPCOp(result, node);
|
|
|
|
|
PADDLE_ENFORCE(op_dev_id != -1,
|
|
|
|
|
"Can not schedule the RPC operator to the right place.");
|
|
|
|
@ -768,11 +760,6 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
insert_op = true;
|
|
|
|
|
need_broadcast_var_ = true;
|
|
|
|
|
} else if (OpHaveRole(*node, OpRole::kDist)) {
|
|
|
|
|
// in async_mode, each graph will send it's own gradient, do not need to
|
|
|
|
|
// merge gradient.
|
|
|
|
|
if (strategy_.async_mode_ && node->Op()->Type() != "concat") {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(result, node);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
|
// the input(block of parameter) of concat is on different device,
|
|
|
|
@ -844,7 +831,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
}
|
|
|
|
|
auto recv_param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
if (recv_param_grad.size() == 2U && !strategy_.async_mode_) {
|
|
|
|
|
if (recv_param_grad.size() == 2U) {
|
|
|
|
|
op_dev_id = GetVarDeviceID(recv_param_grad[1]);
|
|
|
|
|
VLOG(10) << "recv param " << recv_param_grad[0]
|
|
|
|
|
<< " get grad place: " << recv_param_grad[1]
|
|
|
|
|