From 21496447dfae41fb1a939413fa9cac324cfe9e96 Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 9 Jul 2020 09:54:37 +0800 Subject: [PATCH] fix code dex --- mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index d62111c010..680d6f3ed6 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -438,12 +438,9 @@ std::vector GetRankFromGroup(const Group &group) { Status GatherV2PInfo::InferForwardCommunication() { forward_op_.clear(); - if (target_ != CPU) { - return SUCCESS; - } auto param_strategy = strategy_->GetInputDim().at(0); - // don't split axis, no need forward communication - if (param_strategy.at(IntToSize(axis_)) == 1) { + // don't split axis or target is not CPU, no need forward communication + if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { return SUCCESS; } // split axis