|
|
|
@ -44,6 +44,18 @@ Status GatherV2PInfo::GetAttrs() {
|
|
|
|
|
}
|
|
|
|
|
axis_ = axis;
|
|
|
|
|
|
|
|
|
|
// get target
|
|
|
|
|
auto target_iter = attrs_.find(TARGET);
|
|
|
|
|
if (target_iter != attrs_.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(target_iter->second);
|
|
|
|
|
if (target_iter->second->isa<StringImm>()) {
|
|
|
|
|
target_ = target_iter->second->cast<StringImmPtr>()->value();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : The value of target is not a string.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -61,8 +73,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
auto param_shape = inputs_shape_.at(0);
|
|
|
|
|
auto param_strategy = strategy->GetInputDim().at(0);
|
|
|
|
|
auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
|
|
|
|
|
if (slice_shape % 8 != 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
|
|
|
|
|
if (slice_shape % 8 != 0 && slice_shape != 1) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -74,20 +86,20 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
// don't support scalar index
|
|
|
|
|
if (inputs_shape_.at(1).size() == 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Don't support scalar index.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
|
|
|
|
Shape index_shape = inputs_shape_.at(1);
|
|
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
|
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -95,7 +107,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
auto index_strategy = strategy->GetInputDim().at(1);
|
|
|
|
|
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>());
|
|
|
|
|
if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -104,7 +116,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
|
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
|
|
|
|
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -290,18 +302,85 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
std::vector<Group> group_list;
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
size_t dim = IntToSize(axis_);
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
|
|
|
|
|
dim = (axis_ + 1) % 2;
|
|
|
|
|
}
|
|
|
|
|
if (CreateGroupByDim(dim, &group_list) != SUCCESS) {
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
|
|
|
|
int32_t rank = g_device_manager->global_rank();
|
|
|
|
|
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
|
|
|
|
|
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
|
|
|
|
|
RankList group_devices;
|
|
|
|
|
if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (group_devices.size() == 1) {
|
|
|
|
|
MS_LOG(INFO) << "the group is empty";
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
group_ = g_device_manager->CreateGroup(group_devices);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> GetRankFromGroup(const Group &group) {
|
|
|
|
|
std::vector<int32_t> rank_list;
|
|
|
|
|
auto device_list = group.GetDevicesList();
|
|
|
|
|
for (auto &device : device_list) {
|
|
|
|
|
rank_list.insert(rank_list.end(), device.rank() % 8);
|
|
|
|
|
}
|
|
|
|
|
return rank_list;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() {
|
|
|
|
|
forward_op_.clear();
|
|
|
|
|
if (target_ != CPU) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
// don't split axis, no need forward communication
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) == 1) {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
// split axis
|
|
|
|
|
OperatorName operator_name;
|
|
|
|
|
if (InferGroup() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
auto group_size = group_.GetDevNum();
|
|
|
|
|
Attr attr_group;
|
|
|
|
|
// group size <= 8
|
|
|
|
|
std::vector<int32_t> rank_list;
|
|
|
|
|
if (group_size <= 8) {
|
|
|
|
|
reduce_scatter_flag_ = false;
|
|
|
|
|
operator_name = HOST_REDUCE_SCATTER;
|
|
|
|
|
rank_list = GetRankFromGroup(group_);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
|
|
|
|
|
} else {
|
|
|
|
|
// group size > 8
|
|
|
|
|
reduce_scatter_flag_ = true;
|
|
|
|
|
split_num_ = SizeToInt(group_size / 8);
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
operator_name = REDUCE_SCATTER;
|
|
|
|
|
int32_t rank = g_device_manager->global_rank();
|
|
|
|
|
size_t repeat = group_size / 8;
|
|
|
|
|
for (size_t i = 0; i < repeat; ++i) {
|
|
|
|
|
rank_list.push_back(rank + SizeToInt(i * 8));
|
|
|
|
|
}
|
|
|
|
|
Group g = g_device_manager->CreateGroup(rank_list);
|
|
|
|
|
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
|
|
|
|
|
}
|
|
|
|
|
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
|
|
|
|
|
OperatorAttrs attrs = {attr_op, attr_group};
|
|
|
|
|
OperatorParams params;
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator op = std::make_pair(operator_name, args);
|
|
|
|
|
|
|
|
|
|
group_ = group_list.at(0);
|
|
|
|
|
forward_op_.push_back(op);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -346,6 +425,10 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
|
|
|
|
|
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
// target_ == CPU, no need to raplace graph
|
|
|
|
|
if (target_ == CPU) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
|
|
|
|
|
return nullptr;
|
|
|
|
@ -353,11 +436,34 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
return replace_graph_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::ComputeReplaceOp() {
|
|
|
|
|
if (InferBias() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer offset failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
OperatorName op_name = EMBEDDING_LOOKUP;
|
|
|
|
|
OperatorAttrs attrs;
|
|
|
|
|
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
|
|
|
|
|
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
|
|
|
|
|
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
|
|
|
|
|
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
|
|
|
|
|
std::make_pair(param_split_num, 6)};
|
|
|
|
|
OperatorArgs args = std::make_pair(attrs, params);
|
|
|
|
|
Operator op = std::make_pair(op_name, args);
|
|
|
|
|
replace_op_.push_back(op);
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
|
|
|
|
|
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Init failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
// only target_ == CPU, we need to replace op
|
|
|
|
|
if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << name_ << ": Init success.";
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|