|
|
|
@ -756,6 +756,11 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
insert_op = true;
|
|
|
|
|
need_broadcast_var_ = true;
|
|
|
|
|
} else if (OpHaveRole(*node, OpRole::kDist)) {
|
|
|
|
|
// in async_mode, each graph will send it's own gradient, do not need to
|
|
|
|
|
// merge gradient.
|
|
|
|
|
if (strategy_.async_mode_ && node->Op()->Type() != "concat") {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(result, node);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
|
// the input(block of parameter) of concat is on different device,
|
|
|
|
@ -827,7 +832,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
|
|
|
|
|
}
|
|
|
|
|
auto recv_param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
if (recv_param_grad.size() == 2U) {
|
|
|
|
|
if (recv_param_grad.size() == 2U && !strategy_.async_mode_) {
|
|
|
|
|
op_dev_id = GetVarDeviceID(recv_param_grad[1]);
|
|
|
|
|
VLOG(10) << "recv param " << recv_param_grad[0]
|
|
|
|
|
<< " get grad place: " << recv_param_grad[1]
|
|
|
|
|