|
|
|
@ -680,8 +680,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->Op()->Type() == "split_byref" ||
|
|
|
|
|
node->Op()->Type() == "split_selected_rows" ||
|
|
|
|
|
node->Op()->Type() == "split_ids") {
|
|
|
|
|
node->Op()->Type() == "split_selected_rows") {
|
|
|
|
|
// TODO(paddle-dev): getting the first var is not safe.
|
|
|
|
|
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
|
|
|
|
|