|
|
|
@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
// Create RPC related op handles that connects its in ops and out ops.
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
ir::Node *node) const {
|
|
|
|
|
// FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode
|
|
|
|
|
// put them into transpiler.
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (node->Op()->Type() == "send") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
|
|
|
|
|
"This hack no longer holds, please fix.");
|
|
|
|
|
// the variable name which contains .block means it was splited by
|
|
|
|
|
// split_byref op
|
|
|
|
|
// so that we can balance the variable blocks to all the pserver
|
|
|
|
|
// instances.
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
|
|
|
|
|
node->inputs[0]->Name().find(".block") == std::string::npos) {
|
|
|
|
|
std::vector<std::string> input_var_names;
|
|
|
|
|
for (ir::Node *n : node->inputs) {
|
|
|
|
|
input_var_names.push_back(n->Name());
|
|
|
|
|
}
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(input_var_names);
|
|
|
|
|
auto send_param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
|
|
|
|
|
VLOG(10) << "send grad " << input_var_names[0] << " origin "
|
|
|
|
|
<< send_param_grad[1] << " place: " << op_dev_id;
|
|
|
|
|
for (auto &varname : input_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(send_param_grad[1], op_dev_id);
|
|
|
|
|
}
|
|
|
|
|
} else if (node->Op()->Type() == "recv") {
|
|
|
|
|
std::vector<std::string> output_var_names;
|
|
|
|
|
for (ir::Node *n : node->outputs) {
|
|
|
|
|
output_var_names.push_back(n->Name());
|
|
|
|
|
}
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(output_var_names);
|
|
|
|
|
auto recv_param_grad = boost::get<std::vector<std::string>>(
|
|
|
|
|
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
|
|
|
|
|
// FIXME(typhoonzero): assume each recv op output one param
|
|
|
|
|
// Use the same place as send.
|
|
|
|
|
if (recv_param_grad.size() == 2U) {
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
|
|
|
|
|
VLOG(10) << "recv param " << recv_param_grad[0]
|
|
|
|
|
<< " get grad place: " << recv_param_grad[1]
|
|
|
|
|
<< " place: " << op_dev_id;
|
|
|
|
|
} else {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(output_var_names);
|
|
|
|
|
}
|
|
|
|
|
for (auto &varname : output_var_names) {
|
|
|
|
|
result->Get<ShardedVarDevice>(kShardedVarDevice)
|
|
|
|
|
.emplace(varname, op_dev_id);
|
|
|
|
|