|
|
|
@ -28,24 +28,25 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
// get axis, the third input is the axis, is a ValueNode
|
|
|
|
|
if (input_value_.at(2) == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto axis = GetValue<int>(input_value_.at(2));
|
|
|
|
|
// if axis is negative then convert it to positive
|
|
|
|
|
auto params_shape = inputs_shape_.at(0);
|
|
|
|
|
if (params_shape.size() == 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
axis += SizeToInt(inputs_shape_[0].size());
|
|
|
|
|
// get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
|
|
|
|
|
if (target_ != CPU) {
|
|
|
|
|
if (input_value_.at(2) == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto axis = GetValue<int>(input_value_.at(2));
|
|
|
|
|
// if axis is negative then convert it to positive
|
|
|
|
|
auto params_shape = inputs_shape_.at(0);
|
|
|
|
|
if (params_shape.size() == 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
axis += SizeToInt(inputs_shape_[0].size());
|
|
|
|
|
}
|
|
|
|
|
axis_ = axis;
|
|
|
|
|
}
|
|
|
|
|
axis_ = axis;
|
|
|
|
|
|
|
|
|
|
// get target
|
|
|
|
|
auto target_iter = attrs_.find(TARGET);
|
|
|
|
|
if (target_iter != attrs_.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(target_iter->second);
|
|
|
|
@ -53,16 +54,8 @@ Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// target=CPU, axis must be 0
|
|
|
|
|
if (target_ == "CPU" && axis_ != 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto manual_split_iter = attrs_.find("manual_split");
|
|
|
|
|
if (manual_split_iter != attrs_.end()) {
|
|
|
|
|
param_split_shapes_.clear();
|
|
|
|
@ -459,38 +452,13 @@ Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto group_size = group_.GetDevNum();
|
|
|
|
|
Attr attr_group;
|
|
|
|
|
if (host_reduce_scatter_) {
|
|
|
|
|
// group size <= 8
|
|
|
|
|
std::vector<int32_t> rank_list;
|
|
|
|
|
if (group_size <= 8) {
|
|
|
|
|
reduce_scatter_flag_ = false;
|
|
|
|
|
operator_name = HOST_REDUCE_SCATTER;
|
|
|
|
|
rank_list = GetRankFromGroup(group_);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
|
|
|
|
} else {
|
|
|
|
|
// group size > 8, don't support host reduce_scatter
|
|
|
|
|
reduce_scatter_flag_ = true;
|
|
|
|
|
split_num_ = SizeToInt(group_size / 8);
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
operator_name = REDUCE_SCATTER;
|
|
|
|
|
int32_t rank = g_device_manager->global_rank();
|
|
|
|
|
size_t repeat = group_size / 8;
|
|
|
|
|
for (size_t i = 0; i < repeat; ++i) {
|
|
|
|
|
rank_list.push_back(rank + SizeToInt(i * 8));
|
|
|
|
|
}
|
|
|
|
|
Group g = g_device_manager->CreateGroup(rank_list);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
operator_name = REDUCE_SCATTER;
|
|
|
|
|
if (InferGroup() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
|
|
|
|
operator_name = REDUCE_SCATTER;
|
|
|
|
|
if (InferGroup() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
|
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
|
|
|
|
OperatorAttrs attrs = {attr_op, attr_group};
|
|
|
|
|
OperatorParams params;
|
|
|
|
@ -582,10 +550,7 @@ Status GatherV2PInfo::ComputeReplaceOp() {
|
|
|
|
|
OperatorName op_name = EMBEDDING_LOOKUP;
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
|
|
|
|
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
|
|
|
|
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
|
|
|
|
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
|
|
|
|
|
std::make_pair(param_split_num, 5)};
|
|
|
|
|
OperatorParams params = {std::make_pair(param_offset, 3)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator op = std::make_pair(op_name, args);
|
|
|
|
|
replace_op_.push_back(op);
|
|
|
|
|