|
|
|
@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// target=CPU, axis must be 0
|
|
|
|
|
if (target_ == "CPU" && axis_ != 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
int32_t rank = g_device_manager->global_rank();
|
|
|
|
|
auto input_shape = inputs_shape_.at(0);
|
|
|
|
|
auto params_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
// axis don't split
|
|
|
|
|
if (params_strategy.at(axis_) == 1) {
|
|
|
|
|
bias_ = 0;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
// params_size=1, axis=0
|
|
|
|
|
if ((input_shape.size() == 1) && (axis_ == 0)) {
|
|
|
|
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
|
|
|
@ -353,6 +364,7 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
}
|
|
|
|
|
auto group_size = group_.GetDevNum();
|
|
|
|
|
Attr attr_group;
|
|
|
|
|
if (host_reduce_scatter_) {
|
|
|
|
|
// group size <= 8
|
|
|
|
|
std::vector<int32_t> rank_list;
|
|
|
|
|
if (group_size <= 8) {
|
|
|
|
@ -361,7 +373,7 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
rank_list = GetRankFromGroup(group_);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
|
|
|
|
} else {
|
|
|
|
|
// group size > 8
|
|
|
|
|
// group size > 8, don't support host reduce_scatter
|
|
|
|
|
reduce_scatter_flag_ = true;
|
|
|
|
|
split_num_ = SizeToInt(group_size / 8);
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
@ -374,6 +386,14 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
Group g = g_device_manager->CreateGroup(rank_list);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
operator_name = REDUCE_SCATTER;
|
|
|
|
|
if (InferGroup() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
|
|
|
|
}
|
|
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
|
|
|
|
OperatorAttrs attrs = {attr_op, attr_group};
|
|
|
|
|
OperatorParams params;
|
|
|
|
@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
|
|
|
|
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
|
|
|
|
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
|
|
|
|
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
|
|
|
|
|
std::make_pair(param_split_num, 6)};
|
|
|
|
|
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
|
|
|
|
|
std::make_pair(param_split_num, 5)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator op = std::make_pair(op_name, args);
|
|
|
|
|
replace_op_.push_back(op);
|
|
|
|
|