|
|
|
@ -208,7 +208,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
|
|
|
|
|
int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
|
|
|
|
|
[](int64_t s, int64_t shape) { return s + shape; });
|
|
|
|
|
if (split_shape_sum != inputs_shape_[0][0]) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]";
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Sum of split shapes must be equal to param_shape[0]";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -261,21 +261,34 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, index can't be splited
|
|
|
|
|
// param_strategy(axis) != 1, index can't be split
|
|
|
|
|
auto index_strategy = strategy->GetInputDim().at(1);
|
|
|
|
|
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited.";
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, Don't support repeated calc
|
|
|
|
|
// param_strategy(axis) != 1, and axis != 0, don't support repeated calc
|
|
|
|
|
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) {
|
|
|
|
|
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, and axis == 0, and repeated calculation, need to set repeated num to the right
|
|
|
|
|
// of dev-matrix. For example, parameter strategy is [2, 1], indices strategy is [1, 1], dev num is 16,
|
|
|
|
|
// and dev_matrix is [2, 1, 1, 1, 8], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
|
|
|
|
|
// can communicate normally.
|
|
|
|
|
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) {
|
|
|
|
|
if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
|
|
|
|
|
repeated_num_in_dev_matrix_right_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -493,6 +506,10 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
// params_size=2, axis=0
|
|
|
|
|
if ((input_shape.size() == 2) && (axis_ == 0)) {
|
|
|
|
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
|
|
|
|
// if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
|
|
|
|
|
if (repeated_calc_num_ > 1) {
|
|
|
|
|
rank = rank / repeated_calc_num_;
|
|
|
|
|
}
|
|
|
|
|
bias_ = rank / params_strategy.at(1) * slice_size_;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|