fix code dex

pull/2950/head
lichenever 5 years ago
parent 180b3029e5
commit 21496447df

@ -438,12 +438,9 @@ std::vector<int32_t> GetRankFromGroup(const Group &group) {
Status GatherV2PInfo::InferForwardCommunication() {
forward_op_.clear();
if (target_ != CPU) {
return SUCCESS;
}
auto param_strategy = strategy_->GetInputDim().at(0);
// don't split axis, no need forward communication
if (param_strategy.at(IntToSize(axis_)) == 1) {
// don't split axis or target is not CPU, no need forward communication
if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) {
return SUCCESS;
}
// split axis

Loading…
Cancel
Save