|
|
|
@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
int op_dev_id = -1;
|
|
|
|
|
if (op.Type() == "split_byref") {
|
|
|
|
|
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
|
|
|
|
|
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
|
|
|
|
|
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
|
|
|
|
|
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
|
|
|
|
|