|
|
|
@ -199,10 +199,8 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Don't support repeated calc
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
|
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if (IntToSize(product_p) < dev_num) {
|
|
|
|
|
if (product_p < stage_device_size_) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -272,10 +270,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, Don't support repeated calc
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
|
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) {
|
|
|
|
|
if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
@ -349,13 +345,11 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
|
} else {
|
|
|
|
|
out_dev_matrix_shape_ = dev_matrix_shape_;
|
|
|
|
|
}
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id_).size();
|
|
|
|
|
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if (param_product * index_product < SizeToInt(dev_num)) {
|
|
|
|
|
if (param_product * index_product < stage_device_size_) {
|
|
|
|
|
// add the repeated calculation num to the last dimension of dev matrix
|
|
|
|
|
out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product)));
|
|
|
|
|
out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -539,11 +533,8 @@ Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
dim = (axis_ + 1) % 2;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(g_device_manager);
|
|
|
|
|
RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id_);
|
|
|
|
|
int64_t rank = g_device_manager->global_rank();
|
|
|
|
|
DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_);
|
|
|
|
|
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
|
|
|
|
RankList group_devices;
|
|
|
|
|
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
|
|
|
@ -777,11 +768,10 @@ std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
|
|
|
|
|
if (manual_split_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
|
|
|
|
|
}
|
|
|
|
|
CheckGlobalDeviceManager();
|
|
|
|
|
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
|
|
|
|
|
|
|
|
|
Dimensions param_strategy(inputs_shape_[0].size(), 1);
|
|
|
|
|
Dimensions index_strategy;
|
|
|
|
|
index_strategy.push_back(SizeToLong(dev_num));
|
|
|
|
|
index_strategy.push_back(stage_device_size_);
|
|
|
|
|
for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
|
|
|
|
|
index_strategy.push_back(1);
|
|
|
|
|
}
|
|
|
|
|