|
|
|
@ -243,8 +243,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
|
|
|
|
Shape index_shape = inputs_shape_.at(1);
|
|
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
|
|
|
|
return FAILED;
|
|
|
|
|
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
|
|
|
|
axis_split_forward_allreduce_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (manual_split_) {
|
|
|
|
@ -257,7 +257,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -403,7 +403,8 @@ void GatherV2PInfo::InferOutputsTensorMap() {
|
|
|
|
|
} else {
|
|
|
|
|
// param_strategy(axis) != 1
|
|
|
|
|
if (axis_ == 0) {
|
|
|
|
|
if (dynamic_shape_indices_ && target_ != CPU) {
|
|
|
|
|
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
|
|
|
|
|
// the output is repeat calculation
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
|
|
|
|
} else {
|
|
|
|
|
tensor_map_out.insert(tensor_map_out.end(), 0);
|
|
|
|
@ -549,15 +550,6 @@ Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RankList GetRankFromGroup(const Group &group) {
|
|
|
|
|
RankList rank_list;
|
|
|
|
|
auto device_list = group.GetDevicesList();
|
|
|
|
|
for (auto &device : device_list) {
|
|
|
|
|
rank_list.insert(rank_list.end(), device.rank() % 8);
|
|
|
|
|
}
|
|
|
|
|
return rank_list;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -628,7 +620,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
|
|
|
|
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
|
|
|
|
|
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
|
|
|
|
|
// don't need expandim,if param_size = 1,
|
|
|
|
|
// don't need expand dim, if param_size = 1
|
|
|
|
|
if (inputs_shape_.at(0).size() == 1) {
|
|
|
|
|
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
|
|
|
|
|
}
|
|
|
|
@ -640,7 +632,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
|
|
|
|
OperatorAttrs attrs = {attr_op, attr_group};
|
|
|
|
|
AnfNodePtr reduce_op;
|
|
|
|
|
if (dynamic_shape_indices_) {
|
|
|
|
|
if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
|
|
|
|
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
|
|
|
|
|
} else {
|
|
|
|
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
|
|
|
|
|