|
|
|
@ -438,12 +438,9 @@ std::vector<int32_t> GetRankFromGroup(const Group &group) {
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
forward_op_.clear();
|
|
|
|
|
if (target_ != CPU) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
// don't split axis, no need forward communication
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) == 1) {
|
|
|
|
|
// don't split axis or target is not CPU, no need forward communication
|
|
|
|
|
if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
// split axis
|
|
|
|
|