Merge pull request #11698 from typhoonzero/fix_sparse_dist_paraexe

fix sparse paraexe dist train
port
Wu Yi 7 years ago committed by GitHub
commit bb18de68c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -470,7 +470,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "split_byref") {
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());

Loading…
Cancel
Save