|
|
|
@ -279,19 +279,19 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// param_strategy(axis) != 1, and axis == 0, and repeated calculation, need to set repeated num to the right
|
|
|
|
|
// of dev-matrix. For example, parameter strategy is [2, 1], indices strategy is [1, 1], dev num is 16,
|
|
|
|
|
// and dev_matrix is [2, 1, 1, 1, 8], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
|
|
|
|
|
// can communicate normally.
|
|
|
|
|
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) {
|
|
|
|
|
if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
|
|
|
|
|
repeated_num_in_dev_matrix_right_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
|
|
|
|
|
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
|
|
|
|
|
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
|
|
|
|
|
// can communicate normally, and dev0 to dev7 have the all parameters.
|
|
|
|
|
repeated_num_in_dev_matrix_right_ = false;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -364,8 +364,12 @@ Status GatherV2PInfo::InferDevMatrixShape() {
|
|
|
|
|
auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
if (param_product * index_product < stage_device_size_) {
|
|
|
|
|
// add the repeated calculation num to the last dimension of dev matrix
|
|
|
|
|
out_dev_matrix_shape_.push_back(stage_device_size_ / (param_product * index_product));
|
|
|
|
|
auto repeated_calc_num = stage_device_size_ / (param_product * index_product);
|
|
|
|
|
if (repeated_num_in_dev_matrix_right_) {
|
|
|
|
|
out_dev_matrix_shape_.push_back(repeated_calc_num);
|
|
|
|
|
} else {
|
|
|
|
|
(void)out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), repeated_calc_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -505,7 +509,11 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
|
|
|
|
// if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
|
|
|
|
|
if (repeated_calc_num_ > 1) {
|
|
|
|
|
rank = rank / repeated_calc_num_;
|
|
|
|
|
if (repeated_num_in_dev_matrix_right_) {
|
|
|
|
|
rank = rank / repeated_calc_num_;
|
|
|
|
|
} else {
|
|
|
|
|
rank = rank % params_strategy[0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bias_ = rank * slice_size_;
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -515,7 +523,11 @@ Status GatherV2PInfo::InferBias() {
|
|
|
|
|
slice_size_ = input_shape.at(0) / params_strategy.at(0);
|
|
|
|
|
// if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
|
|
|
|
|
if (repeated_calc_num_ > 1) {
|
|
|
|
|
rank = rank / repeated_calc_num_;
|
|
|
|
|
if (repeated_num_in_dev_matrix_right_) {
|
|
|
|
|
rank = rank / repeated_calc_num_;
|
|
|
|
|
} else {
|
|
|
|
|
rank = rank % (params_strategy[0] * params_strategy[1]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
|
|
|
|
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
|
|
|
|
@ -567,15 +579,22 @@ Status GatherV2PInfo::InferGroup() {
|
|
|
|
|
int64_t rank = g_device_manager->global_rank();
|
|
|
|
|
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
|
|
|
|
|
RankList group_devices;
|
|
|
|
|
|
|
|
|
|
// the dev_matrix[0] is repeated_calc_num, so the dim need to add 1
|
|
|
|
|
if ((repeated_calc_num_ > 1) && !repeated_num_in_dev_matrix_right_) {
|
|
|
|
|
dim = dim + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (group_devices.size() == 1) {
|
|
|
|
|
MS_LOG(INFO) << "the group is empty";
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The group is empty";
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The group ranks is " << group_devices;
|
|
|
|
|
group_ = g_device_manager->CreateGroup(group_devices);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -640,6 +659,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|
|
|
|
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage() << ", the bias is " << bias_;
|
|
|
|
|
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
|
|
|
|
|
auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
|
|
|
|
|
auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
|
|
|
|
@ -683,7 +703,7 @@ ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0);
|
|
|
|
|
// target_ == CPU, no need to raplace graph
|
|
|
|
|
// target_ == CPU, no need to replace graph
|
|
|
|
|
if (target_ == CPU) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|